diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..372799f --- /dev/null +++ b/Dockerfile @@ -0,0 +1,8 @@ +FROM git.modelhub.org.cn:9443/enginex-iluvatar/mr-bi150-4.3.0-x86-ubuntu20.04-py3.10-poc-llm-infer:v1.2.3 + +RUN mkdir /workspace +WORKDIR /workspace/ + +COPY ./launch_service /workspace/launch_service + +ENTRYPOINT ["./launch_service"] diff --git a/README.md b/README.md index 52da238..4c5907d 100644 --- a/README.md +++ b/README.md @@ -1 +1,111 @@ -# README \ No newline at end of file +# 天数智芯 智铠100 文本生成引擎(基于 vLLM 优化) + +本项目是为**天数智芯-智铠100**加速卡深度优化的高性能文本生成推理引擎,基于开源 **vLLM** 框架进行架构级适配与增强,实现对 **Qwen3 系列**等最新大模型的高效支持。通过引入 **Prefix Caching**、PagedAttention 等先进优化技术,显著提升吞吐与响应速度,同时提供标准 **OpenAI 兼容 API 接口**,便于无缝集成现有应用生态。 + +## 支持模型 + +- **Qwen3** +- **Llama3** +- **DeepSeek-R1-Distill** +- 其他兼容 vLLM 的 HuggingFace 模型(持续扩展中) + +> 模型下载地址:[https://modelscope.cn/models/Qwen](https://modelscope.cn/models/Qwen) + +--- + +## Quick Start + +### 1. 模型下载 + +从 ModelScope 下载所需模型(以 Qwen2.5-7B-Instruct 为例): + +```bash +modelscope download --model qwen/Qwen2.5-7B-Instruct README.md --local_dir /mnt/models/Qwen2.5-7B-Instruct +``` + +> ⚠️ 请确保模型路径在后续 Docker 启动时正确挂载。 + +--- + +### 2. 拉取并构建 Docker 镜像 + +我们提供已预装智铠100驱动与vLLM优化版本的Docker镜像: + +``` +# 本地构建 +docker build -t enginex-iluvatar-vllm:bi100 -f Dockerfile . +``` + +--- + +### 3. 启动服务容器 + +```bash +docker run -it --rm -p 8000:80 \ + --name vllm-iluvatar \ + -v /mnt/models/Qwen2.5-7B-Instruct:/model:ro \ + --privileged \ + -e TENSOR_PARALLEL_SIZE=1 \ + -e PREFIX_CACHING=true \ + -e MAX_MODEL_LEN=10000 \ + enginex-iluvatar-vllm:bi100 +``` + +> ✅ 参数说明: +> - `PREFIX_CACHING=true`: 启用 Prefix Caching 优化,显著提升多请求共享前缀的推理效率 +> - `MAX_MODEL_LEN=10000`: 支持长上下文推理 +> - `--privileged`: 确保智铠100设备可见 + +--- + +## 4. 测试服务(使用 OpenAI 兼容接口) + +服务启动后,可通过标准 OpenAI SDK 或 `curl` 进行测试。 + +### 示例:文本生成请求 + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "qwen3-8b", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "请用中文介绍一下上海的特点。"} + ], + "temperature": 0.7, + "max_tokens": 512 + }' +``` + +### 使用 OpenAI Python SDK(需安装 `openai>=1.0`) + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8000/v1", api_key="none") + +response = client.chat.completions.create( + model="qwen3-8b", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "请简要介绍杭州的特色文化。"} + ], + max_tokens=512, + temperature=0.7 +) + +print(response.choices[0].message.content) +``` + +--- + +## 测试结果对比(A100 vs 智铠100) + +在相同模型和输入条件下,测试平均输出速度(单位:字每秒),结果如下: + +| 模型 | 智铠100 输出速度 | Nvidia A100 输出速度 | +|--------|--------------------------|-------------------------------| +| Qwen2.5-7B-Instruct | 56.4 | 112.4 | +| Qwen2.5-1.5B-Instruct-AWQ | 123.1 | 100.8 | + diff --git a/launch_service b/launch_service new file mode 100755 index 0000000..f1571ed --- /dev/null +++ b/launch_service @@ -0,0 +1,79 @@ +#!/bin/bash + +export PYTHONPATH=/usr/local/corex/lib64/python3/dist-packages +export LD_LIBRARY_PATH=/usr/local/corex/lib64:/usr/local/openmpi/lib +export PATH=/usr/local/corex/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/corex/lib64/python3/dist-packages/bin:/usr/local/openmpi/bin +export JAVA_HOME=/root/apps/jdk1.8.0_411 +export JRE_HOME=/root/apps/jdk1.8.0_411/jre +export JMETER_HOME=/root/apps/apache-jmeter-5.6.3 +export CLASSPATH=.:/root/apps/jdk1.8.0_411/lib/dt.jar:/root/apps/jdk1.8.0_411/lib/tools.jar:/root/apps/apache-jmeter-5.6.3/lib/ext/ApacheJMeter_core.jar:/root/apps/apache-jmeter-5.6.3/lib/jorphan.jar:/root/apps/apache-jmeter-5.6.3/lib/logkit-2.0.jar: +export PATH=/root/apps/apache-jmeter-5.6.3/bin:/root/apps/jdk1.8.0_411/bin:/usr/local/corex/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/corex/lib64/python3/dist-packages/bin:/usr/local/openmpi/bin +/iluvatar/welcome.sh + +data +cat /proc/cpuinfo | tail -n 50 +ixsmi +unset CUDA_VISIBLE_DEVICES +export +date + +DEFAULT_HOST="0.0.0.0" +DEFAULT_PORT="80" +DEFAULT_SERVED_MODEL_NAME="llm" +DEFAULT_MODEL_PATH="/model" +DEFAULT_MAX_MODEL_LEN="10000" +DEFAULT_TENSOR_PARALLEL_SIZE="1" +DEFAULT_MAX_NUM_SEQS="64" +DEFAULT_ENFORCE_EAGER="true" +DEFAULT_DISABLE_LOG_REQUESTS="true" +DEFAULT_PREFIX_CACHING="true" + +HOST_VAL=${HOST:-$DEFAULT_HOST} +PORT_VAL=${PORT:-$DEFAULT_PORT} +SERVED_MODEL_NAME_VAL=${SERVED_MODEL_NAME:-$DEFAULT_SERVED_MODEL_NAME} +MODEL_PATH_VAL=${MODEL_PATH:-$DEFAULT_MODEL_PATH} +MAX_MODEL_LEN_VAL=${MAX_MODEL_LEN:-$DEFAULT_MAX_MODEL_LEN} +TENSOR_PARALLEL_SIZE_VAL=${TENSOR_PARALLEL_SIZE:-$DEFAULT_TENSOR_PARALLEL_SIZE} +MAX_NUM_SEQS_VAL=${MAX_NUM_SEQS:-$DEFAULT_MAX_NUM_SEQS} +INCLUDE_ENFORCE_EAGER_FLAG=${ENFORCE_EAGER:-$DEFAULT_ENFORCE_EAGER} +INCLUDE_DISABLE_LOG_REQUESTS_FLAG=${DISABLE_LOG_REQUESTS:-$DEFAULT_DISABLE_LOG_REQUESTS} +INCLUDE_PREFIX_CACHING_FLAG=${PREFIX_CACHING:-$DEFAULT_PREFIX_CACHING} + +CMD_ARGS=() +CMD_ARGS+=(--host "$HOST_VAL") +CMD_ARGS+=(--port "$PORT_VAL") + +if [[ "$INCLUDE_ENFORCE_EAGER_FLAG" != "false" && "$INCLUDE_ENFORCE_EAGER_FLAG" != "0" ]]; then + CMD_ARGS+=(--enforce-eager) +fi +if [[ "$INCLUDE_DISABLE_LOG_REQUESTS_FLAG" != "false" && "$INCLUDE_DISABLE_LOG_REQUESTS_FLAG" != "0" ]]; then + CMD_ARGS+=(--disable-log-requests) +fi +if [[ "$INCLUDE_PREFIX_CACHING_FLAG" != "false" && "$INCLUDE_PREFIX_CACHING_FLAG" != "0" ]]; then + CMD_ARGS+=(--enable-prefix-caching) +fi + +CMD_ARGS+=(--served-model-name "$SERVED_MODEL_NAME_VAL") +CMD_ARGS+=(--model "$MODEL_PATH_VAL") +CMD_ARGS+=(--max-model-len "$MAX_MODEL_LEN_VAL") +CMD_ARGS+=(--tensor-parallel-size "$TENSOR_PARALLEL_SIZE_VAL") +CMD_ARGS+=(--max-num-seqs "$MAX_NUM_SEQS_VAL") + +echo "--------------------------------------------------" +echo "Starting VLLM OpenAI API Server..." +echo "Using effective arguments:" +echo " Host (--host): $HOST_VAL" +echo " Port (--port): $PORT_VAL" +echo " Enforce Eager (--enforce-eager):" $([[ "$INCLUDE_ENFORCE_EAGER_FLAG" != "false" && "$INCLUDE_ENFORCE_EAGER_FLAG" != "0" ]] && echo "Enabled" || echo "Disabled (Env: ENFORCE_EAGER=$ENFORCE_EAGER)") +echo " Disable Log Req (--disable-log-requests):" $([[ "$INCLUDE_DISABLE_LOG_REQUESTS_FLAG" != "false" && "$INCLUDE_DISABLE_LOG_REQUESTS_FLAG" != "0" ]] && echo "Enabled" || echo "Disabled (Env: DISABLE_LOG_REQUESTS=$DISABLE_LOG_REQUESTS)") +echo " Served Model Name (--served-model-name): $SERVED_MODEL_NAME_VAL" +echo " Model Path (--model): $MODEL_PATH_VAL" +echo " Max Model Length (--max-model-len): $MAX_MODEL_LEN_VAL" +echo " Tensor Parallel Size (--tensor-parallel-size): $TENSOR_PARALLEL_SIZE_VAL" +echo " Max Num Seqs (--max-num-seqs): $MAX_NUM_SEQS_VAL" +echo "--------------------------------------------------" +echo "Full cmd:" +echo "python3 -m vllm.entrypoints.openai.api_server ${CMD_ARGS[*]}" +echo "--------------------------------------------------" + +python3 -m vllm.entrypoints.openai.api_server "${CMD_ARGS[@]}" diff --git a/vllm/__init__.py b/vllm/__init__.py new file mode 100644 index 0000000..52022fb --- /dev/null +++ b/vllm/__init__.py @@ -0,0 +1,52 @@ +# SPDX-License-Identifier: Apache-2.0 +"""vLLM: a high-throughput and memory-efficient inference engine for LLMs""" +# The version.py should be independent library, and we always import the +# version library first. Such assumption is critical for some customization. +from .version import __version__, __version_tuple__ # isort:skip + +# The environment variables override should be imported before any other +# modules to ensure that the environment variables are set before any +# other modules are imported. +import vllm.env_override # isort:skip # noqa: F401 + +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.llm_engine import LLMEngine +from vllm.entrypoints.llm import LLM +from vllm.executor.ray_utils import initialize_ray_cluster +from vllm.inputs import PromptType, TextPrompt, TokensPrompt +from vllm.model_executor.models import ModelRegistry +from vllm.outputs import (ClassificationOutput, ClassificationRequestOutput, + CompletionOutput, EmbeddingOutput, + EmbeddingRequestOutput, PoolingOutput, + PoolingRequestOutput, RequestOutput, ScoringOutput, + ScoringRequestOutput) +from vllm.pooling_params import PoolingParams +from vllm.sampling_params import SamplingParams + +__all__ = [ + "__version__", + "__version_tuple__", + "LLM", + "ModelRegistry", + "PromptType", + "TextPrompt", + "TokensPrompt", + "SamplingParams", + "RequestOutput", + "CompletionOutput", + "PoolingOutput", + "PoolingRequestOutput", + "EmbeddingOutput", + "EmbeddingRequestOutput", + "ClassificationOutput", + "ClassificationRequestOutput", + "ScoringOutput", + "ScoringRequestOutput", + "LLMEngine", + "EngineArgs", + "AsyncLLMEngine", + "AsyncEngineArgs", + "initialize_ray_cluster", + "PoolingParams", +] diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py new file mode 100644 index 0000000..2a9e0f2 --- /dev/null +++ b/vllm/_custom_ops.py @@ -0,0 +1,1570 @@ +import functools +from typing import TYPE_CHECKING, List, Optional, Tuple, Dict, Any + +import torch +import torch.library +import torch.nn.functional as F +import math +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.scalar_type import ScalarType + +import ixformer.inference.functions as ops +from ixformer.distributed import _distributed as cdist + + +logger = init_logger(__name__) + +supports_moe_ops = True + +def register_fake(fn): + return lambda name: fn + + +# activation ops +def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + ops.silu_and_mul(x, out) + + +def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + ops.gelu_and_mul(x, out) + + +def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + ops.gelu_tanh_and_mul(x, out) + + +def fatrelu_and_mul(out: torch.Tensor, + x: torch.Tensor, + threshold: float = 0.0) -> None: + raise NotImplementedError("FIX soon") + + +def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: + out.copy_(F.gelu(x,approximate="tanh")) + return out + + +def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None: + out.copy_(F.gelu(x,approximate="tanh")) + return out + + +def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None: + out.copy_(F.gelu(x,approximate="tanh")) + return out + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + raise NotImplementedError("Do not use this in our implement") + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + raise NotImplementedError("Do not use this in our implement") + + +def paged_attention_rocm( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + query_start_loc: Optional[torch.Tensor], + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + raise NotImplementedError("Do not use this in our implement") + + +def mla_decode_kvcache_cpu( + out: torch.Tensor, + query: torch.Tensor, + kv_cache: torch.Tensor, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, +) -> None: + raise NotImplementedError("Do not use this in our implement") + + +# pos encoding ops +def rotary_embedding( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + head_size: int, + cos_sin_cache: torch.Tensor, + is_neox: bool, +) -> None: + ops.vllm_rotary_embedding(positions, query, key, head_size, + cos_sin_cache, is_neox) + + +def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, + key: torch.Tensor, head_size: int, + cos_sin_cache: torch.Tensor, is_neox: bool, + rot_dim: int, + cos_sin_cache_offsets: torch.Tensor) -> None: + ops.vllm_batched_rotary_embedding(positions, query, key, head_size, + cos_sin_cache, is_neox, rot_dim, + cos_sin_cache_offsets) + + +# layer norm ops +def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, + epsilon: float) -> None: + ops.rms_norm(input, weight, epsilon, out) + + +def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, + weight: torch.Tensor, epsilon: float, + residual_alpha: Optional[float] = 1) -> None: + ops.residual_rms_norm(input=input, weight=weight, residual=residual, eps=epsilon, residual_alpha=residual_alpha) + + +def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int, + input_tokens: torch.Tensor, + sampled_token_ids: torch.Tensor, + input_positions: torch.Tensor, + seq_lens: torch.Tensor, slot_mapping: torch.Tensor, + block_tables: torch.Tensor) -> None: + """Advance a step on GPU for existing inputs for a multi-step runner""" + return ops.advance_step_flashattn(num_seqs, num_queries, + block_size, input_tokens, + sampled_token_ids, + input_positions, seq_lens, + slot_mapping, block_tables) + + +def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int, + input_tokens: torch.Tensor, + sampled_token_ids: torch.Tensor, + input_positions: torch.Tensor, + seq_lens: torch.Tensor, slot_mapping: torch.Tensor, + block_tables: torch.Tensor, + paged_kv_indices: torch.Tensor, + paged_kv_indptr: torch.Tensor, + paged_kv_last_page_len: torch.Tensor, + block_table_bound: torch.Tensor) -> None: + raise NotImplementedError("FIX soon") + + +# fused quant layer norm ops +def rms_norm_dynamic_per_token_quant( + input: torch.Tensor, + weight: torch.Tensor, + epsilon: float, + quant_dtype: torch.dtype, + scale_ub: Optional[torch.Tensor] = None, + residual: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError("FIX soon") + output = torch.empty_like(input, dtype=quant_dtype) + scales = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + + ops.rms_norm_dynamic_per_token_quant(output, input, weight, + scales, epsilon, scale_ub, + residual) + return output, scales + + +# quantization ops +# awq +def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, + zeros: torch.Tensor, split_k_iters: int, thx: int, + thy: int) -> torch.Tensor: + raise NotImplementedError("FIX soon") + + +def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, + pack_factor, group_size: int = 128) -> torch.Tensor: + return ops.wui4a16(input, qweight, scales, qzeros, None, group_size, "NN") + + +# gptq +def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, + b_g_idx: torch.Tensor, use_exllama: bool, + bit: int) -> torch.Tensor: + return ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros ,b_gptq_scales, + b_g_idx, use_exllama, bit) + + +if hasattr(ops, "gptq_gemm"): + + @register_fake("_C::gptq_gemm") + def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_gptq_qzeros: torch.Tensor, + b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor, + use_exllama: bool, bit: int) -> torch.Tensor: + return torch.empty((a.size(0), b_q_weight.size(1)), + dtype=a.dtype, + device=a.device) + + +def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, + bit: int) -> None: + ops.vllm_gptq_shuffle(q_weight, q_perm, bit) + + +# marlin +def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, + size_n: int, size_k: int) -> torch.Tensor: + raise NotImplementedError("FIX soon") + + +# marlin_24 +def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_meta: torch.Tensor, b_scales: torch.Tensor, + workspace: torch.Tensor, b_q_type: ScalarType, + size_m: int, size_n: int, size_k: int) -> torch.Tensor: + raise NotImplementedError("FIX soon") + + +if hasattr(ops, "gptq_marlin_24_gemm"): + + @register_fake("_C::gptq_marlin_24_gemm") + def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_meta: torch.Tensor, b_scales: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + + @register_fake("_C::gptq_marlin_gemm") + def _gptq_marlin_gemm_fake(a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + b_zeros: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt, + is_k_full: bool, + has_zp: bool = False, + use_atomic_add: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: + return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + + @register_fake("_C::marlin_qqq_gemm") + def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + s_tok: torch.Tensor, s_ch: torch.Tensor, + s_group: torch.Tensor, workspace: torch.Tensor, + size_m: torch.SymInt, size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), + dtype=torch.float16, + device=a.device) + + @register_fake("_C::marlin_gemm") + def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, + size_m: torch.SymInt, size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), + dtype=torch.float16, + device=a.device) + + @register_fake("_C::awq_dequantize") + def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor, + zeros: torch.Tensor, split_k_iters: torch.SymInt, + thx: int, thy: int) -> torch.Tensor: + in_c = qweight.size(0) + qout_c = qweight.size(1) + out_c = qout_c * 8 + return torch.empty((in_c, out_c), + dtype=scales.dtype, + device=scales.device) + + @register_fake("_C::awq_gemm") + def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor, + qzeros: torch.Tensor, scales: torch.Tensor, + split_k_iters: torch.SymInt) -> torch.Tensor: + num_in_feats = input.size(0) + return torch.empty((split_k_iters, num_in_feats, qweight.size(1) * 8), + dtype=input.dtype, + device=input.device).sum(0) + + @register_fake("_C::aqlm_gemm") + def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor, + codebooks: torch.Tensor, scales: torch.Tensor, + codebook_partition_sizes: list[int], + bias: Optional[torch.Tensor]) -> torch.Tensor: + out_features = codes.size(0) * codebooks.size(2) + flat_input = input.reshape((-1, input.size(-1))) + flat_output = torch.empty((flat_input.size(0), out_features), + dtype=input.dtype, + device=input.device) + + output_sizes = list(input.shape) + output_sizes.pop() + output_sizes.append(-1) + return flat_output.reshape(tuple(output_sizes)) + + @register_fake("_C::aqlm_dequant") + def _aqlm_dequant_fake( + codes: torch.Tensor, codebooks: torch.Tensor, + codebook_partition_sizes: list[int]) -> torch.Tensor: + in_features = codes.size(1) * 8 + out_features = codes.size(0) + return torch.empty((out_features, in_features), + dtype=codebooks.dtype, + device=codebooks.device) + + @register_fake("_C::fp8_marlin_gemm") + def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, + num_bits: int, size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) + + @register_fake("_C::machete_mm") + def machete_mm_fake( + a: torch.Tensor, + # b_q Should be the tensor returned by machete_prepack_B + b_q: torch.Tensor, + b_type: ScalarType, + out_type: Optional[torch.dtype] = None, + b_group_scales: Optional[torch.Tensor] = None, + b_group_zeros: Optional[torch.Tensor] = None, + b_group_size: Optional[int] = None, + b_channel_scales: Optional[torch.Tensor] = None, + a_token_scales: Optional[torch.Tensor] = None, + schedule: Optional[str] = None, + ) -> torch.Tensor: + m = a.size(0) + n = b_q.size(1) + return torch.empty((m, n), device=a.device, dtype=a.dtype) + + @register_fake("_C::machete_prepack_B") + def machete_prepack_B_fake( + b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType, + group_scales_type: Optional[torch.dtype]) -> torch.Tensor: + return torch.empty_like(b_q_weight, + memory_format=torch.contiguous_format) + + +if hasattr(torch.ops._C, "allspark_w8a16_gemm"): + + @register_fake("_C::allspark_w8a16_gemm") + def _allspark_w8a16_gemm_fake(a: torch.Tensor, b_qweight: torch.Tensor, + b_scales: torch.Tensor, + b_qzeros: Optional[torch.Tensor], + n: torch.SymInt, group_size: torch.SymInt, + sm_count: torch.SymInt, + sm_version: torch.SymInt, + CUBLAS_M_THRESHOLD: torch.SymInt, + has_zp: bool, + n32k16_reorder: bool) -> torch.Tensor: + m = a.size(0) + return torch.empty((m, n), device=a.device, dtype=a.dtype) + + +if hasattr(torch.ops._C, "ggml_dequantize"): + + @register_fake("_C::ggml_dequantize") + def _ggml_dequantize_fake( + W: torch.Tensor, + quant_type: int, + m: torch.SymInt, + n: torch.SymInt, + dtype: Optional[torch.dtype] = None) -> torch.Tensor: + return torch.empty((m, n), dtype=torch.float16, device=W.device) + + @register_fake("_C::ggml_mul_mat_vec_a8") + def _ggml_mul_mat_vec_a8_fake( + W: torch.Tensor, + X: torch.Tensor, + quant_type: int, + row: torch.SymInt, + ) -> torch.Tensor: + return torch.empty((1, row), dtype=X.dtype, device=W.device) + + @register_fake("_C::ggml_mul_mat_a8") + def _ggml_mul_mat_a8_fake( + W: torch.Tensor, + X: torch.Tensor, + quant_type: int, + row: torch.SymInt, + ) -> torch.Tensor: + batch = X.size(0) + return torch.empty((batch, row), dtype=X.dtype, device=W.device) + + @register_fake("_C::ggml_moe_a8") + def _ggml_moe_a8_fake( + X: torch.Tensor, + W: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + quant_type: int, + row: torch.SymInt, + top_k: torch.SymInt, + tokens: torch.SymInt, + ) -> torch.Tensor: + tokens = X.size(0) + return torch.empty((tokens * top_k, row), + dtype=torch.float16, + device=W.device) + + +# cutlass +def cutlass_scaled_mm_supports_fp4(cuda_device_capability: int) -> bool: + return torch.ops._C.cutlass_scaled_mm_supports_fp4(cuda_device_capability) + + +def cutlass_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor, + block_scale_a: torch.Tensor, + block_scale_b: torch.Tensor, alpha: torch.Tensor, + out_dtype: torch.dtype) -> torch.Tensor: + assert a.ndim == 2 and b.ndim == 2 + m, n = a.shape[0], b.shape[0] + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + torch.ops._C.cutlass_scaled_fp4_mm(out, a, b, block_scale_a, block_scale_b, + alpha) + return out + + +def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: + return False + + +def cutlass_scaled_mm_supports_block_fp8(cuda_device_capability: int) -> bool: + return False + + +def cutlass_scaled_mm(a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, + format: Optional[str] = "TN") -> torch.Tensor: + """ + `cutlass_scaled_mm` implements a fused version of + `output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)` + where scale_a * a and scale_b * b are implemented using numpy-style + broadcasting. + + In order to support blockwise scaling like found in DeepSeek V3 we also + support extended "group" broadcast rules. We extend the numpy-style + broadcasting rules with the following rule: + "if the extent of a dimension in the source shape is between 1 and + corresponding extent in the target shape we repeat each element along + that dimension src_shape[dim] // target_shape[dim] times consecutively" + example if we have: + a = [[1, 2], and target_shape = (2, 4) + [3, 4]] + then we would expand a to: + a = [[1, 1, 2, 2], + [3, 3, 4, 4]] + currently we only support the case: + scale_a.shape * [1, 128] == a.shape + scale_b.shape * [128, 128] == b.shape + """ + assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) + assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) + assert bias is None or bias.shape[0] == b.shape[ + 1] and bias.dtype == out_dtype + + + m = a.shape[0] + n = b.shape[1] + if format == "TN": + b = b.t() + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + ops.w8a8(a, b, scale_a, scale_b, bias, format=format, output=out, out_dtype=out_dtype) + + return out + + +def cutlass_scaled_mm_azp(a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + azp_adj: torch.Tensor, + azp: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + :param azp_adj: In the per-tensor case, this should include the azp. + Always per-channel. + :param azp: Only set in the per-token case. Per-token if set. + """ + raise NotImplementedError("FIX soon") + + +def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool: + return False + + +def cutlass_group_gemm_supported(cuda_device_capability: int) -> bool: + return torch.ops._C.cutlass_group_gemm_supported(cuda_device_capability) + +def cutlass_sparse_compress(a: torch.Tensor) \ + -> tuple[torch.Tensor, torch.Tensor]: + """ + Compresses a sparse matrix for use with Cutlass sparse operations. + + This function takes a dense tensor and compresses it into two components: + non-zero elements and metadata. The compressed representation is compatible + with Cutlass sparse kernels. + + Args: + a (torch.Tensor): + The input tensor to be compressed. Must have one of the following data types: + - `torch.int8` + - `torch.float8_e4m3fn` + - `torch.bfloat16` + - `torch.float16` + + Returns: + tuple[torch.Tensor, torch.Tensor]: + A tuple containing: + - `a_nzs` (torch.Tensor): A tensor containing non-zero elements of `a`. + - `a_meta` (torch.Tensor): A tensor containing metadata for the sparse representation. + + Raises: + ValueError: If the compression operation fails. + + Notes: + - The `a_meta` tensor has a data type of `torch.uint8`. + - Each metadata element encodes the sparsity of 4 non-zero elements (i.e., `elemsPerMetaElem = 4`). + - The shape of `a_nzs` is `(m, k // 2)`, where `m` and `k` are the dimensions of the input tensor. + - The shape of `a_meta` is `(m, k // 2 // elemsPerMetaElem)`. + """ + assert (a.dtype in [ + torch.int8, torch.float8_e4m3fn, torch.bfloat16, torch.float16 + ]) + assert (a.is_contiguous()) + + # a_meta.dtype: torch.uint8 so elemsPerMetaElem = 8b / 2b_per_nz = 4 + elemsPerMetaElem = 4 + assert (a.shape[1] % (2 * elemsPerMetaElem) == 0) + + return torch.ops._C.cutlass_sparse_compress(a) + + +def cutlass_scaled_sparse_mm( + a: torch.Tensor, + bt_nzs: torch.Tensor, + bt_meta: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Performs a scaled sparse matrix multiplication using Cutlass. + + Steps: + 1. Create a dense matrix `a` of shape (m, k) on the CUDA device: + `a = torch.randn((m, k), device='cuda')`. + + 2. Create a dense matrix `b` of shape (k, n) on the CUDA device: + `b = torch.randn((k, n), device='cuda')`. + + 3. Prune matrix `b` to 2:4 sparsity along the specified dimension: + `b = prune_to_2_4(b, dim=0)`. + + 4. Compress the transposed sparse matrix `b.t()`: + `bt_nzs, bt_meta = cutlass_sparse_compress(b.t())`. + + 5. Perform sparse matrix multiplication using the compressed matrix, + applying scaling factors for `a` and `b`, and the output data type: + `out = cutlass_scaled_sparse_mm(a, bt_nzs, bt_meta, scale_a, scale_b, out_dtype)`. + + Returns: + - The result of the scaled sparse matrix multiplication. + """ + assert (bt_nzs.shape[0] % 16 == 0 and bt_nzs.shape[1] % 16 == 0) + assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) + assert bias is None or bias.shape[0] == bt_nzs.shape[0] \ + and bias.dtype == out_dtype + + m = a.shape[0] + n = bt_nzs.shape[0] + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + torch.ops._C.cutlass_scaled_sparse_mm(out, a, bt_nzs, bt_meta, scale_a, + scale_b, bias) + + return out + + +def get_cutlass_moe_mm_data( + topk_ids: torch.Tensor, expert_offsets: torch.Tensor, + problem_sizes1: torch.Tensor, problem_sizes2: torch.Tensor, + input_permutation: torch.Tensor, output_permutation: torch.Tensor, + num_experts: int, n: int, k: int): + """ + Prepare data necessary to perform CUTLASS grouped matrix multiplications + used in CUTLASS-based fused MoE. + + The function takes in topk_ids (token-expert mapping) and uses it to + compute: + - expert_offsets: Indices that mark at which token index each expert begins + its computation after the input is sorted with + input_permutation. The number of tokens computed with + expert E is expert_offsets[E + 1] - expert_offsets[E] + - problem_sizes1, problem_sizes2: MxNxK sizes of each expert's + multiplication in two grouped MMs used in + the fused MoE operation. + - input_permutation: Permutation that must be used to shuffle the input + before executing the MMs. + - output_permutation: Permutation that must be used to shuffle the output + after executing the MMs. + """ + torch.ops._C.get_cutlass_moe_mm_data(topk_ids, expert_offsets, + problem_sizes1, problem_sizes2, + input_permutation, output_permutation, + num_experts, n, k) + + +def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, + b_tensors: torch.Tensor, a_scales: torch.Tensor, + b_scales: torch.Tensor, expert_offsets: torch.Tensor, + problem_sizes: torch.Tensor, a_strides: torch.Tensor, + b_strides: torch.Tensor, c_strides: torch.Tensor): + """ + A single grouped matrix multiplication used in CUTLASS-based fused MoE. + The function executes fp8-quantized OUT = AB matrix multiplication. + + - expert_offsets: Indices that mark at which token index each expert begins + its computation. The number of tokens computed with + expert E is expert_offsets[E + 1] - expert_offsets[E] + - problem_sizes: MxNxK sizes of each expert's multiplication in two grouped + MMs used in the fused MoE operation. + - a/b/c_strides: The data strides passed to grouped matrix multiplication. + """ + torch.ops._C.cutlass_moe_mm(out_tensors, a_tensors, b_tensors, a_scales, + b_scales, expert_offsets, problem_sizes, + a_strides, b_strides, c_strides) + + +# aqlm +def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor, + codebooks: torch.Tensor, scales: torch.Tensor, + codebook_partition_sizes: list[int], + bias: Optional[torch.Tensor]) -> torch.Tensor: + raise NotImplementedError("FIX soon") + + +def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor, + codebook_partition_sizes: List[int]) -> torch.Tensor: + raise NotImplementedError("FIX soon") + + +# gptq_marlin +def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, + size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + raise NotImplementedError("FIX soon") + + +# gptq_marlin +def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + raise NotImplementedError("FIX soon") + + +def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, + size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + raise NotImplementedError("FIX soon") + + +def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, + size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + raise NotImplementedError("FIX soon") + + +def gptq_marlin_gemm(a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + b_zeros: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool, + has_zp: bool = False, + use_atomic_add: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: + raise NotImplementedError("FIX soon") + + +# fp8 marlin +def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, + num_bits: int, size_m: int, size_n: int, + size_k: int) -> torch.Tensor: + raise NotImplementedError("FIX soon") + + +# machete +def machete_supported_schedules( + a_type: torch.dtype, + b_type: ScalarType, + group_scales_type: Optional[torch.dtype], + group_zeros_type: Optional[torch.dtype] = None, + channel_scales_type: Optional[torch.dtype] = None, + token_scales_type: Optional[torch.dtype] = None, + out_type: Optional[torch.dtype] = None) -> List[str]: + raise NotImplementedError("FIX soon") + + +def machete_mm( + a: torch.Tensor, + # b_q Should be the tensor returned by machete_prepack_B + b_q: torch.Tensor, + b_type: ScalarType, + out_type: Optional[torch.dtype] = None, + b_group_scales: Optional[torch.Tensor] = None, + b_group_zeros: Optional[torch.Tensor] = None, + b_group_size: Optional[int] = None, + b_channel_scales: Optional[torch.Tensor] = None, + a_token_scales: Optional[torch.Tensor] = None, + schedule: Optional[str] = None) -> torch.Tensor: + raise NotImplementedError("FIX soon") + + +def machete_prepack_B( + b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType, + group_scales_type: Optional[torch.dtype]) -> torch.Tensor: + raise NotImplementedError("FIX soon") + + +if hasattr(ops, "permute_cols"): + + @register_fake("_C::permute_cols") + def _permute_cols_fake(a: torch.Tensor, + perm: torch.Tensor) -> torch.Tensor: + return torch.empty_like(a) + + +def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: + raise NotImplementedError("FIX soon") + + +# fp4 +def scaled_fp4_quant( + input: torch.Tensor, + input_global_scale: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP4 and return quantized tensor and scale. + + This function quantizes the last dimension of the given tensor `input`. For + every 16 consecutive elements, a single dynamically computed scaling factor + is shared. This scaling factor is quantized using the `input_global_scale` + and is stored in a swizzled layout (see + https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x). + + Args: + input: The input tensor to be quantized to FP4 + input_global_scale: A scalar scaling factor for the entire tensor. + + Returns: + tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every + two values are packed into a uint8 and float8_e4m3 scaling factors + in the sizzled layout. + """ + assert not current_platform.is_rocm() + assert input.ndim >= 1, ( + f'input.ndim needs to be >= 1, but got {input.ndim}.') + other_dims = 1 if input.ndim == 1 else -1 + input = input.reshape(other_dims, input.shape[-1]) + m, n = input.shape + block_size = 16 + device = input.device + + assert n % block_size == 0, ( + f'last dim has to be multiple of 16, but got {n}.') + assert input.dtype in (torch.float16, torch.bfloat16), ( + f'input.dtype needs to be fp16 or bf16 but got {input.dtype}.') + + # Two fp4 values will be packed into an uint8. + output = torch.empty((m, n // 2), device=device, dtype=torch.uint8) + + # We use the rounded values to store the swizzled values. Due to the + # requirement of the Tensor Core, the minimum tile is 128x4 for the scales. + # So, we first pad the scales to multiples of 128 and 4. Then, the scales + # (in float8_e4m3fn) are packed into an int32 for every 4 values. More: + # https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x + round_up = lambda x, y: (x + y - 1) // y * y + rounded_m = round_up(m, 128) + scale_n = n // block_size + rounded_n = round_up(scale_n, 4) + output_scale = torch.empty((rounded_m, rounded_n // 4), + device=device, + dtype=torch.int32) + + torch.ops._C.scaled_fp4_quant(output, input, output_scale, + input_global_scale) + output_scale = output_scale.view(torch.float8_e4m3fn) + return output, output_scale + + +# fp8 +def scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + num_token_padding: Optional[int] = None, + scale_ub: Optional[torch.Tensor] = None, + use_per_token_if_dynamic: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP8 and return quantized tensor and scale. + + This function supports both static and dynamic quantization: If you + provide the scale, it will use static scaling and if you omit it, + the scale will be determined dynamically. The function also allows + optional padding of the output tensors for downstream kernels that + will benefit from padding. + + Args: + input: The input tensor to be quantized to FP8 + scale: Optional scaling factor for the FP8 quantization + scale_ub: Optional upper bound for scaling factor in dynamic + per token case + num_token_padding: If specified, pad the first dimension + of the output to at least this value. + use_per_token_if_dynamic: Whether to do per_tensor or per_token + in the dynamic quantization case. + + Returns: + tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + scaling factor. + """ + raise NotImplementedError("FIX soon") + + +# gptq allspark +def allspark_repack_weight( + qweight: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor] = None, + has_zp: bool = False +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Rearrange qweight, scale, and zero_point(if asymmetric) to n32k16 format + for Ampere W8A16 Fused Gemm kernel + + Args: + qweight: uint8 weight tensor, original k x n format. + scale: fp16/bf16 weight scale tensor, 1 x n format. + zero_point: fp16/bf16 weight zero_point tensor, 1 x n format. + Must be provided for asymmetric quantization. + has_zp: if use symmetric quantization, has_zp = False. + if use asymmetric quantization, has_zp = True. + + Returns: + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : + rearranged weight, scale, and optionally zero_point. + """ + K = qweight.shape[0] + N = qweight.shape[1] + N_32align = (N + 32 - 1) // 32 * 32 + + qweight_reorder = torch.empty((N_32align, K), + device=qweight.device, + dtype=qweight.dtype) + scale_reorder = torch.empty((1, N_32align), + device=scale.device, + dtype=scale.dtype) + zero_point_reorder = None + if has_zp: + assert zero_point is not None, ( + "zero_point must be provided for asymmetric quantization.") + zero_point_reorder = torch.empty((1, N_32align), + device=zero_point.device, + dtype=zero_point.dtype) + + torch.ops._C.rearrange_kn_weight_as_n32k16_order( + qweight, scale, zero_point, has_zp, qweight_reorder, scale_reorder, + zero_point_reorder, K, N, N_32align) + + return qweight_reorder, scale_reorder, zero_point_reorder + + +def allspark_w8a16_gemm(a: torch.Tensor, b_qweight: torch.Tensor, + b_scales: torch.Tensor, + b_qzeros: Optional[torch.Tensor], n: int, + group_size: int, sm_count: int, sm_version: int, + CUBLAS_M_THRESHOLD: int, has_zp: bool, + n32k16_reorder: bool) -> torch.Tensor: + + return torch.ops._C.allspark_w8a16_gemm(a, b_qweight, b_scales, b_qzeros, + n, group_size, sm_count, + sm_version, CUBLAS_M_THRESHOLD, + has_zp, n32k16_reorder) + + +# int8 +def scaled_int8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + azp: Optional[torch.Tensor] = None, + symmetric: bool = True +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. + + Args: + input: The input tensor to be quantized to int8. + scale: Optional scaling factor for the int8 quantization. + When not provided, we invoke dynamic-per-token quantization. + azp: Optional zero-point for the int8 quantization. + Must be provided for asymmetric quantization if `scale` is provided. + symmetric: Whether to use symmetric quantization (scale only, azp ignored). + + Returns: + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + """ + output = torch.empty_like(input, dtype=torch.int8) + if scale is not None: + # static-per-tensor quantization. + assert symmetric == ( + azp + is None), "azp must only be provided for asymmetric quantization." + ops.static_scaled_int8_quant(output, input, scale) + return output, scale, azp + + # dynamic-per-token quantization. + input_scales = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + input_azp = None if symmetric else torch.empty_like(input_scales, + dtype=torch.int32) + ops.dynamic_scaled_int8_quant(output, input, input_scales) + return output, input_scales, input_azp + + +# qqq ops +def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + s_tok: torch.Tensor, s_ch: torch.Tensor, + s_group: torch.Tensor, workspace: torch.Tensor, + size_m: int, size_n: int, size_k: int) -> torch.Tensor: + raise NotImplementedError("FIX soon") + + +# gguf +def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, n: int, + dtype: Optional[torch.dtype]) -> torch.Tensor: + raise NotImplementedError("FIX soon") + + +def ggml_mul_mat_vec_a8( + W: torch.Tensor, + X: torch.Tensor, + quant_type: int, + row: int, +) -> torch.Tensor: + raise NotImplementedError("FIX soon") + + +def ggml_mul_mat_a8( + W: torch.Tensor, + X: torch.Tensor, + quant_type: int, + row: int, +) -> torch.Tensor: + raise NotImplementedError("FIX soon") + + +def ggml_moe_a8( + X: torch.Tensor, + W: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + quant_type: int, + row: int, + top_k: int, + tokens: int, +) -> torch.Tensor: + return torch.ops._C.ggml_moe_a8(X, W, sorted_token_ids, expert_ids, + num_tokens_post_padded, quant_type, row, + top_k, tokens) + + +def ggml_moe_get_block_size(quant_type: int) -> int: + return torch.ops._C.ggml_moe_get_block_size(quant_type) + + +# mamba +def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor, + bias_: Optional[torch.Tensor], + conv_states: Optional[torch.Tensor], + query_start_loc: Optional[torch.Tensor], + cache_indices: Optional[torch.Tensor], + has_initial_state: Optional[torch.Tensor], + silu_activation: bool, pad_slot_id: int): + raise NotImplementedError("FIX soon") + + +def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor, + weight: torch.Tensor, bias_: Optional[torch.Tensor], + silu_activation: bool, + cache_seqlens: Optional[torch.Tensor], + conv_state_indices: Optional[torch.Tensor], + pad_slot_id: int): + raise NotImplementedError("FIX soon") + + +def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, + B: torch.Tensor, C: torch.Tensor, + D_: Optional[torch.Tensor], z_: Optional[torch.Tensor], + delta_bias_: Optional[torch.Tensor], + delta_softplus: bool, + query_start_loc: Optional[torch.Tensor], + cache_indices: Optional[torch.Tensor], + has_initial_state: Optional[torch.Tensor], + ssm_states: torch.Tensor, pad_slot_id: int): + raise NotImplementedError("FIX soon") + + +# moe +def moe_sum(input: torch.Tensor, output: torch.Tensor): + raise NotImplementedError("FIX soon") + + +def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, + block_size: int, sorted_token_ids: torch.Tensor, + experts_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor) -> None: + ops.vllm_moe_align_block_size(topk_ids, num_experts, block_size, + sorted_token_ids, experts_ids, + num_tokens_post_pad) + + +def sgl_moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, + block_size: int, sorted_token_ids: torch.Tensor, + experts_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor) -> None: + torch.ops._moe_C.sgl_moe_align_block_size(topk_ids, num_experts, + block_size, sorted_token_ids, + experts_ids, num_tokens_post_pad) + + +def moe_wna16_gemm(input: torch.Tensor, output: torch.Tensor, + b_qweight: torch.Tensor, b_scales: torch.Tensor, + b_qzeros: Optional[torch.Tensor], + topk_weights: Optional[torch.Tensor], + sorted_token_ids: torch.Tensor, experts_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, top_k: int, + BLOCK_SIZE_M: int, BLOCK_SIZE_N: int, BLOCK_SIZE_K: int, + bit: int) -> torch.Tensor: + if not current_platform.is_cuda(): + raise NotImplementedError( + "The optimized moe_wna16_gemm kernel is only " + "available on CUDA platforms") + torch.ops._moe_C.moe_wna16_gemm(input, output, b_qweight, b_scales, + b_qzeros, topk_weights, sorted_token_ids, + experts_ids, num_tokens_post_pad, top_k, + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, + bit) + + +def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, + token_expert_indicies: torch.Tensor, + gating_output: torch.Tensor) -> None: + ops.vllm_moe_topk_softmax(topk_weights, topk_ids, + token_expert_indicies, gating_output) + + +if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): + + @register_fake("_moe_C::marlin_gemm_moe") + def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor, + sorted_ids: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, b_scales: torch.Tensor, + b_zero_points: torch.Tensor, g_idx: torch.Tensor, + perm: torch.Tensor, workspace: torch.Tensor, + b_q_type: ScalarType, size_m: torch.SymInt, + size_n: torch.SymInt, size_k: torch.SymInt, + is_k_full: bool, num_experts: int, topk: int, + moe_block_size: int, replicate_input: bool, + apply_weights: bool) -> torch.Tensor: + return torch.empty((size_m, topk, size_n), + dtype=a.dtype, + device=a.device) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + raise NotImplementedError("Do not use this in our implement") + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash(key, value, key_cache, + value_cache, slot_mapping, + kv_cache_dtype, 1.0, 1.0) + + +def concat_and_cache_mla( + kv_c: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + scale: torch.Tensor, +) -> None: + ops.vllm_concat_and_cache_mla(kv_c, k_pe, kv_cache, + slot_mapping, kv_cache_dtype, + scale) + + +def concat_and_cache_mla_int8( + kv_c_int8: torch.Tensor, + kv_c_scale: torch.Tensor, + k_pe_int8: torch.Tensor, + k_pe_scale: torch.Tensor, + kv_cache: torch.Tensor, + kv_cache_scale: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + scale: torch.Tensor, +) -> None: + ops.vllm_concat_and_cache_mla_int8(kv_c_int8,kv_c_scale, k_pe_int8, k_pe_scale, kv_cache, kv_cache_scale, + slot_mapping, kv_cache_dtype, + scale) + + +def copy_blocks(key_caches: list[torch.Tensor], + value_caches: list[torch.Tensor], + block_mapping: torch.Tensor) -> None: + ops.vllm_copy_blocks(key_caches, value_caches, block_mapping) + + +def copy_blocks_mla(kv_caches: list[torch.Tensor], + block_mapping: torch.Tensor) -> None: + torch.ops._C_cache_ops.copy_blocks_mla(kv_caches, block_mapping) + + +def swap_blocks(src: torch.Tensor, dst: torch.Tensor, + block_mapping: torch.Tensor) -> None: + ops.vllm_swap_blocks(src, dst, block_mapping) + + +def convert_fp8(output: torch.Tensor, + input: torch.Tensor, + scale: float = 1.0, + kv_dtype: str = "fp8") -> None: + raise NotImplementedError("FIX soon") + + + +def gather_cache( + src_cache: torch.Tensor, # [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] + dst: torch.Tensor, # [TOT_TOKENS, ENTRIES...] + block_table: torch.Tensor, # [BATCH, BLOCK_INDICES] + cu_seq_lens: torch.Tensor, # [BATCH+1] + batch_size: int, + seq_starts: torch.Tensor = None +): + ops.vllm_gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size, seq_starts) + +def gather_cache_int8( + src_cache: torch.Tensor, # [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] + src_cache_scale: torch.Tensor,# [NUM_BLOCKS, BLOCK_SIZE, 2] + kv_lora_rank: int, + dst: torch.Tensor, # [TOT_TOKENS, ENTRIES...] + block_table: torch.Tensor, # [BATCH, BLOCK_INDICES] + cu_seq_lens: torch.Tensor, # [BATCH+1] + batch_size: int, + seq_starts: torch.Tensor = None +): + ops.vllm_gather_cache_int8(src_cache,src_cache_scale, kv_lora_rank, dst, block_table, cu_seq_lens, batch_size, seq_starts) + + +def get_device_attribute(attribute: int, device: int) -> int: + raise NotImplementedError("FIX soon") + + +def get_max_shared_memory_per_block_device_attribute(device: int) -> int: + return 32 * 1024 + + +# custom ar +def init_custom_ar(ipc_tensors: list[torch.Tensor], rank_data: torch.Tensor, + rank: int, fully_connected: bool) -> int: + raise NotImplementedError("Do not use this in our implement") + + +def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int, + reg_buffer_sz_bytes: int) -> None: + raise NotImplementedError("Do not use this in our implement") + + +def dispose(fa: int) -> None: + raise NotImplementedError("Do not use this in our implement") + + +def meta_size() -> int: + raise NotImplementedError("Do not use this in our implement") + + +def register_buffer(fa: int, ipc_tensors: List[int]) -> None: + raise NotImplementedError("Do not use this in our implement") + + +def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]: + raise NotImplementedError("Do not use this in our implement") + + +def register_graph_buffers(fa: int, handles: List[List[int]], + offsets: List[List[int]]) -> None: + raise NotImplementedError("Do not use this in our implement") + + +def allocate_shared_buffer_and_handle(size: int) -> tuple[int, torch.Tensor]: + return torch.ops._C_custom_ar.allocate_shared_buffer_and_handle(size) + + +def open_mem_handle(mem_handle: torch.Tensor): + return torch.ops._C_custom_ar.open_mem_handle(mem_handle) + + +def free_shared_buffer(ptr: int) -> None: + torch.ops._C_custom_ar.free_shared_buffer(ptr) + + +def get_flash_mla_metadata( + cache_seqlens: torch.Tensor, + num_heads_per_head_k: int, + num_heads_k: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + cache_seqlens: (batch_size), dtype torch.int32. + num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k. + num_heads_k: num_heads_k. + + Return: + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. + num_splits: (batch_size + 1), dtype torch.int32. + """ + return torch.ops._C.get_flash_mla_metadata(cache_seqlens, + num_heads_per_head_k, + num_heads_k) + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + head_dim_v: int, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, + softmax_scale: Optional[float] = None, + causal: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + q: (batch_size, seq_len_q, num_heads_q, head_dim). + k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). + block_table: (batch_size, max_num_blocks_per_seq), torch.int32. + cache_seqlens: (batch_size), torch.int32. + head_dim_v: Head_dim of v. + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return by get_mla_metadata. + num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata. + softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim). + causal: bool. Whether to apply causal attention mask. + + Return: + out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. + """ + if softmax_scale is None: + softmax_scale = q.shape[-1]**(-0.5) + out, softmax_lse = torch.ops._C.flash_mla_fwd_kvcache( + q, + k_cache, + None, + head_dim_v, + cache_seqlens, + block_table, + softmax_scale, + causal, + tile_scheduler_metadata, + num_splits, + ) + return out, softmax_lse + + +# Add our new features here.. + +# mee +def invoke_fused_moe_kernel( + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, + top_k: int, + config: Dict[str, Any], + compute_type, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + block_shape: Optional[List[int]] = None, +) -> None: + ops.vllm_invoke_fused_moe_kernel( + A, + B, + C, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + mul_routed_weight, + top_k, + config['BLOCK_SIZE_M'] + ) + + +# broadcast +class Async_helper(): + # For now, the comm and the other kernels are in the same stream, so we can remove the stream wait.. + def wait(self,): + return True + + +def broadcast(tensor, src=0, group=None, async_op=False): + cdist.broadcast(tensor,src,group,async_op=True) + if async_op: + return Async_helper() + else: + pass + + +# w8a16 +def linear_w8a16(x: torch.Tensor, qweight: torch.Tensor, scales:torch.Tensor, + group_size: int = -1, format: str = "TN")-> torch.Tensor: + return ops.w8a16(x, qweight, scales, format="TN", group_size=group_size) + + +## lora sgmv / bgmv +def sbgmv_expand(x: torch.Tensor, + w_t_all: torch.Tensor, + y: torch.Tensor, + b_seq_start_loc: torch.Tensor = None, + seq_len_tensor: torch.Tensor = None, + lora_indices_tensor: torch.Tensor = None, + batches: int = -1, + max_seq_length: int = -1, + token_nums: int = -1, + add_input=True, + ): + ''' + x: inputs + w_t_all: lora weight + y: output + + y += x@wt_t_all + ''' + assert x.dtype in [torch.float16, torch.bfloat16, torch.float32] + assert w_t_all.dtype in [ + torch.float16, + torch.bfloat16, + ] + + assert x.is_contiguous() + # assert y.is_contiguous() + if x.dtype == torch.float: + x = x.to(w_t_all.dtype) + + if w_t_all.ndim == 4: # shape:(lora_num,1,size,rank) + assert w_t_all.size(1) == 1 + w_t_all = w_t_all.squeeze(dim=1) + else: + assert w_t_all.ndim == 3 # shape:(lora_num,size,rank) + assert w_t_all.is_contiguous() + + assert add_input == True + + lora_indices = lora_indices_tensor.cpu().tolist() + lora_num = w_t_all.shape[0] + + ## 单一lora model, 且所有request均使用lora + if lora_num == 1 and all(x == lora_indices[0] for x in lora_indices): + if lora_indices[0] != -1: + w_t = w_t_all[0] + y += torch.matmul(x, w_t.t()) + ## 多个lora model + else: + ## prefill + if batches != -1: + for i, lora_id, start, seq_len in zip(range(batches), lora_indices, b_seq_start_loc, seq_len_tensor): + if lora_id != -1: + xi = x[start: start+seq_len] + w_t = w_t_all[lora_id] + y[start:start+seq_len] += (xi @ w_t.t()) + ## decode + else: + batches = x.shape[0] + for i, lora_id in zip(range(batches), lora_indices): + if lora_id != -1: + xi = x[i].unsqueeze(0) + w_t = w_t_all[lora_id] + y[i] += (xi @ w_t.t()).squeeze(0) + + return y + + +def sbgmv_shrink(x: torch.Tensor, + w_t_all: torch.Tensor, + y: torch.Tensor, + b_seq_start_loc: torch.Tensor = None, + seq_len_tensor: torch.Tensor = None, + lora_indices_tensor: torch.Tensor = None, + batches: int = -1, + max_seq_length: int = -1, + token_nums: int = -1, + scale: float = 1.0,): + """ + xx: inputs + w_t_all: lora weight + y: output + scale: float + + y = x@w_t_all * scale + """ + assert x.dtype == w_t_all.dtype + assert x.dtype in [torch.float16, torch.bfloat16] + assert x.is_contiguous() + assert y.is_contiguous() + + if w_t_all.ndim == 4: # shape:(lora_num,1,size,rank) + assert w_t_all.size(1) == 1 + w_t_all = w_t_all.squeeze(dim=1) + else: + assert w_t_all.ndim == 3 # shape:(lora_num,size,rank) + assert w_t_all.is_contiguous() + + lora_num = w_t_all.shape[0] + lora_indices = lora_indices_tensor.cpu().tolist() + + ## 单一lora model, 且所有request均使用lora + if lora_num == 1 and all(x == lora_indices[0] for x in lora_indices): + if lora_indices[0] != -1: + w_t = w_t_all[0] + y = torch.matmul(x, w_t.t()) * scale + ## 多个lora model + else: + ## prefill + if batches != -1: + for i, lora_id, start, seq_len in zip(range(batches), lora_indices, b_seq_start_loc, seq_len_tensor): + if lora_id != -1: + xi = x[start: start+seq_len] + w_t = w_t_all[lora_id] + y[start:start+seq_len] = (xi @ w_t.t())* scale + ## decode + else: + batches = x.shape[0] + for i, lora_id in zip(range(batches), lora_indices): + if lora_id != -1: + xi = x[i].unsqueeze(0) + w_t = w_t_all[lora_id] + y[i] = (xi @ w_t.t()).squeeze(0) * scale + + return y + +def dynamic_scaled_quant_dynamic_int8(x, input_scales=None, int8_out=None, scales=None): + return ops.dynamic_scaled_quant_smoothquant(x, input_scales, int8_out, scales) + +weak_ref_tensor = ops.weak_ref_tensor diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py new file mode 100644 index 0000000..c3d210c --- /dev/null +++ b/vllm/_ipex_ops.py @@ -0,0 +1,241 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import torch + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +try: + import intel_extension_for_pytorch as ipex +except ImportError as e: + logger.warning("Import error msg: %s", e.msg) + + +class ipex_ops: + + @staticmethod + def _reshape_activation_tensor( + x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + num = x.size(0) + d = x.size(1) // 2 + x = x.reshape(num, 2, d) + x1, x2 = torch.chunk(x, chunks=2, dim=1) + x1 = x1.reshape(num, d) + x2 = x2.reshape(num, d) + return x1, x2 + + @staticmethod + def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + ipex.llm.functional.silu_and_mul(x, out) + + @staticmethod + def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + ipex.llm.functional.gelu_and_mul(x, out) + + @staticmethod + def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + ipex.llm.functional.gelu_and_mul(x, out) + + @staticmethod + def gelu_fast(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x) + + @staticmethod + def gelu_new(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x) + + @staticmethod + def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None: + ipex.llm.functional.gelu_quick(x, out) + + @staticmethod + def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + block_size: int, + max_context_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, + ) -> None: + assert kv_cache_dtype == "auto" + num_heads = out.size(1) + num_queries_per_tokens = num_heads // num_kv_heads + ipex.llm.modules.PagedAttention.single_query_kv_attention( + out, + query.contiguous(), + key_cache.view_as(value_cache), + value_cache, + num_queries_per_tokens, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + ) + + @staticmethod + def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + block_size: int, + max_context_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, + ) -> None: + assert kv_cache_dtype == "auto" + num_heads = out.size(1) + num_queries_per_tokens = num_heads // num_kv_heads + ipex.llm.modules.PagedAttention.single_query_kv_attention( + out, + query.contiguous(), + key_cache.view_as(value_cache), + value_cache, + num_queries_per_tokens, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + ) + + @staticmethod + def rotary_embedding( + positions: torch.Tensor, # [batch_size, seq_len] + query: torch.Tensor, # [batch_size, seq_len, num_heads*head_size] + key: torch.Tensor, # [batch_size, seq_len, num_kv_heads*head_size] + head_size: int, + cos_sin_cache: torch.Tensor, # [cos_sin_dim, rot_dim] + is_neox: bool, + ) -> None: + rot_dim = cos_sin_cache.size(1) + ipex.llm.functional.rotary_embedding_batched(positions, query, key, + head_size, cos_sin_cache, + is_neox, rot_dim) + + @staticmethod + def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, + key: torch.Tensor, head_size: int, + cos_sin_cache: torch.Tensor, is_neox: bool, + rot_dim: int, + cos_sin_cache_offsets: torch.Tensor) -> None: + ipex.llm.functional.rotary_embedding_batched(positions, query, key, + head_size, cos_sin_cache, + is_neox, rot_dim, + cos_sin_cache_offsets) + + @staticmethod + def rms_norm(input: torch.Tensor, weight: torch.Tensor, + epsilon: float) -> torch.Tensor: + return ipex.llm.functional.rms_norm(input, weight, epsilon) + + @staticmethod + def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, + weight: torch.Tensor, epsilon: float) -> None: + tmp = ipex.llm.functional.add_rms_norm(residual, input, weight, None, + epsilon, True) + input.copy_(tmp) + + @staticmethod + def varlen_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + seqlen_q: torch.Tensor, + seqlen_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + pdropout: float, + softmax_scale: float, + zero_tensors: bool, + is_causal: bool, + return_softmax: bool, + gen_: torch.Generator, + logits_soft_cap: float, + ) -> None: + if ipex.__version__.endswith("cpu"): + if logits_soft_cap != 0.0: + raise ValueError("IPEX CPU does not support logits_soft_cap") + ipex.llm.functional.varlen_attention(query.contiguous(), + key.contiguous(), + value.contiguous(), out, + seqlen_q.int(), + seqlen_k.int(), max_seqlen_q, + max_seqlen_k, pdropout, + softmax_scale, zero_tensors, + is_causal, return_softmax, + gen_) + else: # XPU build + ipex.llm.functional.varlen_attention(query.contiguous(), + key.contiguous(), + value.contiguous(), out, + seqlen_q.int(), + seqlen_k.int(), max_seqlen_q, + max_seqlen_k, pdropout, + softmax_scale, zero_tensors, + is_causal, return_softmax, + gen_, logits_soft_cap) + + @staticmethod + def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + ) -> None: + assert kv_cache_dtype == "auto" + ipex.llm.modules.PagedAttention.reshape_and_cache( + key, value, key_cache, value_cache, slot_mapping) + + @staticmethod + def copy_blocks(key_caches: list[torch.Tensor], + value_caches: list[torch.Tensor], + block_mapping: torch.Tensor) -> None: + torch.xpu.copy_blocks( # type: ignore + key_caches, + value_caches, + block_mapping, + ) + + @staticmethod + def swap_blocks(src: torch.Tensor, dst: torch.Tensor, + block_mapping: torch.Tensor) -> None: + torch.xpu.swap_blocks(src, dst, block_mapping) # type: ignore diff --git a/vllm/adapter_commons/__init__.py b/vllm/adapter_commons/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/adapter_commons/layers.py b/vllm/adapter_commons/layers.py new file mode 100644 index 0000000..18e0c52 --- /dev/null +++ b/vllm/adapter_commons/layers.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import Tuple + + +@dataclass +class AdapterMapping: + # Per every token in input_ids: + index_mapping: Tuple[int, ...] + # Per sampled token: + prompt_mapping: Tuple[int, ...] + + def __post_init__(self): + self.index_mapping = tuple(self.index_mapping) + self.prompt_mapping = tuple(self.prompt_mapping) \ No newline at end of file diff --git a/vllm/adapter_commons/models.py b/vllm/adapter_commons/models.py new file mode 100644 index 0000000..f9a5d2f --- /dev/null +++ b/vllm/adapter_commons/models.py @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, Optional, TypeVar + +from torch import nn + +from vllm.logger import init_logger +from vllm.utils import LRUCache + +logger = init_logger(__name__) + + +class AdapterModel(ABC): + + def __init__(self, model_id=None): + self.id = model_id + + @abstractmethod + def from_local_checkpoint(cls, model_dir, model_id=None, **kwargs): + # Common initialization code + # Load weights or embeddings from local checkpoint + raise NotImplementedError("Subclasses must implement this method.") + + +T = TypeVar('T') + + +class AdapterLRUCache(LRUCache[int, T]): + + def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]): + super().__init__(capacity) + self.deactivate_fn = deactivate_fn + + def _on_remove(self, key: int, value: Optional[T]): + logger.debug("Removing adapter int id: %d", key) + self.deactivate_fn(key) + return super()._on_remove(key, value) + + +class AdapterModelManager(ABC): + + def __init__( + self, + model: nn.Module, + ): + """Create a AdapterModelManager and adapter for a given model. + Args: + model: the model to be adapted. + """ + self.model: nn.Module = model + self._registered_adapters: Dict[int, Any] = {} + # Dict instead of a Set for compatibility with LRUCache. + self._active_adapters: Dict[int, None] = {} + self.adapter_type = 'Adapter' + self._last_mapping = None + + def __len__(self) -> int: + return len(self._registered_adapters) + + @property + @abstractmethod + def adapter_slots(self) -> int: + raise NotImplementedError + + @property + @abstractmethod + def capacity(self) -> int: + raise NotImplementedError + + @abstractmethod + def activate_adapter(self, adapter_id: int) -> bool: + raise NotImplementedError + + @abstractmethod + def deactivate_adapter(self, adapter_id: int) -> bool: + raise NotImplementedError + + @abstractmethod + def add_adapter(self, adapter: Any) -> bool: + raise NotImplementedError + + @abstractmethod + def set_adapter_mapping(self, mapping: Any) -> None: + raise NotImplementedError + + @abstractmethod + def remove_adapter(self, adapter_id: int) -> bool: + raise NotImplementedError + + @abstractmethod + def remove_all_adapters(self) -> None: + raise NotImplementedError + + @abstractmethod + def get_adapter(self, adapter_id: int) -> Optional[Any]: + raise NotImplementedError + + @abstractmethod + def list_adapters(self) -> Dict[int, Any]: + raise NotImplementedError + + @abstractmethod + def pin_adapter(self, adapter_id: int) -> bool: + raise NotImplementedError diff --git a/vllm/adapter_commons/request.py b/vllm/adapter_commons/request.py new file mode 100644 index 0000000..2b604b9 --- /dev/null +++ b/vllm/adapter_commons/request.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod + + +class AdapterRequest(ABC): + """ + Base class for adapter requests. + """ + + @property + @abstractmethod + def adapter_id(self) -> int: + raise NotImplementedError + + def __post_init__(self) -> None: + if self.adapter_id < 1: + raise ValueError(f"id must be > 0, got {self.adapter_id}") + + def __eq__(self, value: object) -> bool: + return isinstance( + value, self.__class__) and self.adapter_id == value.adapter_id + + def __hash__(self) -> int: + return hash(self.adapter_id) diff --git a/vllm/adapter_commons/utils.py b/vllm/adapter_commons/utils.py new file mode 100644 index 0000000..c2dc543 --- /dev/null +++ b/vllm/adapter_commons/utils.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Dict, Optional, Set + + +## model functions +def deactivate_adapter(adapter_id: int, active_adapters: Dict[int, None], + deactivate_func: Callable) -> bool: + if adapter_id in active_adapters: + deactivate_func(adapter_id) + active_adapters.pop(adapter_id) + return True + return False + + +def add_adapter(adapter: Any, registered_adapters: Dict[int, Any], + capacity: int, add_func: Callable) -> bool: + if adapter.id not in registered_adapters: + if len(registered_adapters) >= capacity: + raise RuntimeError('No free adapter slots.') + add_func(adapter) + registered_adapters[adapter.id] = adapter + return True + return False + + +def set_adapter_mapping(mapping: Any, last_mapping: Any, + set_mapping_func: Callable) -> Any: + if last_mapping != mapping: + set_mapping_func(mapping) + return mapping + return last_mapping + + +def remove_adapter(adapter_id: int, registered_adapters: Dict[int, Any], + deactivate_func: Callable) -> bool: + deactivate_func(adapter_id) + return bool(registered_adapters.pop(adapter_id, None)) + + +def list_adapters(registered_adapters: Dict[int, Any]) -> Dict[int, Any]: + return dict(registered_adapters) + + +def get_adapter(adapter_id: int, + registered_adapters: Dict[int, Any]) -> Optional[Any]: + return registered_adapters.get(adapter_id) + + +## worker functions +def set_active_adapters_worker(requests: Set[Any], mapping: Optional[Any], + apply_adapters_func, + set_adapter_mapping_func) -> None: + apply_adapters_func(requests) + set_adapter_mapping_func(mapping) + + +def add_adapter_worker(adapter_request: Any, list_adapters_func, + load_adapter_func, add_adapter_func, + activate_adapter_func) -> bool: + if adapter_request.adapter_id in list_adapters_func(): + return False + loaded_adapter = load_adapter_func(adapter_request) + loaded = add_adapter_func(loaded_adapter) + activate_adapter_func(loaded_adapter.id) + return loaded + + +def apply_adapters_worker(adapter_requests: Set[Any], list_adapters_func, + adapter_slots: int, remove_adapter_func, + add_adapter_func) -> None: + models_that_exist = list_adapters_func() + models_map = { + adapter_request.adapter_id: adapter_request + for adapter_request in adapter_requests if adapter_request + } + if len(models_map) > adapter_slots: + raise RuntimeError( + f"Number of requested models ({len(models_map)}) is greater " + f"than the number of GPU model slots " + f"({adapter_slots}).") + new_models = set(models_map) + models_to_add = new_models - models_that_exist + models_to_remove = models_that_exist - new_models + for adapter_id in models_to_remove: + remove_adapter_func(adapter_id) + for adapter_id in models_to_add: + add_adapter_func(models_map[adapter_id]) + + +def list_adapters_worker(adapter_manager_list_adapters_func) -> Set[int]: + return set(adapter_manager_list_adapters_func()) diff --git a/vllm/adapter_commons/worker_manager.py b/vllm/adapter_commons/worker_manager.py new file mode 100644 index 0000000..ce24e08 --- /dev/null +++ b/vllm/adapter_commons/worker_manager.py @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from typing import Any, Optional, Set + +import torch + + +class AbstractWorkerManager(ABC): + + def __init__(self, device: torch.device): + self.device = device + + @property + @abstractmethod + def is_enabled(self) -> bool: + raise NotImplementedError + + @abstractmethod + def set_active_adapters(self, requests: Set[Any], + mapping: Optional[Any]) -> None: + raise NotImplementedError + + @abstractmethod + def add_adapter(self, adapter_request: Any) -> bool: + raise NotImplementedError + + @abstractmethod + def remove_adapter(self, adapter_id: int) -> bool: + raise NotImplementedError + + @abstractmethod + def remove_all_adapters(self) -> None: + raise NotImplementedError + + @abstractmethod + def list_adapters(self) -> Set[int]: + raise NotImplementedError diff --git a/vllm/assets/__init__.py b/vllm/assets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/assets/audio.py b/vllm/assets/audio.py new file mode 100644 index 0000000..0203dc0 --- /dev/null +++ b/vllm/assets/audio.py @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from pathlib import Path +from typing import Literal +from urllib.parse import urljoin + +import numpy.typing as npt + +from vllm.utils import PlaceholderModule + +from .base import VLLM_S3_BUCKET_URL, get_vllm_public_assets + +try: + import librosa +except ImportError: + librosa = PlaceholderModule("librosa") # type: ignore[assignment] + +ASSET_DIR = "multimodal_asset" + + +@dataclass(frozen=True) +class AudioAsset: + name: Literal["winning_call", "mary_had_lamb"] + + @property + def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]: + audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg", + s3_prefix=ASSET_DIR) + return librosa.load(audio_path, sr=None) + + def get_local_path(self) -> Path: + return get_vllm_public_assets(filename=f"{self.name}.ogg", + s3_prefix=ASSET_DIR) + + @property + def url(self) -> str: + return urljoin(VLLM_S3_BUCKET_URL, f"{ASSET_DIR}/{self.name}.ogg") diff --git a/vllm/assets/base.py b/vllm/assets/base.py new file mode 100644 index 0000000..03f3b9d --- /dev/null +++ b/vllm/assets/base.py @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: Apache-2.0 + +from functools import lru_cache +from pathlib import Path +from typing import Optional + +import vllm.envs as envs +from vllm.connections import global_http_connection + +VLLM_S3_BUCKET_URL = "https://vllm-public-assets.s3.us-west-2.amazonaws.com" + + +def get_cache_dir() -> Path: + """Get the path to the cache for storing downloaded assets.""" + path = Path(envs.VLLM_ASSETS_CACHE) + path.mkdir(parents=True, exist_ok=True) + + return path + + +@lru_cache +def get_vllm_public_assets(filename: str, + s3_prefix: Optional[str] = None) -> Path: + """ + Download an asset file from ``s3://vllm-public-assets`` + and return the path to the downloaded file. + """ + asset_directory = get_cache_dir() / "vllm_public_assets" + asset_directory.mkdir(parents=True, exist_ok=True) + + asset_path = asset_directory / filename + if not asset_path.exists(): + if s3_prefix is not None: + filename = s3_prefix + "/" + filename + global_http_connection.download_file( + f"{VLLM_S3_BUCKET_URL}/{filename}", + asset_path, + timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT) + + return asset_path diff --git a/vllm/assets/image.py b/vllm/assets/image.py new file mode 100644 index 0000000..2b1d258 --- /dev/null +++ b/vllm/assets/image.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import Literal + +import torch +from PIL import Image + +from .base import get_vllm_public_assets + +VLM_IMAGES_DIR = "vision_model_images" + + +@dataclass(frozen=True) +class ImageAsset: + name: Literal["stop_sign", "cherry_blossom"] + + @property + def pil_image(self) -> Image.Image: + image_path = get_vllm_public_assets(filename=f"{self.name}.jpg", + s3_prefix=VLM_IMAGES_DIR) + return Image.open(image_path) + + @property + def image_embeds(self) -> torch.Tensor: + """ + Image embeddings, only used for testing purposes with llava 1.5. + """ + image_path = get_vllm_public_assets(filename=f"{self.name}.pt", + s3_prefix=VLM_IMAGES_DIR) + return torch.load(image_path, map_location="cpu", weights_only=True) diff --git a/vllm/assets/video.py b/vllm/assets/video.py new file mode 100644 index 0000000..32b0b86 --- /dev/null +++ b/vllm/assets/video.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from functools import lru_cache +from typing import Literal + +import cv2 +import numpy as np +import numpy.typing as npt +from huggingface_hub import hf_hub_download +from PIL import Image + +from .base import get_cache_dir + + +@lru_cache +def download_video_asset(filename: str) -> str: + """ + Download and open an image from huggingface + repo: raushan-testing-hf/videos-test + """ + video_directory = get_cache_dir() / "video-example-data" + video_directory.mkdir(parents=True, exist_ok=True) + + video_path = video_directory / filename + video_path_str = str(video_path) + if not video_path.exists(): + video_path_str = hf_hub_download( + repo_id="raushan-testing-hf/videos-test", + filename=filename, + repo_type="dataset", + cache_dir=video_directory, + ) + return video_path_str + + +def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray: + cap = cv2.VideoCapture(path) + if not cap.isOpened(): + raise ValueError(f"Could not open video file {path}") + + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + frames = [] + + num_frames = num_frames if num_frames > 0 else total_frames + frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) + for idx in range(total_frames): + ok = cap.grab() # next img + if not ok: + break + if idx in frame_indices: # only decompress needed + ret, frame = cap.retrieve() + if ret: + frames.append(frame) + + frames = np.stack(frames) + if len(frames) < num_frames: + raise ValueError(f"Could not read enough frames from video file {path}" + f" (expected {num_frames} frames, got {len(frames)})") + return frames + + +def video_to_pil_images_list(path: str, + num_frames: int = -1) -> list[Image.Image]: + frames = video_to_ndarrays(path, num_frames) + return [ + Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + for frame in frames + ] + + +@dataclass(frozen=True) +class VideoAsset: + name: Literal["sample_demo_1.mp4"] + num_frames: int = -1 + + @property + def pil_images(self) -> list[Image.Image]: + video_path = download_video_asset(self.name) + ret = video_to_pil_images_list(video_path, self.num_frames) + return ret + + @property + def np_ndarrays(self) -> npt.NDArray: + video_path = download_video_asset(self.name) + ret = video_to_ndarrays(video_path, self.num_frames) + return ret diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py new file mode 100644 index 0000000..85c5715 --- /dev/null +++ b/vllm/attention/__init__.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionState, AttentionType) +from vllm.attention.layer import Attention +from vllm.attention.selector import get_attn_backend + +__all__ = [ + "Attention", + "AttentionBackend", + "AttentionMetadata", + "AttentionType", + "AttentionMetadataBuilder", + "Attention", + "AttentionState", + "get_attn_backend", +] diff --git a/vllm/attention/backends/__init__.py b/vllm/attention/backends/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py new file mode 100644 index 0000000..82d60f9 --- /dev/null +++ b/vllm/attention/backends/abstract.py @@ -0,0 +1,301 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from contextlib import contextmanager +from dataclasses import dataclass, fields +from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, + Protocol, Set, Tuple, Type, TypeVar) + +import torch + +from vllm.multimodal import MultiModalPlaceholderMap + +if TYPE_CHECKING: + from vllm.worker.model_runner_base import (ModelRunnerBase, + ModelRunnerInputBase, + ModelRunnerInputBuilderBase) + + +class AttentionType: + """ + Attention type. + Use string to be compatible with `torch.compile`. + """ + # Decoder attention between previous layer Q/K/V + DECODER = "decoder" + # Encoder attention between previous layer Q/K/V for encoder-decoder + ENCODER = "encoder" + # Encoder attention between previous layer Q/K/V + ENCODER_ONLY = "encoder_only" + # Attention between dec. Q and enc. K/V for encoder-decoder + ENCODER_DECODER = "encoder_decoder" + + +class AttentionBackend(ABC): + """Abstract class for attention backends.""" + # For some attention backends, we allocate an output tensor before + # calling the custom op. When piecewise cudagraph is enabled, this + # makes sure the output tensor is allocated inside the cudagraph. + accept_output_buffer: bool = False + + @staticmethod + @abstractmethod + def get_name() -> str: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_impl_cls() -> Type["AttentionImpl"]: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_state_cls() -> Type["AttentionState"]: + raise NotImplementedError + + @classmethod + def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": + return cls.get_metadata_cls()(*args, **kwargs) + + @staticmethod + @abstractmethod + def get_builder_cls() -> Type["AttentionMetadataBuilder"]: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + raise NotImplementedError + + @staticmethod + @abstractmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + raise NotImplementedError + + @staticmethod + @abstractmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + raise NotImplementedError + + def advance_step(self, model_input: "ModelRunnerInputBase", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, num_seqs: int, num_queries: int) -> None: + raise NotImplementedError + + +@dataclass +class AttentionMetadata: + """Attention metadata for prefill and decode batched together.""" + # Total number of prefill requests. + num_prefills: int + # Number of prefill tokens. + num_prefill_tokens: int + # Number of decode tokens. Note that it is equivalent to the number of + # decode requests. + num_decode_tokens: int + # (num_tokens,). The indices of the token slots that input tokens will be + # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size + # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot + # in block 0, and 1st slot in block 1, respectively. + slot_mapping: torch.Tensor + + # The index maps that relate multi-modal embeddings to the corresponding + # placeholders. + # + # N.B. These aren't really related to attention and don't belong on this + # type -- this is just a temporary solution to make them available to + # `model_executable`. + multi_modal_placeholder_index_maps: Optional[Dict[ + str, MultiModalPlaceholderMap.IndexMap]] + + # Enable/disable KV scales calculation. This is so that we can disable the + # calculation until after prefill and cuda graph capture. + enable_kv_scales_calculation: bool + + @property + @abstractmethod + def prefill_metadata(self) -> Optional["AttentionMetadata"]: + """Return the attention metadata that's required to run prefill + attention.""" + pass + + @property + @abstractmethod + def decode_metadata(self) -> Optional["AttentionMetadata"]: + """Return the attention metadata that's required to run decode + attention.""" + pass + + def asdict_zerocopy(self, + skip_fields: Optional[Set[str]] = None + ) -> Dict[str, Any]: + """Similar to dataclasses.asdict, but avoids deepcopying.""" + if skip_fields is None: + skip_fields = set() + # Note that if we add dataclasses as fields, they will need + # similar handling. + return { + field.name: getattr(self, field.name) + for field in fields(self) if field.name not in skip_fields + } + + +T = TypeVar("T", bound=AttentionMetadata) + + +class AttentionState(ABC, Generic[T]): + """Holds attention backend-specific objects reused during the + lifetime of the model runner.""" + + @abstractmethod + def __init__(self, runner: "ModelRunnerBase"): + ... + + @abstractmethod + @contextmanager + def graph_capture(self, max_batch_size: int): + """Context manager used when capturing CUDA graphs.""" + yield + + @abstractmethod + def graph_clone(self, batch_size: int) -> "AttentionState[T]": + """Clone attention state to save in CUDA graph metadata.""" + ... + + @abstractmethod + def graph_capture_get_metadata_for_batch( + self, + batch_size: int, + is_encoder_decoder_model: bool = False) -> T: + """Get attention metadata for CUDA graph capture of batch_size.""" + ... + + @abstractmethod + def get_graph_input_buffers( + self, + attn_metadata: T, + is_encoder_decoder_model: bool = False) -> Dict[str, Any]: + """Get attention-specific input buffers for CUDA graph capture.""" + ... + + @abstractmethod + def prepare_graph_input_buffers( + self, + input_buffers: Dict[str, Any], + attn_metadata: T, + is_encoder_decoder_model: bool = False) -> None: + """In-place modify input buffers dict for CUDA graph replay.""" + ... + + @abstractmethod + def begin_forward(self, model_input: "ModelRunnerInputBase") -> None: + """Prepare state for forward pass.""" + ... + + +class AttentionMetadataBuilder(ABC, Generic[T]): + """Abstract class for attention metadata builders.""" + + @abstractmethod + def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None: + """Create the builder, remember some configuration and parameters.""" + raise NotImplementedError + + @abstractmethod + def prepare(self) -> None: + """Prepare for one batch.""" + raise NotImplementedError + + @abstractmethod + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int) -> T: + """Build attention metadata with on-device tensors.""" + raise NotImplementedError + + +class AttentionLayer(Protocol): + + _q_scale: torch.Tensor + _k_scale: torch.Tensor + _v_scale: torch.Tensor + _k_scale_float: float + _v_scale_float: float + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + ... + + +class AttentionImpl(ABC, Generic[T]): + + @abstractmethod + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + ) -> None: + raise NotImplementedError + + @abstractmethod + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: T, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError + + +class MLAAttentionImpl(AttentionImpl[T], Generic[T]): + + @abstractmethod + def forward( + self, + layer: AttentionLayer, + hidden_states_or_cq: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: T, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError + + +def is_quantized_kv_cache(kv_cache_dtype: str) -> bool: + return kv_cache_dtype != "auto" diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py new file mode 100644 index 0000000..ea4f840 --- /dev/null +++ b/vllm/attention/backends/blocksparse_attn.py @@ -0,0 +1,457 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import (CommonAttentionState, + CommonMetadataBuilder) +from vllm.attention.ops.blocksparse_attention.interface import ( + LocalStridedBlockSparseAttn, get_head_sliding_step) +from vllm.attention.ops.paged_attn import PagedAttention +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) + + +@dataclass +class BlocksparseParams: + max_seqlen: int + + # Num q heads per tensor-parallel rank/partition + num_heads: int # per TP partition + # Num kv heads per tensor-parallel rank/partition + num_kv_heads: int + + # block size used for blocksparse attention. + # This is the block_size used in `local_blocks`, `vert_stride`. + block_size: int + + # Number of blocks for local attention, i.e., number of + # local attended tokens / `sparse_block_size` + local_blocks: int + + # Attend to one block per every `vert_stride` blocks. + # Controlling the sparsity + vert_stride: int + """ + If to use the same vertical stride offset for all heads, + i.e., attend to the same block of tokens on all heads. + By default, it is False, i.e., attention on the non-local + blocks depends on the `head_idx`, that is on + blocks satisfying + `(block_idx + head_idx * head_sliding_step + 1) % vert_stride == 0` + where `head_sliding_step=max(1, int(vert_stride / num_total_heads))`, + `block_idx = position_id // sparse_block_size`. + See `..ops.blocksparse_attention.utils:get_sparse_attn_mask` + for more detail. + """ + homo_head: bool = False + + # If within a group, the kv offsets that each q attends is the same or no. + homo_head_group: bool = False + + # Decided by homo_head and homo_head group + head_sliding_step: int = field(init=False) + + # range of q heads to for a TP rank + active_head_range: Tuple = field(init=False) + + def __post_init__(self): + assert self.block_size > 0 + assert self.local_blocks >= 0 + assert self.vert_stride >= 1 + assert self.num_heads % self.num_kv_heads == 0 + + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + total_heads = tp_size * self.num_heads + total_kv_heads = tp_size * self.num_kv_heads + + if self.homo_head: + self.head_sliding_step = 0 + elif self.homo_head_group: + head_sliding_step = get_head_sliding_step(total_kv_heads, + self.vert_stride) + # negative indicates sliding along kv heads, i.e., homo q group + self.head_sliding_step = -head_sliding_step + else: + self.head_sliding_step = get_head_sliding_step( + total_heads, self.vert_stride) + + self.active_head_range = ( + tp_rank * self.num_heads, + (tp_rank + 1) * self.num_heads, + ) + + +class BlocksparseFlashAttentionBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "BLOCK_SPARSE_FLASH_ATTN" + + @staticmethod + def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]: + return BlocksparseFlashAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return BlocksparseFlashAttentionMetadata + + @staticmethod + def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]: + return BlocksparseFlashAttentionMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class BlocksparseFlashAttentionMetadata(AttentionMetadata): + """A copy of Metadata for FlashAttentionBackend, + to avoid having to install flash_attn. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ----------------------| + # |-- query_len ---| + + # Maximum query length in the batch. None for decoding. + max_query_len: Optional[int] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + # Max number of query tokens for among request in the batch. + max_decode_query_len: Optional[int] = None + + _cached_prefill_metadata: Optional[ + "BlocksparseFlashAttentionMetadata"] = None + _cached_decode_metadata: Optional[ + "BlocksparseFlashAttentionMetadata"] = None + + @property + def prefill_metadata( + self) -> Optional["BlocksparseFlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + assert self.seq_start_loc is not None + + self._cached_prefill_metadata = BlocksparseFlashAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cached_decode_metadata = BlocksparseFlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + ) + return self._cached_decode_metadata + + +class BlocksparseFlashAttentionMetadataBuilder( + CommonMetadataBuilder[BlocksparseFlashAttentionMetadata]): + + _metadata_cls = BlocksparseFlashAttentionMetadata + + +class BlocksparseFlashAttentionImpl(AttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prompt_tokens -------------->| + |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| + + Otherwise, the layout is as follows: + |<------------------ num_generation_tokens (M) ----------------->| + |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + ) -> None: + assert blocksparse_params is not None + assert alibi_slopes is None, ValueError( + "Alibi not support for blocksparse flash attention.") + assert sliding_window is None, ValueError( + "sliding_window is invalid for blocksparse attention.") + assert logits_soft_cap is None, ValueError( + "logits_soft_cap is invalid for blocksparse attention.") + + if "num_heads" not in blocksparse_params: + blocksparse_params["num_heads"] = num_heads + if "num_kv_heads" not in blocksparse_params: + blocksparse_params["num_kv_heads"] = num_kv_heads or num_heads + self.blocksparse_params = BlocksparseParams(**blocksparse_params) + self.kv_cache_dtype = kv_cache_dtype + + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.alibi_slopes = alibi_slopes + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + self.local_blocks = self.blocksparse_params.local_blocks + self.vert_stride = self.blocksparse_params.vert_stride + self.sparse_block_size = self.blocksparse_params.block_size + self.head_sliding_step = self.blocksparse_params.head_sliding_step + + supported_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {supported_head_sizes}.") + + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + total_num_heads = num_heads * self.tp_size + self.bs_attn = LocalStridedBlockSparseAttn( + total_num_heads, + self.blocksparse_params.max_seqlen, + self.blocksparse_params.local_blocks, + self.blocksparse_params.vert_stride, + self.blocksparse_params.block_size, + homo_head=self.blocksparse_params.homo_head, + active_head_range=self.blocksparse_params.active_head_range, + ) + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "BlocksparseFlashAttentionImpl") + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: BlocksparseFlashAttentionMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with FlashAttention and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if kv_cache.numel() > 0: + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + if prefill_meta := attn_metadata.prefill_metadata: + + # Prompt run. + # normal attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + + assert kv_cache.numel() == 0 \ + or prefill_meta.block_tables is None \ + or prefill_meta.block_tables.numel() == 0, \ + "Does not support prefix-enabled attention." + + output = self.bs_attn( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + sm_scale=self.scale, + ) + + if decode_meta := attn_metadata.decode_metadata: + # Decoding run. + output = PagedAttention.forward_decode( + query, + key_cache, + value_cache, + decode_meta.block_tables, + decode_meta.seq_lens_tensor, + self.blocksparse_params.max_seqlen, + self.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + layer._k_scale, + layer._v_scale, + tp_rank=self.tp_rank, + blocksparse_local_blocks=self.local_blocks, + blocksparse_vert_stride=self.vert_stride, + blocksparse_block_size=self.sparse_block_size, + blocksparse_head_sliding_step=self.head_sliding_step, + ) + + assert output is not None + # Reshape the output tensor. + return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/backends/cpu_mla.py b/vllm/attention/backends/cpu_mla.py new file mode 100644 index 0000000..e2d1690 --- /dev/null +++ b/vllm/attention/backends/cpu_mla.py @@ -0,0 +1,303 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch + +import vllm._custom_ops as ops +from vllm._ipex_ops import ipex_ops +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadataBuilder, + AttentionType, + is_quantized_kv_cache) +from vllm.attention.backends.mla.common import MLACommonImpl, MLACommonState +from vllm.attention.backends.torch_sdpa import TorchSDPAMetadata +from vllm.utils import make_tensor_with_pad +from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder + + +class CPUMLABackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "CPU_MLA" + + @staticmethod + def get_metadata_cls() -> Type["CPUMLAMetadata"]: + return CPUMLAMetadata + + @staticmethod + def get_builder_cls() -> Type["CPUMLAMetadataBuilder"]: + return CPUMLAMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["MLACommonState"]: + return MLACommonState + + @staticmethod + def get_impl_cls() -> Type["CPUMLAImpl"]: + return CPUMLAImpl + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, # assumed to be 1 for MLA + head_size: int, + ) -> Tuple[int, ...]: + return (num_blocks, block_size, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + ops.copy_blocks_mla(kv_caches, src_to_dists) + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [576] + + +@dataclass +class CPUMLAMetadata(TorchSDPAMetadata): + # New for MLA + # Input positions for rotrary embeddings since for MLA the rotary + # position embeddings are applied inside the attention backend + input_positions: torch.Tensor = None + + # required by MLACommonImpl + is_profile_run: bool = False + + +class CPUMLAMetadataBuilder(AttentionMetadataBuilder[CPUMLAMetadata]): + + def __init__(self, input_builder: ModelInputForCPUBuilder) -> None: + self.chunked_prefill = input_builder.chunked_prefill + self.input_builder = input_builder + assert not self.chunked_prefill, \ + "chunked prefill is currently not supported" + + def prepare(self): + self.input_data = self.input_builder.input_data + + def build(self, seq_lens, query_lens, cuda_graph_pad_size, batch_size): + input_data = self.input_data + prefill_seq_lens = seq_lens[0:input_data.num_prefills] + prefill_query_lens = query_lens[0:input_data.num_prefills] + slot_mapping = torch.tensor(input_data.slot_mapping, + dtype=torch.long, + device="cpu") + + # metadata for prefill + if input_data.num_prefills > 0: + query_lens_tensor = torch.tensor(prefill_query_lens, + dtype=torch.int32, + device="cpu") + kv_lens_tensor = torch.tensor(prefill_seq_lens, + dtype=torch.int32, + device="cpu") + query_start_loc = torch.zeros(input_data.num_prefills + 1, + dtype=torch.int32, + device="cpu") + kv_start_loc = torch.zeros(input_data.num_prefills + 1, + dtype=torch.int32, + device="cpu") + torch.cumsum(query_lens_tensor, + dim=0, + dtype=torch.int32, + out=query_start_loc[1:]) + torch.cumsum(kv_lens_tensor, + dim=0, + dtype=torch.int32, + out=kv_start_loc[1:]) + max_query_len = max(prefill_query_lens) + max_kv_len = max(prefill_seq_lens) + + # for chunked-prefill + if self.chunked_prefill: + prefill_block_tables = make_tensor_with_pad( + self.input_data.prefill_block_tables, + pad=0, + dtype=torch.int32, + device="cpu", + ) + else: + prefill_block_tables = None + + else: + query_start_loc = None + kv_start_loc = None + max_query_len = None + max_kv_len = None + prefill_block_tables = None + + # metadata for decode + if input_data.num_decode_tokens != 0: + seq_lens_tensor = torch.tensor( + input_data.seq_lens[input_data.num_prefills:], + dtype=torch.int32, + device="cpu", + ) + block_tables = make_tensor_with_pad( + self.input_data.decode_block_tables, + pad=0, + dtype=torch.int32, + device="cpu", + ) + else: + block_tables = torch.tensor([]) + seq_lens_tensor = torch.tensor( + input_data.seq_lens[:input_data.num_prefills], + dtype=torch.int32, + device="cpu", + ) + + # For multi-modal models + placeholder_index_maps = None + if len(input_data.multi_modal_inputs_list) != 0: + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + input_data.multi_modal_placeholder_maps.items() + } + + return CPUMLAMetadata( + chunked_prefill=self.chunked_prefill, + seq_lens=prefill_seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_kv_len=max_kv_len, + query_start_loc=query_start_loc, + kv_start_loc=kv_start_loc, + max_decode_seq_len=input_data.max_decode_seq_len, + num_prefills=input_data.num_prefills, + num_prefill_tokens=input_data.num_prefill_tokens, + num_decode_tokens=input_data.num_decode_tokens, + block_tables=block_tables, + prefill_block_tables=prefill_block_tables, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, + input_positions=torch.tensor([self.input_data.input_positions])) + + +class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + **mla_args) + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "CPUMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "CPUMLAImpl") + + # states is implemented. + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "CPUMLAImpl with FP8 KV cache not yet supported") + + def _forward_prefill( + self, + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: CPUMLAMetadata, # type: ignore[override] + ) -> torch.Tensor: + + prefill_metadata = attn_metadata.prefill_metadata + assert prefill_metadata is not None + + kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim + v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], + value=0) + + output = torch.empty_like(q) + ipex_ops.varlen_attention( + query=q, + key=k, + value=v_padded, + out=output, + seqlen_q=prefill_metadata.query_start_loc, + seqlen_k=prefill_metadata.query_start_loc, + max_seqlen_q=prefill_metadata.max_query_len, + max_seqlen_k=prefill_metadata.max_query_len, + pdropout=0.0, + softmax_scale=self.scale, + zero_tensors=False, + is_causal=True, + return_softmax=False, + gen_=None, + logits_soft_cap=0.0, + ) + + # remove padding + output = output.view(-1, self.num_heads, + q.shape[-1])[..., :v.shape[-1]] + output = output.reshape(-1, self.num_heads * v.shape[-1]) + return self.o_proj(output)[0] + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: CPUMLAMetadata, # type: ignore[override] + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + + decode_meta = attn_metadata.decode_metadata + assert decode_meta is not None + + q = torch.cat([q_nope, q_pe], dim=-1) + o = q.new_empty(q.shape[0], self.num_heads, self.kv_lora_rank) + + # Run MQA + ops.mla_decode_kvcache_cpu(o, q, kv_c_and_k_pe_cache, self.scale, + decode_meta.block_tables, + decode_meta.seq_lens_tensor) + return self._v_up_proj_and_o_proj(o) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py new file mode 100644 index 0000000..eef0a62 --- /dev/null +++ b/vllm/attention/backends/flash_attn.py @@ -0,0 +1,995 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Attention layer with FlashAttention.""" +from collections import defaultdict +from dataclasses import dataclass +from itertools import accumulate +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type + +import torch + +from vllm import _custom_ops as ops +# yapf conflicts with isort for this block +# yapf: disable +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionType, + is_quantized_kv_cache) +# yapf: enable +from vllm.attention.backends.utils import ( + PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, + compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, + get_seq_len_block_table_args, is_all_cross_attn_metadata_set, + is_all_encoder_attn_metadata_set, is_block_tables_empty) +from vllm.logger import init_logger +from vllm.multimodal import MultiModalPlaceholderMap +from vllm.utils import async_tensor_h2d, make_tensor_with_pad +# from vllm.vllm_flash_attn import (flash_attn_varlen_func, +# flash_attn_with_kvcache) +from ixformer.contrib.vllm_flash_attn import (flash_attn_varlen_func, + flash_attn_with_kvcache) +from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8, + get_flash_attn_version) + +if TYPE_CHECKING: + from vllm.worker.model_runner import (ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) + +logger = init_logger(__name__) + + + +class FlashAttentionBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [32, 64, 72, 80, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_name() -> str: + return "FLASH_ATTN" + + @staticmethod + def get_impl_cls() -> Type["FlashAttentionImpl"]: + return FlashAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return FlashAttentionMetadata + + @staticmethod + def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]: + return FlashAttentionMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, num_kv_heads, block_size, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + src_key_cache = src_kv_cache[0] + dst_key_cache = dst_kv_cache[0] + ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + src_value_cache = src_kv_cache[1] + dst_value_cache = dst_kv_cache[1] + ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + + ops.copy_blocks(key_caches, value_caches, src_to_dists) + + +@dataclass +class FlashAttentionMetadata(AttentionMetadata): + """Metadata for FlashAttentionBackend. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + + use_cuda_graph: bool + + # Maximum query length in the batch. + max_query_len: Optional[int] = None + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] = None + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + + _cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None + _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None + + # Begin encoder attn & enc/dec cross-attn fields... + + # Encoder sequence lengths representation + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + encoder_seq_start_loc: Optional[torch.Tensor] = None + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + + @property + def is_all_encoder_attn_metadata_set(self): + ''' + All attention metadata required for encoder attention is set. + ''' + return is_all_encoder_attn_metadata_set(self) + + @property + def is_all_cross_attn_metadata_set(self): + ''' + All attention metadata required for enc/dec cross-attention is set. + + Superset of encoder attention required metadata. + ''' + return is_all_cross_attn_metadata_set(self) + + @property + def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert ((self.seq_lens is not None) + or (self.encoder_seq_lens is not None)) + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) + + # Compute some attn_metadata fields which default to None + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[:self.num_prefill_tokens]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + seq_start_loc = (None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) + block_tables = (None if self.block_tables is None else + self.block_tables[:self.num_prefills]) + + self._cached_prefill_metadata = FlashAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_query_len=0, + max_decode_seq_len=0, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + encoder_seq_start_loc=self.encoder_seq_start_loc, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) + + # Compute some attn_metadata fields which default to None + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[self.num_prefill_tokens:]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) + block_tables = (None if self.block_tables is None else + self.block_tables[self.num_prefills:]) + # if self.use_cuda_graph: + # self.max_decode_seq_len = self.block_tables.shape[-1] * 16 + + self._cached_decode_metadata = FlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + seq_lens=None, + seq_lens_tensor=seq_lens_tensor, + max_decode_query_len=self.max_decode_query_len, + max_query_len=self.max_query_len, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + # Batch may be composed of prefill|decodes, adjust query start + # indices to refer to the start of decodes. E.g. + # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. + query_start_loc=(self.query_start_loc[self.num_prefills:] - + self.query_start_loc[self.num_prefills]) + if self.query_start_loc is not None else None, + seq_start_loc=self.seq_start_loc[self.num_prefills:] + if self.seq_start_loc is not None else None, + context_lens_tensor=None, + block_tables=block_tables, + use_cuda_graph=self.use_cuda_graph, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + encoder_seq_start_loc=self.encoder_seq_start_loc, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + return self._cached_decode_metadata + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + assert self.use_cuda_graph + + if turn_prefills_into_decodes: + # When Mutli-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes. This update reflects that + # conversion. + assert self.num_decode_tokens + self.num_prefills == num_seqs + self.num_decode_tokens += self.num_prefills + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.max_prefill_seq_len = 0 + self.max_query_len = 1 + + self.slot_mapping = self.slot_mapping[:num_seqs] + else: + assert self.seq_lens is not None + assert self.max_decode_seq_len == max(self.seq_lens) + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.num_decode_tokens == num_seqs + assert self.slot_mapping.shape == (num_seqs, ) + + assert self.seq_lens is not None + assert len(self.seq_lens) == num_seqs + assert self.seq_lens_tensor is not None + assert self.seq_lens_tensor.shape == (num_seqs, ) + assert self.max_query_len == 1 + assert self.max_prefill_seq_len == 0 + + assert self.query_start_loc is not None + assert self.query_start_loc.shape == (num_queries + 1, ) + assert self.seq_start_loc is not None + assert self.seq_start_loc.shape == (num_seqs + 1, ) + + assert self.context_lens_tensor is not None + assert self.context_lens_tensor.shape == (num_queries, ) + + assert self.block_tables is not None + assert self.block_tables.shape[0] == num_seqs + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + self.seq_lens[i] += 1 + self.max_decode_seq_len = max(self.seq_lens) + + ops.advance_step_flashattn(num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables) + + +class FlashAttentionMetadataBuilder( + AttentionMetadataBuilder[FlashAttentionMetadata]): + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.input_builder = input_builder + self.runner = input_builder.runner + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + + def prepare(self): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + self.has_prefix_cache_hit = False + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool, prefix_cache_hit: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): + self.context_lens.append(context_len) + + if is_prompt: + mm_maps = inter_data.multi_modal_placeholder_maps + if mm_maps: + for modality, placeholders in mm_maps.items(): + self.multimodal_placeholder_maps[modality].extend( + placeholders) + + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if prefix_cache_hit: + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + block_table = block_tables[seq_id] + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + if curr_sliding_window_block == 0: + block_table = block_tables[seq_id] + else: + block_table = block_tables[seq_id][ + -curr_sliding_window_block:] + self.block_tables.append(block_table) + + # Compute slot mapping. + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + + def _get_graph_runner_block_tables( + self, num_seqs: int, + block_tables: List[List[int]]) -> torch.Tensor: + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + max_batch_size, max_blocks = self.runner.graph_block_tables.shape + assert max_batch_size >= num_seqs + + graph_block_tables = self.runner.graph_block_tables[:num_seqs] + for i, block_table in enumerate(block_tables): + if block_table: + num_blocks = len(block_table) + if num_blocks <= max_blocks: + graph_block_tables[i, :num_blocks] = block_table + else: + # It may be possible to have more blocks allocated due + # to lookahead slots of multi-step, however, they are + # not used anyway, so can be safely ignored. + graph_block_tables[ + i, :max_blocks] = block_table[:max_blocks] + + return torch.from_numpy(graph_block_tables).to( + device=self.runner.device, non_blocking=True) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + prefix_cache_hit = any([ + inter_data.prefix_cache_hit + for inter_data in self.input_builder.inter_data_list + ]) + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled, + prefix_cache_hit) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_query_len = max(query_lens) + decode_query_lens = query_lens[self.num_prefills:] + if len(decode_query_lens) > 0: + max_decode_query_len = max(decode_query_lens) + else: + max_decode_query_len = 1 + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + query_start_loc = list(accumulate(query_lens, initial=0)) + seq_start_loc = list(accumulate(seq_lens, initial=0)) + + num_seqs = len(seq_lens) + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size - self.num_prefill_tokens + block_tables = self._get_graph_runner_block_tables( + num_seqs, self.block_tables) + else: + block_tables = make_tensor_with_pad( + self.block_tables, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + assert device is not None + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) + query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, + device, + self.runner.pin_memory) + seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, + device, self.runner.pin_memory) + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + self.multimodal_placeholder_maps.items() + } + + return FlashAttentionMetadata( + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=True, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_decode_query_len=max_decode_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc_tensor, + seq_start_loc=seq_start_loc_tensor, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + ) + + +class FlashAttentionImpl(AttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| + + Otherwise, the layout is as follows: + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| + |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "FlashAttention does not support block-sparse attention.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window - 1, + 0) if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype + # self.vllm_flash_attn_version = get_flash_attn_version( + # requires_alibi=self.alibi_slopes is not None) + self.vllm_flash_attn_version = 2 + if is_quantized_kv_cache(self.kv_cache_dtype) and ( + not self.kv_cache_dtype.startswith("fp8") + or not flash_attn_supports_fp8()): + raise NotImplementedError( + f"FlashAttention does not support {self.kv_cache_dtype} " + "kv-cache on this device " + f"(FA supports fp8 = {flash_attn_supports_fp8()}).") + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + support_head_sizes = FlashAttentionBackend.get_supported_head_sizes() + if head_size not in support_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by FlashAttention. " + f"Supported head sizes are: {support_head_sizes}.") + self.attn_type = attn_type + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + kv_cache_scale: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + output: Optional[torch.Tensor] = None, + sqrt_alibi: bool = False, + ) -> torch.Tensor: + """Forward pass with FlashAttention. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + output: shape = [num_tokens, num_heads, head_size] + kv_cache = [2, num_blocks, num_kv_heads, block_size, head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + NOTE: It in-place updates the output tensor. + NOTE: FP8 quantization, flash-attn expect the size of + {q,k,v}_descale to be (num_sequences, num_kv_heads). + We use torch's .expand() to avoid duplicating values + """ + assert output is not None, "Output tensor must be provided." + + # NOTE(woosuk): FlashAttention2 does not support FP8 KV cache. + if self.vllm_flash_attn_version < 3 or output.dtype != torch.bfloat16: + assert ( + layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0), ( + "key/v_scale is only supported in FlashAttention 3 with " + "base dtype bfloat16") + + attn_type = self.attn_type + if (attn_type == AttentionType.ENCODER + and (not attn_metadata.is_all_encoder_attn_metadata_set)): + raise AttributeError("Encoder attention requires setting " + "encoder metadata attributes.") + elif (attn_type == AttentionType.ENCODER_DECODER + and (not attn_metadata.is_all_cross_attn_metadata_set)): + raise AttributeError("Encoder/decoder cross-attention " + "requires setting cross-attention " + "metadata attributes.") + + kv_cache_dtype: str = self.kv_cache_dtype + softmax_scale: float = self.scale + window_size = self.sliding_window + alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes + logits_soft_cap: Optional[float] = self.logits_soft_cap + fp8_attention = kv_cache_dtype.startswith("fp8") + + if fp8_attention and not flash_attn_supports_fp8(): + raise NotImplementedError( + "FlashAttention does not support FP8 kv-cache on this device.") + + if kv_cache.numel() > 0: + key_cache = kv_cache[0] + value_cache = kv_cache[1] + # We skip updating the KV cache under two conditions: + # a. When the Attention Type is ENCODER. In this phase, we compute + # only the encoder attention without updating the cache. + # b. When both Key and Value are None. This occurs during + # cross-attention computation in the decoding phase, where the + # KV cache is already populated with the cross-attention + # tensor. Thus, we skip cache updates during this time. + if (attn_type != AttentionType.ENCODER) and (key is not None) and ( + value is not None): + if attn_type == AttentionType.ENCODER_DECODER: + # Update cross-attention KV cache (prefill-only) + updated_slot_mapping = attn_metadata.cross_slot_mapping + else: + # Update self-attention KV cache (prefill/decode) + updated_slot_mapping = attn_metadata.slot_mapping + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory + # profiling run. + ops.reshape_and_cache_flash( + key, + value, + kv_cache[0], + kv_cache[1], + updated_slot_mapping.flatten(), # type: ignore[union-attr] + kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + if fp8_attention: + kv_cache = kv_cache.view(torch.float8_e4m3fn) + key_cache = key_cache.view(torch.float8_e4m3fn) + value_cache = value_cache.view(torch.float8_e4m3fn) + + if fp8_attention: + num_tokens, num_heads, head_size = query.shape + query, _ = ops.scaled_fp8_quant( + query.reshape( + (num_tokens, num_heads * head_size)).contiguous(), + layer._q_scale) + query = query.reshape((num_tokens, num_heads, head_size)) + + (num_prefill_query_tokens, num_prefill_kv_tokens, + num_decode_query_tokens) = \ + get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) + decode_query = query[num_prefill_query_tokens:] + decode_output = output[num_prefill_query_tokens:] + # QKV for prefill. + query = query[:num_prefill_query_tokens] + prefill_output = output[:num_prefill_query_tokens] + assert query.shape[0] == num_prefill_query_tokens + assert decode_query.shape[0] == num_decode_query_tokens + + if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. + if (kv_cache.numel() == 0 or prefill_meta.block_tables is None + or prefill_meta.block_tables.numel() == 0): + # normal attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \ + _get_query_key_seq_metadata(prefill_meta, True, attn_type) + + key = key[:num_prefill_kv_tokens] + value = value[:num_prefill_kv_tokens] + + if fp8_attention: + num_kv_tokens, num_kv_heads, head_size = key.shape + + key, _ = ops.scaled_fp8_quant( + key.reshape((num_kv_tokens, + num_kv_heads * head_size)).contiguous(), + layer._k_scale) + key = key.reshape((num_kv_tokens, num_kv_heads, head_size)) + + value, _ = ops.scaled_fp8_quant( + value.reshape((num_kv_tokens, + num_kv_heads * head_size)).contiguous(), + layer._v_scale) + value = value.reshape( + (num_kv_tokens, num_kv_heads, head_size)) + + descale_shape = (q_seq_start_loc.shape[0] - 1, key.shape[1]) + flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=q_seq_start_loc, + cu_seqlens_k=k_seq_start_loc, + max_seqlen_q=q_seq_len, + max_seqlen_k=k_seq_len, + softmax_scale=softmax_scale, + causal=_get_causal_option(attn_type), + window_size=window_size, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + sqrt_alibi=sqrt_alibi, + out=prefill_output, + # fa_version=self.fa_version, + ) + else: + # prefix-enabled attention + assert attn_type == AttentionType.DECODER, ( + "Only decoder-only models support prefix caching") + assert prefill_meta.seq_lens is not None + assert prefill_meta.query_start_loc is not None + max_seq_len = max(prefill_meta.seq_lens) + descale_shape = (prefill_meta.query_start_loc.shape[0] - 1, + key.shape[1]) + flash_attn_varlen_func( # noqa + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=prefill_meta.query_start_loc, + max_seqlen_q=prefill_meta.max_query_len, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_k=max_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + block_table=prefill_meta.block_tables, + softcap=logits_soft_cap, + sqrt_alibi=sqrt_alibi, + out=prefill_output, + ) + + if decode_meta := attn_metadata.decode_metadata: + # Decoding run. + # Use flash_attn_varlen_func kernel for speculative decoding + # because different queries might have different lengths. + + assert decode_meta.max_decode_query_len is not None + # use only for actual varlen decoding + if decode_meta.max_decode_query_len > 1: + assert attn_type == AttentionType.DECODER, ( + "Only decoder-only models support max_decode_query_len > 1" + ) + assert decode_meta.query_start_loc is not None + descale_shape = (decode_meta.query_start_loc.shape[0] - 1, + key.shape[1]) + flash_attn_varlen_func( + q=decode_query, + k=key_cache, + v=value_cache, + cu_seqlens_q=decode_meta.query_start_loc, + max_seqlen_q=decode_meta.max_decode_query_len, + cu_seqlens_k=decode_meta.seq_start_loc, + max_seqlen_k=decode_meta.max_decode_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + block_table=decode_meta.block_tables, + sqrt_alibi=sqrt_alibi, + out=decode_output, + # fa_version=self.fa_version, + ) + else: + # Use flash_attn_with_kvcache for normal decoding. + ( + seq_lens_arg, + _, + block_tables_arg, + ) = get_seq_len_block_table_args(decode_meta, False, attn_type) + descale_shape = (seq_lens_arg.shape[0], key_cache.shape[-2]) + flash_attn_with_kvcache( + q=decode_query.unsqueeze(1), + k_cache=key_cache, + v_cache=value_cache, + block_table=block_tables_arg, + cache_seqlens=seq_lens_arg, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + use_sqrt_alibi=sqrt_alibi, + out=decode_output.unsqueeze(1), + # fa_version=self.fa_version, + use_cuda_graph=decode_meta.use_cuda_graph, + max_context_len=decode_meta.block_tables.shape[-1] * 16 if attn_type == AttentionType.DECODER else decode_meta.max_encoder_seq_len, + ) + return output.view(-1, self.num_heads * self.head_size) + + +def _get_query_key_seq_metadata( + attn_metadata, + is_prompt: bool, + attn_type: str, +) -> tuple: + """ + Returns sequence metadata for key and query based on the specified + attention type and whether input is a prompt. + + This function computes the starting locations and maximum sequence lengths + for key and query sequences for different attention types. + + Args: + attn_metadata: The attention metadata object + is_prompt (bool): A flag indicating if the input is a prompt + attn_type (AttentionType): The type of attention being used. + + Returns: + tuple: A tuple containing four integers: + - Starting location for the query sequence. + - Maximum sequence length for the query sequence. + - Starting location for the key sequence. + - Maximum sequence length for the key sequence. + + Raises: + AttributeError: If an invalid attention type is provided. + """ + if attn_type == AttentionType.DECODER: + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + if is_prompt: + max_seq_len = attn_metadata.max_prefill_seq_len + else: + max_seq_len = attn_metadata.max_decode_seq_len + return (attn_metadata.seq_start_loc, max_seq_len, + attn_metadata.seq_start_loc, max_seq_len) + + elif attn_type == AttentionType.ENCODER_DECODER: + # This is cross attention between the where the key + # is the precomputed encoder attention and query + # is the input sequence. + # Choose query max length based on whether it is prompt + # or not. + if is_prompt: + max_seq_len = attn_metadata.max_prefill_seq_len + else: + max_seq_len = attn_metadata.max_decode_seq_len + return (attn_metadata.seq_start_loc, max_seq_len, + attn_metadata.encoder_seq_start_loc, + attn_metadata.max_encoder_seq_len) + elif attn_type == AttentionType.ENCODER: + # For encoder attention both the query and the key are same i.e the + # encoder sequence. + return (attn_metadata.encoder_seq_start_loc, + attn_metadata.max_encoder_seq_len, + attn_metadata.encoder_seq_start_loc, + attn_metadata.max_encoder_seq_len) + elif attn_type == AttentionType.ENCODER_ONLY: + assert is_prompt, "Should not have decode for encoder only model." + return (attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len, + attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +def _get_causal_option(attn_type: str) -> bool: + """ + Determine whether the given attention type is suitable for causal + attention mechanisms. + + Args: + attn_type (AttentionType): The type of attention being evaluated + + Returns: + bool: Returns `True` if the attention type is suitable for causal + attention (i.e., not encoder, encoder-only, or encoder-decoder), + otherwise returns `False`. + """ + return not (attn_type == AttentionType.ENCODER + or attn_type == AttentionType.ENCODER_ONLY + or attn_type == AttentionType.ENCODER_DECODER) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py new file mode 100644 index 0000000..0556c19 --- /dev/null +++ b/vllm/attention/backends/flashinfer.py @@ -0,0 +1,1066 @@ +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type + +from vllm.multimodal import MultiModalPlaceholderMap + +try: + from flashinfer import BatchDecodeWithPagedKVCacheWrapper + from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper + from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper + + from vllm.vllm_flash_attn import flash_attn_varlen_func + FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 +except ImportError: + # Avoid turning these types into variables during type checking + if not TYPE_CHECKING: + BatchDecodeWithPagedKVCacheWrapper = None + CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None + BatchPrefillWithPagedKVCacheWrapper = None + FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 + +import torch + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionState, AttentionType) +from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, + compute_slot_mapping_start_idx, + is_block_tables_empty) +from vllm.attention.layer import Attention +from vllm.attention.ops.paged_attn import PagedAttention +from vllm.config import VllmConfig, get_current_vllm_config +from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, + make_tensor_with_pad) + +if TYPE_CHECKING: + from vllm.worker.model_runner import (ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) + + +class FlashInferBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "FLASHINFER" + + @staticmethod + def get_impl_cls() -> Type["FlashInferImpl"]: + return FlashInferImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return FlashInferMetadata + + @staticmethod + def get_builder_cls() -> Type["FlashInferMetadataBuilder"]: + return FlashInferMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["FlashInferState"]: + return FlashInferState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (num_blocks, 2, block_size, num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [64, 128, 256] + + @staticmethod + def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype: + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + return torch.float8_e4m3fn + elif kv_cache_dtype == "fp8_e5m2": + return torch.float8_e5m2 + else: + raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") + + +@dataclass +class PerLayerParameters: + """ + Currently, FlashInfer backend only support models in which all layers share + the same values for the following hyperparameters. + """ + + window_left: int + logits_soft_cap: Optional[float] + sm_scale: float + + +def get_per_layer_parameters( + vllm_config: VllmConfig) -> Dict[str, PerLayerParameters]: + """ + Scan all attention layers and determine some hyperparameters + to use during `plan`. + """ + + layers = vllm_config.compilation_config.static_forward_context + per_layer_params: Dict[str, PerLayerParameters] = {} + + for key, layer in layers.items(): + assert isinstance(layer, Attention) + + impl = layer.impl + assert isinstance(impl, FlashInferImpl) + + # Infer hyperparameters from the attention layer + window_size = impl.sliding_window + window_left = window_size[0] if window_size is not None else -1 + logits_soft_cap = impl.logits_soft_cap + sm_scale = impl.scale + + per_layer_params[key] = PerLayerParameters(window_left, + logits_soft_cap, sm_scale) + + return per_layer_params + + +def infer_global_hyperparameters( + per_layer_params: Dict[str, PerLayerParameters]) -> PerLayerParameters: + """ + Currently, FlashInfer backend only support models in which all layers share + the same values for the following hyperparameters: + - `window_left` + - `logits_soft_cap` + - `sm_scale` + + So this function asserts that all layers share the same values for these + hyperparameters and returns the global values. + """ + + assert len(per_layer_params) > 0, "No attention layers found in the model." + + param_sets = list(per_layer_params.values()) + global_params = param_sets[0] + for params in param_sets: + assert params == global_params, ( + "FlashInfer backend currently only supports models in which all " + "layers share the same values for the following hyperparameters: " + "`window_left`, `logits_soft_cap`, `sm_scale`.") + + return global_params + + +class FlashInferState(AttentionState): + + def __init__(self, runner): + self.runner = runner + self._is_graph_capturing = False + self._workspace_buffer = None + self._decode_wrapper = None + self._prefill_wrapper = None + + # Global hyperparameters shared by all attention layers + self.global_hyperparameters: Optional[PerLayerParameters] = None + + self.vllm_config = get_current_vllm_config() + + def _get_workspace_buffer(self): + if self._workspace_buffer is None: + self._workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.runner.device) + return self._workspace_buffer + + def _get_prefill_wrapper(self): + if self._prefill_wrapper is None: + self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( + self._get_workspace_buffer(), "NHD") + return self._prefill_wrapper + + def _get_decode_wrapper(self): + if self._decode_wrapper is None: + num_qo_heads = (self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config)) + num_kv_heads = self.runner.model_config.get_num_kv_heads( + self.runner.parallel_config) + use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( + num_qo_heads // num_kv_heads > 4) + self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + self._get_workspace_buffer(), + "NHD", + use_tensor_cores=use_tensor_cores) + return self._decode_wrapper + + @contextmanager + def graph_capture(self, max_batch_size: int): + self._is_graph_capturing = True + self._graph_decode_wrapper = None + self._graph_slot_mapping = torch.full((max_batch_size, ), + PAD_SLOT_ID, + dtype=torch.long, + device=self.runner.device) + self._graph_seq_lens = torch.ones(max_batch_size, + dtype=torch.int32, + device=self.runner.device) + self._graph_block_tables = torch.from_numpy( + self.runner.graph_block_tables).to(device=self.runner.device) + self._graph_decode_workspace_buffer = self._get_workspace_buffer() + self._graph_indices_buffer = torch.empty( + max_batch_size * self.runner.cache_config.num_gpu_blocks, + dtype=torch.int32, + device=self.runner.device) + self._graph_indptr_buffer = torch.empty(max_batch_size + 1, + dtype=torch.int32, + device=self.runner.device) + self._graph_last_page_len_buffer = torch.empty( + max_batch_size, dtype=torch.int32, device=self.runner.device) + yield + self._is_graph_capturing = False + del self._graph_slot_mapping + del self._graph_seq_lens + del self._graph_block_tables + del self._graph_decode_workspace_buffer + del self._graph_indices_buffer + del self._graph_indptr_buffer + del self._graph_last_page_len_buffer + del self._graph_decode_wrapper + + def graph_clone(self, batch_size: int): + assert self._is_graph_capturing + state = self.__class__(self.runner) + state._workspace_buffer = self._graph_decode_workspace_buffer + state._decode_wrapper = self._graph_decode_wrapper + state._prefill_wrapper = self._get_prefill_wrapper() + return state + + def graph_capture_get_metadata_for_batch( + self, batch_size: int, is_encoder_decoder_model: bool = False): + assert self._is_graph_capturing + _indptr_buffer = self._graph_indptr_buffer[:batch_size + 1] + _last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size] + + num_qo_heads = (self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config)) + num_kv_heads = self.runner.model_config.get_num_kv_heads( + self.runner.parallel_config) + use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( + num_qo_heads // num_kv_heads > 4) + self._graph_decode_wrapper = \ + CUDAGraphBatchDecodeWithPagedKVCacheWrapper( + self._graph_decode_workspace_buffer, _indptr_buffer, + self._graph_indices_buffer, _last_page_len_buffer, "NHD", + use_tensor_cores) + if self.runner.kv_cache_dtype.startswith("fp8"): + kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + self.runner.kv_cache_dtype) + else: + kv_cache_dtype = get_kv_cache_torch_dtype( + self.runner.kv_cache_dtype, self.runner.model_config.dtype) + + paged_kv_indptr_tensor_host = torch.arange(0, + batch_size + 1, + dtype=torch.int32) + paged_kv_indices_tensor_host = torch.arange(0, + batch_size, + dtype=torch.int32) + paged_kv_last_page_len_tensor_host = torch.full((batch_size, ), + self.runner.block_size, + dtype=torch.int32) + query_start_loc_host = torch.arange(0, + batch_size + 1, + dtype=torch.int32) + + global_params = infer_global_hyperparameters( + get_per_layer_parameters(self.vllm_config)) + + attn_metadata = self.runner.attn_backend.make_metadata( + num_prefills=0, + slot_mapping=self._graph_slot_mapping[:batch_size], + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + max_prefill_seq_len=0, + block_tables=self._graph_block_tables, + paged_kv_indptr=paged_kv_indptr_tensor_host, + paged_kv_indices=paged_kv_indices_tensor_host, + paged_kv_last_page_len=paged_kv_last_page_len_tensor_host, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim=self.runner.model_config.get_head_size(), + page_size=self.runner.block_size, + seq_start_loc=None, + query_start_loc=query_start_loc_host, + device=self.runner.device, + data_type=kv_cache_dtype, + q_data_type=self.runner.model_config.dtype, + use_cuda_graph=True, + decode_wrapper=self._graph_decode_wrapper, + prefill_wrapper=None, + **dataclasses.asdict(global_params), + ) + attn_metadata.begin_forward() + return attn_metadata + + def get_graph_input_buffers(self, + attn_metadata, + is_encoder_decoder_model: bool = False): + return { + "slot_mapping": attn_metadata.slot_mapping, + } + + def prepare_graph_input_buffers(self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False): + return + + def begin_forward(self, model_input): + assert not self._is_graph_capturing + state = self + use_cuda_graph = model_input.attn_metadata.use_cuda_graph + is_decode = model_input.attn_metadata.num_prefills == 0 + # In case of multistep chunked-prefill, there might be prefill requests + # scheduled while CUDA graph mode is enabled. We don't run graph in that + # case. + if use_cuda_graph and is_decode: + batch_size = model_input.input_tokens.shape[0] + state = (self.runner.graph_runners[model_input.virtual_engine] + [batch_size].attn_state) + model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper( + ) + model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper() + model_input.attn_metadata.begin_forward() + + +@dataclass +class FlashInferMetadata(AttentionMetadata): + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Number of query tokens for each request in the batch. + # Currently, we require that all requests have the same number of query + # tokens during the decoding phase. When speculavie decoding is enabled, + # decode_query_len might be greater than 1. In all other cases, it is 1. + decode_query_len: Optional[int] = 1 + + use_cuda_graph: bool = True + + prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None + decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None + + # Metadata for the prefill stage + seq_start_loc: Optional[torch.Tensor] = None + query_start_loc: Optional[torch.Tensor] = None + block_tables: Optional[torch.Tensor] = None + + # used for GPU in-place advance_step + seq_lens_tensor: Optional[torch.Tensor] = None + block_table_bound: Optional[torch.Tensor] = None + + # An example for paged_kv_indices, paged_kv_indptr: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + # The indptr of the paged kv cache, shape: [batch_size + 1] + paged_kv_indptr: Optional[torch.Tensor] = None + # The page indices of the paged kv cache + paged_kv_indices: Optional[torch.Tensor] = None + # The number of entries in the last page of each request in + # the paged kv cache, shape: [batch_size] + paged_kv_last_page_len: Optional[torch.Tensor] = None + # The number of query/output heads + num_qo_heads: Optional[int] = None + # The number of key/value heads + num_kv_heads: Optional[int] = None + # The dimension of the attention heads + head_dim: Optional[int] = None + # Block size of vllm + page_size: Optional[int] = None + # The data type of the paged kv cache + data_type: torch.dtype = None + # The data type of the query + q_data_type: torch.dtype = None + # FlashInfer 0.2 encourages passing host tensors + device: torch.device = torch.device("cpu") + is_profile_run: bool = False + + # The FlashInfer backend currently supports only models in which all layers + # share the same following hyperparameters: + + # The left (inclusive) window size for the attention window, when + # set to `-1`, the window size will be set to the full length of + # the sequence. Defaults to `-1`. + window_left: int = -1 + # The attention logits soft capping value (used in Gemini, Grok and + # Gemma-2, etc.), if not provided, will be set to `0`. If greater + # than 0, the logits will be capped according to formula: + # $$\texttt{logits\_soft\_cap} \times + # \mathrm{tanh}(x / \texttt{logits\_soft\_cap})$$, + # where $x$ is the input logits. + logits_soft_cap: Optional[float] = None + # The scale used in softmax, if not provided, will be set to + # `1.0 / sqrt(head_dim)`. + sm_scale: Optional[float] = None + + def __post_init__(self): + # Refer to + # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 + supported_head_sizes = FlashInferBackend.get_supported_head_sizes() + if self.head_dim is not None and self.head_dim \ + not in supported_head_sizes: + raise ValueError( + f"Only {supported_head_sizes} are supported for head_dim,", + f" received {self.head_dim}.") + + def begin_forward(self): + if self.num_prefill_tokens > 0: + if self.paged_kv_indices is None: + return + + assert self.prefill_wrapper is not None + assert self.query_start_loc is not None + assert self.paged_kv_indices is not None + assert self.paged_kv_indptr is not None + assert self.paged_kv_last_page_len is not None + assert self.block_table_bound is not None + assert self.seq_lens_tensor is not None + self.query_start_loc = self.query_start_loc[:self.num_prefills + 1] + batch_size = self.query_start_loc.shape[0] - 1 + assert batch_size >= 0 + # We will use flash attention for profiling to + # determine the number of blocks. Therefore, + # we don't need to prepare the input for flashinfer for profile run. + if not self.is_profile_run: + self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) + self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( + self.device) + self.block_table_bound = self.block_table_bound.to(self.device) + self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) + self.paged_kv_indices = self.paged_kv_indices.to(self.device) + self.prefill_wrapper.plan( + self.query_start_loc, + self.paged_kv_indptr[:self.num_prefills + 1], + self.paged_kv_indices, + self.paged_kv_last_page_len[:self.num_prefills], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + causal=True, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_data_type=self.data_type) + if self.num_decode_tokens > 0: + assert self.paged_kv_indices is not None + assert self.paged_kv_indptr is not None + assert self.paged_kv_last_page_len is not None + self.paged_kv_indices = self.paged_kv_indices.to(self.device) + self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) + self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( + self.device) + # handle model warmup path + if self.block_table_bound is not None: + self.block_table_bound = self.block_table_bound.to(self.device) + if self.seq_lens_tensor is not None: + self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) + + assert self.decode_wrapper is not None + self.decode_wrapper.plan( + self.paged_kv_indptr[self.num_prefills:], + self.paged_kv_indices, + self.paged_kv_last_page_len[self.num_prefills:], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + # Disable flashinfer's pos encoding and use vllm's rope. + pos_encoding_mode="NONE", + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + sm_scale=self.sm_scale, + # kv-cache data type. + kv_data_type=self.data_type, + # query data type. + q_data_type=self.q_data_type) + + def asdict_zerocopy(self, + skip_fields: Optional[Set[str]] = None + ) -> Dict[str, Any]: + if skip_fields is None: + skip_fields = set() + # We need to skip the prefill/decode_wrapper field since it cannot be + # broadcasted with nccl when TP is enabled. + skip_fields.add('prefill_wrapper') + skip_fields.add('decode_wrapper') + return super().asdict_zerocopy(skip_fields) + + @property + def prefill_metadata(self) -> Optional["FlashInferMetadata"]: + if self.num_prefills == 0: + return None + return self + + @property + def decode_metadata(self) -> Optional["FlashInferMetadata"]: + if self.num_decode_tokens == 0: + return None + return self + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + + if turn_prefills_into_decodes: + # When Multi-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes. This update reflects that + # conversion. + assert self.num_decode_tokens + self.num_prefills == num_seqs + # Flashinfer doesn't support speculative decoding + chunked-prefill + # + multi-step scheduling yet. + assert self.decode_query_len == 1 + self.num_decode_tokens += self.num_prefills + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.max_prefill_seq_len = 0 + self.max_query_len = 1 + + self.slot_mapping = self.slot_mapping[:num_seqs] + else: + assert self.seq_lens_tensor is not None + + assert num_seqs > 0 + assert num_queries > 0 + assert model_input.attn_metadata is not None + assert sampled_token_ids is not None + + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + assert self.use_cuda_graph + + model_input.input_tokens[:num_queries] = sampled_token_ids.flatten() + + # Update GPU tensors + ops.advance_step_flashinfer( + num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=model_input.input_tokens, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables, + paged_kv_indices=self.paged_kv_indices, + paged_kv_indptr=self.paged_kv_indptr, + paged_kv_last_page_len=self.paged_kv_last_page_len, + block_table_bound=self.block_table_bound) + + +class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + + self.input_builder = input_builder + self.runner = input_builder.runner + + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + + # Global hyperparameters shared by all attention layers + self.global_hyperparameters: Optional[PerLayerParameters] = None + + self.vllm_config = get_current_vllm_config() + + def prepare(self): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + + # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout + # for the precise definition of the following fields. + # An example: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + self.paged_kv_indices: List[int] = [] + # 0 at the beginning of paged_kv_indptr indicates the start of the + # first request’s page indices in the paged_kv_indices list. + self.paged_kv_indptr: List[int] = [0] + # paged_kv_last_page_len is the length of the last page of each request + self.paged_kv_last_page_len: List[int] = [] + self.total_blocks = 0 + self.is_profile_run: bool = False + + if self.global_hyperparameters is None: + # Infer global hyperparameters, since currently we only support + # models in which all layers share the same values for the + # following hyperparameters: + # - `window_left` + # - `logits_soft_cap` + # - `sm_scale` + inferred_params = infer_global_hyperparameters( + get_per_layer_parameters(self.vllm_config)) + self.global_hyperparameters = inferred_params + self.window_left = inferred_params.window_left + self.logits_soft_cap = inferred_params.logits_soft_cap + self.sm_scale = inferred_params.sm_scale + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + computed_block_nums = inter_data.computed_block_nums + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): + self.context_lens.append(context_len) + if is_prompt: + mm_maps = inter_data.multi_modal_placeholder_maps + if mm_maps: + for modality, placeholders in mm_maps.items(): + self.multimodal_placeholder_maps[modality].extend( + placeholders) + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if inter_data.prefix_cache_hit: + block_table = computed_block_nums + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + block_table = block_tables[seq_id][-curr_sliding_window_block:] + self.block_tables.append(block_table) + + is_profile_run = is_block_tables_empty(block_tables) + + # Compute slot mapping. + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + + # It is not necessary to add paged_kv_indices, paged_kv_indptr, + # and paged_kv_last_page_len for profile run because we will + # create dummy inputs. + if is_profile_run: + self.is_profile_run = is_profile_run + return + + block_table = block_tables[seq_id] + self._update_paged_kv_tensors(block_table, seq_len) + + def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int): + # Get the number of valid blocks based on sequence length. + # If seq_len = 16, block_size = 16, + # block_table_bound is 1 with 1 valid block. + # If seq_len = 15, block_size = 16, + # block_table_bound is 0 + 1 with 1 valid block. + self.total_blocks += len(block_table) + block_table_bound = seq_len // self.block_size + 1 \ + if seq_len % self.block_size != 0 \ + else seq_len // self.block_size + self.paged_kv_indices.extend(block_table[:block_table_bound]) + self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + + block_table_bound) + + last_page_len = seq_len % self.block_size + if last_page_len == 0: + last_page_len = self.block_size + self.paged_kv_last_page_len.append(last_page_len) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + decode_query_len = max(query_lens[self.num_prefills:], default=1) + + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size - self.num_prefill_tokens + + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + input_block_tables = self.runner.graph_block_tables[:batch_size] + max_blocks = input_block_tables.shape[1] + for i, block_table in enumerate(self.block_tables): + if block_table: + num_blocks = len(block_table) + if num_blocks <= max_blocks: + input_block_tables[i, :num_blocks] = block_table + else: + # It may be possible to have more blocks allocated due + # to lookahead slots of multi-step, however, they are + # not used anyway, so can be safely ignored. + input_block_tables[ + i, :max_blocks] = block_table[:max_blocks] + + block_tables = torch.from_numpy(input_block_tables).to( + device, non_blocking=True) + + last_paged_kv_indptr = self.paged_kv_indptr[-1] + self.paged_kv_indptr.extend([last_paged_kv_indptr] * + cuda_graph_pad_size) + self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size) + else: + block_tables = make_tensor_with_pad( + self.block_tables, + pad=0, + dtype=torch.int, + device=device, + ) + + assert device is not None + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device, + self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + self.multimodal_placeholder_maps.items() + } + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + if len(self.paged_kv_indptr) > 0: + # extend to the maximum number of blocks as returned by the + # scheduler + self.paged_kv_indices.extend( + [0] * (self.total_blocks - len(self.paged_kv_indices))) + paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, + device="cpu", + dtype=torch.int) + paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr, + device="cpu", + dtype=torch.int) + paged_kv_last_page_len_tensor = torch.tensor( + self.paged_kv_last_page_len, device="cpu", dtype=torch.int) + block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) - + 1, + device="cpu", + dtype=torch.int) + else: + paged_kv_indices_tensor = None + paged_kv_indptr_tensor = None + paged_kv_last_page_len_tensor = None + block_table_bound_tensor = None + + if self.runner.kv_cache_dtype.startswith("fp8"): + kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + self.runner.kv_cache_dtype) + else: + kv_cache_dtype = get_kv_cache_torch_dtype( + self.runner.kv_cache_dtype, self.runner.model_config.dtype) + + return FlashInferMetadata( + decode_query_len=decode_query_len, + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + max_prefill_seq_len=max_prefill_seq_len, + block_tables=block_tables, + paged_kv_indptr=paged_kv_indptr_tensor, + paged_kv_indices=paged_kv_indices_tensor, + paged_kv_last_page_len=paged_kv_last_page_len_tensor, + block_table_bound=block_table_bound_tensor, + seq_lens_tensor=seq_lens_tensor, + num_qo_heads=self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config), + num_kv_heads=self.runner.model_config.get_num_kv_heads( + self.runner.parallel_config), + head_dim=self.runner.model_config.get_head_size(), + page_size=self.block_size, + seq_start_loc=seq_start_loc, + query_start_loc=query_start_loc, + device=device, + data_type=kv_cache_dtype, + q_data_type=self.runner.model_config.dtype, + use_cuda_graph=use_captured_graph, + is_profile_run=self.is_profile_run, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + sm_scale=self.sm_scale, + ) + + +class FlashInferImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window - 1, + 0) if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype + self.logits_soft_cap = logits_soft_cap + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashInferImpl") + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashInferMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + # TODO: directly write to output tensor + num_heads: int = self.num_heads + head_size: int = self.head_size + num_kv_heads: int = self.num_kv_heads + kv_cache_dtype: str = self.kv_cache_dtype + softmax_scale: float = self.scale + window_size = self.sliding_window + alibi_slopes = self.alibi_slopes + logits_soft_cap = self.logits_soft_cap + + num_tokens, hidden_size = query.shape + query = query.view(-1, num_heads, head_size) + key = key.view(-1, num_kv_heads, head_size) + value = value.view(-1, num_kv_heads, head_size) + + if kv_cache.numel() > 0: + # Use the same reshape and cache kernel as flash attention. + ops.reshape_and_cache_flash( + key, + value, + kv_cache[:, 0], + kv_cache[:, 1], + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 + # to process the cache when the kv_cache_dtype is fp8 + if kv_cache_dtype.startswith("fp8"): + torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + kv_cache_dtype) + kv_cache = kv_cache.view(torch_dtype) + + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ + f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa + assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ + f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa + query = query.contiguous( + ) # Flashinfer requires query to be contiguous + # Query for decode. KV is not needed because it is already cached. + # QKV for prefill. + decode_query = query[num_prefill_tokens:] + query = query[:num_prefill_tokens] + + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + window_left = window_size[0] if window_size is not None else -1 + + prefill_output: Optional[torch.Tensor] = None + decode_output: Optional[torch.Tensor] = None + if prefill_meta := attn_metadata.prefill_metadata: + # We will use flash attention for prefill + # when kv_cache is not provided. + # This happens when vllm runs the profiling to + # determine the number of blocks. + if kv_cache.numel() == 0: + prefill_output = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + ) + else: + assert prefill_meta is not None + assert prefill_meta.prefill_wrapper is not None + + assert prefill_meta.prefill_wrapper._causal + assert prefill_meta.prefill_wrapper._window_left == window_left + assert prefill_meta.prefill_wrapper._logits_soft_cap == ( + logits_soft_cap or 0.0) + assert prefill_meta.prefill_wrapper._sm_scale == softmax_scale + + prefill_output = prefill_meta.prefill_wrapper.run( + query, + kv_cache, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + ) + if decode_meta := attn_metadata.decode_metadata: + assert decode_meta is not None + assert decode_meta.decode_wrapper is not None + + assert decode_meta.decode_wrapper._window_left == window_left + assert decode_meta.decode_wrapper._logits_soft_cap == ( + logits_soft_cap or 0.0) + assert decode_meta.decode_wrapper._sm_scale == softmax_scale + + decode_output = decode_meta.decode_wrapper.run( + decode_query, + kv_cache, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + ) + + if prefill_output is None and decode_output is not None: + # Decode only batch. + output, num_tokens = decode_output, num_decode_tokens + elif decode_output is None and prefill_output is not None: + # Prefill only batch. + output, num_tokens = prefill_output, num_prefill_tokens + else: + # Chunked prefill batch does not work with speculative decoding in + # FlashInfer backend, so the query length for decode should be 1. + assert prefill_output is not None + assert decode_output is not None + assert decode_meta is not None + assert decode_meta.decode_query_len == 1 + decode_output = decode_output.squeeze(1) + output = torch.cat([prefill_output, decode_output], dim=0) + return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/backends/flashmla.py b/vllm/attention/backends/flashmla.py new file mode 100644 index 0000000..5d0c230 --- /dev/null +++ b/vllm/attention/backends/flashmla.py @@ -0,0 +1,242 @@ +# SPDX-License-Identifier: Apache-2.0 + +from contextlib import contextmanager +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type + +import torch + +from vllm.attention.backends.abstract import (AttentionType, + is_quantized_kv_cache) +from vllm.attention.backends.mla.common import (MLACommonBackend, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder, + MLACommonState) +from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, + get_mla_metadata, + is_flashmla_supported) + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + + +class FlashMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "FLASHMLA" + + @staticmethod + def get_impl_cls() -> Type["FlashMLAImpl"]: + return FlashMLAImpl + + @staticmethod + def get_metadata_cls() -> Type["FlashMLAMetadata"]: + return FlashMLAMetadata + + @staticmethod + def get_builder_cls() -> Type["FlashMLAMetadataBuilder"]: + return FlashMLAMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["FlashMLAState"]: + return FlashMLAState + + +@dataclass +class FlashMLAMetadata(MLACommonMetadata): + decode_tile_scheduler_metadata: Optional[Tuple[torch.Tensor, + torch.Tensor]] = None + decode_num_splits: Optional[torch.Tensor] = None + + @property + def decode_metadata(self): + decode_metadata = super().decode_metadata + # TODO: cache assignment? + if decode_metadata is not None: + decode_metadata.decode_tile_scheduler_metadata=\ + self.decode_tile_scheduler_metadata + decode_metadata.decode_num_splits=\ + self.decode_num_splits + return decode_metadata + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + raise NotImplementedError( + "advance_step is not implemented for FlashMLA") + + +class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.num_q_heads = self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + m = super().build(seq_lens, query_lens, cuda_graph_pad_size, + batch_size) + + if m.num_decode_tokens > 0: + m.decode_tile_scheduler_metadata, m.decode_num_splits = \ + get_mla_metadata( + m.seq_lens_tensor[m.num_prefills:], + self.num_q_heads, + 1, # MQA for the decode path + ) + + return m + + +class FlashMLAState(MLACommonState[FlashMLAMetadata]): + + def __init__(self, *args, **kwds): + super().__init__(*args, **kwds) + + self.num_q_heads = self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config) + + @contextmanager + def graph_capture(self, max_batch_size: int): + # Run a dummy `get_mla_metadata` so we can get the right shapes + self._graph_decoder_tile_scheduler_metadata, \ + self._graph_decode_num_splits = get_mla_metadata( + torch.ones( + max_batch_size, dtype=torch.int32, device=self.runner.device), + self.num_q_heads, + 1, # MQA for the decode path + ) + + with super().graph_capture(max_batch_size): + yield + + del self._graph_decoder_tile_scheduler_metadata + del self._graph_decode_num_splits + + def graph_capture_get_metadata_for_batch( + self, batch_size: int, is_encoder_decoder_model: bool = False): + metadata = super().graph_capture_get_metadata_for_batch( + batch_size, is_encoder_decoder_model) + assert metadata.num_decode_tokens > 0 + + decoder_tile_scheduler_metadata, decode_num_splits = get_mla_metadata( + self._graph_seq_lens[:batch_size], + self.num_q_heads, + 1, # MQA for the decode path + ) + + self._graph_decoder_tile_scheduler_metadata.copy_( + decoder_tile_scheduler_metadata) + self._graph_decode_num_splits[:batch_size + 1].copy_(decode_num_splits) + + metadata.decode_tile_scheduler_metadata=\ + self._graph_decoder_tile_scheduler_metadata + metadata.decode_num_splits=\ + self._graph_decode_num_splits[:batch_size + 1] + + return metadata + + def get_graph_input_buffers(self, + attn_metadata, + is_encoder_decoder_model: bool = False): + input_buffers = super().get_graph_input_buffers( + attn_metadata, is_encoder_decoder_model) + input_buffers["decode_tile_scheduler_metadata"] = \ + attn_metadata.decode_metadata.decode_tile_scheduler_metadata + input_buffers["decode_num_splits"] = \ + attn_metadata.decode_metadata.decode_num_splits + + return input_buffers + + def prepare_graph_input_buffers(self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False): + super().prepare_graph_input_buffers(input_buffers, attn_metadata, + is_encoder_decoder_model) + + input_buffers["decode_tile_scheduler_metadata"].copy_( + attn_metadata.decode_metadata.decode_tile_scheduler_metadata) + input_buffers["decode_num_splits"].copy_( + attn_metadata.decode_metadata.decode_num_splits) + + +class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + **mla_args) + + assert is_flashmla_supported(), \ + "FlashMLA is not supported on this device" + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "FlashMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashMLAImpl") + + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "FlashMLA with FP8 KV cache not yet supported") + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: FlashMLAMetadata, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + + decode_meta = attn_metadata.decode_metadata + assert decode_meta is not None + + q = torch.cat([q_nope, q_pe], dim=-1)\ + .unsqueeze(1) # Add seqlen dim of 1 (decode) + + o, _ = flash_mla_with_kvcache( + q=q, + k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 + block_table=decode_meta.block_tables, + cache_seqlens=decode_meta.seq_lens_tensor, + head_dim_v=self.kv_lora_rank, + tile_scheduler_metadata=decode_meta.decode_tile_scheduler_metadata, + num_splits=decode_meta.decode_num_splits, + softmax_scale=self.scale, + causal=True, + ) + + return self._v_up_proj_and_o_proj(o) diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py new file mode 100644 index 0000000..f948fbc --- /dev/null +++ b/vllm/attention/backends/hpu_attn.py @@ -0,0 +1,294 @@ +# SPDX-License-Identifier: Apache-2.0 + +############################################################################### +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company +############################################################################### + +import os +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +import vllm_hpu_extension.ops as ops +from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax, + VLLMKVCache) + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, AttentionType, + is_quantized_kv_cache) +from vllm.attention.backends.utils import CommonAttentionState +from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention, + HPUPagedAttentionMetadata) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class HPUAttentionBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "HPU_ATTN" + + @staticmethod + def get_impl_cls() -> Type["HPUAttentionImpl"]: + return HPUAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return HPUAttentionMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return HPUPagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + ) -> None: + HPUPagedAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata): + """Metadata for HPUAttentionbackend.""" + # Currently, input sequences can only contain all prompts + # or all decoding. True if all sequences are prompts. + is_prompt: bool + attn_bias: Optional[torch.Tensor] + seq_lens_tensor: Optional[torch.Tensor] + + +class HPUAttentionImpl(AttentionImpl, torch.nn.Module): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| + + Otherwise, the layout is as follows: + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + max_seq_len: int = 4096, + attn_type: str = AttentionType.DECODER, + ) -> None: + super(AttentionImpl, self).__init__() + self.kv_cache_dtype = kv_cache_dtype + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.matmul_qk = Matmul() + self.softmax = Softmax() + self.matmul_av = Matmul() + self.batch2block_matmul = Matmul() + self.block2batch_matmul = Matmul() + self.k_cache = VLLMKVCache() + self.v_cache = VLLMKVCache() + ops.pa_impl = ops.pa + + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.sliding_window = sliding_window + self.alibi_slopes = alibi_slopes + if alibi_slopes is not None: + alibi_slopes_tensor = torch.tensor(alibi_slopes, + dtype=torch.bfloat16) + self.alibi_slopes = alibi_slopes_tensor + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + self.prefill_usefusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', + '0').lower() in ['1', 'true'] + self.fused_scaled_dot_product_attention = None + if self.prefill_usefusedsdpa: + assert alibi_slopes is None, \ + 'Prefill with FusedSDPA not supported with alibi slopes!' + try: + from habana_frameworks.torch.hpex.kernels import FusedSDPA + self.fused_scaled_dot_product_attention = ModuleFusedSDPA( + FusedSDPA) + except ImportError: + logger().warning("Could not import HPU FusedSDPA kernel. " + "vLLM will use native implementation.") + + suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes() + if head_size not in suppored_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {suppored_head_sizes}.") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "HPUAttentionImpl") + + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "HPUAttention with FP8 KV cache not yet supported") + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: HPUAttentionMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with xFormers and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + batch_size, seq_len, hidden_size = query.shape + _, seq_len_kv, _ = key.shape + + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + block_indices = attn_metadata.block_indices + block_offsets = attn_metadata.block_offsets + if attn_metadata.is_prompt: + key = key.unflatten(0, (block_indices.size(0), -1)) + value = value.unflatten(0, (block_indices.size(0), -1)) + if kv_cache is not None: + key_cache, value_cache = HPUPagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + key_cache = self.k_cache(key, key_cache, block_indices, + block_offsets) + value_cache = self.v_cache(value, value_cache, block_indices, + block_offsets) + + if attn_metadata.is_prompt: + # Prompt run. + if not self.prefill_usefusedsdpa: + # TODO: move this outside of model + assert attn_metadata.attn_bias is not None, \ + 'attn_bias must be set before calling model.forward!' + attn_bias = attn_metadata.attn_bias + if self.alibi_slopes is not None: + position_bias = _make_alibi_bias(self.alibi_slopes, + self.num_kv_heads, + attn_bias.dtype, + attn_bias.shape[-1]) + attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1)) + attn_bias.add_(position_bias) + else: + attn_bias = None + + query_shape = (batch_size, seq_len, self.num_heads, self.head_size) + kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, + self.head_size) + out = ops.prompt_attention( + query.view(query_shape), + key.view(kv_shape), + value.view(kv_shape), + attn_bias=attn_bias, + p=0.0, + scale=self.scale, + matmul_qk_op=self.matmul_qk, + softmax_op=self.softmax, + matmul_av_op=self.matmul_av, + fsdpa_op=self.fused_scaled_dot_product_attention, + ) + output = out.reshape(batch_size, seq_len, hidden_size) + else: + # Decoding run. + output = HPUPagedAttention.forward_decode( + query=query, + key_cache=key_cache, + value_cache=value_cache, + block_list=attn_metadata.block_list, + block_mapping=attn_metadata.block_mapping, + block_bias=attn_metadata.attn_bias, + block_scales=attn_metadata.block_scales, + block_groups=attn_metadata.block_groups, + scale=self.scale, + matmul_qk_op=self.matmul_qk, + matmul_av_op=self.matmul_av, + batch2block_matmul_op=self.batch2block_matmul, + block2batch_matmul_op=self.block2batch_matmul, + keys_fetch_func=self.k_cache.fetch_from_cache, + values_fetch_func=self.v_cache.fetch_from_cache) + # Reshape the output tensor. + return output.view(batch_size, seq_len, hidden_size) + + +def _make_alibi_bias( + alibi_slopes: torch.Tensor, + num_kv_heads: int, + dtype: torch.dtype, + seq_len: int, +) -> torch.Tensor: + bias = torch.arange(seq_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + # Calculate a matrix where each element represents ith element- jth + # element. + bias = bias[None, :] - bias[:, None] + + padded_len = (seq_len + 7) // 8 * 8 + num_heads = alibi_slopes.shape[0] + bias = torch.empty( + 1, # batch size + num_heads, + seq_len, + padded_len, + device=alibi_slopes.device, + dtype=dtype, + )[:, :, :, :seq_len].copy_(bias) + bias.mul_(alibi_slopes[:, None, None]) + if num_heads != num_kv_heads: + bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) + return bias diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py new file mode 100644 index 0000000..d3c61ea --- /dev/null +++ b/vllm/attention/backends/ipex_attn.py @@ -0,0 +1,388 @@ +# SPDX-License-Identifier: Apache-2.0 +""" Attention layer with torch scaled_dot_product_attention + and PagedAttention.""" +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch + +from vllm._ipex_ops import ipex_ops +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, AttentionType, + is_quantized_kv_cache) +from vllm.attention.backends.utils import CommonAttentionState +from vllm.attention.ops.paged_attn import (PagedAttention, + PagedAttentionMetadata) + +_PARTITION_SIZE = 512 + + +class IpexAttnBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "IPEX" + + @staticmethod + def get_impl_cls() -> Type["IpexAttnBackendImpl"]: + return IpexAttnBackendImpl + + @staticmethod + def get_metadata_cls() -> Type["IpexAttnMetadata"]: + return IpexAttnMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + from vllm._ipex_ops import ipex_ops as ops + ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + from vllm._ipex_ops import ipex_ops as ops + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists) + + +@dataclass +class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata): + """Metadata for IpexAttnBackend. + """ + # Currently, input sequences can only contain all prompts + # or all decoding. True if all sequences are prompts. + is_prompt: bool + slot_mapping: torch.Tensor + seq_lens: Optional[List[int]] + seqlen_q: Optional[torch.Tensor] + max_seqlen: Optional[int] + + def __post_init__(self): + # Set during the execution of the first attention op. + # It is a list because it is needed to set per prompt + # when alibi slopes is used. It is because of the limitation + # from xformer API. + # will not appear in the __repr__ and __init__ + self.attn_bias: Optional[List[torch.Tensor]] = None + + @property + def prefill_metadata(self) -> Optional["IpexAttnMetadata"]: + # Currently chunked prefill is not supported + if self.num_decode_tokens == 0: + assert self.num_prefills > 0 + return self + + return None + + @property + def decode_metadata(self) -> Optional["IpexAttnMetadata"]: + # Currently chunked prefill is not supported + if self.num_prefills > 0: + assert self.num_decode_tokens == 0 + return None + + return self + + +class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "IPEX backend does not support block-sparse attention.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.need_mask = (self.alibi_slopes is not None + or self.sliding_window is not None) + if logits_soft_cap is None: + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + + supported_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {supported_head_sizes}.") + if is_quantized_kv_cache(kv_cache_dtype): + raise NotImplementedError( + "IPEX backend does not support FP8 KV cache. " + "Please use xFormers backend instead.") + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "IpexAttnBackendImpl") + + def split_kv_cache( + self, + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + x = 1 + num_blocks = kv_cache.shape[1] + + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, + -1, x) + value_cache = kv_cache[1] + value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) + return key_cache, value_cache + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: IpexAttnMetadata, # type: ignore + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with IPEX varlen_attention and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if kv_cache.numel() > 0: + key_cache, value_cache = self.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + ipex_ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping.flatten(), + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + if attn_metadata.is_prompt: + assert attn_metadata.seq_lens is not None + if (kv_cache.numel() == 0 + or attn_metadata.block_tables.numel() == 0): + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=1) + value = value.repeat_interleave(self.num_queries_per_kv, + dim=1) + + if attn_metadata.attn_bias is None: + if self.alibi_slopes is not None: + att_masks = _make_alibi_bias( + self.alibi_slopes, query.dtype, + attn_metadata.seq_lens) # type: ignore + elif self.sliding_window is not None: + att_masks = _make_sliding_window_bias( + attn_metadata.seq_lens, self.sliding_window, + query.dtype) # type: ignore + else: + att_masks = _make_sliding_window_bias( + attn_metadata.seq_lens, None, dtype=query.dtype) + attn_metadata.attn_bias = att_masks + + output = torch.empty( + (num_tokens, self.num_heads, self.head_size), + dtype=query.dtype, + device=query.device) + ipex_ops.varlen_attention( + query, + key, + value, + output, + attn_metadata.seqlen_q, + attn_metadata.seqlen_q, + attn_metadata.max_seqlen, + attn_metadata.max_seqlen, + pdropout=0.0, + softmax_scale=self.scale, + zero_tensors=False, + is_causal=True, + return_softmax=False, + gen_=None, + logits_soft_cap=self.logits_soft_cap, + ) + else: + # prefix-enabled attention + raise RuntimeError( + "IPEX backend doesn't support prefix decoding.") + + else: + # Decoding run. + max_seq_len = attn_metadata.max_decode_seq_len + output = torch.empty_like(query) + block_size = value_cache.shape[3] + num_seqs, num_heads, head_size = query.shape + max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // + _PARTITION_SIZE) + # NOTE(woosuk): We use a simple heuristic to decide whether to use + # PagedAttention V1 or V2. If the number of partitions is 1, we use + # V1 to avoid the overhead of reduction. Also, if the number of + # sequences or heads is large, we use V1 since there is enough work + # to parallelize. + # TODO(woosuk): Tune this heuristic. + # For context len > 8192, use V2 kernel to avoid shared memory + # shortage. + use_v1 = (max_seq_len <= 8192 and + (max_num_partitions == 1 or num_seqs * num_heads > 512)) + if use_v1: + # Run PagedAttention V1. + ipex_ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + self.num_kv_heads, + self.scale, + attn_metadata.block_tables, + attn_metadata.seq_lens_tensor, + block_size, + max_seq_len, + self.alibi_slopes, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + ipex_ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + self.num_kv_heads, + self.scale, + attn_metadata.block_tables, + attn_metadata.seq_lens_tensor, + block_size, + max_seq_len, + self.alibi_slopes, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + # Reshape the output tensor. + return output.view(-1, self.num_heads * self.head_size) + + +def _make_alibi_bias( + alibi_slopes: torch.Tensor, + dtype: torch.dtype, + seq_lens: List[int], +) -> List[torch.Tensor]: + attn_biases = [] + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + + num_heads = alibi_slopes.shape[0] + bias = bias[None, :].repeat((num_heads, 1, 1)) + bias.mul_(alibi_slopes[:, None, None]) + inf_mask = torch.empty( + (1, seq_len, seq_len), + dtype=bias.dtype, + device=alibi_slopes.device).fill_(-torch.inf).triu_(diagonal=1) + attn_biases.append((bias + inf_mask).to(dtype)) + + return attn_biases + + +def _make_sliding_window_bias( + seq_lens: List[int], + window_size: Optional[int], + dtype: torch.dtype, +) -> List[torch.Tensor]: + attn_biases = [] + for seq_len in seq_lens: + tensor = torch.full( + (1, seq_len, seq_len), + dtype=dtype, + fill_value=1, + ) + shift = 0 + mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore + if window_size is not None: + mask = torch.triu(mask, diagonal=shift - window_size + 1) + mask = torch.log(mask) + attn_biases.append(mask.to(dtype)) + + return attn_biases diff --git a/vllm/attention/backends/mla/__init__.py b/vllm/attention/backends/mla/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py new file mode 100644 index 0000000..a96494b --- /dev/null +++ b/vllm/attention/backends/mla/common.py @@ -0,0 +1,1698 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This file implements common components for MLA implementations. + +First we define: + +Sq as Q sequence length +Skv as KV sequence length + +MLA has two possible ways of computing, a data-movement friendly approach and a +compute friendly approach, we generally want to use the compute friendly +approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1) +and the data-movement friendly approach for "decode" (i.e. the ratio +Sq / Skv is "large"). + +NOTE what we deem small and large is currently determined by if its labelled +prefill or decode by the scheduler, but this is something we should probably +tune. + +Main reference: DeepseekV2 paper, and FlashInfer Implementation +(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). + +Deepseek's MLA attention works the following way: +* Use a single latent vector to represent the per-token entry of the KV cache. +* For decode (i.e. the memory friendly approach) the attention "simulates" a +multi-head attention, while the compute is similar to multi-query attention. + +Below is example of both paths assuming batchsize = 1 + +## More Extent Definitions: + +C Context length, `Skv - Sq` +H hidden size +N number of attention heads +Lq latent dimension for Q 1536 in DSV3 +Lkv latent dimension for K/V 512 in DSV3 +P nope dimension, no rope. 128 in DSV3 +R rope dimension, goes through rope. 64 in DSV3 +V V head dim. 128 in DSV3 + +## Vector/Matrix Definitions + +h_t hidden states (input to attention) shape [Sq, H] +q_c latent/compressed Q shape [Sq, Lq] +q_nope uncompressed Q (no-rope) shape [Sq, N, P] +q_pe uncompressed Q (rope) shape [Sq, N, R] +kv_c latent/compressed KV shape [Skv, Lkv] +k_pe decoupled k position embeddings shape [Skv, R] +new_kv_c new kv_c from current iter shape [Sq, Lkv] +new_k_pe new k_pe from current iter shape [Sq, R] +cache_kv_c cached k_c from previous iters shape [C, Lkv] +cache_k_pe cached k_pe from previous iters shape [C, R] +W_DQ project h_t to q_c shape [H, Lq] +W_UQ project q_c to q_nope shape [Lq, N * P] +W_QR project q_c to q_pe shape [Lq, N * R] +W_DKV project h_t to kv_c shape [H, Lkv] +W_UK project kv_c to k_nope shape [Lkv, N, P] +W_KR project h_t to k_pe shape [H, R] +W_UV project kv_c to v shape [Lkv, N, V] +W_O project v to h_t shape [N * V, H] + + +## Compute Friendly Approach (i.e. "_forward_prefill"): + +q_c = h_t @ W_DQ +q_nope = (q_c @ W_UQ).view(Sq, N, P) +q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) +new_kv_c = h_t @ W_DKV +new_k_pe = RoPE(h_t @ W_KR) +kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) +k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) +k_nope = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P) +v = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V) + +// MHA with QK headdim = P + R +// V headdim = V +// spda_o shape [Sq, N, V] +spda_o = scaled_dot_product_attention( + torch.cat([q_nope, q_pe], dim=-1), + torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), + v +) +return spda_o @ W_O + +NOTE: in the actual code, + `kv_b_proj` is [W_UK; W_UV] concatnated per head + `q_b_proj` is [W_UQ; W_QR] concatnated per head + `out_proj` is W_O + + +## Data-Movement Friendly Approach (i.e. "_forward_decode"): + +Runtime +q_c = h_t @ W_DQ +q_nope = (q_c @ W_UQ).view(-1, N, P) +ql_nope = einsum("snh,lnh->snl", q, W_UK) +q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) +new_kv_c = h_t @ W_DKV +new_k_pe = RoPE(h_t @ W_KR) +kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) +k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) + +// MQA with QK headdim = Lkv + R +// V headdim = Lkv +// spda_o shape [Sq, N, Lkv] +// NOTE: this is less compute-friendly since Lkv > P +// but is more data-movement friendly since its MQA vs MHA +spda_o = scaled_dot_product_attention( + torch.cat([ql_nope, q_pe], dim=-1), + torch.cat([kv_c, k_pe], dim=-1), + kv_c +) + +o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV) +return o.view(-1, N * V) @ self.num_heads @ W_O + + +## Chunked Prefill + +For chunked prefill we want to use the compute friendly algorithm. We are +assuming sufficiently large Sq / Skv ratio, in the future may want to switch to +the data-movement friendly approach if the chunk (i.e. `Sq`) is small. + +However, the compute-friendly approach can potentially run out of memory if Skv +is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)` + +To mitigate this, we chunk the computation of attention with respect to the +current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a +fixed workspace size. + +The chunked prefill approach is as follows: + +MCC Max chunk of context to process per iter, computed dynamically, + used to bound the memory usage + +q_c = h_t @ W_DQ +q_nope = (q_c @ W_UQ).view(Sq, N, P) +q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) +new_kv_c = h_t @ W_DKV +new_k_pe = RoPE(h_t @ W_KR) +new_k_nope = (new_kv_c @ W_UK.view(Lkv, N * P)).view(Sq, N, P) +new_v = (new_kv_c @ W_UV.view(Lkv, N * V)).view(Sq, N, V) + +// MHA between queries and new KV +// with QK headdim = P + R +// V headdim = V +// curr_o shape [Sq, N, V] +// curr_lse shape [N, Sq], this is just order FA returns +curr_o, curr_lse = scaled_dot_product_attention( + torch.cat([q_nope, q_pe], dim=-1), + torch.cat([new_k_nope, new_k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), + new_v, + casual=True, + return_softmax_lse=True +) + +// Compute attention with the already existing context +for chunk_idx in range(cdiv(C, MCC)): + chunk_start = chunk_idx * MCC + chunk_end = min(chunk_start + MCC, C) + Sc = chunk_end - chunk_start + cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end] + cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end] + cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P) + cache_v_chunk = (cache_kv_c_chunk @ W_UV).view(-1, N, V) + + chunk_o, chunk_lse = scaled_dot_product_attention( + torch.cat([q_nope, q_pe], dim=-1), + torch.cat([cache_k_nope_chunk, + cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)], + dim=-1), + cache_v_chunk, + casual=False, + return_softmax_lse=True + ) + + curr_o, curr_lse = merge_attn_states( + suffix_output=curr_o, + suffix_lse=curr_lse, + prefix_output=chunk_o, + prefix_lse=chunk_lse, + ) + +return curr_o @ W_O +""" + +import functools +from abc import abstractmethod +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass +from itertools import accumulate +from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple, + Type, TypeVar) + +import torch + +from vllm import _custom_ops as ops +from vllm import envs +from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionState, MLAAttentionImpl) +from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, + compute_slot_mapping_start_idx, + is_block_tables_empty) +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearBase, RowParallelLinear, + UnquantizedLinearMethod) +from vllm.model_executor.layers.rotary_embedding import ( + DeepseekScalingRotaryEmbedding, RotaryEmbedding) +from vllm.multimodal import MultiModalPlaceholderMap +from vllm.platforms import current_platform +from vllm.triton_utils import HAS_TRITON +from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down +from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version + +if HAS_TRITON: + from vllm.attention.ops.triton_flash_attention import triton_attention + from vllm.attention.ops.triton_merge_attn_states import merge_attn_states +else: + merge_attn_states = None + triton_attention = None + +# try: +# from vllm.vllm_flash_attn import flash_attn_varlen_func +# is_vllm_fa = True +# except ImportError: +# is_vllm_fa = False +# try: +# # For rocm use upstream flash attention +# from flash_attn import flash_attn_varlen_func +# except ImportError: +# flash_attn_varlen_func = None +is_vllm_fa = True +from ixformer.contrib.vllm_flash_attn import flash_attn_varlen_func,merge_attn_states + + +if TYPE_CHECKING: + from vllm.worker.model_runner import (ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) + +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + scaled_quantize, scaled_dequantize) +import ixformer.inference.functions as ixf_ops + +is_hip = current_platform.is_rocm() + + +class MLACommonBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "TRITON_MLA" + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return MLACommonMetadata + + @staticmethod + def get_builder_cls() -> Type["MLACommonMetadataBuilder"]: + return MLACommonMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["MLACommonState"]: + return MLACommonState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, # assumed to be 1 for MLA + head_size: int, + ) -> Tuple[int, ...]: + return (num_blocks, block_size, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + ops.copy_blocks_mla(kv_caches, src_to_dists) + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [576] + + +T = TypeVar("T", bound="MLACommonMetadata") + + +class MLACommonState(AttentionState, Generic[T]): + + def __init__(self, runner): + self.runner = runner + self._is_graph_capturing = False + + scheduler_config = runner.scheduler_config + self.model_config = runner.model_config + cache_config = runner.cache_config + + self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled + self.enable_prefix_caching = cache_config.enable_prefix_caching + + if self.chunked_prefill_enabled or self.enable_prefix_caching: + self.context_chunk_workspace_size = min( + # Max sure there is enough for 8 full length request or at least + # 4 pages of cache per request + max( + 8 * self.model_config.max_model_len, 4 * + scheduler_config.max_num_seqs * cache_config.block_size), + # For long-context models try not to over-allocate limiting + # kv-cache space, limiting it to 64k tokens, + # which would result in the workspace being: + # 2*(576)*(64*1024) = 144mb + # (assuming 576 MLA head dim, and fp16) + # which would result in up-projected context being + # 2*(192*128)*(64*1024) = 3gb + # (assuming 192 QK head dim, 128 heads, and fp16) + 128 * 1024) + assert self.context_chunk_workspace_size >= \ + scheduler_config.max_num_seqs * cache_config.block_size + + @contextmanager + def graph_capture(self, max_batch_size: int): + self._is_graph_capturing = True + + self._graph_slot_mapping = torch.full((max_batch_size, ), + PAD_SLOT_ID, + dtype=torch.long, + device=self.runner.device) + self._graph_seq_lens = torch.ones(max_batch_size, + dtype=torch.int32, + device=self.runner.device) + self._graph_block_tables = torch.from_numpy( + self.runner.graph_block_tables).to(device=self.runner.device) + + self._positions = torch.zeros((max_batch_size, ), + dtype=torch.long, + device=self.runner.device) + + yield + + self._is_graph_capturing = False + del self._graph_slot_mapping + del self._graph_seq_lens + del self._graph_block_tables + del self._positions + + def graph_clone(self, batch_size: int): + assert self._is_graph_capturing + return self.__class__(self.runner) + + def graph_capture_get_metadata_for_batch( + self, + batch_size: int, + is_encoder_decoder_model: bool = False) -> T: + assert self._is_graph_capturing + + attn_metadata = self.runner.attn_backend.make_metadata( + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + use_cuda_graph=True, + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=self._graph_slot_mapping[:batch_size], + seq_lens=None, + seq_lens_tensor=self._graph_seq_lens[:batch_size], + max_query_len=1, + max_decode_query_len=1, + max_prefill_seq_len=0, + max_decode_seq_len=self.runner.max_seq_len_to_capture, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self._graph_block_tables[:batch_size], + input_positions=self._positions[:batch_size], + head_dim=self.runner.model_config.get_head_size()) + + if is_encoder_decoder_model: + raise NotImplementedError( + "MLACommonState does not support encoder/decoder yet") + + return attn_metadata + + def get_graph_input_buffers(self, + attn_metadata, + is_encoder_decoder_model: bool = False): + input_buffers = { + "slot_mapping": attn_metadata.slot_mapping, + "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, + "block_tables": attn_metadata.decode_metadata.block_tables, + "input_positions": attn_metadata.decode_metadata.input_positions, + } + if is_encoder_decoder_model: + raise NotImplementedError( + "MLACommonState does not support encoder/decoder yet") + + return input_buffers + + def prepare_graph_input_buffers(self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False): + input_positions = attn_metadata.input_positions + num_positions = input_positions.shape[0] + input_buffers["seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) + input_buffers["block_tables"].copy_( + attn_metadata.decode_metadata.block_tables, non_blocking=True) + # CUDA graph buffer is padded so only perform a partial copy based on + # num_positions + input_buffers["input_positions"][:num_positions].copy_( + input_positions, non_blocking=True) + if is_encoder_decoder_model: + raise NotImplementedError( + "TritonMLAState does not support encoder/decoder yet") + + def begin_forward(self, model_input): + if self.chunked_prefill_enabled or self.enable_prefix_caching: + if not hasattr(self, "context_chunk_workspace"): + # not self.runner.device does not return the correct device + # for this process, (init_device sets the correct device but + # only on the Worker). The only way Ive figured out to get the + # correct device is to allocate the workspace on the first call + # to begin_forward and use the device of the input tokens + assert model_input.input_tokens is not None + self.context_chunk_workspace = torch.empty( + (self.context_chunk_workspace_size, + self.model_config.get_head_size()), + dtype=self.model_config.dtype, + device=model_input.input_tokens.device, + ) + + model_input.attn_metadata.context_chunk_workspace = \ + self.context_chunk_workspace + + +@dataclass +class MLACommonMetadata(AttentionMetadata): + """Metadata for MLACommon. + + NOTE: Please read the comment at the top of the file before trying to + understand this class + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + # New for MLA (compared to FlashAttention) + # Input positions for rotrary embeddings since for MLA the rotary + # position embeddings are applied inside the attention backend + input_positions: torch.Tensor + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + # Maximum query length in the batch. + max_query_len: Optional[int] = None + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] = None + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + + _cached_prefill_metadata: Optional[Any] = None + _cached_decode_metadata: Optional[Any] = None + + num_prefill_tokens: int + + # The dimension of the attention heads + head_dim: Optional[int] = None + + # Used when chunked prefill is enabled to simulate worst case workspace + # allocations, hopefully to avoid going OOM + is_profile_run: bool = False + + # New for MLA (compared to FlashAttention) + # For chunked prefill + context_chunk_cu_seq_lens: Optional[torch.Tensor] = None + context_chunk_starts: Optional[torch.Tensor] = None + context_chunk_seq_tot: Optional[List[int]] = None + context_chunk_max_seq_lens: Optional[List[int]] = None + # Set by MLAAttentionState in `begin_forward` so it doesn't get broadcasted + context_chunk_workspace: Optional[torch.Tensor] = None + + def __post_init__(self): + supported_head_sizes = MLACommonBackend.get_supported_head_sizes() + if self.head_dim is not None and self.head_dim \ + not in supported_head_sizes: + raise ValueError( + f"Only {supported_head_sizes} are supported for head_dim,", + f" received {self.head_dim}.") + + @property + def prefill_metadata(self): + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + + # Compute some attn_metadata fields which default to None + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[:self.num_prefill_tokens]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + seq_start_loc = (None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) + block_tables = (None if self.block_tables is None else + self.block_tables[:self.num_prefills]) + input_positions = (None if self.input_positions is None else + self.input_positions[:self.num_prefill_tokens]) + + self._cached_prefill_metadata = self.__class__( + # Required by ModelRunner + use_cuda_graph=False, # Not Attention Related + # Required by Attention Metadata + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=slot_mapping, + # Required by Attention Metadata (not used) + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + # MLACommonMetadata + input_positions=input_positions, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_query_len=0, + max_decode_seq_len=0, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + head_dim=self.head_dim, + is_profile_run=self.is_profile_run, + # MLACommonMetadata Chunk prefill specific + context_chunk_cu_seq_lens=self.context_chunk_cu_seq_lens, + context_chunk_starts=self.context_chunk_starts, + context_chunk_seq_tot=self.context_chunk_seq_tot, + context_chunk_max_seq_lens=self.context_chunk_max_seq_lens, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self): + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.seq_lens_tensor is not None + + # Compute some attn_metadata fields which default to None + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[self.num_prefill_tokens:]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) + block_tables = (None if self.block_tables is None else + self.block_tables[self.num_prefills:]) + input_positions = (None if self.input_positions is None else + self.input_positions[self.num_prefill_tokens:]) + if self.use_cuda_graph: + self.max_decode_seq_len = self.block_tables.shape[-1] * 16 + + self._cached_decode_metadata = self.__class__( + # Required by ModelRunner + use_cuda_graph=self.use_cuda_graph, # Not Attention Related + # Required by Attention Metadata + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=slot_mapping, + # Required by Attention Metadata (not used) + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + # MLACommonMetadata + seq_lens=None, + seq_lens_tensor=seq_lens_tensor, + max_decode_query_len=self.max_decode_query_len, + max_query_len=self.max_query_len, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + # Batch may be composed of prefill|decodes, adjust query start + # indices to refer to the start of decodes. E.g. + # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. + query_start_loc=(self.query_start_loc[self.num_prefills:] - + self.query_start_loc[self.num_prefills]) + if self.query_start_loc is not None else None, + seq_start_loc=self.seq_start_loc[self.num_prefills:] + if self.seq_start_loc is not None else None, + context_lens_tensor=None, + block_tables=block_tables, + input_positions=input_positions, + head_dim=self.head_dim, + is_profile_run=self.is_profile_run) + return self._cached_decode_metadata + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + + if turn_prefills_into_decodes: + # When Mutli-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes. This update reflects that + # conversion. + assert self.num_decode_tokens + self.num_prefills == num_seqs + self.num_decode_tokens += self.num_prefills + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.max_prefill_seq_len = 0 + self.max_query_len = 1 + + self.slot_mapping = self.slot_mapping[:num_seqs] + else: + assert self.seq_lens is not None + assert self.max_decode_seq_len == max(self.seq_lens) + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.num_decode_tokens == num_seqs + assert self.slot_mapping.shape == (num_seqs, ) + + assert self.seq_lens is not None + assert len(self.seq_lens) == num_seqs + assert self.seq_lens_tensor is not None + assert self.seq_lens_tensor.shape == (num_seqs, ) + assert self.max_query_len == 1 + assert self.max_prefill_seq_len == 0 + + assert self.query_start_loc is not None + assert self.query_start_loc.shape == (num_queries + 1, ) + assert self.seq_start_loc is not None + assert self.seq_start_loc.shape == (num_seqs + 1, ) + + assert self.context_lens_tensor is not None + assert self.context_lens_tensor.shape == (num_queries, ) + + assert self.block_tables is not None + assert self.block_tables.shape[0] == num_seqs + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + self.seq_lens[i] += 1 + self.max_decode_seq_len = max(self.seq_lens) + + ops.advance_step_flashattn(num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables) + + +class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.input_builder = input_builder + self.runner = input_builder.runner + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + self.chunked_prefill_enabled = \ + self.runner.scheduler_config.chunked_prefill_enabled + self.enable_prefix_caching = \ + self.runner.cache_config.enable_prefix_caching + + if self.chunked_prefill_enabled or self.enable_prefix_caching: + attn_state = self.input_builder.runner.attn_state + self.context_chunk_workspace_size = \ + attn_state.context_chunk_workspace_size + self.page_size = self.runner.block_size + + def prepare(self): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.input_positions: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + self.has_prefix_cache_hit = False + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool, prefix_cache_hit: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block, input_positions) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks, + inter_data.input_positions): + self.input_positions.extend(input_positions) + self.context_lens.append(context_len) + if is_prompt: + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if prefix_cache_hit: + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + block_table = block_tables[seq_id] + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + if curr_sliding_window_block == 0: + block_table = block_tables[seq_id] + else: + block_table = block_tables[seq_id][ + -curr_sliding_window_block:] + self.block_tables.append(block_table) + + # Compute slot mapping. + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + + def _get_graph_runner_block_tables( + self, num_seqs: int, + block_tables: List[List[int]]) -> torch.Tensor: + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + max_batch_size, max_blocks = self.runner.graph_block_tables.shape + assert max_batch_size >= num_seqs + + graph_block_tables = self.runner.graph_block_tables[:num_seqs] + for i, block_table in enumerate(block_tables): + if block_table: + num_blocks = len(block_table) + if num_blocks <= max_blocks: + graph_block_tables[i, :num_blocks] = block_table + else: + # It may be possible to have more blocks allocated due + # to lookahead slots of multi-step, however, they are + # not used anyway, so can be safely ignored. + graph_block_tables[ + i, :max_blocks] = block_table[:max_blocks] + + return torch.from_numpy(graph_block_tables).to( + device=self.runner.device, non_blocking=True) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + prefix_cache_hit = any([ + inter_data.prefix_cache_hit + for inter_data in self.input_builder.inter_data_list + ]) + + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled, + prefix_cache_hit) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_query_len = max(query_lens) + decode_query_lens = query_lens[self.num_prefills:] + if len(decode_query_lens) > 0: + max_decode_query_len = max(decode_query_lens) + else: + max_decode_query_len = 1 + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + query_start_loc = list(accumulate(query_lens, initial=0)) + seq_start_loc = list(accumulate(seq_lens, initial=0)) + + num_seqs = len(seq_lens) + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size - self.num_prefill_tokens + block_tables = self._get_graph_runner_block_tables( + num_seqs, self.block_tables) + else: + block_tables = make_tensor_with_pad( + self.block_tables, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + assert device is not None + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + input_positions = async_tensor_h2d(self.input_positions, torch.long, + device, self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) + query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, + device, + self.runner.pin_memory) + seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, + device, self.runner.pin_memory) + + context_chunk_cu_seq_lens = None + context_chunk_starts = None + context_chunk_seq_tot = None + context_chunk_max_seq_lens = None + + if (self.chunked_prefill_enabled or self.enable_prefix_caching) \ + and self.num_prefills > 0 \ + and context_lens_tensor is not None \ + and context_lens_tensor[:self.num_prefills].max() > 0: + + # NOTE: it is recommend you read the `Chunked Prefill` section in + # the comment at the top of the file before trying to understand + # the following code + + num_prefills_with_context = \ + (context_lens_tensor[:self.num_prefills] > 0).sum().item() + + # currently we allocate an equal amount of workspace for each + # prefill in the batch, we could probably use a more advanced + # algorithm here and allocate more workspace to prefills with + # longer context lengths + max_context_chunk = \ + self.context_chunk_workspace_size // num_prefills_with_context + + # align max_context_chunk to page_size by rounding down, + # currently the `gather_cache` kernel cannot handle + # `context_chunk_starts` that are not aligned to page_size + max_context_chunk = round_down(max_context_chunk, self.page_size) + assert max_context_chunk > 0 + num_chunks = cdiv(context_lens_tensor.max(), max_context_chunk) + + # if `max_context_chunk = 256`, `num_chunks = 3`, and + # `num_prefills_with_context = 4`, create a tensor that looks like + # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] + context_chunk_starts = \ + torch.arange(num_chunks, device=device, dtype=torch.int32)\ + .unsqueeze(1).expand(-1, self.num_prefills)\ + * max_context_chunk + chunk_ends = torch.min(context_lens_tensor[:self.num_prefills]\ + .unsqueeze(0), context_chunk_starts + max_context_chunk) + chunk_seq_lens = (chunk_ends - context_chunk_starts).clamp(min=0) + _context_chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to( + torch.int32) + zero = torch.zeros(num_chunks, dtype=torch.int32, device=device)\ + .unsqueeze(-1) + context_chunk_cu_seq_lens = \ + torch.cat([zero, _context_chunk_cu_seq_lens], dim=1) + context_chunk_max_seq_lens = \ + chunk_seq_lens.max(dim=1).values.tolist() + context_chunk_seq_tot = chunk_seq_lens.sum(dim=1).tolist() + assert max(context_chunk_seq_tot) <= \ + self.context_chunk_workspace_size + + return self.runner.attn_backend.make_metadata( + # Required by ModelRunner + use_cuda_graph=use_captured_graph, # Not Attention Related + # Required by Attention Metadata + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + # Required by Attention Metadata (not used) + multi_modal_placeholder_index_maps=None, # Not Attention Related + enable_kv_scales_calculation=False, + # MLACommonMetadata + input_positions=input_positions, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_decode_query_len=max_decode_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc_tensor, + seq_start_loc=seq_start_loc_tensor, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + head_dim=self.runner.model_config.get_head_size(), + is_profile_run=self.runner.in_profile_run, + # MLACommonMetadata Chunk prefill specific + context_chunk_cu_seq_lens=context_chunk_cu_seq_lens, + context_chunk_starts=context_chunk_starts, + context_chunk_seq_tot=context_chunk_seq_tot, + context_chunk_max_seq_lens=context_chunk_max_seq_lens, + ) + + +class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + q_lora_rank: Optional[int], + kv_lora_rank: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + qk_head_dim: int, + v_head_dim: int, + rotary_emb: RotaryEmbedding, + # q_proj should be q_b_proj if q_lora_rank is not None, but from an + # attention backend perspective we rely on the layer to pass in the + # correct matrix + q_proj: ColumnParallelLinear, + kv_b_proj: ColumnParallelLinear, + o_proj: RowParallelLinear, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_head_dim + self.v_head_dim = v_head_dim + + self.rotary_emb = rotary_emb + self.use_yarn_rope = isinstance(rotary_emb, + DeepseekScalingRotaryEmbedding) + self.q_proj = q_proj + self.kv_b_proj = kv_b_proj + self.o_proj = o_proj + self.triton_fa_func = triton_attention + + # Handle the differences between the flash_attn_varlen from flash_attn + # and the one from vllm_flash_attn. The former is used on RoCM and the + # latter has an additional parameter to control FA2 vs FA3 + self.flash_attn_varlen_func = flash_attn_varlen_func + self.vllm_flash_attn_version = get_flash_attn_version() + if self.vllm_flash_attn_version is not None: + self.flash_attn_varlen_func = \ + functools.partial(flash_attn_varlen_func, + fa_version=self.vllm_flash_attn_version) + + def _v_up_proj_and_o_proj(self, x): + # # Convert from (B, N, L) to (N, B, L) + # x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + # # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + # x = torch.bmm(x, self.W_UV) + # # Convert from (N, B, V) to (B, N * V) + # x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + # return self.o_proj(x)[0] + x = torch.einsum("bnl,lnv->bnv", x, self.W_UV) + return self.o_proj(x.reshape(-1, self.num_heads * self.v_head_dim))[0] + + # Return `ql_nope`, `q_pe` + def _q_proj_and_k_up_proj(self, x): + q_nope, q_pe = self.q_proj(x)[0].view(-1, self.num_heads, self.qk_head_dim).split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + # # Convert from (B, N, P) to (N, B, P) + # q_nope = q_nope.transpose(0, 1) + # # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + # ql_nope = torch.bmm(q_nope, self.W_UK_T) + # # Convert from (N, B, L) to (B, N, L) + # return ql_nope.transpose(0, 1), q_pe + return torch.einsum("bnp,lnp->bnl", q_nope, self.W_UK).view(-1, self.num_heads, self.kv_lora_rank), q_pe + + def process_weights_after_loading(self, act_dtype: torch.dtype): + + def get_layer_weight(layer): + WEIGHT_NAMES = ("weight", "qweight", "weight_packed") + for attr in WEIGHT_NAMES: + if hasattr(layer, attr): + return getattr(layer, attr) + raise AttributeError( + f"Layer '{layer}' has no recognized weight attribute:" + f" {WEIGHT_NAMES}.") + + def get_and_maybe_dequant_weights(layer: LinearBase): + if not isinstance(layer.quant_method, UnquantizedLinearMethod): + # we already prepare a dequant weight for GGUF, skip it. + if layer.quant_method.__class__.__name__ == "GGUFLinearMethod": + weight = layer.weight.clone() + del layer.weight + return weight + if layer.quant_method.__class__.__name__ == "AWQMarlinLinearMethod": + from ixformer.inference.functions import ref_wui4a16 + return ref_wui4a16(None, layer.qweight, layer.scales, layer.qzeros, None, layer.quant_method.quant_config.group_size, only_return_weight=True) + # for W8A8, we directly dequantize it here to avoiding quantization errors + if hasattr(layer, "scheme") and layer.scheme.__class__.__name__ == "CompressedTensorsW8A8Int8" and not layer.scheme.is_static_input_scheme: + quant_weight = layer.weight.T # output, input + scales = layer.weight_scale + return scaled_dequantize(quant_weight, scales, (1, -1), act_dtype) + # NOTE: This should only be used offline, since it's O(N^3) + eye = torch.eye(layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device) + dequant_weights = layer.quant_method.apply(layer, + eye, + bias=None) + del eye + # standardize to (output, input) + return dequant_weights.T + return layer.weight + + back_to_vllm = False + weight_dtype = get_layer_weight(self.kv_b_proj).dtype + assert get_layer_weight(self.q_proj).dtype == weight_dtype + + # when use customize forward, we do not care the specific data types and shape, just reset the + # _v_up_proj_and_o_proj and _q_proj_and_k_up_proj funs to get the correct results. + if envs.VLLM_MLA_CUSTOMIZE: + layer = self.kv_b_proj + quant_method = layer.quant_method.__class__.__name__ + if hasattr(layer, "scheme") and layer.scheme.__class__.__name__ == "CompressedTensorsW8A8Int8": + # must be as same as kv_b_proj + assert hasattr(self.q_proj, "scheme") and self.q_proj.scheme.__class__.__name__ == "CompressedTensorsW8A8Int8" + assert hasattr(self.o_proj, "scheme") and self.o_proj.scheme.__class__.__name__ == "CompressedTensorsW8A8Int8" + + kv_b_proj_weight = self.kv_b_proj.weight # [i,o] + kv_b_proj_weight_scale = self.kv_b_proj.weight_scale # [o, 1] + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + W_UK, W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + kv_b_proj_weight_scale = kv_b_proj_weight_scale.view( + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + 1, + ) + W_UK_S, W_UV_S = kv_b_proj_weight_scale.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=1) + + # self.W_UK = W_UK # [kv_lora_rank, n, qk_nope] + self.W_UK = ixf_ops.marlin_w8_weight_repack( + weights=W_UK.permute(1, 2, 0).contiguous(), + weight_format="int8", + reformat="k16n16_grouped_n", + ) + self.W_UK_S = W_UK_S.contiguous() # [n, qk_nope, 1] + # self.W_UV = W_UV # [kv_lora_rank, n, v_head_dim] + self.W_UV = ixf_ops.marlin_w8_weight_repack( + weights=W_UV.permute(1, 0, 2).contiguous(), + weight_format="int8", + reformat="k16n16", + ) + self.W_UV_S = W_UV_S.contiguous() # [n, v_head_dim, 1] + + def _v_up_proj_and_o_proj_w8a8(self, x): + # x: b, n, kv_lora + # W_UV = (self.W_UV * self.W_UV_S.view(1, self.num_heads, self.v_head_dim)).to(x.dtype) + # x = torch.einsum("bnl,lnv->bnv", x, W_UV) + res = ixf_ops.marlin_w8a16( + x, + self.W_UV, + self.W_UV_S, + group_size=-1, + format="k16n16", + batch_first=False, + ) + return self.o_proj(res.reshape(-1, self.num_heads * self.v_head_dim))[0] + + def _q_proj_and_k_up_proj_w8a8(self, x): + # x: b, q_lora + x = self.q_proj(x)[0].view(-1, self.num_heads, self.qk_head_dim) + x_nope, x_pe = x.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + # W_UK = (self.W_UK * self.W_UK_S.view(1, self.num_heads, self.qk_nope_head_dim)).to(x.dtype) + # return torch.einsum("bnp,lnp->bnl", x_nope, W_UK).view(-1, self.num_heads, self.kv_lora_rank), x_pe + res = ixf_ops.marlin_w8a16( + x_nope, + self.W_UK, + self.W_UK_S, + group_size=-1, + format="k16n16_grouped_n", + batch_first=False, + ) + return res, x_pe + + self._v_up_proj_and_o_proj = functools.partial(_v_up_proj_and_o_proj_w8a8, self) + self._q_proj_and_k_up_proj = functools.partial(_q_proj_and_k_up_proj_w8a8, self) + elif quant_method == "AWQMarlinLinearMethod": + # ===== W_UK & W_UV use W4A16 ===== + def split_4bit_matrix(qweight, scales, qzeros, num_heads, head_dim_1, head_dim_2): + assert head_dim_1 % 8 == 0 + assert head_dim_2 % 8 == 0 + qweight = qweight.view(-1, num_heads, (head_dim_1 + head_dim_2)//8) + w1, w2 = qweight.split([head_dim_1//8, head_dim_2//8], dim=-1) + scales = scales.view(-1, num_heads, head_dim_1 + head_dim_2) + s1, s2 = scales.split([head_dim_1, head_dim_2], dim=-1) + qzeros = qzeros.view(-1, num_heads, (head_dim_1 + head_dim_2)//8) + z1, z2 = qzeros.split([head_dim_1//8, head_dim_2//8], dim=-1) + return (w1.contiguous().view(w1.shape[0], -1), + s1.contiguous().view(s1.shape[0], -1), + z1.contiguous().view(z1.shape[0], -1), + w2.contiguous().view(w2.shape[0], -1), + s2.contiguous().view(s2.shape[0], -1), + z2.contiguous().view(z2.shape[0], -1)) + # split W_UK and W_UV + ( + W_UK, W_UK_S, W_UK_Z, + W_UV, W_UV_S, W_UV_Z + ) = split_4bit_matrix( + self.kv_b_proj.qweight, self.kv_b_proj.scales, self.kv_b_proj.qzeros, + self.num_heads, self.qk_nope_head_dim, self.v_head_dim + ) + # repack W_UK + W_UK = W_UK.reshape(self.kv_lora_rank, self.num_heads, -1).permute(1, 2, 0).contiguous() + W_UK_S = W_UK_S.reshape(W_UK_S.shape[0], self.num_heads, -1).permute(1, 0, 2).contiguous() + W_UK_Z = W_UK_Z.reshape(W_UK_Z.shape[0], self.num_heads, -1).permute(1, 0, 2).contiguous() + ( + self.W_UK, + self.W_UK_S, + self.W_UK_Z + ) = ixf_ops.marlin_w4_weight_repack( + weights = W_UK, + scales = W_UK_S, + zeros = W_UK_Z, + weight_format = "gptq_grouped_n", + reformat = "k16n32_grouped_n", + ) + # repack W_UV + W_UV = W_UV.reshape(self.kv_lora_rank, self.num_heads, -1).permute(1, 0, 2).contiguous() + W_UV_S = W_UV_S.reshape(W_UV_S.shape[0], self.num_heads, -1).permute(1, 0, 2).contiguous() + W_UV_Z = W_UV_Z.reshape(W_UV_Z.shape[0], self.num_heads, -1).permute(1, 0, 2).contiguous() + ( + self.W_UV, + self.W_UV_S, + self.W_UV_Z + ) = ixf_ops.marlin_w4_weight_repack( + weights = W_UV, + scales = W_UV_S, + zeros = W_UV_Z, + weight_format = "awq", + reformat = "k16n32", + ) + self.w4_group_size = self.q_proj.quant_method.quant_config.group_size + + # ===== W_UK & W_UV use W16A16 ===== + # kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + # kv_b_proj_weight = kv_b_proj_weight.view( + # self.kv_lora_rank, + # self.num_heads, + # self.qk_nope_head_dim + self.v_head_dim, + # ) + # self.W_UK_fp16, self.W_UV_fp16 = kv_b_proj_weight.split( + # [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + def _v_up_proj_and_o_proj_w4a16(self, x): + # W16A16 + # res = torch.einsum("bnl,lnv->bnv", x, self.W_UV_fp16) + + # W4A16 + res = ixf_ops.marlin_w4a16( + x, self.W_UV, self.W_UV_S, self.W_UV_Z, + group_size=self.w4_group_size, + format="k16n32", + batch_first=False) + + return self.o_proj(res.view(-1, self.num_heads * self.v_head_dim))[0] + + def _q_proj_and_k_up_proj_w4a16(self, x): + x = self.q_proj(x)[0].view(-1, self.num_heads, self.qk_head_dim) + x_nope, x_pe = x.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + # W16A16 + # return torch.einsum("bnp,lnp->bnl", x_nope, self.W_UK_fp16).view(-1, self.num_heads, self.kv_lora_rank), x_pe + + # W4A16 + res = ixf_ops.marlin_w4a16( + x_nope, self.W_UK, self.W_UK_S, self.W_UK_Z, + group_size=self.w4_group_size, + format="k16n32_grouped_n", + batch_first=False, + ) + return res, x_pe + + self._v_up_proj_and_o_proj = functools.partial(_v_up_proj_and_o_proj_w4a16, self) + self._q_proj_and_k_up_proj = functools.partial(_q_proj_and_k_up_proj_w4a16, self) + else: + print(f"Custom MLA for quant method: {layer.quant_method.__class__.__name__} is not supported, will use the vllm official impl.") + back_to_vllm = True + else: + back_to_vllm = True + + if back_to_vllm: + # we currently do not have quantized bmm's which are needed for + # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform + # the bmm's in 16-bit, the extra memory overhead of this is fairly low + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + assert kv_b_proj_weight.shape == ( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}") + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + + W_UK, W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + # # Convert from (L, N, V) to (N, L, V) + # self.W_UV = W_UV.transpose(0, 1) + # # Convert from (L, N, P) to (N, P, L) + # self.W_UK_T = W_UK.permute(1, 2, 0) + + self.W_UV = W_UV + self.W_UK = W_UK + + def _compute_prefill_context( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + kv_c_and_k_pe_cache_scale: torch.Tensor, + attn_metadata: MLACommonMetadata, + ): + # TODO + prefill_metadata = attn_metadata.prefill_metadata + assert prefill_metadata is not None + assert prefill_metadata.context_chunk_seq_tot is not None + assert prefill_metadata.context_chunk_cu_seq_lens is not None + assert prefill_metadata.context_chunk_starts is not None + assert prefill_metadata.context_chunk_max_seq_lens is not None + assert prefill_metadata.context_lens_tensor is not None + + output = None + iters = len(prefill_metadata.context_chunk_seq_tot) + + # Fetch from attn_metadata directly, since it late bound by + # MLAAttentionState, grabbing it directly `attn_metadata` can avoid + # any weirdness around prefill_metadata caching + assert attn_metadata.context_chunk_workspace is not None + workspace = attn_metadata.context_chunk_workspace + + for i in range(iters): + toks = prefill_metadata.context_chunk_seq_tot[i] + if envs.VLLM_USE_INT8_MLA: + ops.gather_cache_int8( + src_cache=kv_c_and_k_pe_cache, + src_cache_scale=kv_c_and_k_pe_cache_scale, + kv_lora_rank = self.kv_lora_rank, + dst=workspace, + block_table=prefill_metadata.block_tables, + cu_seq_lens=prefill_metadata.context_chunk_cu_seq_lens[i], + batch_size=prefill_metadata.num_prefills, + seq_starts=prefill_metadata.context_chunk_starts[i], + ) + else: + ops.gather_cache( + src_cache=kv_c_and_k_pe_cache, + dst=workspace, + block_table=prefill_metadata.block_tables, + cu_seq_lens=prefill_metadata.context_chunk_cu_seq_lens[i], + batch_size=prefill_metadata.num_prefills, + seq_starts=prefill_metadata.context_chunk_starts[i], + ) + + kv_c_normed = workspace[:toks]\ + [..., :self.kv_lora_rank].contiguous() + k_pe = workspace[:toks]\ + [..., self.kv_lora_rank:].unsqueeze(1) + + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), + dim=-1) + + attn_output = torch.empty(q.shape[0], self.num_heads, self.v_head_dim, dtype=q.dtype, device=q.device) + + # For MLA the v head dim is smaller than qk head dim so we pad + # out v with 0s to match the qk head dim + # v_padded = torch.nn.functional.pad(v, + # [0, q.shape[-1] - v.shape[-1]], + # value=0) + + if is_vllm_fa: + attn_output, attn_softmax_lse = self.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], + max_seqlen_q=prefill_metadata.max_query_len, + max_seqlen_k=prefill_metadata. + context_chunk_max_seq_lens[i], + softmax_scale=self.scale, + causal=False, # Context is unmasked + return_softmax_lse=True, + out=attn_output, + ) + else: + attn_output, attn_softmax_lse, _ = self.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], + max_seqlen_q=prefill_metadata.max_query_len, + max_seqlen_k=prefill_metadata. + context_chunk_max_seq_lens[i], + softmax_scale=self.scale, + causal=False, # Context is unmasked + return_attn_probs=True, + ) + + if output is None: + output = attn_output + output_lse = attn_softmax_lse + else: + output_tmp = torch.empty_like(output) + output_lse_tmp = torch.empty_like(output_lse) + merge_attn_states( + output=output_tmp, + output_lse=output_lse_tmp, + prefix_output=output, + prefix_lse=output_lse, + suffix_output=attn_output, + suffix_lse=attn_softmax_lse, + ) + output = output_tmp + output_lse = output_lse_tmp + + return output, output_lse + + def _forward_prefill( + self, + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + kv_c_and_k_pe_cache_scale: torch.Tensor, + attn_metadata: MLACommonMetadata, + ) -> torch.Tensor: + + prefill_metadata = attn_metadata.prefill_metadata + assert prefill_metadata is not None + + has_context = prefill_metadata.context_lens_tensor is not None \ + and prefill_metadata.context_lens_tensor.max() > 0 + + kv_nope = self.kv_b_proj(kv_c_normed)[0].view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v_nope = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + v = v_nope + k[...,:self.qk_nope_head_dim] = k_nope + attn_output = torch.empty(q.shape[0], self.num_heads, self.v_head_dim, dtype=q.dtype, device=q.device) + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim + # v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], + # value=0) + # attn_output = torch.empty(q.shape[0], self.num_heads, q.shape[-1], dtype=q.dtype, device=q.device) + if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN and not has_context: + output = self.triton_fa_func( + q, + k, + v, + None, + prefill_metadata.query_start_loc, + prefill_metadata.query_start_loc, + prefill_metadata.max_prefill_seq_len, + prefill_metadata.max_prefill_seq_len, + True, # causal + self.scale, + None, # attn_mask is None unless applying ALiBi mask + ) + ## triton flash attention always return 2 objects + if not has_context: + output = output[0] + elif is_vllm_fa: + output = self.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.query_start_loc, + max_seqlen_q=prefill_metadata.max_prefill_seq_len, + max_seqlen_k=prefill_metadata.max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + return_softmax_lse=has_context, + out=attn_output, + ) + else: + output = self.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.query_start_loc, + max_seqlen_q=prefill_metadata.max_prefill_seq_len, + max_seqlen_k=prefill_metadata.max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + return_attn_probs=has_context, + out=attn_output, + ) + + if has_context: + # TODO + # ROCm flash_attn_varlen_func will return 3 objects instead of 2 + suffix_output, suffix_lse = output + context_output, context_lse = self._compute_prefill_context( \ + q, kv_c_and_k_pe_cache, kv_c_and_k_pe_cache_scale, attn_metadata) + + output = torch.empty_like(suffix_output) + output = merge_attn_states( + output=output, + prefix_output=context_output, + prefix_lse=context_lse, + suffix_output=suffix_output, + suffix_lse=suffix_lse, + ) + + # slice by `:v.shape[-1]` in order to remove v headdim padding + # output = output.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]].reshape(-1, self.num_heads * self.v_head_dim) + output = output.view(-1, self.num_heads * self.v_head_dim) + + return self.o_proj(output)[0] + + @abstractmethod + def _forward_decode( + self, + ql_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + kv_c_and_k_pe_cache_scale: torch.Tensor, + attn_metadata: T, + ) -> torch.Tensor: + raise NotImplementedError + + def forward( + self, + layer: AttentionLayer, + hidden_states_or_q_c: torch.Tensor, # query in unified attn + k_c_normed: torch.Tensor, # key in unified attn + k_pe: torch.Tensor, # value in unified attn + kv_cache: torch.Tensor, + kv_cache_scale:torch.Tensor, + attn_metadata: T, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if output is not None: + raise NotImplementedError( + "output is not yet supported for MLAImplBase") + + if attn_metadata.is_profile_run and \ + attn_metadata.context_chunk_workspace is not None: + # During the profile run try to simulate to worse case output size + # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context` + # since this can be large + _ = torch.empty( + (attn_metadata.context_chunk_workspace.shape[0], + self.num_heads, self.qk_nope_head_dim + self.v_head_dim), + device=k_c_normed.device, + dtype=k_c_normed.dtype, + ) + + has_decode = attn_metadata.decode_metadata is not None + has_prefill = attn_metadata.prefill_metadata is not None + + # Restore head dim (for rotary embedding) + k_pe = k_pe.unsqueeze(1) + assert hasattr(attn_metadata, "input_positions") + + num_prefill_tokens: int = attn_metadata.num_prefill_tokens + + decode_hs_or_q_c = hidden_states_or_q_c[num_prefill_tokens:] + decode_k_pe = k_pe[num_prefill_tokens:] + decode_input_positions = \ + attn_metadata.input_positions[num_prefill_tokens:] + + prefill_hs_or_q_c = hidden_states_or_q_c[:num_prefill_tokens] + prefill_k_pe = k_pe[:num_prefill_tokens] + prefill_input_positions = \ + attn_metadata.input_positions[:num_prefill_tokens] + prefill_k_c_normed = k_c_normed[:num_prefill_tokens] + + if has_decode: + decode_ql_nope, decode_q_pe = self._q_proj_and_k_up_proj(decode_hs_or_q_c) + decode_q_pe, decode_k_pe = self.rotary_emb(decode_input_positions, decode_q_pe, decode_k_pe) + + if has_prefill: + prefill_q = self.q_proj(prefill_hs_or_q_c)[0].view(-1, self.num_heads, self.qk_head_dim) + prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] + prefill_k = torch.empty_like(prefill_q) + ixf_ops.mla_rope(prefill_input_positions, prefill_q_pe, prefill_k_pe.squeeze(1), prefill_k[...,self.qk_nope_head_dim:], self.rotary_emb.cos_sin_cache) + + + # write the latent and rope to kv cache + write_kv_cache = (None, None) + if kv_cache.numel() > 0: + # TODO.. + if has_decode: + if envs.VLLM_USE_INT8_MLA: + write_kv_cache = (None, None) + k_c_normed_int8, k_c_normed_scale,_ = ops.scaled_int8_quant(k_c_normed[num_prefill_tokens:]) + decode_k_pe_int8, decode_k_pe_scale,_ = ops.scaled_int8_quant(decode_k_pe) + ops.concat_and_cache_mla_int8( + kv_c_int8 = k_c_normed_int8, + kv_c_scale = k_c_normed_scale[...,0], + k_pe_int8 = decode_k_pe_int8, + k_pe_scale = decode_k_pe_scale[...,0].view(-1,decode_k_pe_int8.shape[-2]), + kv_cache = kv_cache, + kv_cache_scale = kv_cache_scale, + slot_mapping = attn_metadata.slot_mapping.flatten()[num_prefill_tokens:], + kv_cache_dtype=self.kv_cache_dtype, + scale=layer._k_scale, + ) + else: + # ops.concat_and_cache_mla( + # k_c_normed[num_prefill_tokens:], + # decode_k_pe, + # kv_cache, + # attn_metadata.slot_mapping.flatten()[num_prefill_tokens:], + # kv_cache_dtype=self.kv_cache_dtype, + # scale=layer._k_scale, + # ) + write_kv_cache = (k_c_normed[num_prefill_tokens:], decode_k_pe) + if has_prefill: + if envs.VLLM_USE_INT8_MLA: + prefill_k_c_normed_int8, prefill_k_c_normed_scale,_ = ops.scaled_int8_quant(prefill_k_c_normed) + prefill_k_pe_int8, prefill_k_pe_scale,_ = ops.scaled_int8_quant(prefill_k[...,self.qk_nope_head_dim:].contiguous()) + ops.concat_and_cache_mla_int8( + prefill_k_c_normed_int8, + prefill_k_c_normed_scale[...,0], + prefill_k_pe_int8, + prefill_k_pe_scale[...,0].view(-1,prefill_k_pe_int8.shape[-2]), + kv_cache, + kv_cache_scale, + attn_metadata.slot_mapping.flatten()[:num_prefill_tokens], + kv_cache_dtype=self.kv_cache_dtype, + scale=layer._k_scale, + ) + else: + ops.concat_and_cache_mla( + prefill_k_c_normed, + prefill_k[...,self.qk_nope_head_dim:], + kv_cache, + attn_metadata.slot_mapping.flatten()[:num_prefill_tokens], + kv_cache_dtype=self.kv_cache_dtype, + scale=layer._k_scale, + ) + + output = torch.empty(attn_metadata.num_prefill_tokens + + attn_metadata.num_decode_tokens, + self.o_proj.output_size, + device=hidden_states_or_q_c.device, + dtype=hidden_states_or_q_c.dtype) + if has_prefill: + if 0 == attn_metadata.num_decode_tokens: + return self._forward_prefill( + prefill_q, prefill_k_c_normed, prefill_k, kv_cache, kv_cache_scale, + attn_metadata) + output[:num_prefill_tokens] = self._forward_prefill( + prefill_q, prefill_k_c_normed, prefill_k, kv_cache, kv_cache_scale, + attn_metadata) + + if has_decode: + if attn_metadata.num_prefill_tokens == 0: + return self._forward_decode( + decode_ql_nope, decode_q_pe, kv_cache, kv_cache_scale, attn_metadata, *write_kv_cache) + output[num_prefill_tokens:] = self._forward_decode( + decode_ql_nope, decode_q_pe, kv_cache, kv_cache_scale, attn_metadata, *write_kv_cache) + + return output diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py new file mode 100644 index 0000000..2ee66ab --- /dev/null +++ b/vllm/attention/backends/pallas.py @@ -0,0 +1,338 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +import torch_xla.experimental.custom_kernel # Required to register custom ops. + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, AttentionType, + is_quantized_kv_cache) +from vllm.attention.backends.utils import CommonAttentionState + + +class PallasAttentionBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "PALLAS" + + @staticmethod + def get_impl_cls() -> Type["PallasAttentionBackendImpl"]: + return PallasAttentionBackendImpl + + @staticmethod + def get_metadata_cls() -> Type["PallasMetadata"]: + return PallasMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (num_kv_heads, num_blocks, block_size, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + raise RuntimeError("swap_blocks is not used for the TPU backend.") + + @torch.compile(backend="openxla") + @staticmethod + def copy_blocks( + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + src_to_dists: Tuple[torch.Tensor, torch.Tensor], + ) -> None: + src_indices, dst_indices = src_to_dists + for k_cache, v_cache in kv_caches: + torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True) + k_cache[:, dst_indices] = k_cache[:, src_indices] + torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True) + v_cache[:, dst_indices] = v_cache[:, src_indices] + + +@dataclass +class PallasMetadata(AttentionMetadata): + + # Currently, input sequences can only contain all prefills + # or all decoding. + block_tables: Optional[torch.Tensor] = None + context_lens: Optional[torch.Tensor] = None + effective_query_lens: Optional[torch.Tensor] = None + + @property + def prefill_metadata(self) -> Optional["PallasMetadata"]: + if self.num_prefills == 0: + return None + + assert self.num_decode_tokens == 0 + return self + + @property + def decode_metadata(self) -> Optional["PallasMetadata"]: + if self.num_decode_tokens == 0: + return None + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.block_tables is not None + assert self.context_lens is not None + return self + + +class PallasAttentionBackendImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.logits_soft_cap = logits_soft_cap + if head_size % 128 != 0: + raise NotImplementedError("Head size must be a multiple of 128.") + if alibi_slopes is not None: + raise NotImplementedError("Alibi slopes is not supported.") + if sliding_window is not None: + raise NotImplementedError("Sliding window is not supported.") + if is_quantized_kv_cache(kv_cache_dtype): + raise NotImplementedError("FP8 KV cache dtype is not supported.") + if blocksparse_params is not None: + raise NotImplementedError("Blocksparse is not supported.") + + if torch_xla.tpu.version() < 4: + raise NotImplementedError("TPU version must be 4 or higher.") + + self.megacore_mode = None + tpu_env = torch_xla.tpu.get_tpu_env() + tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None) + or tpu_env.get("TYPE", None) + or tpu_env.get("TPU_ACCELERATOR_TYPE", None)) + assert tpu_type is not None + tpu_type = tpu_type.lower() + + if (("lite" not in tpu_type) and ("v6" not in tpu_type)): + if self.num_kv_heads % 2 == 0: + self.megacore_mode = "kv_head" + else: + # NOTE(woosuk): If the batch size is not a multiple of 2, the + # megacore mode will be None. + self.megacore_mode = "batch" + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "PallasAttentionBackendImpl") + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor], + attn_metadata: PallasMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with Pallas attention. + + Args: + query: shape = [batch_size, seq_len, num_heads * head_size] + key: shape = [batch_size, seq_len, num_kv_heads * head_size] + value: shape = [batch_size, seq_len, num_kv_heads * head_size] + kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size] + kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size] + NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor + with shape [0] for profiling run. + attn_metadata: Metadata for attention. + Returns: + shape = [batch_size, seq_len, num_heads * head_size] + """ + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 + batch_size, seq_len, hidden_size = query.shape + query = query.view(batch_size, seq_len, self.num_heads, self.head_size) + key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) + value = value.view(batch_size, seq_len, self.num_kv_heads, + self.head_size) + + if kv_cache[0].numel() > 0: + slot_mapping = attn_metadata.slot_mapping + key_cache, value_cache = kv_cache + write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) + + query = query * self.scale + if attn_metadata.num_prefills > 0: + if attn_metadata.block_tables is None: + # Prefill without paged KV cache. + assert seq_len % 16 == 0, ( + "Pallas FlashAttention kernel requires seq_len to be a " + f"multiple of 16 but got {seq_len}") + + # Handle GQA/MQA. + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, + dim=-2) + key = key.view(batch_size, seq_len, self.num_heads, + self.head_size) + value = value.repeat_interleave(self.num_queries_per_kv, + dim=-2) + value = value.view(batch_size, seq_len, self.num_heads, + self.head_size) + # FlashAttention kernel requires the input shape to be + # [batch_size, num_heads, seq_len, d_model] + # while the input is [batch_size, seq_len, num_heads, d_model]. + # Permute the input to match the required format. + output = torch.ops.xla.flash_attention( + query.permute(0, 2, 1, 3), + key.permute(0, 2, 1, 3), + value.permute(0, 2, 1, 3), + True, + ) + output = output.permute(0, 2, 1, 3) + else: + # Prefill with paged KV cache. + # TODO(woosuk): Tune the below knobs. + num_kv_pages_per_compute_block = 16 + num_queries_per_compute_block = 16 + assert seq_len % num_queries_per_compute_block == 0 + output = torch.ops.xla.multi_queries_paged_attention( + query, + key_cache, + value_cache, + attn_metadata.context_lens, + attn_metadata.block_tables, + attn_metadata.effective_query_lens, + num_kv_pages_per_compute_block, + num_queries_per_compute_block, + use_kernel=True, + attn_logits_soft_cap=self.logits_soft_cap, + ) + else: + # Decoding run. + assert kv_cache[0].numel() > 0 + query = query.squeeze(dim=1) + pages_per_compute_block = 16 # TODO(woosuk): Tune this value. + + assert attn_metadata.block_tables is not None + assert attn_metadata.context_lens is not None + # NOTE(woosuk): The PagedAttention Pallas kernel stores the entire + # block table in SMEM. Therefore, if the block table is too large, + # the kernel compilation will fail. To avoid this, we split the + # batch dimension into smaller chunks and run the kernel multiple + # times. + MAX_SMEM_USAGE = 512 * 1024 + size_per_seq = 4 * attn_metadata.block_tables.shape[1] + max_num_seq = MAX_SMEM_USAGE // size_per_seq + + if batch_size <= max_num_seq: + output = paged_attention( + query, + key_cache, + value_cache, + attn_metadata.context_lens, + attn_metadata.block_tables, + pages_per_compute_block, + self.megacore_mode, + attn_logits_soft_cap=self.logits_soft_cap, + ) + else: + chunk_size = max_num_seq + # Make sure the chunk size is a multiple of 2. + chunk_size = chunk_size // 2 * 2 + num_chunks = (batch_size + chunk_size - 1) // chunk_size + + output = torch.empty_like(query) + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * chunk_size + chunk_end = chunk_start + chunk_size + # NOTE(woosuk): We skip this line because it causes Dynamo + # compilation error. Instead, we rely on the slice operation + # to handle the out-of-bound case. + # chunk_end = min(chunk_end, batch_size) + chunk_output = paged_attention( + query[chunk_start:chunk_end], + key_cache, + value_cache, + attn_metadata.context_lens[chunk_start:chunk_end], + attn_metadata.block_tables[chunk_start:chunk_end], + pages_per_compute_block, + self.megacore_mode, + attn_logits_soft_cap=self.logits_soft_cap, + ) + output[chunk_start:chunk_end] = chunk_output + + # Reshape the output tensor. + return output.reshape(batch_size, seq_len, hidden_size) + + +def write_to_kv_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, +) -> None: + torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True) + torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True) + + key = key.flatten(0, 2) + value = value.flatten(0, 2) + key_cache = key_cache.flatten(0, 2) + value_cache = value_cache.flatten(0, 2) + key_cache.index_copy_(0, slot_mapping, key) + value_cache.index_copy_(0, slot_mapping, value) + + +def paged_attention( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + pages_per_compute_block: int, + megacore_mode: Optional[str], + *, + attn_logits_soft_cap: Optional[float], +) -> torch.Tensor: + batch_size = query.shape[0] + if megacore_mode == "batch" and batch_size % 2 != 0: + megacore_mode = None + else: + megacore_mode = megacore_mode + + return torch.ops.xla.paged_attention( + query, + key_cache, + value_cache, + context_lens, + block_tables, + pages_per_compute_block, + megacore_mode=megacore_mode, + attn_logits_soft_cap=attn_logits_soft_cap, + ) diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py new file mode 100644 index 0000000..f1def25 --- /dev/null +++ b/vllm/attention/backends/placeholder_attn.py @@ -0,0 +1,399 @@ +# SPDX-License-Identifier: Apache-2.0 + +from collections import defaultdict +from dataclasses import dataclass +from itertools import accumulate +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type + +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, + AttentionMetadataBuilder) +from vllm.attention.backends.utils import CommonAttentionState +from vllm.multimodal import MultiModalPlaceholderMap + +if TYPE_CHECKING: + from vllm.worker.model_runner import (ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) +from vllm.utils import async_tensor_h2d + +# Placeholder attention backend for models like Mamba and pooling models that +# lack attention. + + +class PlaceholderAttentionBackend(AttentionBackend): + """Placeholder backend for when no attention is needed.""" + + @staticmethod + def get_name() -> str: + return "NO_ATTENTION" + + @staticmethod + def get_impl_cls() -> Type["PlaceholderAttentionImpl"]: + return PlaceholderAttentionImpl + + @staticmethod + def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]: + return PlaceholderAttentionMetadataBuilder + + @staticmethod + def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]: + return PlaceholderAttentionMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (1, 1, 1, 1, 1) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + return + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + return + + +@dataclass +class PlaceholderAttentionMetadata(AttentionMetadata): + """Attention metadata for prefill and decode batched together.""" + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + # Maximum query length in the batch. + max_query_len: Optional[int] + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + + # Placeholder. + block_tables: Optional[torch.Tensor] = None + + _cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None + _cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None + + @property + def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + # Compute some attn_metadata fields which default to None + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + seq_start_loc = (None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) + + # Placeholders + slot_mapping = torch.empty(0) + block_tables = torch.empty(0) + + self._cached_prefill_metadata = PlaceholderAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_decode_query_len=0, + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.seq_lens_tensor is not None + + # Placeholders + slot_mapping = torch.empty(0) + block_tables = torch.empty(0) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) + + self._cached_decode_metadata = PlaceholderAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + seq_lens=None, + seq_lens_tensor=seq_lens_tensor, + max_decode_query_len=self.max_decode_query_len, + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=(self.query_start_loc[self.num_prefills:] - + self.query_start_loc[self.num_prefills]) + if self.query_start_loc is not None else None, + seq_start_loc=self.seq_start_loc[self.num_prefills:] + if self.seq_start_loc is not None else None, + context_lens_tensor=None, + block_tables=block_tables, + use_cuda_graph=self.use_cuda_graph, + ) + return self._cached_decode_metadata + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + assert self.use_cuda_graph + + assert not turn_prefills_into_decodes, \ + ("Multi-Step + Chunked-Prefill is not supported for attention-free" + "models. turn_prefills_into_decodes is a " + "Multi-Step + Chunked-Prefill specific parameter.") + + assert self.seq_lens is not None + assert self.max_decode_seq_len == max(self.seq_lens) + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.num_decode_tokens == num_seqs + + assert self.seq_lens is not None + assert len(self.seq_lens) == num_seqs + assert self.seq_lens_tensor is not None + assert self.seq_lens_tensor.shape == (num_seqs, ) + assert self.max_query_len == 1 + assert self.max_prefill_seq_len == 0 + + assert self.query_start_loc is not None + assert self.query_start_loc.shape == (num_queries + 1, ) + assert self.seq_start_loc is not None + assert self.seq_start_loc.shape == (num_seqs + 1, ) + + assert self.context_lens_tensor is not None + assert self.context_lens_tensor.shape == (num_queries, ) + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + self.seq_lens[i] += 1 + self.max_decode_seq_len = max(self.seq_lens) + + # Update sequences, masking off entries greater than num_queries + device = self.seq_lens_tensor.device + mask = torch.arange(self.seq_lens_tensor.size(0), + device=device) < num_queries + self.seq_lens_tensor += mask.to(self.seq_lens_tensor.dtype) + if sampled_token_ids is not None: + model_input.input_tokens.masked_scatter_( + mask, sampled_token_ids[:num_queries]) + + +class PlaceholderAttentionMetadataBuilder( + AttentionMetadataBuilder[PlaceholderAttentionMetadata]): + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + + self.input_builder = input_builder + self.runner = input_builder.runner + + def prepare(self): + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.curr_seq_lens: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + """ + is_prompt = inter_data.is_prompt + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): + self.context_lens.append(context_len) + + if is_prompt: + mm_maps = inter_data.multi_modal_placeholder_maps + if mm_maps: + for modality, placeholders in mm_maps.items(): + self.multimodal_placeholder_maps[modality].extend( + placeholders) + + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + + # Some input builders such as ModelInputForCPUBuilder do not have the + # "inter_data_list" attribute. + # Let's check inter_data_list exists before we reference it. + if hasattr(self.input_builder, "inter_data_list"): + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_query_len = max(query_lens) + decode_query_lens = query_lens[self.num_prefills:] + if len(decode_query_lens) > 0: + max_decode_query_len = max(decode_query_lens) + else: + max_decode_query_len = 1 + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + query_start_loc = list(accumulate(query_lens, initial=0)) + seq_start_loc = list(accumulate(seq_lens, initial=0)) + + if use_captured_graph: + num_decode_tokens = batch_size - self.num_prefill_tokens + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + assert device is not None + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, + device, + self.runner.pin_memory) + seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, + device, self.runner.pin_memory) + + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + self.multimodal_placeholder_maps.items() + } + + # Placeholders + slot_mapping_tensor = torch.empty(0) + block_tables = torch.empty(0) + + return PlaceholderAttentionMetadata( + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=True, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_decode_query_len=max_decode_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc_tensor, + seq_start_loc=seq_start_loc_tensor, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + ) + + +class PlaceholderAttentionImpl(AttentionImpl): + + def __init__(self, *args, **kwargs) -> None: + return + + def forward(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py new file mode 100644 index 0000000..9a4ee2a --- /dev/null +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -0,0 +1,900 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Attention layer ROCm GPUs.""" +import itertools +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type + +import torch + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import (CommonAttentionState, + CommonMetadataBuilder) +from vllm.attention.ops.paged_attn import (PagedAttention, + PagedAttentionMetadata) +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.platforms.rocm import use_rocm_custom_paged_attention + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + +logger = init_logger(__name__) +_PARTITION_SIZE_ROCM = 256 + + +class ROCmFlashAttentionBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "ROCM_FLASH" + + @staticmethod + def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]: + return ROCmFlashAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return ROCmFlashAttentionMetadata + + @staticmethod + def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]: + return ROCmFlashAttentionMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): + """Metadata for FlashAttentionBackend. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ----------------------| + # |-- query_len ---| + + # Maximum query length in the batch. None for decoding. + max_query_len: Optional[int] = None + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] = None + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] = None + + _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None + _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None + + # Begin encoder attn & enc/dec cross-attn fields... + + # Encoder sequence lengths representation + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + + @property + def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.block_tables is not None + + self._cached_prefill_metadata = ROCmFlashAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cached_decode_metadata = ROCmFlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + # Batch may be composed of prefill|decodes, adjust query start indices + # to refer to the start of decodes when the two are split apart. + # E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. + if self._cached_decode_metadata.query_start_loc is not None: + qs = self._cached_decode_metadata.query_start_loc + self._cached_decode_metadata.query_start_loc = qs - qs[0] + return self._cached_decode_metadata + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + + assert not turn_prefills_into_decodes, \ + ("Chunked prefill is not supported with rocm_flash_attn yet." + "turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill " + "specific parameter.") + + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + assert self.use_cuda_graph + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.num_decode_tokens == num_seqs + assert self.slot_mapping.shape == (num_seqs, ) + + assert self.seq_lens is not None + assert len(self.seq_lens) == num_seqs + assert self.seq_lens_tensor is not None + assert self.seq_lens_tensor.shape == (num_seqs, ) + assert self.max_query_len == 1 + assert self.max_prefill_seq_len == 0 + assert self.max_decode_seq_len == max(self.seq_lens) + + assert self.query_start_loc is not None + assert self.query_start_loc.shape == (num_queries + 1, ) + assert self.seq_start_loc is not None + assert self.seq_start_loc.shape == (num_seqs + 1, ) + + assert self.context_lens_tensor is not None + assert self.context_lens_tensor.shape == (num_queries, ) + + assert self.block_tables is not None + assert self.block_tables.shape[0] == num_seqs + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + self.seq_lens[i] += 1 + self.max_decode_seq_len = max(self.seq_lens) + + ops.advance_step_flashattn(num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables) + + +class ROCmFlashAttentionMetadataBuilder( + CommonMetadataBuilder[ROCmFlashAttentionMetadata]): + + _metadata_cls = ROCmFlashAttentionMetadata + + +def _make_alibi_bias(alibi_slopes: torch.Tensor, + dtype: torch.dtype, + seq_lens: Optional[List[int]], + make_attn_mask: bool = True) -> List[torch.Tensor]: + attn_biases = [] + if seq_lens: + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + + num_heads = alibi_slopes.shape[0] + bias = bias[None, :].repeat( + (num_heads, 1, 1)).to(alibi_slopes.device) + bias.mul_(alibi_slopes[:, None, None]) + if make_attn_mask: + inf_mask = torch.empty( + (1, seq_len, seq_len), + dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to( + alibi_slopes.device) + attn_biases.append((bias + inf_mask).to(dtype)) + else: + attn_biases.append(bias.to(dtype)) + + return attn_biases + + +def _get_seq_len_block_table_args( + attn_metadata: ROCmFlashAttentionMetadata, + attn_type: str, +) -> tuple: + ''' + The particular choice of sequence-length + attributes which should be extracted from attn_metadata is dependent + on the type of attention operation. + + Decoder attn -> select entirely decoder self-attention-related fields + Encoder/decoder cross-attn -> select encoder sequence lengths + Encoder attn -> select encoder sequence lengths fields + Encoder-only attn -> select prefill sequence lengths with + bidirectional attention + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention op + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention, encoder-only + + Returns: + + * Appropriate sequence-lengths tensors for query and key + * Appropriate max sequence-length scalar + * Causal masking flag + ''' + + if attn_type == AttentionType.ENCODER: + assert attn_metadata.encoder_seq_lens is not None + assert attn_metadata.encoder_seq_lens_tensor is not None + query_seq_start_loc = torch.tensor( + list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)), + device=attn_metadata.encoder_seq_lens_tensor.device, + dtype=attn_metadata.encoder_seq_lens_tensor.dtype) + causal_mask = False + + # No block tables associated with encoder attention + return (query_seq_start_loc, attn_metadata.max_encoder_seq_len, + query_seq_start_loc, attn_metadata.max_encoder_seq_len, + attn_metadata.encoder_seq_lens, causal_mask) + + elif attn_type == AttentionType.ENCODER_ONLY: + # For encoder-only models, we use the prefill sequence lengths + assert attn_metadata.seq_lens is not None + assert attn_metadata.seq_lens_tensor is not None + query_seq_start_loc = torch.tensor( + list(itertools.accumulate([0] + attn_metadata.seq_lens)), + device=attn_metadata.seq_lens_tensor.device, + dtype=attn_metadata.seq_lens_tensor.dtype) + max_seq_len = attn_metadata.max_prefill_seq_len + # Encoder-only models typically use bidirectional attention + causal_mask = False + + return (query_seq_start_loc, max_seq_len, query_seq_start_loc, + max_seq_len, attn_metadata.seq_lens, causal_mask) + + elif attn_type == AttentionType.DECODER: + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + assert attn_metadata.seq_lens is not None + assert attn_metadata.seq_lens_tensor is not None + query_seq_start_loc = torch.tensor( + list(itertools.accumulate([0] + attn_metadata.seq_lens)), + device=attn_metadata.seq_lens_tensor.device, + dtype=attn_metadata.seq_lens_tensor.dtype) + max_seq_len = attn_metadata.max_prefill_seq_len + causal_mask = True + + return (query_seq_start_loc, max_seq_len, query_seq_start_loc, + max_seq_len, attn_metadata.seq_lens, causal_mask) + elif attn_type == AttentionType.ENCODER_DECODER: + assert attn_metadata.seq_lens is not None + assert attn_metadata.encoder_seq_lens_tensor is not None + query_start_loc = torch.tensor( + list(itertools.accumulate([0] + attn_metadata.seq_lens)), + device=attn_metadata.encoder_seq_lens_tensor.device, + dtype=attn_metadata.encoder_seq_lens_tensor.dtype) + + assert attn_metadata.encoder_seq_lens is not None + assert attn_metadata.seq_lens_tensor is not None + key_seq_start_loc = torch.tensor( + list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)), + device=attn_metadata.seq_lens_tensor.device, + dtype=attn_metadata.seq_lens_tensor.dtype) + causal_mask = False + + # Enc/dec cross-attention KVs match encoder sequence length; + # cross-attention utilizes special "cross" block tables + return (query_start_loc, attn_metadata.max_prefill_seq_len, + key_seq_start_loc, attn_metadata.max_encoder_seq_len, + attn_metadata.seq_lens, causal_mask) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +class ROCmFlashAttentionImpl(AttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prompt_tokens -------------->| + |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| + + Otherwise, the layout is as follows: + |<------------------ num_generation_tokens (M) ----------------->| + |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->| + |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "ROCmFlashAttention does not support blocksparse attention.") + + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + self.logits_soft_cap = 0.0 + else: + self.logits_soft_cap = logits_soft_cap + self.attn_type = attn_type + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window, sliding_window) + if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + supported_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {supported_head_sizes}.") + + self.use_naive_attn = False + # NOTE: Allow for switching between Triton and CK. Defaulting to triton. + self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN + if self.use_triton_flash_attn: + if logits_soft_cap is not None: + raise ValueError( + "ROCm Triton FlashAttention does not support attention" + " logits soft capping." + " please try using the ROCm CK " + "FA backend instead by setting the env var " + "`VLLM_USE_TRITON_FLASH_ATTN=0`") + + from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 + triton_attention) + self.attn_func = triton_attention + logger.debug("Using Triton FA in ROCmBackend") + if self.sliding_window != (-1, -1): + logger.warning("ROCm Triton FA does not currently support " + "sliding window attention. If using half " + "precision, please try using the ROCm CK " + "FA backend instead by setting the env var " + "`VLLM_USE_TRITON_FLASH_ATTN=0`") + else: + # if not using triton, navi3x/navi21/navi10 do not use flash-attn + # either + if not current_platform.has_device_capability(90): + self.use_naive_attn = True + else: + try: + from flash_attn import flash_attn_varlen_func # noqa: F401 + self.attn_func = flash_attn_varlen_func + logger.debug("Using CK FA in ROCmBackend") + except ModuleNotFoundError: + self.use_naive_attn = True + + if self.use_naive_attn: + if logits_soft_cap is not None: + raise ValueError( + "ROCm Naive FlashAttention does not support " + "attention logits soft capping.") + + self.attn_func = _sdpa_attention + logger.debug("Using naive (SDPA) attention in ROCmBackend") + + def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" + tokens, n_kv_heads, head_dim = x.shape + return (x[:, :, + None, :].expand(tokens, n_kv_heads, n_rep, + head_dim).reshape(tokens, n_kv_heads * n_rep, + head_dim)) + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: ROCmFlashAttentionMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with FlashAttention and PagedAttention. + + For decoder-only models: query, key and value must be non-None. + + For encoder/decoder models: + * ROCmFlashAttentionImpl.forward() may be invoked for both self- and + cross-attention layers. + * For self-attention: query, key and value must be non-None. + * For cross-attention: + * Query must be non-None + * During prefill, key and value must be non-None; key and value + get cached for use during decode. + * During decode, key and value may be None, since: + (1) key and value tensors were cached during prefill, and + (2) cross-attention key and value tensors do not grow during + decode + + A note on how the attn_type (attention type enum) argument impacts + attention forward() behavior: + + * DECODER: normal decoder-only behavior; + use decoder self-attention block table + * ENCODER: no KV caching; pass encoder sequence + attributes (encoder_seq_lens/encoder_seq_lens_tensor/ + max_encoder_seq_len) to kernel, in lieu of decoder + sequence attributes (seq_lens/seq_lens_tensor/max_seq_len) + * ENCODER_DECODER: cross-attention behavior; + use cross-attention block table for caching KVs derived + from encoder hidden states; since KV sequence lengths + will match encoder sequence lengths, pass encoder sequence + attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/ + max_encoder_seq_len) + * ENCODER_ONLY: bidirectional attention with no KV caching; + use prefill sequence attributes + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + attn_type: Select attention type, between encoder attention, + decoder self-attention, or encoder/decoder cross- + attention. Defaults to decoder self-attention, + which is the vLLM default generally + Returns: + shape = [num_tokens, num_heads * head_size] + """ + query = query.view(-1, self.num_heads, self.head_size) + if key is not None: + assert value is not None + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + else: + assert value is None + + # Only update KV cache for decoder self-attention + # and encoder-decoder cross-attention + if self.attn_type not in [ + AttentionType.ENCODER, AttentionType.ENCODER_ONLY + ] and kv_cache.numel() > 0: + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + if key is not None and value is not None: + # Reshape the input keys and values and store them in the + # cache. If kv_cache is not provided, the new key and value + # tensors are not cached. This happens during the initial + # memory profiling run. + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping + if self.attn_type != AttentionType.ENCODER_DECODER else + attn_metadata.cross_slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + if self.attn_type != AttentionType.ENCODER: + num_prefill_tokens = attn_metadata.num_prefill_tokens + elif self.attn_type == AttentionType.ENCODER_ONLY: + # For encoder-only models, all tokens are processed in one go + num_prefill_tokens = query.shape[0] + else: + assert attn_metadata.num_encoder_tokens is not None + num_prefill_tokens = attn_metadata.num_encoder_tokens + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + + # For encoder-only and encoder models, + # we process all tokens at once + # For decoder and encoder-decoder, + # we may need to limit key/value to prefill tokens + if key is not None and value is not None \ + and self.attn_type not in [AttentionType.ENCODER_DECODER, + AttentionType.ENCODER_ONLY]: + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. + # normal attention and DECODER + if self.attn_type == AttentionType.DECODER and ( + kv_cache.numel() == 0 or prefill_meta.block_tables is None + or prefill_meta.block_tables.numel() == 0): + (query_seq_start_loc, query_max_seq_len, key_seq_start_loc, + key_max_seq_len, seq_lens, + causal_mask) = (prefill_meta.seq_start_loc, + prefill_meta.max_prefill_seq_len, + prefill_meta.seq_start_loc, + prefill_meta.max_prefill_seq_len, + attn_metadata.seq_lens, True) + # prefix-enabled attention and ENCODER/ENCODER_DECODER + else: + (query_seq_start_loc, query_max_seq_len, key_seq_start_loc, + key_max_seq_len, seq_lens, + causal_mask) = _get_seq_len_block_table_args( + prefill_meta, self.attn_type) + # Prompt run. + if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0: + # triton attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + attn_masks = None + if self.use_triton_flash_attn: + if self.alibi_slopes is not None: + attn_masks = _make_alibi_bias( + self.alibi_slopes, + query.dtype, + seq_lens, + make_attn_mask=causal_mask) # type: ignore + out, _ = self.attn_func( + query, + key, + value, + None, + query_seq_start_loc, + key_seq_start_loc, + query_max_seq_len, + key_max_seq_len, + causal_mask, + self.scale, + attn_masks[0][None] + if attn_masks is not None else None, + ) + elif self.use_naive_attn: + if self.num_kv_heads != self.num_heads: + # Interleave for MQA workaround. + key = self.repeat_kv(key, self.num_queries_per_kv) + value = self.repeat_kv(value, self.num_queries_per_kv) + if self.alibi_slopes is not None: + attn_masks = _make_alibi_bias( + self.alibi_slopes, + query.dtype, + attn_metadata.seq_lens, + make_attn_mask=causal_mask) # type: ignore + query = query.movedim(0, query.dim() - 2) + key = key.movedim(0, key.dim() - 2) + value = value.movedim(0, value.dim() - 2) + # sdpa math backend attention + out = self.attn_func( + query, + key, + value, + query_seq_start_loc, + num_prefill_tokens, + self.num_heads, + self.head_size, + self.scale, + attn_masks, + ) + else: + out = self.attn_func( + q=query, + k=key, + v=value, + cu_seqlens_q=query_seq_start_loc, + cu_seqlens_k=key_seq_start_loc, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=key_max_seq_len, + softmax_scale=self.scale, + causal=causal_mask, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + softcap=self.logits_soft_cap, + ) + + # common code for prefill + assert output[:num_prefill_tokens].shape == out.shape + if output.shape[0] > num_prefill_tokens: + output[:num_prefill_tokens] = out + else: + output = out + else: + # prefix-enabled attention - + # not applicable for encoder-only models + if self.attn_type != AttentionType.ENCODER_ONLY: + output[: + num_prefill_tokens] = PagedAttention.forward_prefix( + query, + key, + value, + self.kv_cache_dtype, + key_cache, + value_cache, + prefill_meta.block_tables, + prefill_meta.query_start_loc, + prefill_meta.seq_lens_tensor, + prefill_meta.max_query_len, + self.alibi_slopes, + self.sliding_window[0], + layer._k_scale, + layer._v_scale, + ) + # Skip decode phase for encoder-only models + if (decode_meta := attn_metadata.decode_metadata) and ( + self.attn_type != AttentionType.ENCODER_ONLY): + # Decoding run. + # Whether to use rocm custom paged attention or not + num_seqs, num_heads, head_size = decode_query.shape + block_size = value_cache.shape[3] + gqa_ratio = num_heads // self.num_kv_heads + use_custom = use_rocm_custom_paged_attention( + decode_query.dtype, head_size, block_size, gqa_ratio, + decode_meta.max_decode_seq_len, self.sliding_window) + if use_custom: + max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type + != AttentionType.ENCODER_DECODER else + decode_meta.max_encoder_seq_len) + assert max_seq_len is not None + max_num_partitions = ( + (max_seq_len + _PARTITION_SIZE_ROCM - 1) // + _PARTITION_SIZE_ROCM) + assert _PARTITION_SIZE_ROCM % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + if num_prefill_tokens > 0: + out = output[num_prefill_tokens:] + else: + out = output + + query_start_loc = None + ops.paged_attention_rocm( + out, + exp_sums, + max_logits, + tmp_output, + decode_query, + key_cache, + value_cache, + self.num_kv_heads, + self.scale, + decode_meta.block_tables + if self.attn_type != AttentionType.ENCODER_DECODER else + decode_meta.cross_block_tables, + decode_meta.seq_lens_tensor + if self.attn_type != AttentionType.ENCODER_DECODER else + decode_meta.encoder_seq_lens_tensor, + query_start_loc, + block_size, + max_seq_len, + self.alibi_slopes, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + else: + output[num_prefill_tokens:] = PagedAttention.forward_decode( + decode_query, + key_cache, + value_cache, + decode_meta.block_tables + if self.attn_type != AttentionType.ENCODER_DECODER else + decode_meta.cross_block_tables, + decode_meta.seq_lens_tensor + if self.attn_type != AttentionType.ENCODER_DECODER else + decode_meta.encoder_seq_lens_tensor, + decode_meta.max_decode_seq_len + if self.attn_type != AttentionType.ENCODER_DECODER else + decode_meta.max_encoder_seq_len, + self.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + layer._k_scale, + layer._v_scale, + ) + + # Reshape the output tensor. + return output.view(-1, self.num_heads * self.head_size) + + +def _sdpa_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + seq_lens: List[int], + num_tokens: int, + num_heads: int, + head_size: int, + scale: float, + attn_masks: Optional[List[torch.Tensor]] = None, +) -> torch.Tensor: + start = 0 + output = torch.empty((num_tokens, num_heads, head_size), + dtype=query.dtype, + device=query.device) + + for i, seq_len in enumerate(seq_lens): + end = start + seq_len + with torch.nn.attention.sdpa_kernel( + torch.nn.attention.SDPBackend.MATH): + sub_out = torch.nn.functional.scaled_dot_product_attention( + query[:, start:end, :], + key[:, start:end, :], + value[:, start:end, :], + dropout_p=0.0, + is_causal=attn_masks is None, + attn_mask=attn_masks[i] if attn_masks else None, + scale=scale).movedim(query.dim() - 2, 0) + output[start:end, :, :] = sub_out + start = end + + return output diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py new file mode 100644 index 0000000..afe2acf --- /dev/null +++ b/vllm/attention/backends/torch_sdpa.py @@ -0,0 +1,686 @@ +# SPDX-License-Identifier: Apache-2.0 +""" Attention layer with torch scaled_dot_product_attention + and PagedAttention.""" +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +from torch.nn.functional import scaled_dot_product_attention + +# yapf conflicts with isort for this block +# yapf: disable +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionType, + is_quantized_kv_cache) +# yapf: enable +from vllm.attention.backends.utils import CommonAttentionState +from vllm.attention.ops.ipex_attn import PagedAttention, _use_ipex +from vllm.attention.ops.paged_attn import PagedAttentionMetadata +from vllm.logger import init_logger +from vllm.utils import make_tensor_with_pad +from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder + +logger = init_logger(__name__) + + +class TorchSDPABackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "TORCH_SDPA" + + @staticmethod + def get_impl_cls() -> Type["TorchSDPABackendImpl"]: + return TorchSDPABackendImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return TorchSDPAMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_builder_cls() -> Type["TorchSDPAMetadataBuilder"]: + return TorchSDPAMetadataBuilder + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): + """Metadata for TorchSDPABackend. + """ + # Currently, input sequences can only contain all prompts + # or all decoding. True if all sequences are prompts. + chunked_prefill: bool + seq_lens: Optional[List[int]] = None # For non-chunked prefill + + # For chunked prefill only + max_query_len: Optional[int] = None + max_kv_len: Optional[int] = None + query_start_loc: Optional[torch.Tensor] = None + kv_start_loc: Optional[torch.Tensor] = None + prefill_block_tables: Optional[torch.Tensor] = None + + # Begin encoder attn & enc/dec cross-attn fields... + # Encoder sequence lengths representation + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + + def __post_init__(self): + # Set during the execution of the first attention op. + # It is a list because it is needed to set per prompt + # when alibi slopes is used. It is because of the limitation + # from xformer API. + # will not appear in the __repr__ and __init__ + self.attn_bias: Optional[List[torch.Tensor]] = None + self.encoder_attn_bias: Optional[List[torch.Tensor]] = None + self.cross_attn_bias: Optional[List[torch.Tensor]] = None + + @property + def is_all_encoder_attn_metadata_set(self): + ''' + All attention metadata required for encoder attention is set. + ''' + return ((self.encoder_seq_lens is not None) + and (self.encoder_seq_lens_tensor is not None) + and (self.max_encoder_seq_len is not None)) + + @property + def is_all_cross_attn_metadata_set(self): + ''' + All attention metadata required for enc/dec cross-attention is set. + + Superset of encoder attention required metadata. + ''' + return (self.is_all_encoder_attn_metadata_set + and (self.cross_slot_mapping is not None) + and (self.cross_block_tables is not None)) + + @property + def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]: + if self.num_prefill_tokens == 0: + return None + return self + + @property + def decode_metadata(self) -> Optional["TorchSDPAMetadata"]: + if self.num_decode_tokens == 0: + return None + return self + + def get_seq_lens( + self, + attn_type: str, + ): + ''' + Extract appropriate sequence lengths from attention metadata + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + * Appropriate sequence lengths tensor for query + * Appropriate sequence lengths tensor for key & value + ''' + + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): + seq_lens_q = self.seq_lens + seq_lens_kv = self.seq_lens + elif attn_type == AttentionType.ENCODER: + seq_lens_q = self.encoder_seq_lens + seq_lens_kv = self.encoder_seq_lens + elif attn_type == AttentionType.ENCODER_DECODER: + seq_lens_q = self.seq_lens + seq_lens_kv = self.encoder_seq_lens + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + return seq_lens_q, seq_lens_kv + + def get_attn_bias( + self, + attn_type: str, + ) -> Optional[List[torch.Tensor]]: + ''' + Extract appropriate attention bias from attention metadata + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + * Appropriate attention bias value given the attention type + ''' + + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): + return self.attn_bias + elif attn_type == AttentionType.ENCODER: + return self.encoder_attn_bias + elif attn_type == AttentionType.ENCODER_DECODER: + return self.cross_attn_bias + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + def set_attn_bias( + self, + attn_bias: List[torch.Tensor], + attn_type: str, + ) -> None: + ''' + Update appropriate attention bias field of attention metadata, + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_bias: The desired attention bias value + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + ''' + + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): + self.attn_bias = attn_bias + elif attn_type == AttentionType.ENCODER: + self.encoder_attn_bias = attn_bias + elif attn_type == AttentionType.ENCODER_DECODER: + self.cross_attn_bias = attn_bias + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + def get_seq_len_block_table_args( + self, + attn_type: str, + ) -> tuple: + ''' + The particular choice of sequence-length- and block-table-related + attributes which should be extracted from attn_metadata is dependent + on the type of attention operation. + + Decoder attn -> select entirely decoder self-attention-related fields + Encoder/decoder cross-attn -> select encoder sequence lengths & + cross-attn block-tables fields + Encoder attn -> select encoder sequence lengths fields & no block tables + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * is_prompt: True if prefill, False otherwise + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + + * Appropriate sequence-lengths tensor + * Appropriate max sequence-length scalar + * Appropriate block tables (or None) + ''' + + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + return (self.seq_lens_tensor, self.max_decode_seq_len, + self.block_tables) + elif attn_type == AttentionType.ENCODER_DECODER: + # Enc/dec cross-attention KVs match encoder sequence length; + # cross-attention utilizes special "cross" block tables + return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, + self.cross_block_tables) + elif attn_type == AttentionType.ENCODER: + # No block tables associated with encoder attention + return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, + None) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]): + + def __init__(self, input_builder: ModelInputForCPUBuilder) -> None: + self.chunked_prefill = input_builder.chunked_prefill + self.input_builder = input_builder + + def prepare(self): + self.input_data = self.input_builder.input_data + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata: + input_data = self.input_data + prefill_seq_lens = seq_lens[0:input_data.num_prefills] + prefill_query_lens = query_lens[0:input_data.num_prefills] + slot_mapping = torch.tensor(input_data.slot_mapping, + dtype=torch.long, + device="cpu") + + # For chunked-prefill + if self.chunked_prefill and input_data.num_prefill_tokens != 0: + prefill_block_tables = make_tensor_with_pad( + self.input_data.prefill_block_tables, + pad=0, + dtype=torch.int32, + device="cpu", + ) + query_lens_tensor = torch.tensor(prefill_query_lens, + dtype=torch.int32, + device="cpu") + kv_lens_tensor = torch.tensor(prefill_seq_lens, + dtype=torch.int32, + device="cpu") + query_start_loc = torch.zeros(input_data.num_prefills + 1, + dtype=torch.int32, + device="cpu") + kv_start_loc = torch.zeros(input_data.num_prefills + 1, + dtype=torch.int32, + device="cpu") + torch.cumsum(query_lens_tensor, + dim=0, + dtype=torch.int32, + out=query_start_loc[1:]) + torch.cumsum(kv_lens_tensor, + dim=0, + dtype=torch.int32, + out=kv_start_loc[1:]) + max_query_len = max(prefill_query_lens) + max_kv_len = max(prefill_seq_lens) + else: + prefill_block_tables = None + query_start_loc = None + kv_start_loc = None + max_query_len = None + max_kv_len = None + + # For paged attention + if input_data.num_decode_tokens != 0: + seq_lens_tensor = torch.tensor( + input_data.seq_lens[input_data.num_prefills:], + dtype=torch.int32, + device="cpu", + ) + block_tables = make_tensor_with_pad( + self.input_data.decode_block_tables, + pad=0, + dtype=torch.int32, + device="cpu", + ) + else: + block_tables = torch.tensor([]) + seq_lens_tensor = torch.tensor( + input_data.seq_lens[:input_data.num_prefills], + dtype=torch.int32, + device="cpu", + ) + + # For multi-modal models + placeholder_index_maps = None + if len(input_data.multi_modal_inputs_list) != 0: + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + input_data.multi_modal_placeholder_maps.items() + } + + attn_metadata = TorchSDPAMetadata( + chunked_prefill=self.chunked_prefill, + seq_lens=prefill_seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_kv_len=max_kv_len, + query_start_loc=query_start_loc, + kv_start_loc=kv_start_loc, + max_decode_seq_len=input_data.max_decode_seq_len, + num_prefills=input_data.num_prefills, + num_prefill_tokens=input_data.num_prefill_tokens, + num_decode_tokens=input_data.num_decode_tokens, + block_tables=block_tables, + prefill_block_tables=prefill_block_tables, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, + ) + + return attn_metadata + + +class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "Torch SPDA does not support block-sparse attention.") + if logits_soft_cap is not None: + logger.warning_once("Torch SPDA does not support logits soft cap. " + "Outputs may be slightly off.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.need_mask = (self.alibi_slopes is not None + or self.sliding_window is not None) + + supported_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {supported_head_sizes}.") + + if is_quantized_kv_cache(kv_cache_dtype) and not _use_ipex: + raise NotImplementedError( + "Torch SDPA backend FP8 KV cache requires " + "intel_extension_for_pytorch support.") + self.attn_type = attn_type + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: TorchSDPAMetadata, # type: ignore + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with torch SDPA and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + attn_type = self.attn_type + if (attn_type == AttentionType.ENCODER + and (not attn_metadata.is_all_encoder_attn_metadata_set)): + raise AttributeError("Encoder attention requires setting " + "encoder metadata attributes.") + elif (attn_type == AttentionType.ENCODER_DECODER + and (not attn_metadata.is_all_cross_attn_metadata_set)): + raise AttributeError("Encoder/decoder cross-attention " + "requires setting cross-attention " + "metadata attributes.") + + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + if key is not None: + assert value is not None + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + else: + assert value is None + + if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0): + # KV-cache during decoder-self- or + # encoder-decoder-cross-attention, but not + # during encoder attention. + # + # Even if there are no new key/value pairs to cache, + # we still need to break out key_cache and value_cache + # i.e. for later use by paged attention + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + if (key is not None) and (value is not None): + if attn_type == AttentionType.ENCODER_DECODER: + # Update cross-attention KV cache (prefill-only) + # During cross-attention decode, key & value will be None, + # preventing this IF-statement branch from running + updated_slot_mapping = attn_metadata.cross_slot_mapping + else: + # Update self-attention KV cache (prefill/decode) + updated_slot_mapping = attn_metadata.slot_mapping + + PagedAttention.write_to_paged_cache( + key, value, key_cache, value_cache, updated_slot_mapping, + self.kv_cache_dtype, layer._k_scale, layer._v_scale) + + if attn_type != AttentionType.ENCODER: + # Decoder self-attention supports chunked prefill. + # Encoder/decoder cross-attention requires no chunked + # prefill (100% prefill or 100% decode tokens, no mix) + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + else: + # Encoder attention - chunked prefill is not applicable; + # derive token-count from query shape & and treat them + # as 100% prefill tokens + assert attn_metadata.num_encoder_tokens is not None + num_prefill_tokens = attn_metadata.num_encoder_tokens + num_decode_tokens = 0 + + if attn_type == AttentionType.DECODER: + # Only enforce this shape-constraint for decoder + # self-attention + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + output = torch.empty_like(query) + if prefill_meta := attn_metadata.prefill_metadata: + assert attn_metadata.seq_lens is not None + if not prefill_meta.prefill_metadata.chunked_prefill: # type: ignore + self._run_sdpa_forward(output, + query, + key, + value, + prefill_meta, + attn_type=attn_type) + else: + # prefix-enabled attention + assert not self.need_mask + import intel_extension_for_pytorch.llm.modules as ipex_modules + output = torch.empty_like(query) + ipex_modules.PagedAttention.flash_attn_varlen_func( + output[:prefill_meta.num_prefill_tokens, :, :], + query[:prefill_meta.num_prefill_tokens, :, :], + key_cache, + value_cache, + prefill_meta.query_start_loc, + prefill_meta.kv_start_loc, + prefill_meta.max_query_len, + prefill_meta.max_kv_len, + self.scale, + True, + prefill_meta.prefill_block_tables, + self.alibi_slopes, + ) + + if decode_meta := attn_metadata.decode_metadata: + assert attn_type != AttentionType.ENCODER_ONLY, ( + "Encoder-only models should not have decode metadata.") + # Decoding run. + ( + seq_lens_arg, + max_seq_len_arg, + block_tables_arg, + ) = decode_meta.get_seq_len_block_table_args(attn_type) + + PagedAttention.forward_decode( + output[attn_metadata.num_prefill_tokens:, :, :], + query[attn_metadata.num_prefill_tokens:, :, :], + key_cache, + value_cache, + block_tables_arg, + seq_lens_arg, + max_seq_len_arg, + self.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + layer._k_scale, + layer._v_scale, + ) + + # Reshape the output tensor. + return output.view(-1, self.num_heads * self.head_size) + + def _run_sdpa_forward( + self, + output: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: TorchSDPAMetadata, + attn_type: str = AttentionType.DECODER, + ) -> None: + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=1) + value = value.repeat_interleave(self.num_queries_per_kv, dim=1) + + attn_masks = attn_metadata.get_attn_bias(attn_type) + if attn_masks is None: + if self.alibi_slopes is not None: + attn_masks = _make_alibi_bias( + self.alibi_slopes, query.dtype, + attn_metadata.seq_lens) # type: ignore + elif self.sliding_window is not None: + assert attn_metadata.seq_lens is not None + attn_masks = _make_sliding_window_bias( + attn_metadata.seq_lens, self.sliding_window, + query.dtype) # type: ignore + else: + seq_lens, _ = attn_metadata.get_seq_lens(attn_type) + attn_masks = [None] * len(seq_lens) + attn_metadata.set_attn_bias(attn_masks, attn_type) + + query = query.movedim(0, query.dim() - 2) + key = key.movedim(0, key.dim() - 2) + value = value.movedim(0, value.dim() - 2) + + causal_attn = (attn_type == AttentionType.DECODER) + + seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type) + start_q, start_kv = 0, 0 + for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv, + attn_masks): + end_q = start_q + seq_len_q + end_kv = start_kv + seq_len_kv + sub_out = scaled_dot_product_attention( + query[None, :, start_q:end_q, :], + key[None, :, start_kv:end_kv, :], + value[None, :, start_kv:end_kv, :], + attn_mask=mask, + dropout_p=0.0, + is_causal=causal_attn and mask is None, + scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0) + output[start_q:end_q, :, :] = sub_out + start_q, start_kv = end_q, end_kv + + +def _make_alibi_bias( + alibi_slopes: torch.Tensor, + dtype: torch.dtype, + seq_lens: List[int], +) -> List[torch.Tensor]: + attn_biases: List[torch.Tensor] = [] + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + + num_heads = alibi_slopes.shape[0] + bias = bias[None, :].repeat((num_heads, 1, 1)) + bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0) + inf_mask = torch.empty( + (1, seq_len, seq_len), + dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1) + attn_biases.append((bias + inf_mask).to(dtype)) + + return attn_biases + + +def _make_sliding_window_bias( + seq_lens: List[int], + window_size: Optional[int], + dtype: torch.dtype, +) -> List[torch.Tensor]: + attn_biases: List[torch.Tensor] = [] + for seq_len in seq_lens: + tensor = torch.full( + (1, seq_len, seq_len), + dtype=dtype, + fill_value=1, + ) + shift = 0 + mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore + if window_size is not None: + mask = torch.triu(mask, diagonal=shift - window_size + 1) + mask = torch.log(mask) + attn_biases.append(mask.to(dtype)) + + return attn_biases diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py new file mode 100644 index 0000000..a6c1e27 --- /dev/null +++ b/vllm/attention/backends/triton_mla.py @@ -0,0 +1,141 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional, Type + +import torch + +from vllm.attention.backends.abstract import (AttentionType, + is_quantized_kv_cache) +from vllm.attention.backends.mla.common import (MLACommonBackend, + MLACommonImpl, + MLACommonMetadata) +from vllm.attention.ops.triton_decode_attention import decode_attention_fwd +import ixformer.inference.functions as ixf_ops +import vllm.envs as envs +from vllm import _custom_ops as ops + + +class TritonMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "TRITON_MLA" + + @staticmethod + def get_impl_cls() -> Type["TritonMLAImpl"]: + return TritonMLAImpl + + +class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + **mla_args) + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "TritonMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "TritonMLAImpl") + + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "TritonMLA with FP8 KV cache not yet supported") + self._k_scale = torch.tensor(1.0, dtype=torch.float32) + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + kv_c_and_k_pe_cache_scale: torch.Tensor, + attn_metadata: MLACommonMetadata, + k_c_normed: torch.Tensor=None, + k_pe: torch.Tensor=None, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + + decode_meta = attn_metadata.decode_metadata + assert decode_meta is not None + B = q_nope.shape[0] + q = torch.cat([q_nope, q_pe], dim=-1) + + o = torch.empty(B, + self.num_heads, + self.kv_lora_rank, + dtype=q_nope.dtype, + device=q_nope.device) + + # num_kv_splits = 4 # TODO: heuristic + + # # TODO(lucas) Allocate ahead of time + # attn_logits = torch.empty( + # ( + # B, + # self.num_heads, + # num_kv_splits, + # # NOTE(lucas) idk why the +1 is here but sglang has it so we + # # just mirror that + # self.kv_lora_rank + 1, + # ), + # dtype=torch.float32, + # device=q.device, + # ) + + # # Add a head dim of 1 + # kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(1) + # kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] + # PAGE_SIZE = kv_c_and_k_pe_cache.size(2) + + # # Run MQA + # decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, + # decode_meta.block_tables, + # decode_meta.seq_lens_tensor, + # num_kv_splits, self.scale, PAGE_SIZE) + + if envs.VLLM_USE_INT8_MLA: + q_int8, q_scale,_ = ops.scaled_int8_quant(q) + ixf_ops.vllm_paged_attention_mla_int8( + o, q_int8, q_scale[...,0].view(-1,q_int8.shape[-2]), kv_c_and_k_pe_cache,kv_c_and_k_pe_cache_scale, self.scale, decode_meta.block_tables, decode_meta.seq_lens_tensor, decode_meta.max_decode_seq_len,decode_meta.use_cuda_graph + ) + + else: + # fused q concat & cache write + ixf_ops.vllm_paged_attention_mla_fused( + output=o, + q_nope=q_nope, + q_pe=q_pe.contiguous(), + kv_cache=kv_c_and_k_pe_cache, + scale=self.scale, + block_tables=decode_meta.block_tables, + context_lens=decode_meta.seq_lens_tensor, + max_context_len=decode_meta.max_decode_seq_len, + k_c_normed=k_c_normed, + k_pe=k_pe, + use_cuda_graph=decode_meta.use_cuda_graph + ) + return self._v_up_proj_and_o_proj(o) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py new file mode 100644 index 0000000..b4413c3 --- /dev/null +++ b/vllm/attention/backends/utils.py @@ -0,0 +1,585 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Attention backend utils""" +from collections import defaultdict +from contextlib import contextmanager +from itertools import accumulate +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union + +import numpy as np +import torch + +from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, + AttentionState) +from vllm.attention.backends.abstract import AttentionType +from vllm.logger import init_logger +from vllm.multimodal import MultiModalPlaceholderMap +from vllm.utils import async_tensor_h2d, make_tensor_with_pad + +logger = init_logger(__name__) + +if TYPE_CHECKING: + from vllm.worker.model_runner_base import ModelRunnerBase + +# Error string(s) for encoder/decoder +# unsupported attention scenarios +STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported " + "with encoder/decoder models.") + +PAD_SLOT_ID = -1 + +# Switch to numpy implementation of compute_slot_mapping +# if we have at least this many elements. Could be tuned further. +_COMPUTE_SLOT_MAPPING_NUMPY_NUMEL = 256 + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUBuilder + + +def is_block_tables_empty(block_tables: Union[None, Dict]): + """ + Check if block_tables is None or a dictionary with all None values. + """ + if block_tables is None: + return True + return (isinstance(block_tables, dict) + and all(value is None for value in block_tables.values())) + + +def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int, + context_len: int, sliding_window: int): + """ + Compute the start index of slot mapping. + """ + start_idx = 0 + if is_prompt and sliding_window is not None: + start_idx = max(0, query_len - sliding_window) + return start_idx + + +def _compute_slot_mapping_python(slot_mapping: List[int], + block_table: List[int], range_start: int, + range_end: int, block_size: int): + for i in range(range_start, range_end): + block_number = block_table[i // block_size] + block_offset = i % block_size + slot = block_number * block_size + block_offset + slot_mapping.append(slot) + + +def _compute_slot_mapping_numpy(slot_mapping: List[int], + block_table: List[int], range_start: int, + range_end: int, block_size: int): + block_table_array = np.array(block_table) + idx = np.arange(range_start, range_end) + block_offset = idx % block_size + idx //= block_size + seq_slot_mapping_array = block_table_array[idx] + seq_slot_mapping_array *= block_size + seq_slot_mapping_array += block_offset + slot_mapping.extend(seq_slot_mapping_array) + + +def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int], + seq_id: int, seq_len: int, context_len: int, + start_idx: int, block_size: int, + block_tables: Dict[int, List[int]]): + """ + Compute slot mapping. + """ + if is_profile_run: + # During memory profiling, the block tables are not + # initialized yet. In this case, we just use a dummy + # slot mapping. + # In embeddings, the block tables are {seq_id: None}. + slot_mapping.extend([PAD_SLOT_ID] * seq_len) + return + + # Mask the [0, start_idx) tokens of the prompt with + # PAD_SLOT_ID, where start_idx is max(0, seq_len - + # sliding_window). For example, if the prompt len is 10, + # sliding window is 8, and block size is 4, the first two + # tokens are masked and the slot mapping will be + # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + padding_mask_len = max(0, start_idx - context_len) + slot_mapping.extend([PAD_SLOT_ID] * padding_mask_len) + + range_start = max(start_idx, context_len) + range_end = seq_len + numel = range_end - range_start + block_table = block_tables[seq_id] + + # numpy implementation will be faster than python if we have + # many elements, otherwise it will be slower. + if numel < _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL: + _compute_slot_mapping_python(slot_mapping, block_table, range_start, + range_end, block_size) + else: + _compute_slot_mapping_numpy(slot_mapping, block_table, range_start, + range_end, block_size) + + +TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata') + + +class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): + + _metadata_cls: Type[TAttentionMetadata] + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.input_builder = input_builder + self.runner = input_builder.runner + + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + + def prepare(self): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool): + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): + self.context_lens.append(context_len) + if is_prompt: + mm_maps = inter_data.multi_modal_placeholder_maps + if mm_maps: + for modality, placeholders in mm_maps.items(): + self.multimodal_placeholder_maps[modality].extend( + placeholders) + + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if inter_data.prefix_cache_hit: + block_table = block_tables[seq_id] + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + if curr_sliding_window_block == 0: + block_table = block_tables[seq_id] + else: + block_table = block_tables[seq_id][ + -curr_sliding_window_block:] + self.block_tables.append(block_table) + + # Compute slot mapping. + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_query_len = max(query_lens) + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + query_start_loc = list(accumulate(query_lens, initial=0)) + seq_start_loc = list(accumulate(seq_lens, initial=0)) + + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size + + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + input_block_tables = self.runner.graph_block_tables[:batch_size] + for i, block_table in enumerate(self.block_tables): + if block_table: + input_block_tables[i, :len(block_table)] = block_table + block_tables = torch.from_numpy(input_block_tables).to( + device, non_blocking=True) + else: + block_tables = make_tensor_with_pad( + self.block_tables, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, "query_lens: {}".format(query_lens) + + assert device is not None + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) + query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, + device, + self.runner.pin_memory) + seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, + device, self.runner.pin_memory) + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + self.multimodal_placeholder_maps.items() + } + + return self._metadata_cls( # type: ignore + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=True, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc_tensor, + seq_start_loc=seq_start_loc_tensor, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + ) + + +class CommonAttentionState(AttentionState): + + def __init__(self, runner: "ModelRunnerBase"): + self.runner = runner + self._is_graph_capturing = False + + @contextmanager + def graph_capture(self, max_batch_size: int): + + self._is_graph_capturing = True + + self._graph_slot_mapping = torch.full((max_batch_size, ), + PAD_SLOT_ID, + dtype=torch.long, + device=self.runner.device) + self._graph_seq_lens = torch.ones(max_batch_size, + dtype=torch.int32, + device=self.runner.device) + self._graph_block_tables = torch.from_numpy( + self.runner.graph_block_tables).to(device=self.runner.device) + + yield + + self._is_graph_capturing = False + del self._graph_slot_mapping + del self._graph_seq_lens + del self._graph_block_tables + + def graph_clone(self, batch_size: int) -> "CommonAttentionState": + assert self._is_graph_capturing + return self.__class__(self.runner) + + def graph_capture_get_metadata_for_batch( + self, batch_size: int, is_encoder_decoder_model: bool = False): + assert self._is_graph_capturing + attn_metadata = self.runner.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=self._graph_slot_mapping[:batch_size], + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + seq_lens=None, + seq_lens_tensor=self._graph_seq_lens[:batch_size], + max_query_len=1, + max_decode_query_len=1, + max_prefill_seq_len=0, + max_decode_seq_len=self.runner.max_seq_len_to_capture, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self._graph_block_tables[:batch_size], + use_cuda_graph=True, + ) + if is_encoder_decoder_model: + # The encoder decoder model works only with XFormers and + # Flash Attention backend. Assert the same. + assert self.runner.attn_backend.get_name() in\ + ["XFORMERS", "FLASH_ATTN"], \ + f"Expected attn_backend name to be either 'XFORMERS' or " \ + f"'FLASH_ATTN', but "\ + f"got '{self.runner.attn_backend.get_name()}'" + self._update_captured_metadata_for_enc_dec_model( + batch_size=batch_size, attn_metadata=attn_metadata) + + return attn_metadata + + def get_graph_input_buffers( + self, + attn_metadata, + is_encoder_decoder_model: bool = False) -> Dict[str, Any]: + input_buffers = { + "slot_mapping": attn_metadata.slot_mapping, + "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, + "block_tables": attn_metadata.decode_metadata.block_tables, + } + if is_encoder_decoder_model: + # The encoder decoder model works only with XFormers and + # Flash Attention backend. Assert the same. + assert self.runner.attn_backend.get_name() in\ + ["XFORMERS", "FLASH_ATTN"], \ + f"Expected attn_backend name to be either 'XFORMERS' or "\ + f"'FLASH_ATTN', but "\ + f"got '{self.runner.attn_backend.get_name()}'" + self._add_additonal_input_buffers_for_enc_dec_model( + attn_metadata=attn_metadata, input_buffers=input_buffers) + return input_buffers + + def prepare_graph_input_buffers( + self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False) -> None: + input_buffers["seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) + input_buffers["block_tables"].copy_( + attn_metadata.decode_metadata.block_tables, non_blocking=True) + if is_encoder_decoder_model: + # The encoder decoder model works only with XFormers and + # Flash Attention backend. Assert the same. + assert self.runner.attn_backend.get_name() in\ + ["XFORMERS", "FLASH_ATTN"], \ + f"Expected attn_backend name to be either 'XFORMERS' or "\ + f"'FLASH_ATTN', but "\ + f"got '{self.runner.attn_backend.get_name()}'" + self._prepare_input_buffers_for_enc_dec_model( + attn_metadata, input_buffers) + + def begin_forward(self, model_input) -> None: + return + + def _update_captured_metadata_for_enc_dec_model(self, batch_size: int, + attn_metadata): + """ + Updates the attention metadata parameters for CUDA graph capture in an + encoder-decoder model. + + This method modifies attention-related tensors and metadata required + for CUDA graph capture in encoder-decoder models. Specifically, it + updates the cross-attention and encoder sequence tensors in the + AttentionMetadata object. + """ + # During decode phase the cross_slot_mapping will be empty. Hence set + # an empty tensor for CUDA Graph capture. + attn_metadata.cross_slot_mapping = torch.tensor( + [], dtype=torch.int).cuda() + attn_metadata.cross_block_tables = torch.full( + (batch_size, self.runner.get_max_block_per_batch()), + 1, + dtype=torch.int).cuda() + attn_metadata.encoder_seq_lens = torch.full((batch_size, ), + 1, + dtype=torch.int).cuda() + attn_metadata.encoder_seq_lens_tensor = torch.full( + (batch_size, ), 1, dtype=torch.int).cuda() + attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture + attn_metadata.num_encoder_tokens = 0 + + def _add_additonal_input_buffers_for_enc_dec_model( + self, attn_metadata, input_buffers: Dict[str, Any]): + """ + Saves additional input buffers specific to the encoder-decoder model + from the attention metadata. + + This method extracts and stores encoder-decoder related input buffers + from the `attn_metadata` into the `input_buffers` dictionary. The + buffers include encoder sequence lengths, cross-slot mappings, and + cross-block tables, which are essential for the encoder-decoder model + during CUDA graph replay. + """ + input_buffers["encoder_seq_lens_tensor"] = ( + attn_metadata.decode_metadata.encoder_seq_lens_tensor) + input_buffers["cross_slot_mapping"] = ( + attn_metadata.decode_metadata.cross_slot_mapping) + input_buffers["cross_block_tables"] = ( + attn_metadata.decode_metadata.cross_block_tables) + + def _prepare_input_buffers_for_enc_dec_model(self, attn_metadata, + input_buffers: Dict[str, + Any]): + """ + Populates input buffers with data from the encoder-decoder model's + attention metadata. + + This method fills the input buffers with encoder-decoder specific + tensors. It copies data from the `attn_metadata` and keyword arguments + (`kwargs`) into corresponding buffers in the `input_buffers` dictionary. + The copied data includes attention-related metadata as well as input + IDs and positional information for the encoder. + """ + input_buffers["encoder_seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.encoder_seq_lens_tensor, + non_blocking=True) + input_buffers["cross_slot_mapping"].copy_( + attn_metadata.decode_metadata.cross_slot_mapping, + non_blocking=True) + input_buffers["cross_block_tables"].copy_( + attn_metadata.decode_metadata.cross_block_tables, + non_blocking=True) + + +def is_all_encoder_attn_metadata_set(attn_metadata): + ''' + All attention metadata required for encoder attention is set. + ''' + return ((attn_metadata.encoder_seq_lens is not None) + and (attn_metadata.encoder_seq_lens_tensor is not None) + and (attn_metadata.max_encoder_seq_len is not None)) + + +def is_all_cross_attn_metadata_set(attn_metadata): + ''' + All attention metadata required for enc/dec cross-attention is set. + + Superset of encoder attention required metadata. + ''' + return (attn_metadata.is_all_encoder_attn_metadata_set + and (attn_metadata.cross_slot_mapping is not None) + and (attn_metadata.cross_block_tables is not None)) + + +def get_seq_len_block_table_args( + attn_metadata, + is_prompt: bool, + attn_type: str, +) -> tuple: + ''' + The particular choice of sequence-length- and block-table-related + attributes which should be extracted from attn_metadata is dependent + on the type of attention operation. + + Decoder attn -> select entirely decoder self-attention-related fields + Encoder/decoder cross-attn -> select encoder sequence lengths & + cross-attn block-tables fields + Encoder attn -> select encoder sequence lengths fields & no block tables + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention op + * is_prompt: True if prefill, False otherwise + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + + * Appropriate sequence-lengths tensor + * Appropriate max sequence-length scalar + * Appropriate block tables (or None) + ''' + + if attn_type == AttentionType.DECODER: + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + if is_prompt: + max_seq_len = attn_metadata.max_prefill_seq_len + else: + max_seq_len = attn_metadata.max_decode_seq_len + return (attn_metadata.seq_lens_tensor, max_seq_len, + attn_metadata.block_tables) + elif attn_type == AttentionType.ENCODER_DECODER: + # Enc/dec cross-attention KVs match encoder sequence length; + # cross-attention utilizes special "cross" block tables + return (attn_metadata.encoder_seq_lens_tensor, + attn_metadata.max_encoder_seq_len, + attn_metadata.cross_block_tables) + elif attn_type == AttentionType.ENCODER: + # No block tables associated with encoder attention + return (attn_metadata.encoder_seq_lens_tensor, + attn_metadata.max_encoder_seq_len, None) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +def get_num_prefill_decode_query_kv_tokens( + attn_metadata, + attn_type: str, +) -> Tuple[int, int, int]: + """ + Calculate the number of prefill and decode tokens for query, key/value + based on the attention metadata and the specified attention type. + + Args: + attn_metadata (FlashAttentionMetadata): Attention Metadata object. + attn_type (AttentionType): The type of attention being used. + Returns: + Tuple[int, int, int]: A tuple containing three integers: + - The number of prefill query tokens. + - The number of prefill key/value tokens. + - The number of decode query tokens. + + Raises: + AssertionError: If the number of encoder tokens in `attn_metadata` + is `None` when required for the calculations. + """ + num_prefill_query_tokens = 0 + num_decode_query_tokens = 0 + num_prefill_kv_tokens = 0 + if attn_type == AttentionType.ENCODER: + # Encoder attention is only invoked during prefill phase. + # The same input servers a both query and key. + assert attn_metadata.num_encoder_tokens is not None + num_prefill_query_tokens = attn_metadata.num_encoder_tokens + num_prefill_kv_tokens = attn_metadata.num_encoder_tokens + num_decode_query_tokens = 0 + elif attn_type == AttentionType.ENCODER_DECODER: + assert attn_metadata.num_encoder_tokens is not None + num_prefill_query_tokens = attn_metadata.num_prefill_tokens + # The key is the encoder/cross-attention. + num_prefill_kv_tokens = attn_metadata.num_encoder_tokens + num_decode_query_tokens = attn_metadata.num_decode_tokens + else: # attn_type == AttentionType.DECODER or + # attn_type == AttentionType.ENCODER_ONLY + num_prefill_query_tokens = attn_metadata.num_prefill_tokens + num_prefill_kv_tokens = attn_metadata.num_prefill_tokens + num_decode_query_tokens = attn_metadata.num_decode_tokens + + return (num_prefill_query_tokens, num_prefill_kv_tokens, + num_decode_query_tokens) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py new file mode 100644 index 0000000..14c94c9 --- /dev/null +++ b/vllm/attention/backends/xformers.py @@ -0,0 +1,793 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Attention layer with xFormers and PagedAttention.""" +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +from xformers import ops as xops +from xformers.ops.fmha.attn_bias import (AttentionBias, + BlockDiagonalCausalMask, + BlockDiagonalMask, + LowerTriangularMaskWithTensorBias) + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import ( + CommonAttentionState, CommonMetadataBuilder, + get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args, + is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set) +from vllm.attention.ops.paged_attn import (PagedAttention, + PagedAttentionMetadata) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class XFormersBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "XFORMERS" + + @staticmethod + def get_impl_cls() -> Type["XFormersImpl"]: + return XFormersImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return XFormersMetadata + + @staticmethod + def get_builder_cls() -> Type["XFormersMetadataBuilder"]: + return XFormersMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): + """Metadata for XFormersbackend. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ----------------------| + # |-- query_len ---| + + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # FIXME: It is for flash attn. + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] = None + + # FIXME: It is for flash attn. + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] = None + + # Maximum query length in the batch. None for decoding. + max_query_len: Optional[int] = None + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] = None + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + + # Self-attention prefill/decode metadata cache + _cached_prefill_metadata: Optional["XFormersMetadata"] = None + _cached_decode_metadata: Optional["XFormersMetadata"] = None + + # Begin encoder attn & enc/dec cross-attn fields... + + # Encoder sequence lengths representation + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + # FIXME: It is for flash attn. + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + encoder_seq_start_loc: Optional[torch.Tensor] = None + + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + + def __post_init__(self): + # Set during the execution of the first attention op. + # It is a list because it is needed to set per prompt + # when alibi slopes is used. It is because of the limitation + # from xformer API. + # will not appear in the __repr__ and __init__ + self.attn_bias: Optional[List[AttentionBias]] = None + self.encoder_attn_bias: Optional[List[AttentionBias]] = None + self.cross_attn_bias: Optional[List[AttentionBias]] = None + + @property + def is_all_encoder_attn_metadata_set(self): + ''' + All attention metadata required for encoder attention is set. + ''' + return is_all_encoder_attn_metadata_set(self) + + @property + def is_all_cross_attn_metadata_set(self): + ''' + All attention metadata required for enc/dec cross-attention is set. + + Superset of encoder attention required metadata. + ''' + return is_all_cross_attn_metadata_set(self) + + @property + def prefill_metadata(self) -> Optional["XFormersMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + # Recover cached prefill-phase attention + # metadata structure + return self._cached_prefill_metadata + + assert ((self.seq_lens is not None) + or (self.encoder_seq_lens is not None)) + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) + + # Compute some attn_metadata fields which default to None + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + seq_start_loc = (None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1]) + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[:self.num_prefill_tokens]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) + block_tables = (None if self.block_tables is None else + self.block_tables[:self.num_prefills]) + + # Construct & cache prefill-phase attention metadata structure + self._cached_prefill_metadata = XFormersMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["XFormersMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + # Recover cached decode-phase attention + # metadata structure + return self._cached_decode_metadata + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) + + # Compute some attn_metadata fields which default to None + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[self.num_prefill_tokens:]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) + block_tables = (None if self.block_tables is None else + self.block_tables[self.num_prefills:]) + + # Construct & cache decode-phase attention metadata structure + self._cached_decode_metadata = XFormersMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + seq_lens_tensor=seq_lens_tensor, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + block_tables=block_tables, + use_cuda_graph=self.use_cuda_graph, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + + # Batch may be composed of prefill|decodes, adjust query start indices + # to refer to the start of decodes when the two are split apart. + # E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. + if self._cached_decode_metadata.query_start_loc is not None: + qs = self._cached_decode_metadata.query_start_loc + self._cached_decode_metadata.query_start_loc = qs - qs[0] + return self._cached_decode_metadata + + +def _get_attn_bias( + attn_metadata: XFormersMetadata, + attn_type: str, +) -> Optional[AttentionBias]: + ''' + Extract appropriate attention bias from attention metadata + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + * Appropriate attention bias value given the attention type + ''' + + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): + return attn_metadata.attn_bias + elif attn_type == AttentionType.ENCODER: + return attn_metadata.encoder_attn_bias + elif attn_type == AttentionType.ENCODER_DECODER: + return attn_metadata.cross_attn_bias + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +def _set_attn_bias( + attn_metadata: XFormersMetadata, + attn_bias: List[Optional[AttentionBias]], + attn_type: str, +) -> None: + ''' + Update appropriate attention bias field of attention metadata, + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_bias: The desired attention bias value + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + ''' + + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): + attn_metadata.attn_bias = attn_bias + elif attn_type == AttentionType.ENCODER: + attn_metadata.encoder_attn_bias = attn_bias + elif attn_type == AttentionType.ENCODER_DECODER: + attn_metadata.cross_attn_bias = attn_bias + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]): + + _metadata_cls = XFormersMetadata + + +class XFormersImpl(AttentionImpl[XFormersMetadata]): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| + + Otherwise, the layout is as follows: + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| + |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "XFormers does not support block-sparse attention.") + if logits_soft_cap is not None: + logger.warning_once("XFormers does not support logits soft cap. " + "Outputs may be slightly off.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + suppored_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in suppored_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {suppored_head_sizes}.") + + self.attn_type = attn_type + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: Optional[torch.Tensor], + value: Optional[torch.Tensor], + kv_cache: torch.Tensor, + attn_metadata: "XFormersMetadata", + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with xFormers and PagedAttention. + + For decoder-only models: query, key and value must be non-None. + + For encoder/decoder models: + * XFormersImpl.forward() may be invoked for both self- and cross- + attention layers. + * For self-attention: query, key and value must be non-None. + * For cross-attention: + * Query must be non-None + * During prefill, key and value must be non-None; key and value + get cached for use during decode. + * During decode, key and value may be None, since: + (1) key and value tensors were cached during prefill, and + (2) cross-attention key and value tensors do not grow during + decode + + A note on how the attn_type (attention type enum) argument impacts + attention forward() behavior: + + * DECODER: normal decoder-only behavior; + use decoder self-attention block table + * ENCODER: no KV caching; pass encoder sequence + attributes (encoder_seq_lens/encoder_seq_lens_tensor/ + max_encoder_seq_len) to kernel, in lieu of decoder + sequence attributes (seq_lens/seq_lens_tensor/max_seq_len). + Used for encoder branch of encoder-decoder models. + * ENCODER_ONLY: no kv_caching, uses the normal attention + attributes (seq_lens/seq_lens_tensor/max_seq_len). + * ENCODER_DECODER: cross-attention behavior; + use cross-attention block table for caching KVs derived + from encoder hidden states; since KV sequence lengths + will match encoder sequence lengths, pass encoder sequence + attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/ + max_encoder_seq_len) + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + attn_type: Select attention type, between encoder attention, + decoder self-attention, or encoder/decoder cross- + attention. Defaults to decoder self-attention, + which is the vLLM default generally + Returns: + shape = [num_tokens, num_heads * head_size] + """ + attn_type = self.attn_type + # Check that appropriate attention metadata attributes are + # selected for the desired attention type + if (attn_type == AttentionType.ENCODER + and (not attn_metadata.is_all_encoder_attn_metadata_set)): + raise AttributeError("Encoder attention requires setting " + "encoder metadata attributes.") + + elif (attn_type == AttentionType.ENCODER_DECODER + and (not attn_metadata.is_all_cross_attn_metadata_set)): + raise AttributeError("Encoder/decoder cross-attention " + "requires setting cross-attention " + "metadata attributes.") + + query = query.view(-1, self.num_heads, self.head_size) + if key is not None: + assert value is not None + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + else: + assert value is None + + # Self-attention vs. cross-attention will impact + # which KV cache memory-mapping & which + # seqlen datastructures we utilize + + if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0): + # KV-cache during decoder-self- or + # encoder-decoder-cross-attention, but not + # during encoder attention. + # + # Even if there are no new key/value pairs to cache, + # we still need to break out key_cache and value_cache + # i.e. for later use by paged attention + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + if (key is not None) and (value is not None): + + if attn_type == AttentionType.ENCODER_DECODER: + # Update cross-attention KV cache (prefill-only) + # During cross-attention decode, key & value will be None, + # preventing this IF-statement branch from running + updated_slot_mapping = attn_metadata.cross_slot_mapping + else: + # Update self-attention KV cache (prefill/decode) + updated_slot_mapping = attn_metadata.slot_mapping + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory + # profiling run. + PagedAttention.write_to_paged_cache( + key, value, key_cache, value_cache, updated_slot_mapping, + self.kv_cache_dtype, layer._k_scale, layer._v_scale) + (num_prefill_query_tokens, num_prefill_kv_tokens, + num_decode_query_tokens) = \ + get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_query_tokens:] + # QKV for prefill. + query = query[:num_prefill_query_tokens] + if key is not None and value is not None: + key = key[:num_prefill_kv_tokens] + value = value[:num_prefill_kv_tokens] + + assert query.shape[0] == num_prefill_query_tokens + assert decode_query.shape[0] == num_decode_query_tokens + + if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. + if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0: + # normal attention. + # block tables are empty if the prompt does not have a cached + # prefix. + out = self._run_memory_efficient_xformers_forward( + query, key, value, prefill_meta, attn_type=attn_type) + assert out.shape == output[:num_prefill_query_tokens].shape + output[:num_prefill_query_tokens] = out + else: + assert attn_type != AttentionType.ENCODER_ONLY, ( + "Encoder-only models should not have prefix attention.") + + assert prefill_meta.query_start_loc is not None + assert prefill_meta.max_query_len is not None + + # prefix-enabled attention + # TODO(Hai) this triton kernel has regression issue (broke) to + # deal with different data types between KV and FP8 KV cache, + # to be addressed separately. + out = PagedAttention.forward_prefix( + query, + key, + value, + self.kv_cache_dtype, + key_cache, + value_cache, + prefill_meta.block_tables, + prefill_meta.query_start_loc, + prefill_meta.seq_lens_tensor, + prefill_meta.max_query_len, + self.alibi_slopes, + self.sliding_window, + layer._k_scale, + layer._v_scale, + ) + assert output[:num_prefill_query_tokens].shape == out.shape + output[:num_prefill_query_tokens] = out + + if decode_meta := attn_metadata.decode_metadata: + assert attn_type != AttentionType.ENCODER_ONLY, ( + "Encoder-only models should not have decode metadata.") + + ( + seq_lens_arg, + max_seq_len_arg, + block_tables_arg, + ) = get_seq_len_block_table_args(decode_meta, False, attn_type) + + output[num_prefill_query_tokens:] = PagedAttention.forward_decode( + decode_query, + key_cache, + value_cache, + block_tables_arg, + seq_lens_arg, + max_seq_len_arg, + self.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + layer._k_scale, + layer._v_scale, + ) + + # Reshape the output tensor. + return output.view(-1, self.num_heads * self.head_size) + + def _run_memory_efficient_xformers_forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: XFormersMetadata, + attn_type: str = AttentionType.DECODER, + ) -> torch.Tensor: + """Attention for 1D query of multiple prompts. Multiple prompt + tokens are flattened in to `query` input. + + See https://facebookresearch.github.io/xformers/components/ops.html + for API spec. + + Args: + output: shape = [num_prefill_tokens, num_heads, head_size] + query: shape = [num_prefill_tokens, num_heads, head_size] + key: shape = [num_prefill_tokens, num_kv_heads, head_size] + value: shape = [num_prefill_tokens, num_kv_heads, head_size] + attn_metadata: Metadata for attention. + attn_type: Select attention type, between encoder attention, + decoder self-attention, or encoder/decoder cross- + attention. Defaults to decoder self-attention, + which is the vLLM default generally + """ + + original_query = query + if self.num_kv_heads != self.num_heads: + # GQA/MQA requires the shape [B, M, G, H, K]. + # Note that the output also has the same shape (which is different + # from a spec from the doc). + query = query.view(query.shape[0], self.num_kv_heads, + self.num_queries_per_kv, query.shape[-1]) + key = key[:, :, + None, :].expand(key.shape[0], self.num_kv_heads, + self.num_queries_per_kv, key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + value.shape[-1]) + + # Set attention bias if not provided. This typically happens at + # the very attention layer of every iteration. + # FIXME(woosuk): This is a hack. + attn_bias = _get_attn_bias(attn_metadata, attn_type) + if attn_bias is None: + if self.alibi_slopes is None: + + # Cross attention block of decoder branch of encoder-decoder + # model uses seq_lens for dec / encoder_seq_lens for enc + if (attn_type == AttentionType.ENCODER_DECODER): + assert attn_metadata.seq_lens is not None + assert attn_metadata.encoder_seq_lens is not None + + # Cross-attention mask is non-causal + attn_bias = BlockDiagonalMask.from_seqlens( + attn_metadata.seq_lens, + attn_metadata.encoder_seq_lens, + device=query.device) + + # Encoder branch of encoder-decoder model uses + # attn_metadata.encoder_seq_lens + elif attn_type == AttentionType.ENCODER: + + assert attn_metadata.encoder_seq_lens is not None + + # Encoder self-attention mask is non-causal + attn_bias = BlockDiagonalMask.from_seqlens( + attn_metadata.encoder_seq_lens, device=query.device) + + # Self-attention block of encoder-only model just + # uses the seq_lens directly. + elif attn_type == AttentionType.ENCODER_ONLY: + assert attn_metadata.seq_lens is not None + + # Encoder self-attention mask is non-causal + attn_bias = BlockDiagonalMask.from_seqlens( + attn_metadata.seq_lens, device=query.device) + + # Self-attention block of decoder branch just + # uses the seq_lens directly + elif attn_type == AttentionType.DECODER: + assert attn_metadata.seq_lens is not None + + # Decoder self-attention mask is causal + attn_bias = BlockDiagonalCausalMask.from_seqlens( + attn_metadata.seq_lens, device=query.device) + else: + raise ValueError("Unknown AttentionType: %s", attn_type) + + if self.sliding_window is not None: + attn_bias = attn_bias.make_local_attention( + self.sliding_window) + attn_bias = [attn_bias] + else: + assert attn_type == AttentionType.DECODER + assert attn_metadata.seq_lens is not None + attn_bias = _make_alibi_bias(self.alibi_slopes, + self.num_kv_heads, query.dtype, + attn_metadata.seq_lens) + + _set_attn_bias(attn_metadata, attn_bias, attn_type) + + # No alibi slopes. + # TODO(woosuk): Too many view operations. Let's try to reduce + # them in the future for code readability. + if self.alibi_slopes is None: + # Add the batch dimension. + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + out = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=attn_bias[0], + p=0.0, + scale=self.scale) + return out.view_as(original_query) + + # Attention with alibi slopes. + # FIXME(woosuk): Because xformers does not support dynamic sequence + # lengths with custom attention bias, we process each prompt one by + # one. This is inefficient, especially when we have many short prompts. + assert attn_metadata.seq_lens is not None + output = torch.empty_like(original_query) + start = 0 + for i, seq_len in enumerate(attn_metadata.seq_lens): + end = start + seq_len + out = xops.memory_efficient_attention_forward( + query[None, start:end], + key[None, start:end], + value[None, start:end], + attn_bias=attn_bias[i], + p=0.0, + scale=self.scale) + # TODO(woosuk): Unnecessary copy. Optimize. + output[start:end].copy_(out.view_as(original_query[start:end])) + start += seq_len + return output + + +def _make_alibi_bias( + alibi_slopes: torch.Tensor, + num_kv_heads: int, + dtype: torch.dtype, + seq_lens: List[int], +) -> List[AttentionBias]: + attn_biases: List[AttentionBias] = [] + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + # Calculate a matrix where each element represents ith element- jth + # element. + bias = bias[None, :] - bias[:, None] + + padded_len = (seq_len + 7) // 8 * 8 + num_heads = alibi_slopes.shape[0] + bias = torch.empty( + 1, # batch size + num_heads, + seq_len, + padded_len, + device=alibi_slopes.device, + dtype=dtype, + )[:, :, :, :seq_len].copy_(bias) + bias.mul_(alibi_slopes[:, None, None]) + attn_biases.append(LowerTriangularMaskWithTensorBias(bias)) + + return attn_biases diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py new file mode 100644 index 0000000..c9e6b73 --- /dev/null +++ b/vllm/attention/layer.py @@ -0,0 +1,425 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Attention layer.""" +from typing import Any, Dict, List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import vllm.envs as envs +from vllm.attention import AttentionType +from vllm.attention.selector import backend_name_to_enum, get_attn_backend +from vllm.config import CacheConfig, get_current_vllm_config +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.model_executor.layers.linear import UnquantizedLinearMethod +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.platforms import _Backend, current_platform +from vllm.utils import direct_register_custom_op + +import ixformer.contrib.vllm_flash_attn as ops + + +class Attention(nn.Module): + """Attention layer. + + This class takes query, key, and value tensors as input. The input tensors + can either contain prompt tokens or generation tokens. + The class does the following: + + 1. Store the input key and value tensors in the KV cache. + 2. Perform (multi-head/multi-query/grouped-query) attention. + 3. Return the output tensor. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + per_layer_sliding_window: Optional[int] = None, + use_mla: bool = False, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + **extra_impl_args, + ) -> None: + """ + The KV cache is stored inside this class and is accessed via + `self.kv_cache`. + """ + super().__init__() + if per_layer_sliding_window is not None: + # per-layer sliding window + sliding_window = per_layer_sliding_window + elif cache_config is not None: + # model-level sliding window + sliding_window = cache_config.sliding_window + else: + sliding_window = None + + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + block_size = cache_config.block_size + is_attention_free = cache_config.is_attention_free + calculate_kv_scales = cache_config.calculate_kv_scales + else: + kv_cache_dtype = "auto" + block_size = 16 + is_attention_free = False + calculate_kv_scales = False + if num_kv_heads is None: + num_kv_heads = num_heads + + # The default k/v_scale is set to 1.0. This is ignored + # when kv-cache is not fp8, and should be used with + # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we + # expect the pre-quantized k/v_scale to be loaded along + # with the model weights. + self.kv_cache_dtype = kv_cache_dtype + self.calculate_kv_scales = calculate_kv_scales + self._k_scale = torch.tensor(1.0, dtype=torch.float32) + self._v_scale = torch.tensor(1.0, dtype=torch.float32) + # FlashAttn doesn't support quantizing the kv-cache only + # but requires q to be quantized as well. + self._q_scale = torch.tensor(1.0, dtype=torch.float32) + + # We also keep the float32 versions of k/v_scale for attention + # backends that don't support tensors (Flashinfer) + self._k_scale_float = 1.0 + self._v_scale_float = 1.0 + + self.use_mla = use_mla + self.num_heads = num_heads + self.head_size = head_size + self.num_kv_heads = num_kv_heads + self.sliding_window = sliding_window + + quant_method = quant_config.get_quant_method( + self, prefix=prefix) if quant_config else None + if quant_method is not None and not isinstance( + quant_method, UnquantizedLinearMethod): + assert isinstance(quant_method, BaseKVCacheMethod) + # TODO (mgoin): kv cache dtype should be specified in the FP8 + # checkpoint config and become the "auto" behavior + if self.kv_cache_dtype == "fp8_e5m2": + raise ValueError("fp8_e5m2 kv-cache is not supported with " + "fp8 checkpoints.") + # If quantization is enabled, we make "k_scale" and "v_scale" + # parameters so that it can be loaded from the model checkpoint. + # The k/v_scale will then be converted back to native float32 + # values after weight loading. + self.quant_method = quant_method + self.quant_method.create_weights(self) + + # During model initialization, the default dtype is set as the model + # weight and activation dtype. + dtype = torch.get_default_dtype() + attn_backend = get_attn_backend(head_size, + dtype, + kv_cache_dtype, + block_size, + is_attention_free, + blocksparse_params is not None, + use_mla=use_mla) + impl_cls = attn_backend.get_impl_cls() + self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + **extra_impl_args) + self.backend = backend_name_to_enum(attn_backend.get_name()) + self.dtype = dtype + + # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how + # torch.compile works by registering the attention as one giant + # opaque custom op. For other platforms, we directly call them + # and let torch.compile handle them. + self.use_direct_call = True + + self.use_output = attn_backend.accept_output_buffer + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + self.layer_name = prefix + self.attn_type = attn_type + # use a placeholder kv cache tensor during init, which will be replaced + # by bind_kv_cache + # this variable will not be accessed if use_direct_call is True + self.kv_cache = [ + torch.tensor([]) for _ in range( + get_current_vllm_config().parallel_config.num_virtual_engine) + ] + self.kv_cache_scale = [ + torch.tensor([]) for _ in range( + get_current_vllm_config().parallel_config.num_virtual_engine) + ] + + self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) + self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) + self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + # For some alternate attention backends like MLA the attention output + # shape does not match the query shape, so we optionally let the model + # definition specify the output tensor shape. + output_shape: Optional[torch.Size] = None, + ) -> torch.Tensor: + """ + The KV cache is stored inside this class and is accessed via + `self.kv_cache`. + + Attention metadata (`attn_metadata`) is set using a context manager in + the model runner's `execute_model` method. It is accessed via forward + context using + `vllm.forward_context.get_forward_context().attn_metadata`. + """ + if self.calculate_kv_scales: + attn_metadata = get_forward_context().attn_metadata + if attn_metadata.enable_kv_scales_calculation: + self.calc_kv_scales(query, key, value) + if self.use_output: + output_shape = (output_shape + if output_shape is not None else query.shape) + output = torch.empty(output_shape, + dtype=query.dtype, + device=query.device) + # hidden_size = output_shape[-1] + # We skip reshaping query, key and value tensors for the MLA + # backend since these tensors have different semantics and are + # processed differently. + if not self.use_mla: + # Reshape the query, key, and value tensors. + # NOTE(woosuk): We do this outside the custom op to minimize the + # CPU overheads from the non-CUDA-graph regions. + query = query.view(-1, self.num_heads, self.head_size) + output = output.view(-1, self.num_heads, self.head_size) + if key is not None: + key = key.view(-1, self.num_kv_heads, self.head_size) + if value is not None: + value = value.view(-1, self.num_kv_heads, self.head_size) + if self.use_direct_call: + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + self_kv_cache_scale = self.kv_cache_scale[forward_context.virtual_engine] + return self.impl.forward(self, + query, + key, + value, + self_kv_cache, + self_kv_cache_scale, + attn_metadata, + output=output) + else: + torch.ops.vllm.unified_attention_with_output( + query, key, value, output, self.layer_name) + # return output.view(-1, hidden_size) + else: + if self.use_direct_call: + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + self_kv_cache_scale = self.kv_cache_scale[forward_context.virtual_engine] + return self.impl.forward(self, query, key, value, + self_kv_cache,self_kv_cache_scale, attn_metadata) + else: + return torch.ops.vllm.unified_attention( + query, key, value, self.layer_name) + + def calc_kv_scales(self, query, key, value): + self._q_scale.copy_(torch.abs(query).max() / self.q_range) + self._k_scale.copy_(torch.abs(key).max() / self.k_range) + self._v_scale.copy_(torch.abs(value).max() / self.v_range) + self._k_scale_float = self._k_scale.item() + self._v_scale_float = self._v_scale.item() + # We only calculate the scales once + self.calculate_kv_scales = False + + def extra_repr(self) -> str: + s = f"head_size={self.impl.head_size}" # type: ignore + s += f", num_heads={self.impl.num_heads}" # type: ignore + s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore + s += f", scale={self.impl.scale}" # type: ignore + s += f", backend={self.impl.__class__.__name__}" + return s + + def process_weights_after_loading(self, act_dtype: torch.dtype): + if hasattr(self.impl, "process_weights_after_loading"): + self.impl.process_weights_after_loading(act_dtype) + + +class MultiHeadAttention(nn.Module): + """Multi-headed attention without any cache, used for ViT.""" + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + ): + super().__init__() + self.num_heads = num_heads + self.head_size = head_size + self.scale = scale + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + + # assert self.num_heads % self.num_kv_heads == 0 + # self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + # dtype = torch.get_default_dtype() + # attn_backend = get_attn_backend(head_size, + # dtype, + # kv_cache_dtype=None, + # block_size=16, + # is_attention_free=False) + # backend = backend_name_to_enum(attn_backend.get_name()) + # if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}: + # backend = _Backend.XFORMERS + + # self.attn_backend = backend if backend in { + # _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1 + # } else _Backend.TORCH_SDPA + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + ) -> torch.Tensor: + """Input shape: batch_size x seq_len x hidden_size""" + # TODO(Isotr0py): Use existing backend implementations and support FA3 + bsz, q_len, _ = query.size() + kv_len = key.size(1) + + # query = query.view(bsz, q_len, self.num_heads, self.head_size) + # key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size) + # value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size) + + # if (num_repeat := self.num_queries_per_kv) > 1: + # # Handle MQA and GQA + # key = torch.repeat_interleave(key, num_repeat, dim=2) + # value = torch.repeat_interleave(value, num_repeat, dim=2) + + # if self.attn_backend == _Backend.XFORMERS: + # from xformers import ops as xops + + # out = xops.memory_efficient_attention_forward(query, + # key, + # value, + # scale=self.scale) + # elif self.attn_backend == _Backend.TORCH_SDPA: + # query, key, value = (x.transpose(1, 2) + # for x in (query, key, value)) + # out = F.scaled_dot_product_attention(query, + # key, + # value, + # scale=self.scale) + # out = out.transpose(1, 2) + # elif self.attn_backend == _Backend.PALLAS_VLLM_V1: + # query, key, value = (x.transpose(1, 2) + # for x in (query, key, value)) + # from torch_xla.experimental.custom_kernel import flash_attention + # out = flash_attention(query, key, value, sm_scale=self.scale) + # out = out.transpose(1, 2) + + # return out.reshape(bsz, q_len, -1) + query = query.view(bsz * q_len, self.num_heads, self.head_size) + key = key.view(bsz * kv_len, self.num_kv_heads, self.head_size) + value = value.view(bsz * kv_len, self.num_kv_heads, self.head_size) + cu_q = torch.tensor([0,] + [q_len for _ in range(bsz)], device=query.device, dtype=torch.int32).cumsum(dim=0, dtype=torch.int32) + cu_kv = torch.tensor([0,] + [kv_len for _ in range(bsz)], device=query.device, dtype=torch.int32).cumsum(dim=0, dtype=torch.int32) + out = ops.flash_attn_varlen_func( + query, + key, + value, + cu_q, + cu_kv, + q_len, + kv_len, + softmax_scale=self.scale, + causal=False, + ) + + return out.view(bsz, q_len, -1) + + +def unified_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + layer_name: str, +) -> torch.Tensor: + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + self = forward_context.no_compile_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] + return self.impl.forward(self, query, key, value, kv_cache, attn_metadata) + + +def unified_attention_fake( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + layer_name: str, +) -> torch.Tensor: + return torch.empty_like(query).contiguous() + + +direct_register_custom_op( + op_name="unified_attention", + op_func=unified_attention, + mutates_args=[], + fake_impl=unified_attention_fake, + dispatch_key=current_platform.dispatch_key, +) + + +def unified_attention_with_output( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + self = forward_context.no_compile_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] + self.impl.forward(self, + query, + key, + value, + kv_cache, + attn_metadata, + output=output) + + +def unified_attention_with_output_fake( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="unified_attention_with_output", + op_func=unified_attention_with_output, + mutates_args=["output"], + fake_impl=unified_attention_with_output_fake, + dispatch_key=current_platform.dispatch_key, +) diff --git a/vllm/attention/ops/__init__.py b/vllm/attention/ops/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/attention/ops/blocksparse_attention/__init__.py b/vllm/attention/ops/blocksparse_attention/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py b/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py new file mode 100644 index 0000000..71caf3c --- /dev/null +++ b/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py @@ -0,0 +1,432 @@ +# SPDX-License-Identifier: Apache-2.0 + +import torch +import triton +import triton.language as tl + + +def blocksparse_flash_attn_varlen_fwd( + q, + k, + v, # (#tokens, n_heads, head_size) + cu_seqlens_k, + cu_seqlens_q, + sm_scale, + sparse_layout, + *, + block_size=64, + q_block_size=None, + max_seqlen=None): + # split q to blocks + + assert isinstance(sparse_layout, (list, tuple)) + + _, n_heads, head_size = q.shape + batch_size = cu_seqlens_k.size(0) - 1 + q_block_size = q_block_size or block_size + + assert q.dim() == k.dim() == v.dim() == 3 + assert q.size(1) % k.size(1) == 0 + assert q.size(2) == k.size(2) + # TODO(linxihui): allow k, v to have different head_size + assert k.shape == v.shape + assert cu_seqlens_k.dim() == 1 + + q_k_ratio = q.size(1) // k.size(1) + + if cu_seqlens_q is None: + if q.size(0) == batch_size: # decoding only + cu_seqlens_q = torch.arange( + 0, + batch_size + 1, + dtype=cu_seqlens_k.dtype, + device=cu_seqlens_k.device, + ) + elif q.size(0) == k.size(0): + cu_seqlens_q = cu_seqlens_k + else: + raise ValueError("cu_seqlens_q must be specified\ + if it mix of prefilling and decoding.") + else: + assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0) + + # switch to use cpu to avoid too many kernel launches when iterated over + q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu() + k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu() + + assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), ( + "length of q should either be 1 (decoding) or same as k (prefilling).") + + if max_seqlen: + assert k_lens.max() <= max_seqlen + + n_blocks = (q_lens + q_block_size - 1) // q_block_size + + q_batch_ids = torch.tensor( + [i for i, n in enumerate(n_blocks) for _ in range(n)], + dtype=cu_seqlens_q.dtype, + device=cu_seqlens_q.device, + ) + q_start_sids = torch.tensor( + [i * q_block_size for n in n_blocks for i in range(n)], + dtype=cu_seqlens_q.dtype, + device=cu_seqlens_q.device, + ) + + out = q.new_empty(q.shape) + cu_seqlens_q = cu_seqlens_q.contiguous() + cu_seqlens_k = cu_seqlens_k.contiguous() + + layout_crow_indices, layout_col_indices = sparse_layout + block_d = triton.next_power_of_2(head_size) + + decoding_only = (q_lens == 1).all().item() + grid = (len(q_start_sids), n_heads, 1) + + _fwd_kernel_batch_inference[grid]( + q, + k, + v, + out, + sm_scale, + cu_seqlens_q[:-1], + cu_seqlens_q[1:], + cu_seqlens_k[:-1], + cu_seqlens_k[1:], + q_batch_ids, + q_start_sids, + 0, + *q.stride(), + 0, + *k.stride(), + 0, + *v.stride(), + 0, + *out.stride(), + layout_crow_indices, + layout_col_indices, + *layout_crow_indices.stride(), + *layout_col_indices.stride(), + q_k_ratio, + HAS_BATCH_DIM=False, + D_HEAD=head_size, + BLOCK_M=q_block_size, + BLOCK_N=block_size, + BLOCK_D=block_d, + BLOCK_M_LOADING=(16 if decoding_only else + q_block_size), # smaller for decoding + EVEN_D=block_d == head_size, + num_warps=1 if decoding_only else 4, + num_stages=3) + + return out + + +@triton.jit +def _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + Q, + k_block_col_idx, + layout_col_ptr, + layout_col_stride_h, + layout_col_stride_m, + k_ptrs, + v_ptrs, + off_h, + offs_m, + offs_n, + offs_d, + stride_kt, + stride_vt, + sm_scale, + k_seqlen, + past_len, + LAST_K_BLOCK: tl.constexpr, + BLOCK_M_LOADING: tl.constexpr, + BLOCK_N: tl.constexpr, + D_HEAD: tl.constexpr, + EVEN_D: tl.constexpr, + M_LT_N: tl.constexpr, +): + k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h + + k_block_col_idx * layout_col_stride_m).to(tl.int32) + start_n = k_block_id * BLOCK_N + if LAST_K_BLOCK: + if EVEN_D: + k = tl.load( + k_ptrs + start_n * stride_kt, + mask=offs_n[None, :] + start_n < k_seqlen, + other=0.0, + ) + else: + k = tl.load( + k_ptrs + start_n * stride_kt, + mask=(offs_n[None, :] + start_n < k_seqlen) & + (offs_d[:, None] < D_HEAD), + other=0.0, + ) + else: + if EVEN_D: + k = tl.load(k_ptrs + start_n * stride_kt) + else: + k = tl.load(k_ptrs + start_n * stride_kt, + mask=offs_d[:, None] < D_HEAD, + other=0.0) + + qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + + # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N + if LAST_K_BLOCK | M_LT_N: + qk += tl.where( + offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), + 0, + float("-inf"), + ) + + # flash-attn2 + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + p = tl.math.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + # update m_i + m_i = m_ij + l_i = l_i * alpha + l_ij + + p = p.to(Q.dtype.element_ty) + # update acc + if LAST_K_BLOCK: + if EVEN_D: + v = tl.load( + v_ptrs + start_n * stride_vt, + mask=offs_n[:, None] + start_n < k_seqlen, + other=0.0, + ) + else: + v = tl.load( + v_ptrs + start_n * stride_vt, + mask=(offs_n[:, None] + start_n < k_seqlen) & + (offs_d[None, :] < D_HEAD), + other=0.0, + ) + else: + if EVEN_D: + v = tl.load(v_ptrs + start_n * stride_vt) + else: + v = tl.load(v_ptrs + start_n * stride_vt, + mask=offs_d[None, :] < D_HEAD, + other=0.0) + + acc += tl.dot(p, v) + + return acc, l_i, m_i + + +@triton.heuristics({ + "M_LT_N": + lambda kwargs: kwargs["BLOCK_M"] < kwargs["BLOCK_N"], +}) +@triton.jit +def _fwd_kernel_batch_inference( + Q, + K, + V, + Out, + sm_scale, + q_batch_starts, + q_batch_ends, + k_batch_starts, + k_batch_ends, + q_batch_ids, + q_start_sids, + stride_qb, + stride_qt, + stride_qh, + stride_qd, + stride_kb, + stride_kt, + stride_kh, + stride_kd, + stride_vb, + stride_vt, + stride_vh, + stride_vd, + stride_ob, + stride_ot, + stride_oh, + stride_od, + layout_crow_ptr, + layout_col_ptr, + layout_crow_stride_h, + layout_crow_stride_m, + layout_col_stride_h, + layout_col_stride_m, + q_k_ratio, + HAS_BATCH_DIM: tl.constexpr, + D_HEAD: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_M_LOADING: tl.constexpr, + EVEN_D: tl.constexpr, + M_LT_N: tl.constexpr, +): + """ + NOTATION: + pid: position id + sid: storage id + sbid: storage block id + pbid: position block id + offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col) + + TODO(linxihui): + Optimize grouped-attn + """ + off_zm = tl.program_id(0) + off_h = tl.program_id(1) + + off_h_for_kv = off_h // q_k_ratio + + if HAS_BATCH_DIM: + off_z = tl.program_id(2) + Q += off_z * stride_qb + K += off_z * stride_kb + V += off_z * stride_vb + Out += off_z * stride_ob + start_m = off_zm + q_start_sid = start_m * BLOCK_M # always 0 for decoding + else: + off_z = tl.load(q_batch_ids + off_zm).to(tl.int32) # [0, 0, 0, 1] + q_start_sid = tl.load(q_start_sids + off_zm) + start_m = q_start_sid // BLOCK_M # q_sbid + + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + + q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32) + q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start + k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32) + k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start + past_len = k_seqlen - q_seqlen + + Q += q_cu_start * stride_qt + off_h * stride_qh + K += k_cu_start * stride_kt + off_h_for_kv * stride_kh + V += k_cu_start * stride_vt + off_h_for_kv * stride_vh + Out += q_cu_start * stride_ot + off_h * stride_oh + + q_pbid = (past_len + q_start_sid) // BLOCK_M + + if EVEN_D: + q = tl.load( + Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, + mask=offs_m[:, None] < q_seqlen, + other=0.0, + ) + else: + q = tl.load( + Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, + mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD), + other=0.0, + ) + + sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h + + q_pbid * layout_crow_stride_m) + + # TODO(linxihui): load at once, with any Triton version + # that supports `tl.split`, e.g., Triton 3.0 + k_block_start = tl.load(sparse_crow_ptr).to(tl.int32) + k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32) + + m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) + acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32) + + k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd + v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd + + sm_scale *= ( + 1.44269504 # 1/log2 as we use base2 for exponential and logarithm + ) + + for k_block_col_idx in range(k_block_start, k_block_end - 1): + acc, l_i, m_i = _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + Q, + k_block_col_idx, + layout_col_ptr, + layout_col_stride_h, + layout_col_stride_m, + k_ptrs, + v_ptrs, + off_h, + offs_m, + offs_n, + offs_d, + stride_kt, + stride_vt, + sm_scale, + k_seqlen, + past_len, + False, + BLOCK_M_LOADING, + BLOCK_N, + D_HEAD, + EVEN_D, + M_LT_N, + ) + + acc, l_i, m_i = _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + Q, + k_block_end - 1, + layout_col_ptr, + layout_col_stride_h, + layout_col_stride_m, + k_ptrs, + v_ptrs, + off_h, + offs_m, + offs_n, + offs_d, + stride_kt, + stride_vt, + sm_scale, + k_seqlen, + past_len, + True, + BLOCK_M_LOADING, + BLOCK_N, + D_HEAD, + EVEN_D, + M_LT_N, + ) + + # flash-attn 2 + m_i += tl.math.log2(l_i) + acc = acc / l_i[:, None] + + # write output + if EVEN_D: + tl.store( + Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, + acc, + mask=offs_m[:, None] < q_seqlen, + ) + else: + tl.store( + Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, + acc, + mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD), + ) diff --git a/vllm/attention/ops/blocksparse_attention/interface.py b/vllm/attention/ops/blocksparse_attention/interface.py new file mode 100644 index 0000000..6ab69ea --- /dev/null +++ b/vllm/attention/ops/blocksparse_attention/interface.py @@ -0,0 +1,238 @@ +# SPDX-License-Identifier: Apache-2.0 + +import math + +import torch + +from vllm.platforms import current_platform + +from .utils import (dense_to_crow_col, get_head_sliding_step, + get_sparse_attn_mask) + +IS_COMPUTE_8_OR_ABOVE = current_platform.has_device_capability(80) + +if IS_COMPUTE_8_OR_ABOVE: + from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd + + +class LocalStridedBlockSparseAttn(torch.nn.Module): + + def __init__( + self, + n_heads, + max_seqlen, + local_blocks, + vert_stride, + block_size, + device=None, + dtype=None, + homo_head=False, + active_head_range=None, + q_block_size=None, + use_spda=None, + ): + super().__init__() + if use_spda is None: + use_spda = current_platform.is_rocm() or \ + current_platform.is_cpu() or not \ + IS_COMPUTE_8_OR_ABOVE + device = device or (torch.cuda.current_device() + if current_platform.is_cuda_alike() else "cpu") + device = torch.device(device) + # NOTE: vllm CPU backend support BF16 instead of FP16. + dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE + or device.type == "cpu" else torch.half) + + self.n_heads = n_heads + self.max_seqlen = max_seqlen + self.local_blocks = local_blocks + self.vert_stride = vert_stride + self.use_spda = use_spda + self.dtype = dtype + self.device = device + self.block_size = block_size + self.q_block_size = q_block_size + self.homo_head = homo_head + self.active_head_range = active_head_range + self.head_sliding_step = get_head_sliding_step(n_heads, vert_stride, + homo_head) + + sparse_layout, sparse_pattern, self.dense_attn_mask = ( + self.get_attn_pattern(dtype, device)) + + if q_block_size is not None and q_block_size != block_size: + if q_block_size > block_size: + assert q_block_size % block_size == 0 + blocks_to_merge = q_block_size // block_size + shape = sparse_pattern.shape + sparse_pattern = sparse_pattern.view(shape[0], -1, + blocks_to_merge, + shape[-1]) + sparse_pattern = sparse_pattern.sum(2) + sparse_layout = dense_to_crow_col(sparse_pattern) + else: + raise ValueError( + "Does not support smaller q_block_size. It will be slower." + ) + + self.sparse_layout = sparse_layout + + def get_attn_pattern(self, dtype, device): + sparse_layout, sparse_pattern, dense_attn_mask = get_sparse_attn_mask( + self.n_heads, + self.max_seqlen, + self.max_seqlen, + dtype, + device, + block_size=self.block_size, + local_blocks=self.local_blocks, + vert_stride=self.vert_stride, + homo_head=self.homo_head, + return_dense=self.use_spda, + dense_mask_type="bias", + ) + if (not self.homo_head) and (self.active_head_range is not None): + assert isinstance(self.active_head_range, tuple) + assert (len(self.active_head_range) == 2) + h_start, h_end = self.active_head_range + sparse_layout = tuple(x[h_start:h_end] for x in sparse_layout) + if self.use_spda: + dense_attn_mask = dense_attn_mask[h_start:h_end] + return sparse_layout, sparse_pattern, dense_attn_mask + + def varlen_attn(self, + q, + k, + v, + cu_seqlens_k, + cu_seqlens_q=None, + sm_scale=None): + """ + q, k, v: shape = (num_tokens, num_heads_q/kv, head_size). + Support grouped attention, with `q[:, i*r:(i*r + r)]` + is correspondent to `k[:, i]`, where `r` is the q/k ratio. + cu_seqlens_k: shape=(batch_size + 1,), + indicating segment of samples, + e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i + cu_seqlens_q: shape=(batch_size + 1, ). + Default None: same as cu_seqlens_k for prefilling or + [0, 1, .., batch_size] for decoding. + The only case you need to specify is when q is a mix of + prefilling and decoding. + sm_scale: softmax scale, default to 1/sqrt(head_size). + + return: tensor of shape as q. + """ + assert ( + IS_COMPUTE_8_OR_ABOVE + ), "Requires compute capability of 8 or above (Ampere or newer) to use \ + Triton kernel." + + sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1)) + + return blocksparse_flash_attn_varlen_fwd( + q, + k, + v, + cu_seqlens_k, + cu_seqlens_q, + sm_scale, + self.sparse_layout, + block_size=self.block_size, + q_block_size=self.q_block_size, + max_seqlen=self.max_seqlen, + ) + + @staticmethod + def transpose_and_pad(x, cu_seqlens, maxlen, head_repeats=1): + """ + :param x: (total_tokens, n_heads, head_size) + :return: (batch, n_heads, length, head_size) + """ + x_padded = x.new_empty( + len(cu_seqlens) - 1, x.size(1), head_repeats, maxlen, x.size(2)) + cu_seqlens = cu_seqlens.cpu() + for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])): + x_padded[i, :, :, :e - s].copy_(x[s:e].transpose(0, + 1).unsqueeze(1)) + return x_padded.flatten(1, 2) + + @staticmethod + def transpose_and_unpad(x_padded, cu_seqlens): + """ + :param x_padded: (batch, n_heads, length, head_size) + :return: (total_tokens, n_heads, head_size) + """ + cu_seqlens = cu_seqlens.cpu() + total_n_tokens = cu_seqlens[-1] + x = x_padded.new_empty(total_n_tokens, x_padded.size(1), + x_padded.size(3)) + for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])): + x[s:e].copy_(x_padded[i, :, :e - s].transpose(0, 1)) + return x + + def spda(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None): + """For CPU, V100 or other older GPUs. + NOTE: torch SPDA supports nested tensor, + but seems extremely slow. Choose to pad instead. + """ + assert (cu_seqlens_q is None or + (cu_seqlens_q + == cu_seqlens_k).all()), "Can only handle prompt with SPDA." + assert q.size(0) == k.size(0), "can only handle prompt with SPDA." + + assert q.size(1) % k.size(1) == 0 + q_k_ratio = q.size(1) // k.size(1) + sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1)) + cu_seqlens = cu_seqlens_k.cpu() + maxlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + + if (self.dense_attn_mask.dtype != q.dtype + or self.dense_attn_mask.device != q.device): + _, _, self.dense_attn_mask = self.get_attn_pattern( + q.dtype, q.device) + attn_mask = self.dense_attn_mask[None, :, :maxlen, :maxlen] + + q2 = self.transpose_and_pad(q, cu_seqlens, maxlen, 1) + k2, v2 = (self.transpose_and_pad(x, cu_seqlens, maxlen, q_k_ratio) + for x in [k, v]) + spda_output = torch.nn.functional.scaled_dot_product_attention( + q2, k2, v2, attn_mask=attn_mask, scale=sm_scale) + return self.transpose_and_unpad(spda_output, cu_seqlens) + + def forward(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None): + """Dispatch to `varlen_attn` (Ampere or newer) or + `self.spda`(cpu, Volta, Turing or older)based on + the type of device used and cuda compute capability. + + q, k, v: shape = (num_tokens, num_heads_q/kv, head_size). + Support grouped attention, with `q[:, i*r:(i*r + r)]` + is correspondent to `k[:, i]`, where `r` is the q/k ratio. + cu_seqlens_k: shape=(batch_size + 1,), indicating segment of samples, + e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i + cu_seqlens_q: shape=(batch_size + 1, ). + Default None: same as cu_seqlens_k for prefilling or + [0, 1, .., batch_size] for decoding. + The only case you need to specify + is when q is a mix of prefilling + and decoding. + sm_scale: softmax scale, default to 1/sqrt(head_size). + + return: tensor of shape as q. + """ + assert k.dim() == 3 + if self.use_spda: + return self.spda( + q, + k, + v, + cu_seqlens_k, + cu_seqlens_q=cu_seqlens_q, + sm_scale=sm_scale, + ) + return self.varlen_attn(q, + k, + v, + cu_seqlens_k, + cu_seqlens_q=cu_seqlens_q, + sm_scale=sm_scale) diff --git a/vllm/attention/ops/blocksparse_attention/utils.py b/vllm/attention/ops/blocksparse_attention/utils.py new file mode 100644 index 0000000..4de9bd5 --- /dev/null +++ b/vllm/attention/ops/blocksparse_attention/utils.py @@ -0,0 +1,244 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Helper functions for 3D sparse pattern +# These function are not optimized and very inefficient. +# Avoid calling them too frequent or use a cache mechanism. + +from functools import lru_cache + +import numpy as np +import torch +import triton + + +class csr_matrix: + """Simple implementation of CSR matrix conversion without scipy. + This replaced scipy.sparse.csr_matrix() previously used.""" + + def __init__(self, input_array): + if not isinstance(input_array, np.ndarray): + raise ValueError("Input must be a NumPy array") + + self.shape = input_array.shape + rows, cols = self.shape + data = [] + indices = [] + indptr = [0] + + for i in range(rows): + for j in range(cols): + if input_array[i, j]: + data.append(input_array[i, j]) + indices.append(j) + indptr.append(len(indices)) + + self.data = np.array(data) + self.indices = np.array(indices) + self.indptr = np.array(indptr) + + +def dense_to_crow_col(x: torch.Tensor): + """Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing. + NOTE: col_indices padded -1 + """ + device = x.device + pad = -1 + dim = x.dim() + assert x.dim() in (2, 3) + if x.dim() == 2: + x = x[None] + x = [csr_matrix(xi.bool().cpu().numpy()) for xi in x] + crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x]) + cols = [torch.from_numpy(xi.indices) for xi in x] + max_cols = max(len(xi) for xi in cols) + cols = [ + torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])]) + for xi in cols + ] + cols = torch.vstack(cols) + if dim == 2: + crows = crows[0] + cols = cols[0] + return crows.to(device), cols.to(device) + + +def crow_col_to_dense(crows: torch.Tensor, + cols: torch.Tensor, + dtype: torch.dtype = torch.float16): + dim = crows.dim() + if dim == 1: + crows = crows[None] + cols = cols[None] + device = crows.device + crows, cols = crows.cpu(), cols.cpu() # faster in cpu + shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1) + x = torch.zeros(shape, dtype=dtype) + for i in range(shape[0]): + for j in range(shape[1]): + x[i, j, cols[i, crows[i, j]:crows[i, j + 1]]] = 1 + if dim == 1: + x = x[0] + return x.to(device) + + +def dense_to_ccol_row(x: torch.Tensor): + """Similar, but to CSC format""" + x = x.transpose(-2, -1) + return dense_to_crow_col(x) + + +def ccol_row_to_dense(ccol: torch.Tensor, + rows: torch.Tensor, + dtype: torch.dtype = torch.float16): + return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous() + + +def _get_sparse_attn_mask_homo_head( + q_len: int, + max_seqlen: int, + dtype: torch.dtype, + device: torch.device, + block_size: int = 128, + local_blocks: int = 4, + vert_stride: int = 4, + return_dense: bool = False, +): + """ + :return: a tuple of 3: + - tuple of crow_indices, col_indices representation + of CSR format. + - block dense mask + - all token dense mask (be aware that it can be + OOM if it is too big) if `return_dense==True`, + otherwise, None + """ + with torch.no_grad(): + num_blocks = triton.cdiv(max_seqlen, block_size) + q_pos = torch.arange(num_blocks)[:, None] + k_pos = torch.arange(num_blocks)[None] + mask_vert_strided = (torch.arange(num_blocks) + 1) % vert_stride == 0 + block_mask_dense = (((q_pos >= k_pos) + & ((q_pos - k_pos < local_blocks) + | mask_vert_strided)).to(device).to(dtype)) + num_blocks_q = triton.cdiv(q_len, block_size) + block_mask_dense_output = (dense_to_crow_col( + block_mask_dense[-num_blocks_q:].contiguous())) + if return_dense: + mask_dense = torch.kron( + block_mask_dense, + block_mask_dense.new_ones((block_size, block_size)), + ) + causal_mask = torch.tril(torch.ones( + max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:] + mask_dense = mask_dense[-q_len:, :max_seqlen] * causal_mask + return ( + block_mask_dense_output, + block_mask_dense, + mask_dense, + ) + else: + return ( + block_mask_dense_output, + block_mask_dense, + None, + ) + + +def binary_mask_to_bias(mask_dense: torch.Tensor): + mask_dense = 1 - mask_dense + mask_dense.masked_fill_(mask_dense.bool(), -torch.inf) + return mask_dense + + +def get_head_sliding_step(n_heads: int, + vert_stride: int, + homo_head: bool = False): + if homo_head: + return 0 + return max(1, int(vert_stride / n_heads)) + + +@lru_cache +def get_sparse_attn_mask( + n_heads: int, + q_len: int, + max_seqlen: int, + dtype: torch.dtype, + device: torch.device, + block_size: int = 64, + local_blocks: int = 4, + vert_stride: int = 4, + homo_head: bool = True, + return_dense: bool = False, + dense_mask_type: str = "binary", +): + """ + :param dense_mask_type: "binary" (0 for skip token, 1 for others) + or "bias" (-inf for skip token, 0 or others) + :return: a tuple of 3: + - tuple of crow_indices, col_indices representation + of CSR format. + - block dense mask + - all token dense mask (be aware that it can be OOM if it + is too big) if `return_dense==True`, otherwise, None + """ + assert dense_mask_type in ("binary", "bias") + if homo_head: + with torch.no_grad(): + (crow, col), block_mask_dense, mask_dense = ( + _get_sparse_attn_mask_homo_head( + q_len, + max_seqlen, + dtype, + device, + block_size, + local_blocks, + vert_stride, + return_dense, + )) + crow = crow[None].expand(n_heads, crow.shape[0]) + col = col[None].expand(n_heads, col.shape[0]) + if return_dense: + mask_dense = mask_dense[None].expand(n_heads, + *mask_dense.shape) + if dense_mask_type == "bias": + mask_dense = binary_mask_to_bias(mask_dense) + return (crow, col), block_mask_dense, mask_dense + + with torch.no_grad(): + num_blocks = triton.cdiv(max_seqlen, block_size) + q_pos = torch.arange(num_blocks)[None, :, None] + k_pos = torch.arange(num_blocks)[None, None] + head_sliding_step = get_head_sliding_step(n_heads, vert_stride) + mask_vert_strided = [ + (torch.arange(num_blocks) + h * head_sliding_step + 1) % + vert_stride == 0 for h in range(n_heads) + ] + mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1) + block_mask_dense = (((q_pos >= k_pos) + & ((q_pos - k_pos < local_blocks) + | mask_vert_strided)).to(device).to(dtype)) + num_blocks_q = triton.cdiv(q_len, block_size) + block_mask_dense_output = block_mask_dense[:, -num_blocks_q:] + if return_dense: + mask_dense = torch.kron( + block_mask_dense, + block_mask_dense.new_ones((block_size, block_size)), + ) + causal_mask = torch.tril(torch.ones( + max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:] + mask_dense = mask_dense[..., -q_len:, :max_seqlen] * causal_mask[None] + if dense_mask_type == "bias": + mask_dense = binary_mask_to_bias(mask_dense) + + return ( + dense_to_crow_col(block_mask_dense_output), + block_mask_dense, + mask_dense, + ) + else: + return ( + dense_to_crow_col(block_mask_dense_output), + block_mask_dense, + None, + ) diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py new file mode 100644 index 0000000..1b47581 --- /dev/null +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -0,0 +1,366 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Authors: +# - Burkhard Ringlein +# - Jan van Lunteren +# - Chih-Chieh Yang +# - Thomas Parnell + +import torch +import triton +import triton.language as tl + +from vllm import _custom_ops as ops +from vllm.platforms.rocm import use_rocm_custom_paged_attention + +from .prefix_prefill import context_attention_fwd + + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@triton.jit +def kernel_paged_attention_2d( + output_ptr, # [num_tokens, num_query_heads, head_size] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] + value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + num_queries_per_kv_padded: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + BLOCK_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + x: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.int64, # int + stride_k_cache_4: tl.int64, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.int64, # int + filter_by_query_len: tl.constexpr, # bool + query_start_len_ptr, # [num_seqs+1] +): + seq_idx = tl.program_id(0) + kv_head_idx = tl.program_id(1) + + if filter_by_query_len: + cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + + 1) + cur_batch_query_len = cur_batch_in_all_stop_index \ + - cur_batch_in_all_start_index + if cur_batch_query_len > 1: + return + else: + cur_batch_in_all_start_index = seq_idx + + query_head_idx = kv_head_idx * num_queries_per_kv + tl.arange( + 0, num_queries_per_kv_padded) + + query_offset = (cur_batch_in_all_start_index * query_stride_0 + + query_head_idx[:, None] * query_stride_1) + + head_mask = query_head_idx < (kv_head_idx + 1) * num_queries_per_kv + head_mask = head_mask & (query_head_idx < num_query_heads) + + dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, + 0).to(tl.int1) + + # Q : (num_queries_per_kv, HEAD_SIZE,) + Q = tl.load( + query_ptr + query_offset + tl.arange(0, HEAD_SIZE_PADDED)[None, :], + mask=dim_mask[None, :] & head_mask[:, None], + other=0.0, + ) + + block_table_offset = seq_idx * block_table_stride + + M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32) + L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32) + acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED], + dtype=tl.float32) + + # sequence len for this particular sequence + seq_len = tl.load(seq_lens_ptr + seq_idx) + + # alibi slope for this head + if USE_ALIBI_SLOPES: + alibi_slope = tl.load(alibi_slopes_ptr + query_head_idx, + mask=head_mask, + other=0.0) + + num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) + + # iterate through tiles + for j in range(0, num_blocks): + + physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) + + offs_n = tl.arange(0, BLOCK_SIZE) + offs_d = tl.arange(0, HEAD_SIZE_PADDED) + + v_offset = (physical_block_idx * stride_v_cache_0 + + kv_head_idx * stride_v_cache_1 + + offs_d[None, :] * stride_v_cache_2 + + offs_n[:, None] * stride_v_cache_3) + + k_offset = (physical_block_idx * stride_k_cache_0 + + kv_head_idx * stride_k_cache_1 + + (offs_d[:, None] // x) * stride_k_cache_2 + + offs_n[None, :] * stride_k_cache_3 + + (offs_d[:, None] % x) * stride_k_cache_4) + + # K : (HEAD_SIZE, BLOCK_SIZE) + K_load = tl.load(key_cache_ptr + k_offset, + mask=dim_mask[:, None], + other=0.0) + + if K_load.dtype.is_fp8(): + K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) + else: + K = K_load + + # V : (BLOCK_SIZE, HEAD_SIZE) + V_load = tl.load(value_cache_ptr + v_offset, + mask=dim_mask[None, :], + other=0.0) + + if V_load.dtype.is_fp8(): + V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) + else: + V = V_load + + seq_offset = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + boundary = tl.full([BLOCK_SIZE], seq_len, dtype=tl.int32) + seq_mask = seq_offset[None, :] < boundary + + # S : (num_queries_per_kv, BLOCK_SIZE,) + S = tl.where(head_mask[:, None] & seq_mask, 0.0, + float("-inf")).to(tl.float32) + S += scale * tl.dot(Q, K) + + context_len = seq_len - 1 + + if SLIDING_WINDOW > 0: + S = tl.where((context_len - seq_offset) < SLIDING_WINDOW, S, + -10000) + + if USE_ALIBI_SLOPES: + S += alibi_slope[:, None] * (seq_offset - context_len) + + # compute running maximum + # m_j : (num_queries_per_kv,) + m_j = tl.maximum(M, tl.max(S, axis=1)) + + # P : (num_queries_per_kv, BLOCK_SIZE,) + P = tl.exp(S - m_j[:, None]) + + # l_j : (num_queries_per_kv,) + l_j = tl.sum(P, axis=1) + + # alpha : (num_queries_per_kv, ) + alpha = tl.exp(M - m_j) + + # acc : (num_queries_per_kv, BLOCK_SIZE,) + acc = acc * alpha[:, None] + + # update constants + L = L * alpha + l_j + M = m_j + + # acc : (num_queries_per_kv, BLOCK_SIZE,) + acc += tl.dot(P.to(V.dtype), V) + + # epilogue + acc = acc / L[:, None] + + output_offset = (cur_batch_in_all_start_index * output_stride_0 + + query_head_idx * output_stride_1) + + tl.store( + output_ptr + output_offset[:, None] + + tl.arange(0, HEAD_SIZE_PADDED)[None, :], + acc, + mask=dim_mask[None, :] & head_mask[:, None], + ) + + +def chunked_prefill_paged_decode( + query, + key, + value, + output, + kv_cache_dtype, + key_cache, + value_cache, + block_table, + query_start_loc, + seq_lens, + max_seq_len, + max_query_len, + k_scale, + v_scale, + alibi_slopes=None, + sliding_window=None, + sm_scale=None, +): + + if sm_scale is None: + sm_scale = 1.0 / (query.shape[1]**0.5) + + use_alibi_slopes = alibi_slopes is not None + + if sliding_window is None or sliding_window <= 0: + sliding_window = 0 + + if max_query_len > 1: + context_attention_fwd( + q=query, + k=key, + v=value, + o=output, + kv_cache_dtype=kv_cache_dtype, + k_cache=key_cache, + v_cache=value_cache, + b_loc=block_table, + b_start_loc=query_start_loc, + b_seq_len=seq_lens, + max_seq_len=max_seq_len, + max_input_len=max_query_len, + k_scale=k_scale, + v_scale=v_scale, + alibi_slopes=alibi_slopes, + sliding_window=sliding_window, + sm_scale=sm_scale, + skip_decode=True, + ) + + block_size = value_cache.shape[3] + num_seqs = len(seq_lens) + num_query_heads = query.shape[1] + num_kv_heads = key.shape[1] + num_queries_per_kv = query.shape[1] // key.shape[1] + head_size = query.shape[2] + + # Conversion of FP8 Tensor from uint8 storage to + # appropriate torch.dtype for interpretation by Triton + if "fp8" in kv_cache_dtype: + assert key_cache.dtype == torch.uint8 + assert value_cache.dtype == torch.uint8 + + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + target_dtype = torch.float8_e4m3fn + elif kv_cache_dtype == "fp8_e5m2": + target_dtype = torch.float8_e5m2 + else: + raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype) + + key_cache = key_cache.view(target_dtype) + value_cache = value_cache.view(target_dtype) + + num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv), + 16) + + use_custom = use_rocm_custom_paged_attention(query.dtype, head_size, + block_size, + num_queries_per_kv, + max_seq_len, sliding_window) + if use_custom: + _PARTITION_SIZE_ROCM = 256 + max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) // + _PARTITION_SIZE_ROCM) + assert _PARTITION_SIZE_ROCM % block_size == 0 + total_num_seq = query.shape[0] + tmp_output = torch.empty( + size=(total_num_seq, num_query_heads, max_num_partitions, + head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(total_num_seq, num_query_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + + ops.paged_attention_rocm( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale=sm_scale, + block_tables=block_table, + seq_lens=seq_lens, + query_start_loc=query_start_loc, + block_size=block_size, + max_seq_len=max_seq_len, + alibi_slopes=alibi_slopes, + kv_cache_dtype=kv_cache_dtype, + k_scale=k_scale, + v_scale=v_scale, + ) + else: + kernel_paged_attention_2d[( + num_seqs, + num_kv_heads, + )]( + output_ptr=output, + query_ptr=query, + key_cache_ptr=key_cache, + value_cache_ptr=value_cache, + block_tables_ptr=block_table, + seq_lens_ptr=seq_lens, + alibi_slopes_ptr=alibi_slopes, + scale=sm_scale, + k_scale=k_scale, + v_scale=v_scale, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + num_queries_per_kv_padded=num_queries_per_kv_padded, + block_table_stride=block_table.stride(0), + query_stride_0=query.stride(0), + query_stride_1=query.stride(1), + output_stride_0=output.stride(0), + output_stride_1=output.stride(1), + BLOCK_SIZE=block_size, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + SLIDING_WINDOW=sliding_window, + x=key_cache.shape[4], + stride_k_cache_0=key_cache.stride(0), + stride_k_cache_1=key_cache.stride(1), + stride_k_cache_2=key_cache.stride(2), + stride_k_cache_3=key_cache.stride(3), + stride_k_cache_4=key_cache.stride(4), + stride_v_cache_0=value_cache.stride(0), + stride_v_cache_1=value_cache.stride(1), + stride_v_cache_2=value_cache.stride(2), + stride_v_cache_3=value_cache.stride(3), + filter_by_query_len=True, + query_start_len_ptr=query_start_loc, + ) diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py new file mode 100644 index 0000000..18b69a6 --- /dev/null +++ b/vllm/attention/ops/flashmla.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: Apache-2.0 +# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py +from typing import Optional, Tuple + +import torch + +from vllm.logger import init_logger +from vllm.platforms import current_platform + +logger = init_logger(__name__) + +if current_platform.is_cuda(): + try: + import vllm._flashmla_C # noqa: F401 + _flashmla_C_AVAILABLE = True + except ImportError: + _flashmla_C_AVAILABLE = False +else: + _flashmla_C_AVAILABLE = False + + +def is_flashmla_supported() -> Tuple[bool, Optional[str]]: + """ + Return: is_supported_flag, unsupported_reason (optional). + """ + if not current_platform.is_cuda(): + return False, "FlashMLA is only supported on CUDA devices." + if current_platform.get_device_capability()[0] != 9: + return False, "FlashMLA is only supported on Hopper devices." + if not _flashmla_C_AVAILABLE: + return False, "vllm._flashmla_C is not available, likely was not "\ + "compiled due to insufficient nvcc version or a supported arch "\ + "(only sm90a currently) was not in the list of target arches to "\ + "compile for." + return True, None + + +def get_mla_metadata( + cache_seqlens: torch.Tensor, + num_heads_per_head_k: int, + num_heads_k: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + cache_seqlens: (batch_size), dtype torch.int32. + num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k. + num_heads_k: num_heads_k. + + Return: + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), + dtype torch.int32. + num_splits: (batch_size + 1), dtype torch.int32. + """ + return torch.ops._flashmla_C.get_mla_metadata(cache_seqlens, + num_heads_per_head_k, + num_heads_k) + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + head_dim_v: int, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, + softmax_scale: Optional[float] = None, + causal: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + q: (batch_size, seq_len_q, num_heads_q, head_dim). + k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). + block_table: (batch_size, max_num_blocks_per_seq), torch.int32. + cache_seqlens: (batch_size), torch.int32. + head_dim_v: Head_dim of v. + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), + torch.int32, return by get_mla_metadata. + num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(head_dim). + causal: bool. Whether to apply causal attention mask. + + Return: + out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. + """ + if softmax_scale is None: + softmax_scale = q.shape[-1]**(-0.5) + out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla( + q, + k_cache, + None, + head_dim_v, + cache_seqlens, + block_table, + softmax_scale, + causal, + tile_scheduler_metadata, + num_splits, + ) + return out, softmax_lse + + +# +# TODO: Add fake functions +# +# @register_fake("_flashmla_C::get_mla_metadata") +# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]: +# return .... +# +# @register_fake("_flashmla_C::fwd_kvcache_mla") +# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]: +# return .... +# diff --git a/vllm/attention/ops/hpu_paged_attn.py b/vllm/attention/ops/hpu_paged_attn.py new file mode 100644 index 0000000..49ea420 --- /dev/null +++ b/vllm/attention/ops/hpu_paged_attn.py @@ -0,0 +1,106 @@ +# SPDX-License-Identifier: Apache-2.0 + +############################################################################### +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company +############################################################################### + +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import torch +from vllm_hpu_extension import cache_ops, ops + +# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. +_PARTITION_SIZE = 512 + + +@dataclass +class HPUPagedAttentionMetadata: + """Metadata for PagedAttention.""" + block_list: Optional[torch.Tensor] + block_mapping: Optional[torch.Tensor] + block_usage: Optional[torch.Tensor] + block_indices: Optional[torch.Tensor] + block_offsets: Optional[torch.Tensor] + block_scales: Optional[torch.Tensor] + block_groups: Optional[torch.Tensor] + + +class HPUPagedAttention: + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [64, 80, 96, 112, 128, 256] + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def split_kv_cache( + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + key_cache = kv_cache[0] + value_cache = kv_cache[1] + return key_cache, value_cache + + @staticmethod + def write_to_paged_cache(key: torch.Tensor, value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, kv_cache_dtype: str, + is_prompt: bool) -> None: + cache_ops.reshape_and_cache(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype, is_prompt) + + @staticmethod + def forward_decode(**kwargs) -> torch.Tensor: + return ops.flat_pa(**kwargs) + + @staticmethod + def forward_prefix( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + subquery_start_loc: torch.Tensor, + seq_lens_tensor: torch.Tensor, + context_lens: torch.Tensor, + max_query_len: int, + alibi_slopes: Optional[torch.Tensor], + sliding_window: Optional[int], + ) -> torch.Tensor: + raise NotImplementedError( + "forward_prefix is not implemented for HPUPagedAttention") + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + src_key_cache = src_kv_cache[0] + dst_key_cache = dst_kv_cache[0] + cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + + src_value_cache = src_kv_cache[1] + dst_value_cache = dst_kv_cache[1] + cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + cache_ops.copy_blocks(key_caches, value_caches, src_to_dists) diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py new file mode 100644 index 0000000..6d96f58 --- /dev/null +++ b/vllm/attention/ops/ipex_attn.py @@ -0,0 +1,193 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, List, Optional, Tuple + +try: + import intel_extension_for_pytorch.llm.modules as ipex_modules + _use_ipex = True +except ImportError: + _use_ipex = False + +import torch + +from vllm import _custom_ops as ops + + +class _PagedAttention: + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [32, 64, 80, 96, 112, 128, 192, 256] + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + *args, + ) -> Tuple[int, ...]: + return (2, num_blocks, block_size * num_kv_heads * head_size) + + @staticmethod + def split_kv_cache( + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + *args, + ) -> Tuple[torch.Tensor, torch.Tensor]: + x = 16 // kv_cache.element_size() + num_blocks = kv_cache.shape[1] + + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, + -1, x) + value_cache = kv_cache[1] + value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) + return key_cache, value_cache + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + *args, + ) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping.flatten(), + kv_cache_dtype, + k_scale, + v_scale, + ) + + @staticmethod + def forward_decode( + output: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + max_context_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + k_scale: torch.Tensor, + v_scale: torch.Tensor, + *args, + ) -> None: + tp_rank: int = 0 + blocksparse_local_blocks: int = 0 + blocksparse_vert_stride: int = 0 + blocksparse_block_size: int = 64 + blocksparse_head_sliding_step: int = 0 + block_size = value_cache.shape[3] + + ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + *args, + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists) + + +class _IPEXPagedAttention(_PagedAttention): + + @staticmethod + def split_kv_cache( + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + *args, + ) -> Tuple[torch.Tensor, torch.Tensor]: + num_blocks = kv_cache.shape[1] + + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size) + value_cache = kv_cache[1] + value_cache = value_cache.view(num_blocks, num_kv_heads, -1, head_size) + return key_cache, value_cache + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + *args, + ) -> None: + ipex_modules.PagedAttention.reshape_and_cache( + key, value, key_cache, value_cache, + slot_mapping.flatten().int()) + + @staticmethod + def forward_decode( + output: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + max_context_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + k_scale: torch.Tensor, + v_scale: torch.Tensor, + *args, + ) -> None: + block_size = value_cache.shape[2] + head_mapping = torch.arange( + 0, + num_kv_heads, + device="cpu", + dtype=torch.int32, + ).view(num_kv_heads, + 1).repeat_interleave(query.size(1) // num_kv_heads).flatten() + ipex_modules.PagedAttention.single_query_cached_kv_attention( + output, query.contiguous(), key_cache, value_cache, head_mapping, + scale, block_tables, context_lens, block_size, max_context_len, + alibi_slopes) + + +PagedAttention = _IPEXPagedAttention if _use_ipex else _PagedAttention diff --git a/vllm/attention/ops/nki_flash_attn.py b/vllm/attention/ops/nki_flash_attn.py new file mode 100644 index 0000000..6bce587 --- /dev/null +++ b/vllm/attention/ops/nki_flash_attn.py @@ -0,0 +1,905 @@ +# SPDX-License-Identifier: Apache-2.0 + +import neuronxcc.nki.isa as nisa +import neuronxcc.nki.language as nl +import numpy as np +import torch +from neuronxcc import nki +from neuronxcc.nki.language import par_dim + + +def ceil_div(a, b): + return (a + b - 1) // b + + +def is_power_of_2(x): + return x > 0 and (x & (x - 1)) == 0 + + +@nki.jit +def load_block_tables(block_tables_hbm, num_tiles, num_blocks_per_tile): + """ + Load block tables from HBM into SRAM + + `block_tables_hbm` has shape `(num_tiles * num_blocks_per_tile, )`. + In case `num_tiles > B_P_SIZE`, we need further tile `num_tile` dimension. + """ + B_P_SIZE = 128 + + # reshape as `(num_tiles, num_blocks_per_tile)` + assert len(block_tables_hbm.shape) == 1 + (num_total_blocks, ) = block_tables_hbm.shape + assert num_blocks_per_tile * num_tiles == num_total_blocks + block_tables_hbm = block_tables_hbm.reshape( + (num_tiles, num_blocks_per_tile)) + + block_tables_sbuf = nl.zeros( + (ceil_div(num_tiles, + B_P_SIZE), par_dim(B_P_SIZE), num_blocks_per_tile), + dtype=nl.int32, + ) + for i in nl.affine_range(ceil_div(num_tiles, B_P_SIZE)): + i_p = nl.arange(B_P_SIZE)[:, None] + i_f = nl.arange(num_blocks_per_tile)[None, :] + block_tables_sbuf[i, i_p, i_f] = nl.load( + block_tables_hbm[i_p + i * B_P_SIZE, i_f], + dtype=nl.int32, + mask=(i_p + i * B_P_SIZE < num_tiles), + ) + return block_tables_sbuf + + +@nki.jit +def transform_block_tables_for_indirect_load( + block_tables, + block_size_tiling_factor, + num_head, + head_id, +): + """ + This function does two things: + 1. calculate new `block_tables` for a `head_id` after flattening + `num_block`, `num_head`, and `block_size_tiling_factor` dimensions + 2. transpose the result so that `block_table` for each tile is mapped to + SBUF Partition dimension for vectorized DMA + + Tiling trick to further improve DMA performance: + Given KV cache shape `(num_block, num_head, block_size, D)`, when loading M + blocks of a given `head_id` from HBM, the load `cache[block_tables, + head_id]` has shape `(M, block_size, D)`. If M < B_P_SIZE = 128, DMA may not + fully utilize hardware parallelization. The solution is to tile `block_size` + into `(block_size_tiling_factor, tiled_block_size)` s.t. `M * + block_size_tiling_factor = B_P_SIZE`. After tiling, KV cache has shape + `(num_block, num_head, block_size_tiling_factor, tiled_block_size, D)`. + + Note: + We don't further tile D dimension as small DMA size also hurts performance. + """ + B_P_SIZE = 128 + num_partitions, num_tiles_per_partition, num_blocks_per_tile = ( + block_tables.shape) + assert num_tiles_per_partition == B_P_SIZE + assert is_power_of_2( + num_blocks_per_tile), f"{num_blocks_per_tile=} is not power of 2" + + num_loads = ceil_div(num_blocks_per_tile, B_P_SIZE) + block_tables_transposed = nl.ndarray( + ( + num_loads, + par_dim(B_P_SIZE), + num_partitions * num_tiles_per_partition, + ), + dtype=nl.int32, + ) + + # prepare iota ahead of time to avoid repeatedly using Gpsimd + if num_head > 1: + head_id = nisa.iota(head_id, dtype=nl.int32).reshape((1, 1)) + head_id = nl.transpose( + head_id.broadcast_to((1, num_tiles_per_partition))) + if num_blocks_per_tile > 1: + head_id = head_id.broadcast_to( + (num_tiles_per_partition, num_blocks_per_tile)) + + if block_size_tiling_factor > 1: + broadcast_shape = ( + num_tiles_per_partition, + num_blocks_per_tile, + block_size_tiling_factor, + ) + offset = nisa.iota(nl.arange(block_size_tiling_factor)[None, None, :], + dtype=nl.int32).broadcast_to(broadcast_shape) + + for partition_id in nl.affine_range(num_partitions): + block_tables_partition = block_tables[partition_id] + if num_head > 1: + # fuse num_block and num_head dimension + block_tables_partition = block_tables_partition * num_head + head_id + + if block_size_tiling_factor > 1: + # need to apply block size tiling trick + assert num_blocks_per_tile * block_size_tiling_factor == B_P_SIZE + block_tables_partition = ((block_tables_partition * + block_size_tiling_factor).reshape( + (num_tiles_per_partition, + num_blocks_per_tile, + 1)).broadcast_to(broadcast_shape)) + new_block_tables = block_tables_partition + offset + new_block_tables = new_block_tables.reshape( + (num_tiles_per_partition, B_P_SIZE)) + else: + new_block_tables = block_tables_partition + + # transpose the block table so that it can be used by vector DGE + for i in nl.affine_range(num_loads): + i_p = nl.arange(B_P_SIZE)[:, None] + i_f = (partition_id * num_tiles_per_partition + + nl.arange(num_tiles_per_partition)[None, :]) + block_tables_transposed[i, i_p, i_f] = nl.transpose( + new_block_tables[:, nl.ds(i * B_P_SIZE, B_P_SIZE)]) + return block_tables_transposed + + +@nki.jit +def load_kv_tile_from_cache( + cur_k_tile, + cur_v_tile, + kv_cache, + block_tables, + large_k_tile_idx, + num_blocks_per_large_tile, + tiled_block_size, + B_P_SIZE, + B_D_SIZE, +): + """ + Load KV cache and transform Key and Value into layout required by Matmul + + Vectorized DMA Load layout: + Key and Value: (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE) + + Layout used by attention matmuls: + Key: (par_dim(B_D_SIZE), seqlen_kv) + Value: (seqlen_kv // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE) + equivalent to (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE) + """ + # load key cache + num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE) + for load_idx in nl.affine_range(num_loads): + i_p = nl.arange(B_P_SIZE)[:, None] + i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :] + loaded = nl.load(kv_cache[0, block_tables[load_idx, i_p, + large_k_tile_idx], i_f]) + if cur_k_tile.dtype != loaded.dtype: + loaded = nl.copy(loaded, dtype=cur_k_tile.dtype) + # Transpose SBUF tensor using PE + for tb_i in nl.affine_range(tiled_block_size): + cur_k_tile[ + :, + nl.ds( + load_idx * B_P_SIZE * tiled_block_size + tb_i * B_P_SIZE, + B_P_SIZE, + ), + ] = nl.transpose(loaded[:, nl.ds(tb_i * B_D_SIZE, B_D_SIZE)]) + + # load value cache + for load_idx in nl.affine_range(num_loads): + loaded = nl.load(kv_cache[1, block_tables[load_idx, i_p, + large_k_tile_idx], i_f]) + if cur_v_tile.dtype != loaded.dtype: + loaded = nl.copy(loaded, dtype=cur_v_tile.dtype) + i_p = nl.arange(B_P_SIZE)[:, None] + i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :] + cur_v_tile[ + :, + nl.ds( + load_idx * tiled_block_size * B_D_SIZE, + tiled_block_size * B_D_SIZE, + ), + ] = loaded + + +@nki.jit +def transpose_p_local(p_local_transposed, + p_local, + LARGE_TILE_SZ, + B_F_SIZE=512): + for i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE): + if nisa.get_nc_version() == nisa.nc_version.gen3: + p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE), + buffer=nl.sbuf, + dtype=p_local.dtype) + else: + p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE), + buffer=nl.psum, + dtype=np.float32) + + for j in nl.affine_range(B_F_SIZE // 128): + j_128_slice = nl.ds(j * 128, 128) + i_j_128_slice = nl.ds(i * B_F_SIZE + j * 128, 128) + + if nisa.get_nc_version() == nisa.nc_version.gen3: + p_local_t_tmp[:, j_128_slice] = nisa.dma_transpose( + p_local[:, i_j_128_slice]) + else: + p_local_t_tmp[:, j_128_slice] = nisa.nc_transpose( + p_local[:, i_j_128_slice]) + + p_local_transposed[:, nl.ds(i * B_F_SIZE, B_F_SIZE)] = nl.copy( + p_local_t_tmp, dtype=p_local_transposed.dtype) + + +@nki.jit +def _flash_attention_core( + q_local_tile, + k, + v, + o_buffer, + l_buffer, + m_buffer, + kernel_dtype, + acc_type, + tile_mask, + use_causal_mask, + q_tile_idx=None, + initialize=False, + LARGE_TILE_SZ=2048, + B_P_SIZE=128, + B_F_SIZE=512, + B_D_SIZE=128, + qk_res_buffer=None, +): + """ + The flash attention core function to calculate self attention between a tile + of q and a block of K and V. + The q_local_tile has (B_P_SIZE, B_D_SIZE) + The K and V have shape (B_D_SIZE, LARGE_TILE_SZ), whose free dimension will + be split into size B_F_SIZE tiles + + The results are stored in the following three buffers + o_buffer: (B_P_SIZE, d) + l_buffer: (B_P_SIZE, 1) + m_buffer: (B_P_SIZE, 1) + + All IO buffers are in SBUF. + """ + num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE + + qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), + buffer=nl.sbuf, + dtype=acc_type) + max_local = nl.ndarray((par_dim(B_P_SIZE), num_k_tile_per_large_tile), + dtype=acc_type) + for k_i in nl.affine_range(num_k_tile_per_large_tile): + k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE) + + if use_causal_mask: + # mask are used to only apply computation to the lower half of the + # matrix, which reduce the arithmetic intensity by up to 50% + multiplication_required_selection = (q_tile_idx * B_P_SIZE + >= k_i * B_F_SIZE) + else: + multiplication_required_selection = True + + if multiplication_required_selection: + qk_psum = nl.ndarray((par_dim(B_P_SIZE), B_F_SIZE), + dtype=np.float32, + buffer=nl.psum) # (128, 512) + qk_psum[:, :] = nl.matmul(q_local_tile, + k[:, k_i_b_f_slice], + transpose_x=True) # (p(128), 512) + qk_res_buf[:, k_i_b_f_slice] = nl.where( + tile_mask[:, k_i_b_f_slice], + qk_psum[:, nl.ds(0, B_F_SIZE)], + -9984.0, + dtype=acc_type, + ) + else: + qk_res_buf[:, k_i_b_f_slice] = -9984.0 + + # Calculate max of the current tile + max_local[:, k_i] = nisa.tensor_reduce( + np.max, + qk_res_buf[:, k_i_b_f_slice], + axis=(1, ), + dtype=acc_type, + negate=False, + ) + + if qk_res_buffer is not None: + qk_res_buffer[:, :] = nl.copy(qk_res_buf[:, :]) + + max_ = nisa.tensor_reduce( + np.max, + max_local[:, :], + axis=(1, ), + dtype=acc_type, + negate=False, + ) + + o_previous_scaled = nl.ndarray((par_dim(B_P_SIZE), B_D_SIZE), + dtype=o_buffer.dtype) + + if initialize: + m_buffer[:, 0] = nl.copy(max_) + m_current = max_ + else: + m_previous = nl.copy(m_buffer[:, 0]) + m_buffer[:, 0] = nl.maximum(m_previous, max_) # (128,1) + + m_current = m_buffer[:, 0] + # Compute scaling factor + alpha = nisa.activation( + np.exp, + m_previous, + bias=-1 * m_current, + scale=1.0, + ) + o_previous_scaled[...] = nl.multiply(o_buffer[:, :], alpha) + + p_local = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), + dtype=kernel_dtype) + REDUCTION_TILE = min(2048, LARGE_TILE_SZ // 2) + + p_partial_sum = nl.ndarray( + (par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE), + dtype=acc_type, + ) + + for k_r_i in nl.affine_range(LARGE_TILE_SZ // REDUCTION_TILE): + k_r_i_reduce_slice = nl.ds(k_r_i * REDUCTION_TILE, REDUCTION_TILE) + + # compute exp(qk - max) + # Compute partial row - tile sum of exp(qk - max)) + # FIXME : Use activation accumulate to accumulate over k_r_i loop ? + p_local[:, k_r_i_reduce_slice] = nisa.activation_reduce( + np.exp, + qk_res_buf[:, k_r_i_reduce_slice], + bias=-1 * m_current, + scale=1.0, + reduce_op=nl.add, + reduce_res=p_partial_sum[:, k_r_i], + dtype=kernel_dtype, + ) + + ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type) + + p_local_transposed = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), + dtype=kernel_dtype) + transpose_p_local( + p_local_transposed=p_local_transposed, + p_local=p_local, + LARGE_TILE_SZ=LARGE_TILE_SZ, + B_F_SIZE=B_F_SIZE, + ) + + pv_psum = nl.zeros( + (par_dim(B_P_SIZE), B_D_SIZE), + dtype=np.float32, + buffer=nl.psum, + ) + for k_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE): + pv_psum[:, :] += nl.matmul( + p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)], + v[:, nl.ds(k_i * B_D_SIZE, B_D_SIZE)], + transpose_x=True, + ) # (128, 128) (p(Br), d) + + if initialize: + o_buffer[:, :] = nl.copy(pv_psum[:, :]) + l_buffer[:, 0] = nl.add(nl.log(ps), max_) + else: + o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum) + + l_prev = l_buffer[:, 0] + l_exp = nl.add( + nl.exp(nl.subtract(l_prev, m_current)), + ps, + ) + l_buffer[:, 0] = nl.add(m_current, nl.log(l_exp)) + + +@nki.jit +def load_v_tile(v_hbm_tile, cur_v_tile, large_tile_idx, v_i, LARGE_TILE_SZ): + B_P_SIZE = 128 + B_D_SIZE = v_hbm_tile.shape[-1] + loaded = nl.load(v_hbm_tile[ + nl.ds(large_tile_idx * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE), + :, + ]) + if cur_v_tile.dtype != loaded.dtype: + loaded = nl.copy(loaded, dtype=cur_v_tile.dtype) + cur_v_tile[:, nl.ds(v_i * B_D_SIZE, B_D_SIZE)] = loaded + + +@nki.jit +def flash_paged_attention( + query, + key, + value, + kv_cache, + block_tables, + mask, + softmax_scale=None, + mixed_precision=True, + LARGE_TILE_SZ=2048, + return_debug_tensors=False, +): + """ + Flash PagedAttention Forward Kernel. + + IO tensor layouts: + - query: shape (1, n_heads, d, seq_q) + - key: shape (1, n_kv_heads, d, seq_k) + - value: shape (1, n_kv_heads, seq_v, d) + - kv_cache: (2, num_blocks, n_kv_heads, block_size, d) + - block_tables: (num_active_blocks, ) + - mask: (seq_q, num_active_blocks * block_size + seq_q) + - o: shape (1, n_heads, seq_q, d) + + - This kernel requires seq_k == seq_v + - We use continuous batching by default, so the batch dimension is + always 1, and different requests are concatenated along sequence + dimension. + - We use paged cache blocks (kv_cache) to store KV cache. + + IO tensor dtypes: + - This kernel assumes all IO tensors have the same dtype except for + block_tables (int32) and mask (int32) + - If mixed_percision is True, then all Tensor Engine operation will be + performed in bfloat16 and accumulation will be performed in float32. + Otherwise the intermediates will be in the same type as the inputs. + + Compile-time Constants: + - softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)` + - mixed_precision: flag to set non-matmul ops in fp32 precision, default + is set to `true`, if false, we use same precision as input types + - LARGE_TILE_SZ: `default=2048`, size of the kv tile size for attention + computation reduction + + GQA support Notes: + the spmd kernel for launching kernel should be on kv_heads instead of + nheads + + Example usage: + MHA: q: [b, h, d, s], k: [b, h, d, s], v: [b, h, s, d] + usage: `flash_fwd[b, h](q, k, v, ...)` + GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d] + usage: `flash_fwd[b, kv_h](q, k, v, ...)` + """ + B_F_SIZE = 512 + B_P_SIZE = 128 + b, h, d, seqlen_q = query.shape + B_D_SIZE = d + n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine + _, num_blocks, k_h, block_size, _ = kv_cache.shape + q_h_per_k_h = h // k_h + assert b == 1, f"invalid batch size {b=}" + assert d <= 128, f" we do not support head_dim > 128, got head dim {d=}" + cache_shape = (2, num_blocks, k_h, block_size, d) + assert (tuple(kv_cache.shape) == cache_shape + ), f"{kv_cache.shape=} mismatch, expect {cache_shape}" + assert key is None or tuple(key.shape) == ( + 1, + k_h, + d, + seqlen_q, + ), f"key shape {key.shape} mismatch!" + assert value is None or tuple(value.shape) == ( + 1, + k_h, + seqlen_q, + d, + ), f"value shape {value.shape} mismatch!" + + assert ( + nl.program_ndim() == 2 + ), f"Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!" + batch_id = nl.program_id(axis=0) + head_id = nl.program_id(axis=1) + + (num_active_blocks, ) = block_tables.shape + context_kv_len = num_active_blocks * block_size + assert ( + LARGE_TILE_SZ % B_F_SIZE == 0 + ), f"Need {LARGE_TILE_SZ=} to be divisible by {B_F_SIZE=} in transpose_p" + assert (context_kv_len % LARGE_TILE_SZ == 0 + ), f"Need {context_kv_len=} to be divisible by {LARGE_TILE_SZ=}" + + num_blocks_per_large_tile = LARGE_TILE_SZ // block_size + assert is_power_of_2( + num_blocks_per_large_tile + ), f"{num_blocks_per_large_tile=} is expected of be power of 2" + if seqlen_q > B_F_SIZE: + MAX_REDUCTION_TILE = 2048 + if seqlen_q // 2 > MAX_REDUCTION_TILE: + assert ( + seqlen_q % MAX_REDUCTION_TILE == 0 + ), f"{seqlen_q=} should be divisible by {MAX_REDUCTION_TILE=}" + else: + assert (seqlen_q % B_F_SIZE == 0 + ), f"{seqlen_q=} should be divisible by {B_F_SIZE=})" + + kernel_dtype = nl.bfloat16 if mixed_precision else query.dtype + acc_type = np.dtype(np.float32) if mixed_precision else kernel_dtype + softmax_scale = softmax_scale or (1.0 / (d**0.5)) + num_large_k_tile = context_kv_len // LARGE_TILE_SZ + + o = nl.ndarray((b, h, seqlen_q, d), + dtype=query.dtype, + buffer=nl.shared_hbm) + hbm_l_buffer, hbm_m_buffer, hbm_qk_res, qk_res_buffer = ( + None, + None, + None, + None, + ) + if return_debug_tensors: + hbm_l_buffer = nl.ndarray((b, h, seqlen_q), + dtype=acc_type, + buffer=nl.shared_hbm) + hbm_m_buffer = nl.ndarray((b, h, seqlen_q), + dtype=acc_type, + buffer=nl.shared_hbm) + hbm_qk_res = nl.ndarray((b, h, B_P_SIZE, seqlen_q), + dtype=acc_type, + buffer=nl.shared_hbm) + qk_res_buffer = nl.zeros( + (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), seqlen_q), + dtype=acc_type, + buffer=nl.sbuf, + lazy_initialization=True, + ) + block_tables_sbuf = load_block_tables( + block_tables_hbm=block_tables, + num_tiles=num_large_k_tile, + num_blocks_per_tile=num_blocks_per_large_tile, + ) + + # On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient + if num_blocks_per_large_tile < B_P_SIZE: + # we checked num_blocks_per_tile is a power of 2 + assert B_P_SIZE % num_blocks_per_large_tile == 0 + block_size_tiling_factor = B_P_SIZE // num_blocks_per_large_tile + # We assume block_size >= block_size_tiling_factor + assert block_size % block_size_tiling_factor == 0 + else: + block_size_tiling_factor = 1 + tiled_block_size = block_size // block_size_tiling_factor + + # Indirect DMA load must be placed along Partition Dimension + block_tables_sbuf = transform_block_tables_for_indirect_load( + block_tables_sbuf, + block_size_tiling_factor=block_size_tiling_factor, + num_head=k_h, + head_id=head_id, + ) + + # Flatten KV cache to be 3D for loading into SBUF + new_cache_shape = ( + 2, + num_blocks * k_h * block_size_tiling_factor, + tiled_block_size * d, + ) + kv_cache = kv_cache.reshape(new_cache_shape) + + # Global Flash Attention accumulators + o_buffer = nl.zeros( + (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), d), + dtype=acc_type, + buffer=nl.sbuf, + lazy_initialization=True, + ) + l_buffer = nl.zeros( + (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1), + dtype=acc_type, + buffer=nl.sbuf, + lazy_initialization=True, + ) + m_buffer = nl.zeros( + (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1), + dtype=acc_type, + buffer=nl.sbuf, + lazy_initialization=True, + ) + + for large_k_tile_idx in nl.sequential_range(0, num_large_k_tile): + num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE) + cur_k_tile = nl.ndarray( + (par_dim(B_D_SIZE), LARGE_TILE_SZ), + dtype=kernel_dtype, + ) + cur_v_tile = nl.ndarray( + (par_dim(B_P_SIZE), num_loads * tiled_block_size * B_D_SIZE), + dtype=kernel_dtype, + ) + load_kv_tile_from_cache( + cur_k_tile=cur_k_tile, + cur_v_tile=cur_v_tile, + kv_cache=kv_cache, + block_tables=block_tables_sbuf, + large_k_tile_idx=large_k_tile_idx, + num_blocks_per_large_tile=num_blocks_per_large_tile, + tiled_block_size=tiled_block_size, + B_P_SIZE=B_P_SIZE, + B_D_SIZE=B_D_SIZE, + ) + + for i in nl.affine_range(n_tile_q): + cur_mask = nl.load(mask[ + nl.ds(i * B_P_SIZE, B_P_SIZE), + nl.ds(large_k_tile_idx * LARGE_TILE_SZ, LARGE_TILE_SZ), + ]) + for i_q_h in nl.affine_range(q_h_per_k_h): + q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype) + q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h] + q_sbuf_tile = nl.load(q_hbm_tile[:, + nl.ds(i * + B_P_SIZE, B_P_SIZE)]) + if q_sbuf_tile.dtype != kernel_dtype: + q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype) + q_tile[:, :] = q_sbuf_tile * softmax_scale + + _flash_attention_core( + q_local_tile=q_tile, + k=cur_k_tile, + v=cur_v_tile, + o_buffer=o_buffer[i, i_q_h], + l_buffer=l_buffer[i, i_q_h], + m_buffer=m_buffer[i, i_q_h], + kernel_dtype=kernel_dtype, + acc_type=acc_type, + tile_mask=cur_mask, + use_causal_mask=False, + q_tile_idx=i, + initialize=large_k_tile_idx == 0, + LARGE_TILE_SZ=LARGE_TILE_SZ, + B_P_SIZE=B_P_SIZE, + B_F_SIZE=B_F_SIZE, + B_D_SIZE=B_D_SIZE, + ) + + # compute attention between input query, key and value + if key is not None and value is not None: + B_F_SIZE = min(seqlen_q, B_F_SIZE) + LARGE_TILE_SZ = seqlen_q + + cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ), + dtype=kernel_dtype) + cur_v_tile = nl.ndarray( + (par_dim(B_P_SIZE), LARGE_TILE_SZ // B_P_SIZE * B_D_SIZE), + dtype=kernel_dtype, + ) + + loaded = nl.load(key[batch_id, head_id, :, :]) + if loaded.dtype != kernel_dtype: + loaded = nl.copy(loaded, dtype=kernel_dtype) + cur_k_tile[:, :] = loaded + + v_hbm_tile = value[batch_id, head_id] + for v_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE): + load_v_tile( + v_hbm_tile=v_hbm_tile, + cur_v_tile=cur_v_tile, + large_tile_idx=0, + v_i=v_i, + LARGE_TILE_SZ=LARGE_TILE_SZ, + ) + + for i in nl.affine_range(n_tile_q): + cur_mask = nl.load(mask[ + nl.ds(i * B_P_SIZE, B_P_SIZE), + nl.ds(context_kv_len, LARGE_TILE_SZ), + ]) + for i_q_h in nl.affine_range(q_h_per_k_h): + + q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype) + q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h] + q_sbuf_tile = nl.load(q_hbm_tile[:, + nl.ds(i * + B_P_SIZE, B_P_SIZE)]) + if q_sbuf_tile.dtype != kernel_dtype: + q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype) + q_tile[:, :] = q_sbuf_tile * softmax_scale + _flash_attention_core( + q_local_tile=q_tile, + k=cur_k_tile, + v=cur_v_tile, + o_buffer=o_buffer[i, i_q_h], + l_buffer=l_buffer[i, i_q_h], + m_buffer=m_buffer[i, i_q_h], + kernel_dtype=kernel_dtype, + acc_type=acc_type, + tile_mask=cur_mask, + use_causal_mask=True, + q_tile_idx=i, + initialize=False, + LARGE_TILE_SZ=LARGE_TILE_SZ, + B_P_SIZE=B_P_SIZE, + B_F_SIZE=B_F_SIZE, + B_D_SIZE=B_D_SIZE, + qk_res_buffer=(qk_res_buffer[i, i_q_h] + if qk_res_buffer is not None else None), + ) + + # -- -- -- -- write output to buffer on HBM -- -- -- -- -- -- # + for i_q_h in nl.affine_range(q_h_per_k_h): + for i in nl.affine_range(n_tile_q): + out = nl.multiply( + o_buffer[i, i_q_h], + nl.exp(m_buffer[i, i_q_h] - l_buffer[i, i_q_h]), + dtype=kernel_dtype, + ) + + nl.store( + o[ + batch_id, + head_id * q_h_per_k_h + i_q_h, + nl.ds(i * B_P_SIZE, B_P_SIZE), + :, + ], + out, + ) + # maximum and summation statistics + if return_debug_tensors: + nl.store( + hbm_m_buffer[ + batch_id, + head_id * q_h_per_k_h + i_q_h, + nl.ds(i * B_P_SIZE, B_P_SIZE), + ], + m_buffer[i, i_q_h, :, :], + ) + nl.store( + hbm_l_buffer[ + batch_id, + head_id * q_h_per_k_h + i_q_h, + nl.ds(i * B_P_SIZE, B_P_SIZE), + ], + l_buffer[i, i_q_h], + ) + nl.store( + hbm_qk_res[batch_id, head_id * q_h_per_k_h + i_q_h, :, :], + qk_res_buffer[batch_id, i_q_h, :, :], + ) + + if return_debug_tensors: + return o, hbm_m_buffer, hbm_l_buffer, hbm_qk_res + return o + + +def reorder_context_mask(mask, LARGE_TILE_SZ, block_size): + """ + Reorder the mask to make it compatible with the flash attention kernel. + + We vectorize KV cache read to improve DMA utilization. However, the layout + that maximizes DMA bandwidth changes the order tokens are consumed. + + The token layout (inner 2 dimensions) after vectorized load is (B_P_SIZE, + tiled_block_size) in a tile of `B_P_SIZE * tiled_block_size` tokens. And + each step the engine consumes a column (rather than a row) of B_P_SIZE + tokens. Therefore, the tokens are visited in a strided way. + + To make sure mask matches the order tokens are consumed, we need to properly + transpose mask. + """ + total_query_len, total_seq_len = mask.shape + context_kv_len = total_seq_len - total_query_len + + B_P_SIZE = 128 + assert (LARGE_TILE_SZ + >= B_P_SIZE), f"{LARGE_TILE_SZ=} must be larger than {B_P_SIZE=}" + num_tiled_blocks = max(B_P_SIZE, LARGE_TILE_SZ // block_size) + tiled_block_size = LARGE_TILE_SZ // num_tiled_blocks + if tiled_block_size > 1: + # Mask reordering is needed when tiled_block_size > 1 + device = mask.device + mask = mask.cpu() + context_mask = mask[:, :context_kv_len] + context_mask = context_mask.view( + total_query_len, + context_kv_len // LARGE_TILE_SZ, + num_tiled_blocks // B_P_SIZE, + B_P_SIZE, + tiled_block_size, + ) + context_mask = context_mask.transpose(3, 4).reshape( + total_query_len, context_kv_len) + new_mask = mask[:, context_kv_len:] + return torch.concat([context_mask, new_mask], dim=1).to(device) + else: + return mask + + +def flash_attn_varlen_nkifunc( + query, + key, + value, + kv_cache, + block_table, + attn_mask, + n_kv_head=None, + head_size=None, + LARGE_TILE_SZ=2048, + mixed_precision=True, +): + """ + Compute flash paged attention for variable length sequences. + + This function is a wrapper around the flash attention NKI kernel. It takes + in the following arguments: + - query: (1, n_heads, d, seq_q) + - key: (1, n_kv_heads, d, seq_k) + - value: (1, n_kv_heads, seq_v, d) + - kv_cache: (2, n_blocks, n_kv_heads, block_size, d) + - block_tables: (n_active_blocks, ) + - attn_mask: (seq_q, n_active_blocks * block_size + seq_q) + + Notes: + - attn_mask must be reordered outside using `reorder_context_mask` + - Key/value cache layout must be (n_blocks, n_kv_heads, block_size, d) + for better DMA throughput + """ + if n_kv_head is None: + n_kv_head = kv_cache.shape[2] + assert kv_cache.shape[0] == 2 + assert kv_cache.shape[2] == n_kv_head + if head_size is None: + head_size = kv_cache.shape[-1] + + kwargs = dict( + query=query, + key=key, + value=value, + kv_cache=kv_cache, + block_tables=block_table, + mask=attn_mask, + softmax_scale=1.0 / (head_size**0.5), + mixed_precision=mixed_precision, + LARGE_TILE_SZ=LARGE_TILE_SZ, + ) + + o = flash_paged_attention[1, n_kv_head](**kwargs) + return o + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, +) -> None: + """ + Writes key-value pairs to the KV cache at specified positions. + + Args: + key (torch.Tensor): Key tensor with shape + (num_tokens, n_kv_head, d_head) + value (torch.Tensor): Value tensor with shape + (num_tokens, n_kv_head, d_head) + kv_cache (torch.Tensor): Key/value cache tensor with shape + (2, num_blocks, n_kv_head, block_size, d_head) + slot_mapping (torch.Tensor): Mapping tensor indicating cache positions + with shape (num_tokens) + + Returns: + None: Updates the kv_cache tensor in-place + """ + block_size = kv_cache.size(3) + n_kv_head = key.size(1) + + # Calculate indices with explicit floor division + block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor") + block_offsets = slot_mapping % block_size + + # Create the head indices tensor + head_indices = torch.arange(n_kv_head, device=key.device) + + # Update caches using index_put_ + kv_cache.index_put_( + (torch.tensor([0], device=key.device), block_indices[:, None], + head_indices[None, :], block_offsets[:, None]), key) + + kv_cache.index_put_( + (torch.tensor([1], device=key.device), block_indices[:, None], + head_indices[None, :], block_offsets[:, None]), value) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py new file mode 100644 index 0000000..827c304 --- /dev/null +++ b/vllm/attention/ops/paged_attn.py @@ -0,0 +1,255 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + from vllm.attention.ops.prefix_prefill import context_attention_fwd + +# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. +_PARTITION_SIZE = 512 + + +@dataclass +class PagedAttentionMetadata: + """Metadata for PagedAttention.""" + # (batch_size,). The length of sequences (entire tokens seen so far) per + # sequence. + seq_lens_tensor: Optional[torch.Tensor] + # Maximum sequence length in the batch. 0 if it is prefill-only batch. + max_decode_seq_len: int + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + +class PagedAttention: + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [32, 64, 80, 96, 112, 120, 128, 192, 256] + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (2, num_blocks, block_size * num_kv_heads * head_size) + + @staticmethod + def split_kv_cache( + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + x = 16 // kv_cache.element_size() + num_blocks = kv_cache.shape[1] + + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, + -1, x) + value_cache = kv_cache[1] + value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) + return key_cache, value_cache + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + ) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping.flatten(), + kv_cache_dtype, + k_scale, + v_scale, + ) + + @staticmethod + def forward_decode( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + max_seq_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + k_scale: torch.Tensor, + v_scale: torch.Tensor, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, + ) -> torch.Tensor: + if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1: + # use blocksparse paged attention + block_size = value_cache.size(-1) + assert (blocksparse_block_size > 0 and + blocksparse_block_size % block_size == 0), \ + (f"{blocksparse_block_size=} needs to be a multiple of" + f"{block_size=} used in block_tables.") + + output = torch.empty_like(query) + block_size = value_cache.shape[3] + num_seqs, num_heads, head_size = query.shape + max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // + _PARTITION_SIZE) + # NOTE(woosuk): We use a simple heuristic to decide whether to use + # PagedAttention V1 or V2. If the number of partitions is 1, we use + # V1 to avoid the overhead of reduction. Also, if the number of + # sequences or heads is large, we use V1 since there is enough work + # to parallelize. + # TODO(woosuk): Tune this heuristic. + # For context len > 8192, use V2 kernel to avoid shared memory shortage. + use_v1 = (max_seq_len <= 8192 + and (max_num_partitions == 1 or num_seqs * num_heads > 512)) + + if use_v1: + # Run PagedAttention V1. + ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + return output + + @staticmethod + def forward_prefix( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache_dtype: str, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + query_start_loc: torch.Tensor, + seq_lens_tensor: torch.Tensor, + max_query_len: int, + alibi_slopes: Optional[torch.Tensor], + sliding_window: Optional[int], + k_scale: torch.Tensor, + v_scale: torch.Tensor, + ) -> torch.Tensor: + output = torch.empty_like(query) + max_seq_len = None + context_attention_fwd( + query, + key, + value, + output, + kv_cache_dtype, + key_cache, + value_cache, + block_tables, + # query_start_loc is (batch_size + 1,) + query_start_loc, + seq_lens_tensor, + max_seq_len, + max_query_len, + k_scale, + v_scale, + alibi_slopes, + sliding_window, + ) + return output + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + src_key_cache = src_kv_cache[0] + dst_key_cache = dst_kv_cache[0] + ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + + src_value_cache = src_kv_cache[1] + dst_value_cache = dst_kv_cache[1] + ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py new file mode 100644 index 0000000..e0478c2 --- /dev/null +++ b/vllm/attention/ops/prefix_prefill.py @@ -0,0 +1,894 @@ +# SPDX-License-Identifier: Apache-2.0 + +# The kernels in this file are adapted from LightLLM's context_attention_fwd: +# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py + +import torch +import triton +import triton.language as tl + +from vllm.platforms import current_platform + +# Static kernels parameters +BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64 +NUM_WARPS = 4 if current_platform.is_rocm() else 8 + +# To check compatibility +IS_TURING = current_platform.get_device_capability() == (7, 5) + +if triton.__version__ >= "2.1.0": + + @triton.jit + def _fwd_kernel( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + k_scale, + v_scale, + B_Start_Loc, + B_Seqlen, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + IN_PRECISION: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, # head size + BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 + BLOCK_N: tl.constexpr, + SLIDING_WINDOW: tl.constexpr, + SKIP_DECODE: tl.constexpr, + ): + + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) + cur_batch_query_len = (cur_batch_in_all_stop_index - + cur_batch_in_all_start_index) + cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len + + if SKIP_DECODE and cur_batch_query_len == 1: + return + + # start position inside of the query + # generally, N goes over kv, while M goes over query_len + block_start_loc = BLOCK_M * start_m + + # initialize offsets + # [N]; starts at 0 + offs_n = tl.arange(0, BLOCK_N) + # [D]; starts at 0 + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) + # [M]; starts at current position in query + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # [M,D] + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, + 0).to(tl.int1) # [D] + + q = tl.load(Q + off_q, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_query_len), + other=0.0) # [M,D] + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # [M] + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # [M] + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], + dtype=tl.float32) # [M,D] + + # compute query against context (no causal mask here) + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) # [N] + # [D,N] + off_k = (bn[None, :] * stride_k_cache_bs + + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * + stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + # [N,D] + off_v = ( + bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k_load = tl.load(K_cache + off_k, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] + + if k_load.dtype.is_fp8(): + k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) + else: + k = k_load + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N] + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + if SLIDING_WINDOW > 0: + # (cur_batch_ctx_len + offs_m[:, None]) are the positions of + # Q entries in sequence + # (start_n + offs_n[None, :]) are the positions of + # KV entries in sequence + # So the condition makes sure each entry in Q only attends + # to KV entries not more than SLIDING_WINDOW away. + # + # We can't use -inf here, because the + # sliding window may lead to the entire row being masked. + # This then makes m_ij contain -inf, which causes NaNs in + # exp(). + qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) - + (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, + -10000) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) # [M] + p = tl.exp(qk - m_ij[:, None]) # [M,N] + l_ij = tl.sum(p, 1) # [M] + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) # [M] + alpha = tl.exp(m_i - m_i_new) # [M] + beta = tl.exp(m_ij - m_i_new) # [M] + l_i_new = alpha * l_i + beta * l_ij # [M] + + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v_load = tl.load(V_cache + off_v, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_ctx_len), + other=0.0) # [N,D] + if v_load.dtype.is_fp8(): + v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) + else: + v = v_load + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) + # # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + # block_mask is 0 when we're already past the current query length + block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0) + + # compute query against itself (with causal mask) + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_query_len), + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) + qk *= sm_scale + # apply causal mask + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + if SLIDING_WINDOW > 0: + qk = tl.where( + offs_m[:, None] - (start_n + offs_n[None, :]) + < SLIDING_WINDOW, qk, -10000) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_query_len), + other=0.0) + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_query_len)) + return + + @triton.jit + def _fwd_kernel_flash_attn_v2( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + q = tl.load(Q + off_q, + mask=offs_m[:, None] + < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = (bn[None, :] * stride_k_cache_bs + + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * + stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = ( + bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k = tl.load(K_cache + off_k, + mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, + other=0.0) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(V_cache + off_v, + mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) + < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) + < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + # acc /= l_i[:, None] + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + return + + @triton.jit + def _fwd_kernel_alibi( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + k_scale, + v_scale, + B_Start_Loc, + B_Seqlen, + Alibi_slopes, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + IN_PRECISION: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, # head size + BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 + BLOCK_N: tl.constexpr, + SKIP_DECODE: tl.constexpr, + ): + # attn_bias[] + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + # cur_batch_seq_len: the length of prompts + # cur_batch_ctx_len: the length of prefix + # cur_batch_in_all_start_index: the start id of the dim=0 + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) + cur_batch_query_len = (cur_batch_in_all_stop_index - + cur_batch_in_all_start_index) + cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len + + if SKIP_DECODE and cur_batch_query_len == 1: + return + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1) + + q = tl.load(Q + off_q, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) + + alibi_slope = tl.load(Alibi_slopes + cur_head) + alibi_start_q = tl.arange( + 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len + alibi_start_k = 0 + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = (bn[None, :] * stride_k_cache_bs + + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * + stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = ( + bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k_load = tl.load(K_cache + off_k, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] + + if k_load.dtype.is_fp8(): + k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) + else: + k = k_load + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # load alibi + alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - + alibi_start_q[:, None]) * alibi_slope + alibi = tl.where( + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), + alibi, float("-inf")) + qk += alibi + alibi_start_k += BLOCK_N + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v_load = tl.load(V_cache + off_v, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_ctx_len), + other=0.0) + if v_load.dtype.is_fp8(): + v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) + else: + v = v_load + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision='ieee') + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + # init alibi + alibi_slope = tl.load(Alibi_slopes + cur_head) + alibi_start_q = tl.arange( + 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len + alibi_start_k = cur_batch_ctx_len + # # init debugger + # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc + # offset_db_k = tl.arange(0, BLOCK_N) + # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL] + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) + < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, acc=qk, input_precision='ieee') + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # load alibi + alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - + alibi_start_q[:, None]) * alibi_slope + alibi = tl.where( + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), + alibi, float("-inf")) + qk += alibi + alibi_start_k += BLOCK_N + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) + < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision='ieee') + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + acc = acc / l_i[:, None] + + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)) + return + + @torch.inference_mode() + def context_attention_fwd(q, + k, + v, + o, + kv_cache_dtype: str, + k_cache, + v_cache, + b_loc, + b_start_loc, + b_seq_len, + max_seq_len, + max_input_len, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + alibi_slopes=None, + sliding_window=None, + sm_scale=None, + skip_decode=False): + + q_dtype_is_f32 = q.dtype is torch.float32 + # need to reduce num. blocks when using fp32 + # due to increased use of GPU shared memory + # if q.dtype is torch.float32: + BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK + + # Turing does have tensor core for float32 multiplication + # use ieee as fallback for triton kernels work. There is also + # warning on vllm/config.py to inform users this fallback + # implementation + IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None + + # Conversion of FP8 Tensor from uint8 storage to + # appropriate torch.dtype for interpretation by Triton + if "fp8" in kv_cache_dtype: + assert (k_cache.dtype == torch.uint8) + assert (v_cache.dtype == torch.uint8) + + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + target_dtype = current_platform.fp8_dtype() + elif kv_cache_dtype == "fp8_e5m2": + target_dtype = torch.float8_e5m2 + else: + raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype) + + k_cache = k_cache.view(target_dtype) + v_cache = v_cache.view(target_dtype) + + if (k_cache.dtype == torch.uint8 + or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"): + raise ValueError("kv_cache_dtype='auto' unsupported for\ + FP8 KV Cache prefill kernel") + + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + # round up Lk to a power of 2 - this is required for Triton block size + Lk_padded = triton.next_power_of_2(Lk) + + if sm_scale is None: + sm_scale = 1.0 / (Lq**0.5) + batch, head = b_seq_len.shape[0], q.shape[1] + num_queries_per_kv = q.shape[1] // k.shape[1] + + assert batch + 1 == len(b_start_loc) + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, + + # 0 means "disable" + if sliding_window is None or sliding_window <= 0: + sliding_window = 0 + + if alibi_slopes is not None: + _fwd_kernel_alibi[grid]( + q, + k, + v, + k_cache, + v_cache, + b_loc, + sm_scale, + k_scale, + v_scale, + b_start_loc, + b_seq_len, + alibi_slopes, + v_cache.shape[3], + k_cache.shape[4], + o, + b_loc.stride(0), + b_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + k_cache.stride( + 4 + ), #[num_blocks, num_kv_heads, head_size/x, block_size, x] + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride( + 3), #[num_blocks, num_kv_heads, head_size, block_size] + num_queries_per_kv=num_queries_per_kv, + IN_PRECISION=IN_PRECISION, + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_DMODEL_PADDED=Lk_padded, + BLOCK_N=BLOCK, + SKIP_DECODE=skip_decode, + num_warps=NUM_WARPS, + num_stages=1, + ) + return + + _fwd_kernel[grid]( + q, + k, + v, + k_cache, + v_cache, + b_loc, + sm_scale, + k_scale, + v_scale, + b_start_loc, + b_seq_len, + v_cache.shape[3], + k_cache.shape[4], + o, + b_loc.stride(0), + b_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + k_cache.stride( + 4), #[num_blocks, num_kv_heads, head_size/x, block_size, x] + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride( + 3), #[num_blocks, num_kv_heads, head_size, block_size] + num_queries_per_kv=num_queries_per_kv, + IN_PRECISION=IN_PRECISION, + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_DMODEL_PADDED=Lk_padded, + BLOCK_N=BLOCK, + SLIDING_WINDOW=sliding_window, + SKIP_DECODE=skip_decode, + num_warps=NUM_WARPS, + num_stages=1, + ) + return diff --git a/vllm/attention/ops/triton_decode_attention.py b/vllm/attention/ops/triton_decode_attention.py new file mode 100644 index 0000000..40daec3 --- /dev/null +++ b/vllm/attention/ops/triton_decode_attention.py @@ -0,0 +1,674 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +# which was originally adapted from +# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py +# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py + +# Changes: +# - Add support for page size >= 1. + +# Copyright 2025 vLLM Team +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Memory-efficient attention for decoding. +It supports page size >= 1. +""" + +import logging + +import triton +import triton.language as tl + +from vllm.platforms import current_platform + +is_hip_ = current_platform.is_rocm() + +logger = logging.getLogger(__name__) + +# TODO: Remove this when triton>=3.2.0. This issue will not affect performance +# and accuracy. +logger.warning( + "The following error message 'operation scheduled before its operands' " + "can be ignored.") + + +@triton.jit +def tanh(x): + # Tanh is just a scaled sigmoid + return 2 * tl.sigmoid(2 * x) - 1 + + +@triton.jit +def _fwd_kernel_stage1( + Q, + K_Buffer, + V_Buffer, + sm_scale, + Req_to_tokens, + B_Seqlen, + Att_Out, + stride_req_to_tokens_b, + stride_qbs, + stride_qh, + stride_buf_kbs, + stride_buf_kh, + stride_buf_vbs, + stride_buf_vh, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + kv_group_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_N: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, + PAGE_SIZE: tl.constexpr, + logit_cap: tl.constexpr, + Lk: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + split_kv_id = tl.program_id(2) + + cur_kv_head = cur_head // kv_group_num + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_req_idx = cur_batch + + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d + q = tl.load(Q + off_q, mask=mask_d, other=0.0) + + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, + cur_batch_seq_len) + + e_max = -float("inf") + e_sum = 0.0 + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + if split_kv_end > split_kv_start: + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_page_number = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + + offs_n // PAGE_SIZE, + mask=offs_n < split_kv_end, + other=0, + ) + kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE + offs_buf_k = (kv_loc[:, None] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + offs_d[None, :]) + k = tl.load( + K_Buffer + offs_buf_k, + mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]), + other=0.0, + ) + qk = tl.sum(q[None, :] * k, 1) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + qk = tl.where(offs_n < split_kv_end, qk, float("-inf")) + + offs_buf_v = (kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + offs_dv[None, :]) + v = tl.load( + V_Buffer + offs_buf_v, + mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), + other=0.0, + ) + + n_e_max = tl.maximum(tl.max(qk, 0), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max) + acc *= re_scale + acc += tl.sum(p[:, None] * v, 0) + + e_sum = e_sum * re_scale + tl.sum(p, 0) + e_max = n_e_max + + offs_mid_o = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + offs_dv) + + tl.store( + Att_Out + offs_mid_o, + acc / e_sum, + mask=(mask_dv), + ) + + offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + Lv) + + tl.store( + Att_Out + offs_mid_o_1, + e_max + tl.log(e_sum), + ) + + +def _decode_att_m_fwd( + q, + k_buffer, + v_buffer, + att_out, + Req_to_tokens, + B_Seqlen, + num_kv_splits, + sm_scale, + page_size, + logit_cap, +): + BLOCK = 64 if not is_hip_ else 8 + + NUM_KV_SPLITS = num_kv_splits + Lk = k_buffer.shape[-1] + Lv = v_buffer.shape[-1] + + batch, head_num = q.shape[0], q.shape[1] + + grid = (batch, head_num, NUM_KV_SPLITS) + kv_group_num = q.shape[1] // k_buffer.shape[-2] + + num_warps = 4 + if kv_group_num != 1: + num_warps = 1 if is_hip_ else 2 + + BLOCK_DMODEL = triton.next_power_of_2(Lk) + BLOCK_DV = triton.next_power_of_2(Lv) + + _fwd_kernel_stage1[grid]( + q, + k_buffer, + v_buffer, + sm_scale, + Req_to_tokens, + B_Seqlen, + att_out, + Req_to_tokens.stride(0), + q.stride(0), + q.stride(1), + k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + att_out.stride(0), + att_out.stride(1), + att_out.stride(2), + kv_group_num=kv_group_num, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DV=BLOCK_DV, + BLOCK_N=BLOCK, + NUM_KV_SPLITS=NUM_KV_SPLITS, + PAGE_SIZE=page_size, + logit_cap=logit_cap, + num_warps=num_warps, + num_stages=2, + Lk=Lk, + Lv=Lv, + ) + + +@triton.jit +def _fwd_grouped_kernel_stage1( + Q, + K_Buffer, + V_Buffer, + sm_scale, + Req_to_tokens, + B_Seqlen, + Att_Out, + stride_req_to_tokens_b, + stride_qbs, + stride_qh, + stride_buf_kbs, + stride_buf_kh, + stride_buf_vbs, + stride_buf_vh, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + kv_group_num: tl.constexpr, + q_head_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DPE: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_H: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, + PAGE_SIZE: tl.constexpr, + logit_cap: tl.constexpr, + Lk: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head_id = tl.program_id(1) + cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H) + split_kv_id = tl.program_id(2) + + if kv_group_num > BLOCK_H: + VALID_BLOCK_H: tl.constexpr = BLOCK_H + else: + VALID_BLOCK_H: tl.constexpr = kv_group_num + cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H) + mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H + mask_h = mask_h & (cur_head < q_head_num) + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_req_idx = cur_batch + + offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[ + None, :] + q = tl.load(Q + offs_q, + mask=(mask_h[:, None]) & (mask_d[None, :]), + other=0.0) + + if BLOCK_DPE > 0: + offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) + mask_dpe = offs_dpe < Lk + off_qpe = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh + + offs_dpe[None, :]) + qpe = tl.load(Q + off_qpe, + mask=(mask_h[:, None]) & (mask_dpe[None, :]), + other=0.0) + + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, + cur_batch_seq_len) + + e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") + e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) + acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32) + + if split_kv_end > split_kv_start: + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_page_number = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + + offs_n // PAGE_SIZE, + mask=offs_n < split_kv_end, + other=0, + ) + kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE + offs_buf_k = (kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + offs_d[:, None]) + k = tl.load( + K_Buffer + offs_buf_k, + mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), + other=0.0, + ) + qk = tl.dot(q, k.to(q.dtype)) + if BLOCK_DPE > 0: + offs_buf_kpe = (kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_dpe[:, None]) + kpe = tl.load( + K_Buffer + offs_buf_kpe, + mask=(offs_n[None, :] < split_kv_end) & + (mask_dpe[:, None]), + other=0.0, + ) + qk += tl.dot(qpe, kpe.to(qpe.dtype)) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + qk = tl.where(mask_h[:, None] & (offs_n[None, :] < split_kv_end), + qk, float("-inf")) + + offs_buf_v = (kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + offs_dv[None, :]) + v = tl.load( + V_Buffer + offs_buf_v, + mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), + other=0.0, + ) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + acc *= re_scale[:, None] + acc += tl.dot(p.to(v.dtype), v) + + e_sum = e_sum * re_scale + tl.sum(p, 1) + e_max = n_e_max + + offs_mid_o = (cur_batch * stride_mid_ob + + cur_head[:, None] * stride_mid_oh + + split_kv_id * stride_mid_os + offs_dv[None, :]) + + tl.store( + Att_Out + offs_mid_o, + acc / e_sum[:, None], + mask=(mask_h[:, None]) & (mask_dv[None, :]), + ) + + offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + Lv) + + tl.store( + Att_Out + offs_mid_o_1, + e_max + tl.log(e_sum), + mask=mask_h, + ) + + +def _decode_grouped_att_m_fwd( + q, + k_buffer, + v_buffer, + att_out, + Req_to_tokens, + B_Seqlen, + num_kv_splits, + sm_scale, + page_size, + logit_cap, +): + BLOCK = 32 + Lk = k_buffer.shape[-1] + Lv = v_buffer.shape[-1] + + # [TODO] work around shmem limit on MI3xx + if is_hip_ and Lk >= 576: + BLOCK = 16 + + if Lk == 576: + BLOCK_DMODEL = 512 + BLOCK_DPE = 64 + elif Lk == 288: + BLOCK_DMODEL = 256 + BLOCK_DPE = 32 + else: + BLOCK_DMODEL = triton.next_power_of_2(Lk) + BLOCK_DPE = 0 + BLOCK_DV = triton.next_power_of_2(Lv) + + batch, head_num = q.shape[0], q.shape[1] + kv_group_num = q.shape[1] // k_buffer.shape[-2] + + BLOCK_H = 16 + NUM_KV_SPLITS = num_kv_splits + grid = ( + batch, + triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), + NUM_KV_SPLITS, + ) + + extra_kargs = {} + num_stages = 2 + if is_hip_: + # https://rocm.docs.amd.com/en/latest/how-to/rocm-for-ai/inference-optimization/workload.html#mi300x-triton-kernel-performance-optimization + # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py + extra_kargs = { + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } + num_stages = 1 + + _fwd_grouped_kernel_stage1[grid]( + q, + k_buffer, + v_buffer, + sm_scale, + Req_to_tokens, + B_Seqlen, + att_out, + Req_to_tokens.stride(0), + q.stride(0), + q.stride(1), + k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + att_out.stride(0), + att_out.stride(1), + att_out.stride(2), + kv_group_num=kv_group_num, + q_head_num=head_num, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DPE=BLOCK_DPE, + BLOCK_DV=BLOCK_DV, + BLOCK_N=BLOCK, + BLOCK_H=BLOCK_H, + NUM_KV_SPLITS=NUM_KV_SPLITS, + PAGE_SIZE=page_size, + logit_cap=logit_cap, + num_warps=4, + num_stages=num_stages, + Lk=Lk, + Lv=Lv, + **extra_kargs, + ) + + +@triton.jit +def _fwd_kernel_stage2( + Mid_O, + o, + B_Seqlen, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_obs, + stride_oh, + NUM_KV_SPLITS: tl.constexpr, + BLOCK_DV: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + + offs_d = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lv + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d + offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv + + for split_kv_id in range(0, NUM_KV_SPLITS): + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, + cur_batch_seq_len) + + if split_kv_end > split_kv_start: + tv = tl.load(Mid_O + offs_v + split_kv_id * stride_mid_os, + mask=mask_d, + other=0.0) + tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os) + n_e_max = tl.maximum(tlogic, e_max) + + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(tlogic - n_e_max) + acc += exp_logic * tv + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + tl.store( + o + cur_batch * stride_obs + cur_head * stride_oh + offs_d, + acc / e_sum, + mask=mask_d, + ) + + +def _decode_softmax_reducev_fwd( + logits, + q, + o, + v_buffer, + b_seq_len, + num_kv_splits, +): + batch, head_num = q.shape[0], q.shape[1] + Lv = v_buffer.shape[-1] + BLOCK_DV = triton.next_power_of_2(Lv) + + NUM_KV_SPLITS = num_kv_splits + + extra_kargs = {} + if is_hip_: + # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html + # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py + extra_kargs = { + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } + + grid = (batch, head_num) + _fwd_kernel_stage2[grid]( + logits, + o, + b_seq_len, + logits.stride(0), + logits.stride(1), + logits.stride(2), + o.stride(0), + o.stride(1), + NUM_KV_SPLITS=NUM_KV_SPLITS, + BLOCK_DV=BLOCK_DV, + Lv=Lv, + num_warps=4, + num_stages=2, + **extra_kargs, + ) + + +def decode_attention_fwd_normal( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size, + logit_cap=0.0, +): + _decode_att_m_fwd( + q, + k_buffer, + v_buffer, + attn_logits, + req_to_token, + b_seq_len, + num_kv_splits, + sm_scale, + page_size, + logit_cap, + ) + _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, + num_kv_splits) + + +def decode_attention_fwd_grouped( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size, + logit_cap=0.0, +): + _decode_grouped_att_m_fwd( + q, + k_buffer, + v_buffer, + attn_logits, + req_to_token, + b_seq_len, + num_kv_splits, + sm_scale, + page_size, + logit_cap, + ) + _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, + num_kv_splits) + + +def decode_attention_fwd( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size=1, + logit_cap=0.0, +): + assert num_kv_splits == attn_logits.shape[2] + kv_group_num = q.shape[1] // v_buffer.shape[-2] + + if kv_group_num == 1: + # MHA + decode_attention_fwd_normal( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size, + logit_cap, + ) + else: + # GQA/MQA/MLA + decode_attention_fwd_grouped( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size, + logit_cap, + ) diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py new file mode 100644 index 0000000..745818e --- /dev/null +++ b/vllm/attention/ops/triton_flash_attention.py @@ -0,0 +1,821 @@ +#!/usr/bin/env python +# SPDX-License-Identifier: Apache-2.0 +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao +(https://tridao.me/publications/flash2/flash2.pdf) +Credits: OpenAI kernel team, AMD ML Frameworks Triton team + +Features supported: + +1) Fwd with causal masking +2) Any sequence lengths without padding (currently fwd kernel only) +3) Support for different sequence lengths for q and k +4) Nested tensor API currently does not support dropout or bias. + +Not currently supported: + +1) Non power of two head dims + +""" + +import torch +import triton +import triton.language as tl + +torch_dtype: tl.constexpr = torch.float16 + + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@triton.jit +def max_fn(x, y): + return tl.math.max(x, y) + + +@triton.jit +def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): + ms = tl.arange(0, m) + ns = tl.arange(0, n) + return philox_offset + ms[:, None] * stride + ns[None, :] + + +@triton.jit +def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, + stride).to(tl.uint32) + # TODO: use tl.randint for better performance + return tl.rand(philox_seed, rng_offsets) + + +@triton.jit +def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, + stride) + rng_keep = rng_output > dropout_p + return rng_keep + + +@triton.jit +def load_fn(block_ptr, first, second, pad): + if first and second: + tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) + elif first: + tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad) + elif second: + tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad) + else: + tensor = tl.load(block_ptr) + return tensor + + +@triton.jit +def _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + actual_seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + block_min, + block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + bias_ptr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + OFFS_M: tl.constexpr, + OFFS_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + MASK_STEPS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, + PADDED_HEAD: tl.constexpr, +): + # loop over k, v, and update accumulator + for start_n in range(block_min, block_max, BLOCK_N): + # For padded blocks, we will overrun the tensor size if + # we load all BLOCK_N. For others, the blocks are all within range. + k = load_fn( + K_block_ptr, + PADDED_HEAD, + MASK_STEPS and (n_extra_tokens != 0), + "zero", + ) + if PRE_LOAD_V: + v = load_fn( + V_block_ptr, + MASK_STEPS and (n_extra_tokens != 0), + PADDED_HEAD, + "zero", + ) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + # We start from end of seqlen_k so only the first iteration would need + # to be checked for padding if it is not a multiple of block_n + # TODO: This can be optimized to only be true for the padded block. + if MASK_STEPS: # noqa: SIM102 + # If this is the last block / iteration, we want to + # mask if the sequence length is not a multiple of block size + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps + # if not is_modulo_mn. last step might get wasted but that is okay. + # check if this masking works for that case. + if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): + boundary_m = tl.full([BLOCK_M], + actual_seqlen_k, + dtype=tl.int32) + size_n = start_n + OFFS_N[None, :] + mask = size_n < boundary_m[:, None] + qk = tl.where(mask, qk, float("-inf")) + if IS_CAUSAL: + causal_boundary = start_n + offs_n_causal + causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] + qk = tl.where(causal_mask, qk, float("-inf")) + # -- compute qk ---- + qk += tl.dot(q, k) + if bias_ptr is not None: + bias = load_fn(bias_ptr, False, MASK_STEPS + and (n_extra_tokens != 0), "zero") + # While bias is added after multiplying qk with sm_scale, our + # optimization to use 2^x instead of e^x results in an additional + # scale factor of log2(e) which we must also multiply the bias with. + qk += bias * 1.44269504089 + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + + # CAVEAT: Must update l_ij before applying dropout + l_ij = tl.sum(p, 1) + if ENABLE_DROPOUT: + philox_offset = (batch_philox_offset + + start_m * BLOCK_M * actual_seqlen_k + start_n - + BLOCK_N) + keep = dropout_mask( + philox_seed, + philox_offset, + dropout_p, + BLOCK_M, + BLOCK_N, + actual_seqlen_k, + ) + if RETURN_ENCODED_SOFTMAX: + tl.store( + encoded_softmax_block_ptr, + tl.where(keep, p, + -p).to(encoded_softmax_block_ptr.type.element_ty), + ) + p = tl.where(keep, p, 0.0) + elif RETURN_ENCODED_SOFTMAX: + tl.store( + encoded_softmax_block_ptr, + p.to(encoded_softmax_block_ptr.type.element_ty), + ) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not PRE_LOAD_V: + v = load_fn( + V_block_ptr, + MASK_STEPS and (n_extra_tokens != 0), + PADDED_HEAD, + "zero", + ) + # -- update m_i and l_i + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, + (0, BLOCK_N)) + return acc, l_i, m_i + + +@triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_M": 256, + "BLOCK_N": 64, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 128, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M": 256, + "BLOCK_N": 128, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 1, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 3, + "PRE_LOAD_V": True, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 3, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M": 64, + "BLOCK_N": 64, + "waves_per_eu": 4, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "BLOCK_M": 32, + "BLOCK_N": 32, + "waves_per_eu": 4, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + # TODO: This config fails with head_size not pow2 with data mismatches. + # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, + # 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config( + { + "BLOCK_M": 16, + "BLOCK_N": 16, + "waves_per_eu": 1, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + ], + key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], +) +@triton.jit +def attn_fwd( + Q, + K, + V, + bias, + sm_scale, + L, + Out, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + stride_oz, + stride_oh, + stride_om, + stride_on, + stride_bz, + stride_bh, + stride_bm, + stride_bn, + cu_seqlens_q, + cu_seqlens_k, + dropout_p, + philox_seed, + philox_offset_base, + encoded_softmax, + HQ: tl.constexpr, + HK: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, + MAX_SEQLENS_Q: tl.constexpr, + MAX_SEQLENS_K: tl.constexpr, + VARLEN: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + BIAS_TYPE: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, +): + start_m = tl.program_id(0) + off_h_q = tl.program_id(1) + off_z = tl.program_id(2) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + if VARLEN: + cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) + cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) + seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start + # We have a one-size-fits-all grid in id(0). Some seqlens might be too + # small for all start_m so for those we return early. + if start_m * BLOCK_M > seqlen_q: + return + cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) + cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + else: + cu_seqlens_q_start = 0 + cu_seqlens_k_start = 0 + seqlen_q = MAX_SEQLENS_Q + seqlen_k = MAX_SEQLENS_K + + # Now we compute whether we need to exit early due to causal masking. + # This is because for seqlen_q > seqlen_k, M rows of the attn scores + # are completely masked, resulting in 0s written to the output, and + # inf written to LSE. We don't need to do any GEMMs in this case. + # This block of code determines what N is, and if this WG is operating + # on those M rows. + n_blocks = cdiv_fn(seqlen_k, BLOCK_N) + if IS_CAUSAL: + # If seqlen_q == seqlen_k, the attn scores are a square matrix. + # If seqlen_q != seqlen_k, attn scores are rectangular which means + # the causal mask boundary is bottom right aligned, and ends at either + # the top edge (seqlen_q < seqlen_k) or left edge. + # This captures the decrease in n_blocks if we have a rectangular attn + # matrix + n_blocks_seqlen = cdiv_fn( + (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) + # This is what adjusts the block_max for the current WG, only + # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks + n_blocks = min(n_blocks, n_blocks_seqlen) + # If we have no blocks after adjusting for seqlen deltas, this WG is + # part of the blocks that are all 0. We exit early. + if n_blocks <= 0: + o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + + off_h_q * stride_oh) + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) + # We still need to write 0s to the result + # tl.store(O_block_ptr, + # acc.to(Out.type.element_ty), boundary_check=(0,1)) + # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + # + offs_m + # We store inf to LSE, not -inf because in the bwd pass, + # we subtract this + # from qk which makes it -inf, such that exp(qk - inf) = 0 + # for these masked blocks. + # l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) + # tl.store(l_ptrs, l) + # TODO: Should dropout and return encoded softmax be handled here? + return + + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE: tl.constexpr = HQ // HK + off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q + + n_extra_tokens = 0 + if seqlen_k < BLOCK_N: + n_extra_tokens = BLOCK_N - seqlen_k + elif seqlen_k % BLOCK_N: + n_extra_tokens = seqlen_k % BLOCK_N + padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL + + # Compute pointers for all the tensors used in this kernel. + q_offset = (off_z * stride_qz + off_h_q * stride_qh + + cu_seqlens_q_start * stride_qm) + Q_block_ptr = tl.make_block_ptr( + base=Q + q_offset, + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + k_offset = (off_z * stride_kz + off_h_k * stride_kh + + cu_seqlens_k_start * stride_kn) + K_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1), + ) + v_offset = (off_z * stride_vz + off_h_k * stride_vh + + cu_seqlens_k_start * stride_vk) + V_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), + ) + if BIAS_TYPE != 0: + bias_ptr = tl.make_block_ptr( + base=bias + off_h_q * stride_bh, + shape=(seqlen_q, seqlen_k), + strides=(stride_bm, stride_bn), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + else: + bias_ptr = None + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base \ + + (off_z * HQ + off_h_q) \ + * seqlen_q * seqlen_k + else: + batch_philox_offset = 0 + # We can ask to return the dropout mask without actually doing any dropout. + # In this case, we return an invalid pointer so indicate the mask is not i + # valid. + # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.make_block_ptr( + base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, + shape=(seqlen_q, seqlen_k), + strides=(seqlen_k, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + else: + encoded_softmax_block_ptr = 0 + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use 2^x in the loop as we do not + # have native e^x support in HW. + qk_scale = sm_scale * 1.44269504089 + # Q is loaded once at the beginning and shared by all N blocks. + q = load_fn(Q_block_ptr, True, padded_head, "zero") + q = (q * qk_scale).to(Q_block_ptr.type.element_ty) + + # Here we compute how many full and masked blocks we have. + padded_block_k = n_extra_tokens != 0 + is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) + if IS_CAUSAL: + # There are always at least BLOCK_M // BLOCK_N masked blocks. + # Additionally there might be one more due to dissimilar seqlens. + masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) + else: + # Padding on Q does not need to be masked in the FA loop. + masked_blocks = padded_block_k + # if IS_CAUSAL, not is_modulo_mn does not always result in an additional + # block. In this case we might exceed n_blocks so pick the min. + masked_blocks = min(masked_blocks, n_blocks) + n_full_blocks = n_blocks - masked_blocks + block_min = 0 + block_max = n_blocks * BLOCK_N + # Compute for full blocks. Here we set causal to false regardless of its + # value because there is no masking. Similarly we do not need padding. + if n_full_blocks > 0: + block_max = (n_blocks - masked_blocks) * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ + block_min, + block_max, + 0, + 0, + 0, + bias_ptr, + # IS_CAUSAL, .... + False, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, + False, + ENABLE_DROPOUT, + RETURN_ENCODED_SOFTMAX, + padded_head, + ) + block_min = block_max + block_max = n_blocks * BLOCK_N + + tl.debug_barrier() + # Remaining blocks, if any, are full / not masked. + if masked_blocks > 0: + offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0 + K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, + (0, n_full_blocks)) + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + block_min, + block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + bias_ptr, + IS_CAUSAL, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, + True, + ENABLE_DROPOUT, + RETURN_ENCODED_SOFTMAX, + padded_head, + ) + # epilogue + acc = acc / l_i[:, None] + if ENABLE_DROPOUT: + acc = acc / (1 - dropout_p) + # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, + # then we have one block with a row of all NaNs which come from computing + # softmax over a row of all -infs (-inf - inf = NaN). We check for that here + # and store 0s where there are NaNs as these rows should've been zeroed out. + end_m_idx = (start_m + 1) * BLOCK_M + start_m_idx = start_m * BLOCK_M + causal_start_idx = seqlen_q - seqlen_k + acc = acc.to(Out.type.element_ty) + if IS_CAUSAL: # noqa: SIM102 + if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: + out_mask_boundary = tl.full((BLOCK_DMODEL, ), + causal_start_idx, + dtype=tl.int32) + mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) + out_ptrs_mask = (mask_m_offsets[:, None] + >= out_mask_boundary[None, :]) + z = 0.0 + acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) + # write back LSE + # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last + # few rows. This is only true for the last M block. For others, + # overflow_size will be -ve + # overflow_size = end_m_idx - seqlen_q + # if overflow_size > 0: + # boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) + # # This is a > check because mask being 0 blocks the store. + # l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) + # tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) + # else: + # tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + + # write back O + o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + + off_h_q * stride_oh) + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + # Need boundary check on this to make sure the padding from the + # Q and KV tensors in both dims are not part of what we store back. + # TODO: Do the boundary check optionally. + tl.store(O_block_ptr, acc, boundary_check=(0, 1)) + + +def check_args( + q, + k, + v, + o, + varlen=True, + max_seqlens=None, + cu_seqlens_q=None, + cu_seqlens_k=None, +): + assert q.dim() == k.dim() and q.dim() == v.dim() + if varlen: + assert q.dim() == 3 + total_q, nheads_q, head_size = q.shape + total_k, nheads_k, _ = k.shape + assert cu_seqlens_q is not None + assert cu_seqlens_k is not None + assert len(cu_seqlens_q) == len(cu_seqlens_k) + else: + assert q.dim() == 4 + batch, nheads_q, seqlen_q, head_size = q.shape + _, nheads_k, seqlen_k, _ = k.shape + assert max_seqlens > 0 + assert k.shape == v.shape + assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] + # TODO: Change assert if we support qkl f8 and v f16 + assert q.dtype == k.dtype and q.dtype == v.dtype + assert head_size <= 256 + assert o.shape == q.shape + assert (nheads_q % nheads_k) == 0 + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + q, + k, + v, + o, + cu_seqlens_q, + cu_seqlens_k, + max_seqlens_q, + max_seqlens_k, + causal=False, + sm_scale=1.0, + bias=None, + ): + if o is None: + o = torch.empty_like(q, dtype=v.dtype) + + check_args( + q, + k, + v, + o, + varlen=True, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + ) + if True: # varlen + total_q, nheads_q, head_size = q.shape + total_k, nheads_k, _ = k.shape + batch = len(cu_seqlens_q) - 1 + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) + v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) + o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) + else: + batch, seqlen_q, nheads_q, head_size = q.shape + _, seqlen_k, nheads_k, _ = k.shape + q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) + k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) + v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) + o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + + # Get closest power of 2 over or equal to 32. + unpadded_head_dims = {32, 64, 128, 256} + if head_size not in unpadded_head_dims: + padded_d_model = None + for i in unpadded_head_dims: + if i > head_size: + padded_d_model = i + break + assert padded_d_model is not None + else: + padded_d_model = head_size + + grid = lambda META: ( + triton.cdiv(max_seqlens_q, META["BLOCK_M"]), + nheads_q, + batch, + ) + + encoded_softmax = None + + # Seed the RNG so we get reproducible results for testing. + philox_seed = 0x1BF52 + philox_offset = 0x1D4B42 + + if bias is not None: + bias_strides = ( + bias.stride(0), + bias.stride(1), + bias.stride(2), + bias.stride(3), + ) + else: + bias_strides = (0, 0, 0, 0) + + attn_fwd[grid]( + q, + k, + v, + bias, + sm_scale, + None, + o, + *q_strides, + *k_strides, + *v_strides, + *o_strides, + *bias_strides, + cu_seqlens_q, + cu_seqlens_k, + dropout_p=0.0, + philox_seed=philox_seed, + philox_offset_base=philox_offset, + encoded_softmax=encoded_softmax, + HQ=nheads_q, + HK=nheads_k, + ACTUAL_BLOCK_DMODEL=head_size, + MAX_SEQLENS_Q=max_seqlens_q, + MAX_SEQLENS_K=max_seqlens_k, + IS_CAUSAL=causal, + VARLEN=True, + BLOCK_DMODEL=padded_d_model, + BIAS_TYPE=0 if bias is None else 1, + ENABLE_DROPOUT=False, + RETURN_ENCODED_SOFTMAX=False, + ) + + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = head_size + ctx.causal = causal + ctx.dropout_p = 0.0 + ctx.philox_seed = philox_seed + ctx.philox_offset = philox_offset + ctx.encoded_softmax = encoded_softmax + ctx.return_encoded_softmax = False + return o, encoded_softmax + + +triton_attention = _attention.apply diff --git a/vllm/attention/ops/triton_merge_attn_states.py b/vllm/attention/ops/triton_merge_attn_states.py new file mode 100644 index 0000000..9671b93 --- /dev/null +++ b/vllm/attention/ops/triton_merge_attn_states.py @@ -0,0 +1,93 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import torch +import triton +import triton.language as tl + + +# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 +# can be used to combine partial attention results (in the split-KV case) +def merge_attn_states( + output: torch.Tensor, + prefix_output: torch.Tensor, + prefix_lse: torch.Tensor, + suffix_output: torch.Tensor, + suffix_lse: torch.Tensor, + output_lse: Optional[torch.Tensor] = None, +) -> None: + num_tokens = output.shape[0] + num_query_heads = output.shape[1] + head_size = output.shape[2] + padded_head_size = triton.next_power_of_2(head_size) + + # TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead. + merge_attn_states_kernel[(num_tokens, num_query_heads)]( + output, + output_lse, + prefix_output, + prefix_lse, + suffix_output, + suffix_lse, + head_size, + padded_head_size, + output_lse is not None, + ) + + +@triton.jit +def merge_attn_states_kernel( + output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + output_lse, # [NUM_HEADS, NUM_TOKENS] + prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + prefix_lse, # [NUM_HEADS, NUM_TOKENS] + suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + suffix_lse, # [NUM_HEADS, NUM_TOKENS] + HEAD_SIZE: tl.constexpr, + PADDED_HEAD_SIZE: tl.constexpr, + OUTPUT_LSE: tl.constexpr, +): + token_idx = tl.program_id(0) + num_tokens = tl.num_programs(0) + head_idx = tl.program_id(1) + num_heads = tl.num_programs(1) + + p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx) + s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx) + + # FA2 and FA3 have different behavior for when the sum-exp is 0, this namely + # arises with 0 len seqlens. FA3 returns -inf here while FA2 returns inf. + # If we see an inf assume FA2 and convert inf to -inf for consistency + # and correctness. Inf generally doesn't make sense in this context outside + # of undefined-behavior/FA2-case, so I think this a safe assumption. + p_lse = float('-inf') if p_lse == float('inf') else p_lse + s_lse = float('-inf') if s_lse == float('inf') else s_lse + + max_lse = tl.maximum(p_lse, s_lse) + p_lse = p_lse - max_lse + s_lse = s_lse - max_lse + out_se = (tl.exp(p_lse) + tl.exp(s_lse)) + + if OUTPUT_LSE: + out_lse = tl.log(out_se) + max_lse + tl.store(output_lse + head_idx * num_tokens + token_idx, out_lse) + + head_arange = tl.arange(0, PADDED_HEAD_SIZE) + head_mask = head_arange < HEAD_SIZE + p_out = tl.load(prefix_output + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + head_arange, + mask=head_mask) + s_out = tl.load(suffix_output + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + head_arange, + mask=head_mask) + + # NOTE(woosuk): Be careful with the numerical stability. + # We should compute the scale first, and then multiply it with the output. + # Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly. + p_scale = tl.exp(p_lse) / out_se + s_scale = tl.exp(s_lse) / out_se + out = p_out * p_scale + s_out * s_scale + tl.store(output + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + head_arange, + out, + mask=head_mask) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py new file mode 100644 index 0000000..ebbdea2 --- /dev/null +++ b/vllm/attention/selector.py @@ -0,0 +1,186 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +from contextlib import contextmanager +from functools import cache +from typing import Generator, Optional, Type + +import torch + +import vllm.envs as envs +from vllm.attention.backends.abstract import AttentionBackend +from vllm.logger import init_logger +from vllm.platforms import _Backend, current_platform +from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname + +logger = init_logger(__name__) + + +def backend_name_to_enum(backend_name: str) -> Optional[_Backend]: + """ + Convert a string backend name to a _Backend enum value. + + Returns: + * _Backend: enum value if backend_name is a valid in-tree type + * None: otherwise it's an invalid in-tree type or an out-of-tree platform is + loaded. + """ + assert backend_name is not None + return _Backend[backend_name] if backend_name in _Backend.__members__ else \ + None + + +def get_env_variable_attn_backend() -> Optional[_Backend]: + ''' + Get the backend override specified by the vLLM attention + backend environment variable, if one is specified. + + Returns: + + * _Backend enum value if an override is specified + * None otherwise + ''' + backend_name = os.environ.get(STR_BACKEND_ENV_VAR) + return (None + if backend_name is None else backend_name_to_enum(backend_name)) + + +# Global state allows a particular choice of backend +# to be forced, overriding the logic which auto-selects +# a backend based on system & workload configuration +# (default behavior if this variable is None) +# +# THIS SELECTION TAKES PRECEDENCE OVER THE +# VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE +forced_attn_backend: Optional[_Backend] = None + + +def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None: + ''' + Force all attention operations to use a specified backend. + + Passing `None` for the argument re-enables automatic + backend selection., + + Arguments: + + * attn_backend: backend selection (None to revert to auto) + ''' + global forced_attn_backend + forced_attn_backend = attn_backend + + +def get_global_forced_attn_backend() -> Optional[_Backend]: + ''' + Get the currently-forced choice of attention backend, + or None if auto-selection is currently enabled. + ''' + return forced_attn_backend + + +def get_attn_backend( + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + is_attention_free: bool, + is_blocksparse: bool = False, + use_mla: bool = False, +) -> Type[AttentionBackend]: + """Selects which attention backend to use and lazily imports it.""" + # Accessing envs.* behind an @lru_cache decorator can cause the wrong + # value to be returned from the cache if the value changes between calls. + # To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the + # private function. + return _cached_get_attn_backend( + head_size=head_size, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + block_size=block_size, + is_attention_free=is_attention_free, + is_blocksparse=is_blocksparse, + use_v1=envs.VLLM_USE_V1, + use_mla=use_mla, + ) + + +@cache +def _cached_get_attn_backend( + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + is_attention_free: bool, + is_blocksparse: bool = False, + use_v1: bool = False, + use_mla: bool = False, +) -> Type[AttentionBackend]: + if is_blocksparse: + logger.info("Using BlocksparseFlashAttention backend.") + from vllm.attention.backends.blocksparse_attn import ( + BlocksparseFlashAttentionBackend) + return BlocksparseFlashAttentionBackend + + # If there are no attention layers (e.g. we are running Mamba), + # use the placeholder NO_ATTENTION + if is_attention_free: + from vllm.attention.backends.placeholder_attn import ( + PlaceholderAttentionBackend) + return PlaceholderAttentionBackend + + # Check whether a particular choice of backend was + # previously forced. + # + # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND + # ENVIRONMENT VARIABLE. + selected_backend = None + backend_by_global_setting: Optional[_Backend] = ( + get_global_forced_attn_backend()) + if backend_by_global_setting is not None: + selected_backend = backend_by_global_setting + else: + # Check the environment variable and override if specified + backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND + if backend_by_env_var is not None: + selected_backend = backend_name_to_enum(backend_by_env_var) + + # get device-specific attn_backend + attention_cls = current_platform.get_attn_backend_cls( + selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, + use_mla) + if not attention_cls: + raise ValueError( + f"Invalid attention backend for {current_platform.device_name}") + return resolve_obj_by_qualname(attention_cls) + + +@contextmanager +def global_force_attn_backend_context_manager( + attn_backend: _Backend) -> Generator[None, None, None]: + ''' + Globally force a vLLM attention backend override within a + context manager, reverting the global attention backend + override to its prior state upon exiting the context + manager. + + Arguments: + + * attn_backend: attention backend to force + + Returns: + + * Generator + ''' + + # Save the current state of the global backend override (if any) + original_value = get_global_forced_attn_backend() + + # Globally force the new backend override + global_force_attn_backend(attn_backend) + + # Yield control back to the enclosed code block + try: + yield + finally: + # Revert the original global backend override, if any + global_force_attn_backend(original_value) diff --git a/vllm/beam_search.py b/vllm/beam_search.py new file mode 100644 index 0000000..5d4ebdb --- /dev/null +++ b/vllm/beam_search.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional, Union + +from vllm.sequence import Logprob + +if TYPE_CHECKING: + from vllm.multimodal import MultiModalDataDict + + +@dataclass +class BeamSearchSequence: + """A sequence for beam search. + It keeps track of the tokens and the log probability of the sequence. + The text field is optional and will only be filled when the sequence is + about to be returned to the user. + """ + # The tokens includes the prompt. + tokens: list[int] + logprobs: list[dict[int, Logprob]] + cum_logprob: float = 0.0 + text: Optional[str] = None + finish_reason: Optional[str] = None + stop_reason: Union[int, str, None] = None + multi_modal_data: Optional["MultiModalDataDict"] = None + mm_processor_kwargs: Optional[dict[str, Any]] = None + + +@dataclass +class BeamSearchOutput: + """The output of beam search. + It contains the list of the best beam search sequences. + The length of the list is equal to the beam width. + """ + sequences: list[BeamSearchSequence] + + +class BeamSearchInstance: + + def __init__(self, prompt_tokens: list[int]): + self.beams: list[BeamSearchSequence] = [ + BeamSearchSequence(tokens=prompt_tokens, logprobs=[]) + ] + self.completed: list[BeamSearchSequence] = [] + + +def get_beam_search_score( + tokens: list[int], + cumulative_logprob: float, + eos_token_id: int, + length_penalty: float = 1.0, +) -> float: + """Calculate the beam search score with length penalty. + + Adapted from + + https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938 + """ + seq_len = len(tokens) + if tokens[-1] == eos_token_id: + seq_len -= 1 + + return cumulative_logprob / (seq_len**length_penalty) + + +def create_sort_beams_key_function(eos_token_id: int, length_penalty: float): + + def sort_beams_key(x: BeamSearchSequence) -> float: + return get_beam_search_score(x.tokens, x.cum_logprob, eos_token_id, + length_penalty) + + return sort_beams_key diff --git a/vllm/benchmarks/endpoint_request_func.py b/vllm/benchmarks/endpoint_request_func.py new file mode 100644 index 0000000..32767a8 --- /dev/null +++ b/vllm/benchmarks/endpoint_request_func.py @@ -0,0 +1,160 @@ +# SPDX-License-Identifier: Apache-2.0 +"""The request function for API endpoints.""" + +import json +import os +import sys +import time +import traceback +from dataclasses import dataclass, field +from typing import Optional + +import aiohttp +from tqdm.asyncio import tqdm + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) + + +@dataclass +class RequestFuncInput: + """The input for the request function.""" + prompt: str + api_url: str + prompt_len: int + output_len: int + model: str + model_name: Optional[str] = None + best_of: int = 1 + logprobs: Optional[int] = None + extra_body: Optional[dict] = None + multi_modal_content: Optional[dict] = None + ignore_eos: bool = False + + +@dataclass +class RequestFuncOutput: + """The output of the request function including metrics.""" + generated_text: str = "" + success: bool = False + latency: float = 0.0 + output_tokens: int = 0 + ttft: float = 0.0 # Time to first token + itl: list[float] = field( + default_factory=list) # list of inter-token latencies + tpot: float = 0.0 # avg next-token latencies + prompt_len: int = 0 + error: str = "" + + +async def async_request_openai_completions( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + """The async request function for the OpenAI Completions API. + + Args: + request_func_input: The input for the request function. + pbar: The progress bar to display the progress. + + Returns: + The output of the request function. + """ + api_url = request_func_input.api_url + assert api_url.endswith( + ("completions", "profile") + ), "OpenAI Completions API URL must end with 'completions' or 'profile'." + + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: + payload = { + "model": request_func_input.model_name \ + if request_func_input.model_name else request_func_input.model, + "prompt": request_func_input.prompt, + "temperature": 0.0, + "best_of": request_func_input.best_of, + "max_tokens": request_func_input.output_len, + "logprobs": request_func_input.logprobs, + "stream": True, + "stream_options": { + "include_usage": True, + }, + } + if request_func_input.ignore_eos: + payload["ignore_eos"] = request_func_input.ignore_eos + if request_func_input.extra_body: + payload.update(request_func_input.extra_body) + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" + } + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload, + headers=headers) as response: + if response.status == 200: + first_chunk_received = False + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = chunk_bytes.decode("utf-8").removeprefix( + "data: ") + if chunk != "[DONE]": + data = json.loads(chunk) + + # NOTE: Some completion API might have a last + # usage summary response without a token so we + # want to check a token was generated + if choices := data.get("choices"): + # Note that text could be empty here + # e.g. for special tokens + text = choices[0].get("text") + timestamp = time.perf_counter() + # First token + if not first_chunk_received: + first_chunk_received = True + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - + most_recent_timestamp) + + most_recent_timestamp = timestamp + generated_text += text or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get( + "completion_tokens") + if first_chunk_received: + output.success = True + else: + output.success = False + output.error = ( + "Never received a valid chunk to calculate TTFT." + "This response will be marked as failed!") + output.generated_text = generated_text + output.latency = most_recent_timestamp - st + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +# TODO: Add more request functions for different API protocols. +ASYNC_REQUEST_FUNCS = { + "openai-comp": async_request_openai_completions, +} diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py new file mode 100644 index 0000000..813556f --- /dev/null +++ b/vllm/benchmarks/serve.py @@ -0,0 +1,925 @@ +# SPDX-License-Identifier: Apache-2.0 +r"""Benchmark online serving throughput. + +On the server side, run one of the following commands +to launch the vLLM OpenAI API server: + vllm serve + +On the client side, run: + vllm bench serve \ + --endpoint-type \ + --label \ + --model \ + --dataset-name \ + --request-rate \ + --num-prompts +""" +import argparse +import asyncio +import gc +import json +import os +import random +import time +import warnings +from collections.abc import AsyncGenerator +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Optional + +import numpy as np +from tqdm.asyncio import tqdm +from transformers import PreTrainedTokenizerBase + +from vllm.benchmarks.endpoint_request_func import (ASYNC_REQUEST_FUNCS, + RequestFuncInput, + RequestFuncOutput) +from vllm.benchmarks.utils import (convert_to_pytorch_benchmark_format, + write_to_json) +from vllm.transformers_utils.tokenizer import get_tokenizer + +MILLISECONDS_TO_SECONDS_CONVERSION = 1000 + + +@dataclass +class BenchmarkMetrics: + completed: int + total_input: int + total_output: int + request_throughput: float + request_goodput: float + output_throughput: float + total_token_throughput: float + mean_ttft_ms: float + median_ttft_ms: float + std_ttft_ms: float + percentiles_ttft_ms: list[tuple[float, float]] + mean_tpot_ms: float + median_tpot_ms: float + std_tpot_ms: float + percentiles_tpot_ms: list[tuple[float, float]] + mean_itl_ms: float + median_itl_ms: float + std_itl_ms: float + percentiles_itl_ms: list[tuple[float, float]] + # E2EL stands for end-to-end latency per request. + # It is the time taken on the client side from sending + # a request to receiving a complete response. + mean_e2el_ms: float + median_e2el_ms: float + std_e2el_ms: float + percentiles_e2el_ms: list[tuple[float, float]] + + +def sample_random_requests( + prefix_len: int, + input_len: int, + output_len: int, + num_prompts: int, + range_ratio: float, + tokenizer: PreTrainedTokenizerBase, +) -> list[tuple[str, int, int]]: + prefix_token_ids = np.random.randint(0, + tokenizer.vocab_size, + size=prefix_len).tolist() + + input_lens = np.random.randint( + int(input_len * range_ratio), + input_len + 1, + size=num_prompts, + ) + output_lens = np.random.randint( + int(output_len * range_ratio), + output_len + 1, + size=num_prompts, + ) + offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts) + input_requests = [] + for i in range(num_prompts): + prompt = tokenizer.decode(prefix_token_ids + + [(offsets[i] + i + j) % tokenizer.vocab_size + for j in range(input_lens[i])]) + + input_requests.append((prompt, int(prefix_len + input_lens[i]), + int(output_lens[i]), None)) + + return input_requests + + +async def get_request( + input_requests: list[tuple[str, int, int]], + request_rate: float, + burstiness: float = 1.0, +) -> AsyncGenerator[tuple[str, int, int], None]: + """ + Asynchronously generates requests at a specified rate + with OPTIONAL burstiness. + + Args: + input_requests: + A list of input requests, each represented as a tuple. + request_rate: + The rate at which requests are generated (requests/s). + burstiness (optional): + The burstiness factor of the request generation. + Only takes effect when request_rate is not inf. + Default value is 1, which follows a Poisson process. + Otherwise, the request intervals follow a gamma distribution. + A lower burstiness value (0 < burstiness < 1) results + in more bursty requests, while a higher burstiness value + (burstiness > 1) results in a more uniform arrival of requests. + """ + input_requests = iter(input_requests) + + # Calculate scale parameter theta to maintain the desired request_rate. + assert burstiness > 0, ( + f"A positive burstiness factor is expected, but given {burstiness}.") + theta = 1.0 / (request_rate * burstiness) + + for request in input_requests: + yield request + + if request_rate == float("inf"): + # If the request rate is infinity, then we don't need to wait. + continue + + # Sample the request interval from the gamma distribution. + # If burstiness is 1, it follows exponential distribution. + interval = np.random.gamma(shape=burstiness, scale=theta) + # The next request will be sent after the interval. + await asyncio.sleep(interval) + + +def calculate_metrics( + input_requests: list[tuple[str, int, int]], + outputs: list[RequestFuncOutput], + dur_s: float, + tokenizer: PreTrainedTokenizerBase, + selected_percentiles: list[float], + goodput_config_dict: dict[str, float], +) -> tuple[BenchmarkMetrics, list[int]]: + """Calculate the metrics for the benchmark. + + Args: + input_requests: The input requests. + outputs: The outputs of the requests. + dur_s: The duration of the benchmark. + tokenizer: The tokenizer to use. + selected_percentiles: The percentiles to select. + goodput_config_dict: The goodput configuration. + + Returns: + A tuple of the benchmark metrics and the actual output lengths. + """ + actual_output_lens: list[int] = [] + total_input = 0 + completed = 0 + good_completed = 0 + itls: list[float] = [] + tpots: list[float] = [] + all_tpots: list[float] = [] + ttfts: list[float] = [] + e2els: list[float] = [] + for i in range(len(outputs)): + if outputs[i].success: + output_len = outputs[i].output_tokens + + if output_len is None: + # We use the tokenizer to count the number of output tokens + # for some serving backends instead of looking at + # len(outputs[i].itl) since multiple output tokens may be + # bundled together + # Note : this may inflate the output token count slightly + output_len = len( + tokenizer(outputs[i].generated_text, + add_special_tokens=False).input_ids) + actual_output_lens.append(output_len) + total_input += input_requests[i][1] + tpot = 0 + if output_len > 1: + latency_minus_ttft = outputs[i].latency - outputs[i].ttft + tpot = latency_minus_ttft / (output_len - 1) + tpots.append(tpot) + # Note: if output_len <= 1, we regard tpot as 0 for goodput + all_tpots.append(tpot) + itls += outputs[i].itl + ttfts.append(outputs[i].ttft) + e2els.append(outputs[i].latency) + completed += 1 + else: + actual_output_lens.append(0) + + if goodput_config_dict: + valid_metrics = [] + slo_values = [] + + if "ttft" in goodput_config_dict: + valid_metrics.append(ttfts) + slo_values.append(goodput_config_dict["ttft"] / + MILLISECONDS_TO_SECONDS_CONVERSION) + if "tpot" in goodput_config_dict: + valid_metrics.append(all_tpots) + slo_values.append(goodput_config_dict["tpot"] / + MILLISECONDS_TO_SECONDS_CONVERSION) + if "e2el" in goodput_config_dict: + valid_metrics.append(e2els) + slo_values.append(goodput_config_dict["e2el"] / + MILLISECONDS_TO_SECONDS_CONVERSION) + + for req_metric in zip(*valid_metrics): + is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)]) + if is_good_req: + good_completed += 1 + + if completed == 0: + warnings.warn( + "All requests failed. This is likely due to a misconfiguration " + "on the benchmark arguments.", + stacklevel=2) + metrics = BenchmarkMetrics( + completed=completed, + total_input=total_input, + total_output=sum(actual_output_lens), + request_throughput=completed / dur_s, + request_goodput=good_completed / dur_s, + output_throughput=sum(actual_output_lens) / dur_s, + total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, + mean_ttft_ms=np.mean(ttfts or 0) * + 1000, # ttfts is empty if streaming is not supported by the endpoint + std_ttft_ms=np.std(ttfts or 0) * 1000, + median_ttft_ms=np.median(ttfts or 0) * 1000, + percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) + for p in selected_percentiles], + mean_tpot_ms=np.mean(tpots or 0) * 1000, + std_tpot_ms=np.std(tpots or 0) * 1000, + median_tpot_ms=np.median(tpots or 0) * 1000, + percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) + for p in selected_percentiles], + mean_itl_ms=np.mean(itls or 0) * 1000, + std_itl_ms=np.std(itls or 0) * 1000, + median_itl_ms=np.median(itls or 0) * 1000, + percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) + for p in selected_percentiles], + mean_e2el_ms=np.mean(e2els or 0) * 1000, + std_e2el_ms=np.std(e2els or 0) * 1000, + median_e2el_ms=np.median(e2els or 0) * 1000, + percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) + for p in selected_percentiles], + ) + + return metrics, actual_output_lens + + +async def benchmark( + endpoint_type: str, + api_url: str, + base_url: str, + model_id: str, + model_name: str, + tokenizer: PreTrainedTokenizerBase, + input_requests: list[tuple[str, int, int]], + logprobs: Optional[int], + best_of: int, + request_rate: float, + burstiness: float, + disable_tqdm: bool, + profile: bool, + selected_percentile_metrics: list[str], + selected_percentiles: list[str], + ignore_eos: bool, + goodput_config_dict: dict[str, float], + max_concurrency: Optional[int], + lora_modules: Optional[list[str]], +): + if endpoint_type in ASYNC_REQUEST_FUNCS: + request_func = ASYNC_REQUEST_FUNCS[endpoint_type] + else: + raise ValueError(f"Unknown endpoint_type: {endpoint_type}") + + print("Starting initial single prompt test run...") + test_prompt, test_prompt_len, test_output_len, test_mm_content = ( + input_requests[0]) + if endpoint_type != "openai-chat" and test_mm_content is not None: + # multi-modal benchmark is only available on OpenAI Chat endpoint. + raise ValueError("Multi-modal content is only supported on " + "'openai-chat' endpoint_type.") + test_input = RequestFuncInput( + model=model_id, + model_name=model_name, + prompt=test_prompt, + api_url=api_url, + prompt_len=test_prompt_len, + output_len=test_output_len, + logprobs=logprobs, + best_of=best_of, + multi_modal_content=test_mm_content, + ignore_eos=ignore_eos, + ) + + test_output = await request_func(request_func_input=test_input) + if not test_output.success: + raise ValueError( + "Initial test run failed - Please make sure benchmark arguments " + f"are correctly specified. Error: {test_output.error}") + else: + print("Initial test run completed. Starting main benchmark run...") + + if lora_modules: + # For each input request, choose a LoRA module at random. + lora_modules = iter( + [random.choice(lora_modules) for _ in range(len(input_requests))]) + + if profile: + print("Starting profiler...") + profile_input = RequestFuncInput(model=model_id, + model_name=model_name, + prompt=test_prompt, + api_url=base_url + "/start_profile", + prompt_len=test_prompt_len, + output_len=test_output_len, + logprobs=logprobs, + best_of=best_of, + multi_modal_content=test_mm_content, + ignore_eos=ignore_eos) + profile_output = await request_func(request_func_input=profile_input) + if profile_output.success: + print("Profiler started") + + if burstiness == 1.0: + distribution = "Poisson process" + else: + distribution = "Gamma distribution" + + print(f"Traffic request rate: {request_rate}") + print(f"Burstiness factor: {burstiness} ({distribution})") + print(f"Maximum request concurrency: {max_concurrency}") + + pbar = None if disable_tqdm else tqdm(total=len(input_requests)) + + # This can be used once the minimum Python version is 3.10 or higher, + # and it will simplify the code in limited_request_func. + # semaphore = (asyncio.Semaphore(max_concurrency) + # if max_concurrency else contextlib.nullcontext()) + semaphore = (asyncio.Semaphore(max_concurrency) + if max_concurrency else None) + + async def limited_request_func(request_func_input, pbar): + if semaphore is None: + return await request_func(request_func_input=request_func_input, + pbar=pbar) + async with semaphore: + return await request_func(request_func_input=request_func_input, + pbar=pbar) + + benchmark_start_time = time.perf_counter() + tasks: list[asyncio.Task] = [] + async for request in get_request(input_requests, request_rate, burstiness): + prompt, prompt_len, output_len, mm_content = request + req_model_id, req_model_name = model_id, model_name + if lora_modules: + req_lora_module = next(lora_modules) + req_model_id, req_model_name = req_lora_module, req_lora_module + + request_func_input = RequestFuncInput(model=req_model_id, + model_name=req_model_name, + prompt=prompt, + api_url=api_url, + prompt_len=prompt_len, + output_len=output_len, + logprobs=logprobs, + best_of=best_of, + multi_modal_content=mm_content, + ignore_eos=ignore_eos) + tasks.append( + asyncio.create_task( + limited_request_func(request_func_input=request_func_input, + pbar=pbar))) + outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) + + if profile: + print("Stopping profiler...") + profile_input = RequestFuncInput( + model=model_id, + prompt=test_prompt, + api_url=base_url + "/stop_profile", + prompt_len=test_prompt_len, + output_len=test_output_len, + logprobs=logprobs, + best_of=best_of, + ) + profile_output = await request_func(request_func_input=profile_input) + if profile_output.success: + print("Profiler stopped") + + if pbar is not None: + pbar.close() + + benchmark_duration = time.perf_counter() - benchmark_start_time + + metrics, actual_output_lens = calculate_metrics( + input_requests=input_requests, + outputs=outputs, + dur_s=benchmark_duration, + tokenizer=tokenizer, + selected_percentiles=selected_percentiles, + goodput_config_dict=goodput_config_dict, + ) + + print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) + print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", + benchmark_duration)) + print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) + print("{:<40} {:<10}".format("Total generated tokens:", + metrics.total_output)) + print("{:<40} {:<10.2f}".format("Request throughput (req/s):", + metrics.request_throughput)) + if goodput_config_dict: + print("{:<40} {:<10.2f}".format("Request goodput (req/s):", + metrics.request_goodput)) + print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", + metrics.output_throughput)) + print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", + metrics.total_token_throughput)) + + result = { + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "total_output_tokens": metrics.total_output, + "request_throughput": metrics.request_throughput, + "request_goodput:": + metrics.request_goodput if goodput_config_dict else None, + "output_throughput": metrics.output_throughput, + "total_token_throughput": metrics.total_token_throughput, + "input_lens": [output.prompt_len for output in outputs], + "output_lens": actual_output_lens, + "ttfts": [output.ttft for output in outputs], + "itls": [output.itl for output in outputs], + "generated_texts": [output.generated_text for output in outputs], + "errors": [output.error for output in outputs], + } + + def process_one_metric( + # E.g., "ttft" + metric_attribute_name: str, + # E.g., "TTFT" + metric_name: str, + # E.g., "Time to First Token" + metric_header: str, + ): + # This function prints and adds statistics of the specified + # metric. + if metric_attribute_name not in selected_percentile_metrics: + return + print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) + print("{:<40} {:<10.2f}".format( + f"Mean {metric_name} (ms):", + getattr(metrics, f"mean_{metric_attribute_name}_ms"))) + print("{:<40} {:<10.2f}".format( + f"Median {metric_name} (ms):", + getattr(metrics, f"median_{metric_attribute_name}_ms"))) + result[f"mean_{metric_attribute_name}_ms"] = getattr( + metrics, f"mean_{metric_attribute_name}_ms") + result[f"median_{metric_attribute_name}_ms"] = getattr( + metrics, f"median_{metric_attribute_name}_ms") + result[f"std_{metric_attribute_name}_ms"] = getattr( + metrics, f"std_{metric_attribute_name}_ms") + for p, value in getattr(metrics, + f"percentiles_{metric_attribute_name}_ms"): + p_word = str(int(p)) if int(p) == p else str(p) + print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", + value)) + result[f"p{p_word}_{metric_attribute_name}_ms"] = value + + process_one_metric("ttft", "TTFT", "Time to First Token") + process_one_metric("tpot", "TPOT", + "Time per Output Token (excl. 1st token)") + process_one_metric("itl", "ITL", "Inter-token Latency") + process_one_metric("e2el", "E2EL", "End-to-end Latency") + + print("=" * 50) + + return result + + +def check_goodput_args(args): + # Check and parse goodput arguments + goodput_config_dict = {} + VALID_NAMES = ["ttft", "tpot", "e2el"] + if args.goodput: + goodput_config_dict = parse_goodput(args.goodput) + for slo_name, slo_val in goodput_config_dict.items(): + if slo_name not in VALID_NAMES: + raise ValueError( + f"Invalid metric name found, {slo_name}: {slo_val}. " + "The service level objective name should be one of " + f"{str(VALID_NAMES)}. ") + if slo_val < 0: + raise ValueError( + f"Invalid value found, {slo_name}: {slo_val}. " + "The service level objective value should be " + "non-negative.") + return goodput_config_dict + + +def parse_goodput(slo_pairs): + goodput_config_dict = {} + try: + for slo_pair in slo_pairs: + slo_name, slo_val = slo_pair.split(":") + goodput_config_dict[slo_name] = float(slo_val) + except ValueError as err: + raise argparse.ArgumentTypeError( + "Invalid format found for service level objectives. " + "Specify service level objectives for goodput as \"KEY:VALUE\" " + "pairs, where the key is a metric name, and the value is a " + "number in milliseconds.") from err + return goodput_config_dict + + +def save_to_pytorch_benchmark_format(args: argparse.Namespace, + results: dict[str, Any], + file_name: str) -> None: + metrics = [ + "median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms", + "mean_tpot_ms", "median_tpot_ms", "std_tpot_ms", "p99_tpot_ms", + "median_itl_ms", "mean_itl_ms", "std_itl_ms", "p99_itl_ms" + ] + # These raw data might be useful, but they are rather big. They can be added + # later if needed + ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"] + pt_records = convert_to_pytorch_benchmark_format( + args=args, + metrics={k: [results[k]] + for k in metrics}, + extra_info={ + k: results[k] + for k in results if k not in metrics and k not in ignored_metrics + }) + if pt_records: + # Don't use json suffix here as we don't want CI to pick it up + pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json" + write_to_json(pt_file, pt_records) + + +def add_cli_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--endpoint-type", + type=str, + default="openai-comp", + choices=list(ASYNC_REQUEST_FUNCS.keys()), + ) + parser.add_argument( + "--label", + type=str, + default=None, + help="The label (prefix) of the benchmark results. If not specified, " + "the endpoint type will be used as the label.", + ) + parser.add_argument( + "--base-url", + type=str, + default=None, + help="Server or API base url if not using http host and port.", + ) + # Use 127.0.0.1 here instead of localhost to force the use of ipv4 + parser.add_argument("--host", type=str, default="127.0.0.1") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument( + "--endpoint", + type=str, + default="/v1/completions", + help="API endpoint.", + ) + parser.add_argument( + "--dataset-name", + type=str, + default="random", + choices=["random"], + help="Name of the dataset to benchmark on.", + ) + parser.add_argument( + "--max-concurrency", + type=int, + default=None, + help="Maximum number of concurrent requests. This can be used " + "to help simulate an environment where a higher level component " + "is enforcing a maximum number of concurrent requests. While the " + "--request-rate argument controls the rate at which requests are " + "initiated, this argument will control how many are actually allowed " + "to execute at a time. This means that when used in combination, the " + "actual request rate may be lower than specified with --request-rate, " + "if the server is not processing requests fast enough to keep up.") + + parser.add_argument( + "--model", + type=str, + required=True, + help="Name of the model.", + ) + parser.add_argument( + "--tokenizer", + type=str, + help= + "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 + ) + parser.add_argument( + "--best-of", + type=int, + default=1, + help="Generates `best_of` sequences per prompt and " + "returns the best one.", + ) + parser.add_argument("--use-beam-search", action="store_true") + parser.add_argument( + "--num-prompts", + type=int, + default=1000, + help="Number of prompts to process.", + ) + parser.add_argument( + "--logprobs", + type=int, + default=None, + help=("Number of logprobs-per-token to compute & return as part of " + "the request. If unspecified, then either (1) if beam search " + "is disabled, no logprobs are computed & a single dummy " + "logprob is returned for each token; or (2) if beam search " + "is enabled 1 logprob per token is computed"), + ) + parser.add_argument( + "--request-rate", + type=float, + default=float("inf"), + help="Number of requests per second. If this is inf, " + "then all the requests are sent at time 0. " + "Otherwise, we use Poisson process or gamma distribution " + "to synthesize the request arrival times.", + ) + parser.add_argument( + "--burstiness", + type=float, + default=1.0, + help="Burstiness factor of the request generation. " + "Only take effect when request_rate is not inf. " + "Default value is 1, which follows Poisson process. " + "Otherwise, the request intervals follow a gamma distribution. " + "A lower burstiness value (0 < burstiness < 1) results in more " + "bursty requests. A higher burstiness value (burstiness > 1) " + "results in a more uniform arrival of requests.", + ) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--trust-remote-code", + action="store_true", + help="Trust remote code from huggingface", + ) + parser.add_argument( + "--disable-tqdm", + action="store_true", + help="Specify to disable tqdm progress bar.", + ) + parser.add_argument( + "--profile", + action="store_true", + help="Use Torch Profiler. The endpoint must be launched with " + "VLLM_TORCH_PROFILER_DIR to enable profiler.", + ) + parser.add_argument( + "--save-result", + action="store_true", + help="Specify to save benchmark results to a json file", + ) + parser.add_argument( + "--metadata", + metavar="KEY=VALUE", + nargs="*", + help="Key-value pairs (e.g, --metadata version=0.3.3 tp=1) " + "for metadata of this run to be saved in the result JSON file " + "for record keeping purposes.", + ) + parser.add_argument( + "--result-dir", + type=str, + default=None, + help="Specify directory to save benchmark json results." + "If not specified, results are saved in the current directory.", + ) + parser.add_argument( + "--result-filename", + type=str, + default=None, + help="Specify the filename to save benchmark json results." + "If not specified, results will be saved in " + "{label}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" # noqa + " format.", + ) + parser.add_argument( + "--ignore-eos", + action="store_true", + help="Set ignore_eos flag when sending the benchmark request." + "Warning: ignore_eos is not supported in deepspeed_mii and tgi.") + parser.add_argument( + "--percentile-metrics", + type=str, + default="ttft,tpot,itl", + help="Comma-seperated list of selected metrics to report percentils. " + "This argument specifies the metrics to report percentiles. " + "Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". ") + parser.add_argument( + "--metric-percentiles", + type=str, + default="99", + help="Comma-seperated list of percentiles for selected metrics. " + "To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". " + "Use \"--percentile-metrics\" to select metrics.", + ) + parser.add_argument( + "--goodput", + nargs="+", + required=False, + help="Specify service level objectives for goodput as \"KEY:VALUE\" " + "pairs, where the key is a metric name, and the value is in " + "milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, " + "separated by spaces. Allowed request level metric names are " + "\"ttft\", \"tpot\", \"e2el\". For more context on the definition of " + "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " + "and the blog: https://hao-ai-lab.github.io/blogs/distserve") + + random_group = parser.add_argument_group("random dataset options") + random_group.add_argument( + "--random-input-len", + type=int, + default=1024, + help= + "Number of input tokens per request, used only for random sampling.", + ) + random_group.add_argument( + "--random-output-len", + type=int, + default=128, + help= + "Number of output tokens per request, used only for random sampling.", + ) + random_group.add_argument( + "--random-range-ratio", + type=float, + default=1.0, + help="Range of sampled ratio of input/output length, " + "used only for random sampling.", + ) + random_group.add_argument( + "--random-prefix-len", + type=int, + default=0, + help="Number of fixed prefix tokens before random " + " context. The length range of context in a random " + " request is [random-prefix-len, " + " random-prefix-len + random-prefix-len * random-range-ratio).") + + parser.add_argument( + '--tokenizer-mode', + type=str, + default="auto", + choices=['auto', 'slow', 'mistral', 'custom'], + help='The tokenizer mode.\n\n* "auto" will use the ' + 'fast tokenizer if available.\n* "slow" will ' + 'always use the slow tokenizer. \n* ' + '"mistral" will always use the `mistral_common` tokenizer. \n*' + '"custom" will use --tokenizer to select the preregistered tokenizer.') + + parser.add_argument("--served-model-name", + type=str, + default=None, + help="The model name used in the API. " + "If not specified, the model name will be the " + "same as the ``--model`` argument. ") + + parser.add_argument("--lora-modules", + nargs='+', + default=None, + help="A subset of LoRA module names passed in when " + "launching the server. For each request, the " + "script chooses a LoRA module at random.") + + +def main(args: argparse.Namespace): + print(args) + random.seed(args.seed) + np.random.seed(args.seed) + + endpoint_type = args.endpoint_type + label = args.label + model_id = args.model + model_name = args.served_model_name + tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model + tokenizer_mode = args.tokenizer_mode + + if args.base_url is not None: + api_url = f"{args.base_url}{args.endpoint}" + base_url = f"{args.base_url}" + else: + api_url = f"http://{args.host}:{args.port}{args.endpoint}" + base_url = f"http://{args.host}:{args.port}" + + tokenizer = get_tokenizer(tokenizer_id, + tokenizer_mode=tokenizer_mode, + trust_remote_code=args.trust_remote_code) + # TODO: This should be refactored to use the benchmark_dataset.py + # in later PRs. + if args.dataset_name is None: + raise ValueError( + "Please specify '--dataset-name' and the corresponding " + "'--dataset-path' if required.") + elif args.dataset_name == "random": + input_requests = sample_random_requests( + prefix_len=args.random_prefix_len, + input_len=args.random_input_len, + output_len=args.random_output_len, + num_prompts=args.num_prompts, + range_ratio=args.random_range_ratio, + tokenizer=tokenizer, + ) + + else: + raise ValueError(f"Unknown dataset: {args.dataset_name}") + + goodput_config_dict = check_goodput_args(args) + + # Avoid GC processing "static" data - reduce pause times. + gc.collect() + gc.freeze() + + benchmark_result = asyncio.run( + benchmark( + endpoint_type=endpoint_type, + api_url=api_url, + base_url=base_url, + model_id=model_id, + model_name=model_name, + tokenizer=tokenizer, + input_requests=input_requests, + logprobs=args.logprobs, + best_of=args.best_of, + request_rate=args.request_rate, + burstiness=args.burstiness, + disable_tqdm=args.disable_tqdm, + profile=args.profile, + selected_percentile_metrics=args.percentile_metrics.split(","), + selected_percentiles=[ + float(p) for p in args.metric_percentiles.split(",") + ], + ignore_eos=args.ignore_eos, + goodput_config_dict=goodput_config_dict, + max_concurrency=args.max_concurrency, + lora_modules=args.lora_modules, + )) + + # Save config and results to json + if args.save_result: + result_json: dict[str, Any] = {} + + # Setup + current_dt = datetime.now().strftime("%Y%m%d-%H%M%S") + result_json["date"] = current_dt + result_json["endpoint_type"] = endpoint_type + result_json["label"] = label + result_json["model_id"] = model_id + result_json["tokenizer_id"] = tokenizer_id + result_json["best_of"] = args.best_of + result_json["num_prompts"] = args.num_prompts + + # Metadata + if args.metadata: + for item in args.metadata: + if "=" in item: + kvstring = item.split("=") + result_json[kvstring[0].strip()] = kvstring[1].strip() + else: + raise ValueError( + "Invalid metadata format. Please use KEY=VALUE format." + ) + + # Traffic + result_json["request_rate"] = (args.request_rate if args.request_rate + < float("inf") else "inf") + result_json["burstiness"] = args.burstiness + result_json["max_concurrency"] = args.max_concurrency + + # Merge with benchmark result + result_json = {**result_json, **benchmark_result} + + # Save to file + base_model_id = model_id.split("/")[-1] + max_concurrency_str = (f"-concurrency{args.max_concurrency}" + if args.max_concurrency is not None else "") + label = label or endpoint_type + file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" #noqa + if args.result_filename: + file_name = args.result_filename + if args.result_dir: + file_name = os.path.join(args.result_dir, file_name) + with open(file_name, "w", encoding='utf-8') as outfile: + json.dump(result_json, outfile) + save_to_pytorch_benchmark_format(args, result_json, file_name) diff --git a/vllm/benchmarks/utils.py b/vllm/benchmarks/utils.py new file mode 100644 index 0000000..45a0ddb --- /dev/null +++ b/vllm/benchmarks/utils.py @@ -0,0 +1,69 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import json +import math +import os +from typing import Any + + +def convert_to_pytorch_benchmark_format(args: argparse.Namespace, + metrics: dict[str, list], + extra_info: dict[str, Any]) -> list: + """ + Save the benchmark results in the format used by PyTorch OSS benchmark with + on metric per record + https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database + """ + records = [] + if not os.environ.get("SAVE_TO_PYTORCH_BENCHMARK_FORMAT", False): + return records + + for name, benchmark_values in metrics.items(): + record = { + "benchmark": { + "name": "vLLM benchmark", + "extra_info": { + "args": vars(args), + }, + }, + "model": { + "name": args.model, + }, + "metric": { + "name": name, + "benchmark_values": benchmark_values, + "extra_info": extra_info, + }, + } + + tp = record["benchmark"]["extra_info"]["args"].get( + "tensor_parallel_size") + # Save tensor_parallel_size parameter if it's part of the metadata + if not tp and "tensor_parallel_size" in extra_info: + record["benchmark"]["extra_info"]["args"][ + "tensor_parallel_size"] = extra_info["tensor_parallel_size"] + + records.append(record) + + return records + + +class InfEncoder(json.JSONEncoder): + + def clear_inf(self, o: Any): + if isinstance(o, dict): + return {k: self.clear_inf(v) for k, v in o.items()} + elif isinstance(o, list): + return [self.clear_inf(v) for v in o] + elif isinstance(o, float) and math.isinf(o): + return "inf" + return o + + def iterencode(self, o: Any, *args, **kwargs) -> Any: + return super().iterencode(self.clear_inf(o), *args, **kwargs) + + +def write_to_json(filename: str, records: list) -> None: + with open(filename, "w") as f: + json.dump(records, f, cls=InfEncoder) diff --git a/vllm/compilation/__init__.py b/vllm/compilation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py new file mode 100644 index 0000000..45988c2 --- /dev/null +++ b/vllm/compilation/backends.py @@ -0,0 +1,711 @@ +# SPDX-License-Identifier: Apache-2.0 + +import ast +import dataclasses +import os +import pprint +import time +from contextlib import ExitStack +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple +from unittest.mock import patch + +import torch +import torch.fx as fx + +import vllm.envs as envs +from vllm.config import CompilationConfig, VllmConfig +from vllm.logger import init_logger +from vllm.utils import weak_ref_tensors + +from .compiler_interface import EagerAdaptor, InductorAdaptor +from .counter import compilation_counter +from .inductor_pass import InductorPass +from .monitor import end_monitoring_torch_compile +from .pass_manager import PostGradPassManager + +logger = init_logger(__name__) + + +class CompilerManager: + """ + A manager to manage the compilation process, including + caching the compiled graph, loading the compiled graph, + and compiling the graph. + + The cache is a dict mapping + `(runtime_shape, graph_index, backend_name)` + to `any_data` returned from the compiler. + + When serializing the cache, we save it to a Python file + for readability. We don't use json here because json doesn't + support int as key. + """ + + def __init__(self, use_inductor: bool): + self.cache: Dict[Tuple[Optional[int], int, str], Any] = dict() + cls = InductorAdaptor if use_inductor else EagerAdaptor + self.compiler = cls() + + def compute_hash(self, vllm_config: VllmConfig) -> str: + return self.compiler.compute_hash(vllm_config) + + def initialize_cache(self, cache_dir: str, disable_cache: bool = False): + self.disable_cache = disable_cache + self.cache_dir = cache_dir + self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py") + + if not disable_cache and os.path.exists(self.cache_file_path): + # load the cache from the file + with open(self.cache_file_path) as f: + # we use ast.literal_eval to parse the data + # because it is a safe way to parse Python literals. + # do not use eval(), it is unsafe. + self.cache = ast.literal_eval(f.read()) + + self.compiler.initialize_cache(cache_dir=cache_dir, + disable_cache=disable_cache) + + def save_to_file(self): + if self.disable_cache: + return + with open(self.cache_file_path, "w") as f: + printer = pprint.PrettyPrinter(indent=4) + data = printer.pformat(self.cache) + f.write(data) + + def load(self, + graph: fx.GraphModule, + example_inputs: List[Any], + graph_index: int, + runtime_shape: Optional[int] = None) -> Optional[Callable]: + if (runtime_shape, graph_index, self.compiler.name) not in self.cache: + return None + handle = self.cache[(runtime_shape, graph_index, self.compiler.name)] + compiled_graph = self.compiler.load(handle, graph, example_inputs, + graph_index, runtime_shape) + logger.debug( + "Directly load the %s-th graph for shape %s from %s via " + "handle %s", graph_index, str(runtime_shape), self.compiler.name, + handle) + return compiled_graph + + def compile(self, + graph: fx.GraphModule, + example_inputs, + additional_inductor_config, + compilation_config: CompilationConfig, + graph_index: int = 0, + num_graphs: int = 1, + runtime_shape: Optional[int] = None) -> Any: + if graph_index == 0: + # before compiling the first graph, record the start time + global compilation_start_time + compilation_start_time = time.time() + + compilation_counter.num_backend_compilations += 1 + + compiled_graph = None + + # try to load from the cache + compiled_graph = self.load(graph, example_inputs, graph_index, + runtime_shape) + if compiled_graph is not None: + if graph_index == 0: + # adds some info logging for the first graph + logger.info("Directly load the compiled graph for shape %s " + "from the cache", str(runtime_shape)) # noqa + return compiled_graph + + # no compiler cached the graph, or the cache is disabled, + # we need to compile it + compiled_graph, handle = self.compiler.compile( + graph, example_inputs, additional_inductor_config, runtime_shape) + + assert compiled_graph is not None, "Failed to compile the graph" + + # store the artifact in the cache + if handle is not None: + self.cache[(runtime_shape, graph_index, + self.compiler.name)] = handle + if graph_index == 0: + # adds some info logging for the first graph + logger.info("Cache the graph of shape %s for later use", + str(runtime_shape)) + logger.debug( + "store the %s-th graph for shape %s from %s via handle %s", + graph_index, str(runtime_shape), self.compiler.name, handle) + + # after compiling the last graph, record the end time + if graph_index == num_graphs - 1: + now = time.time() + elapsed = now - compilation_start_time + compilation_config.compilation_time += elapsed + if runtime_shape is None: + logger.info("Compiling a graph for general shape takes %.2f s", + elapsed) + else: + logger.info("Compiling a graph for shape %s takes %.2f s", + runtime_shape, elapsed) + + return compiled_graph + + +@dataclasses.dataclass +class SplitItem: + submod_name: str + graph_id: int + is_splitting_graph: bool + graph: fx.GraphModule + + +def split_graph(graph: fx.GraphModule, + ops: List[str]) -> Tuple[fx.GraphModule, List[SplitItem]]: + # split graph by ops + subgraph_id = 0 + node_to_subgraph_id = {} + split_op_graphs = [] + for node in graph.graph.nodes: + if node.op in ("output", "placeholder"): + continue + if node.op == 'call_function' and str(node.target) in ops: + subgraph_id += 1 + node_to_subgraph_id[node] = subgraph_id + split_op_graphs.append(subgraph_id) + subgraph_id += 1 + else: + node_to_subgraph_id[node] = subgraph_id + + # `keep_original_order` is important! + # otherwise pytorch might reorder the nodes and + # the semantics of the graph will change when we + # have mutations in the graph + split_gm = torch.fx.passes.split_module.split_module( + graph, + None, + lambda node: node_to_subgraph_id[node], + keep_original_order=True) + + outputs = [] + + names = [name for (name, module) in split_gm.named_modules()] + + for name in names: + if "." in name or name == "": + # recursive child module or the root module + continue + + module = getattr(split_gm, name) + + graph_id = int(name.replace("submod_", "")) + outputs.append( + SplitItem(name, graph_id, (graph_id in split_op_graphs), module)) + + # sort by intetger graph_id, rather than string name + outputs.sort(key=lambda x: x.graph_id) + + return split_gm, outputs + + +# we share the global graph pool among all the backends +global_graph_pool = None + +compilation_start_time = 0.0 + + +class PiecewiseCompileInterpreter(torch.fx.Interpreter): + """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`. + It runs the given graph with fake inputs, and compile some + submodules specified by `compile_submod_names` with the given + compilation configs. + + NOTE: the order in `compile_submod_names` matters, because + it will be used to determine the order of the compiled piecewise + graphs. The first graph will handle logging, and the last graph + has some special cudagraph output handling. + """ + + def __init__(self, module: torch.fx.GraphModule, + compile_submod_names: List[str], vllm_config: VllmConfig, + graph_pool, vllm_backend: "VllmBackend"): + super().__init__(module) + from torch._guards import detect_fake_mode + self.fake_mode = detect_fake_mode() + self.compile_submod_names = compile_submod_names + self.compilation_config = vllm_config.compilation_config + self.graph_pool = graph_pool + self.vllm_config = vllm_config + self.vllm_backend = vllm_backend + + def run(self, *args): + fake_args = [ + self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t + for t in args + ] + with self.fake_mode: + return super().run(*fake_args) + + def call_module(self, target: torch.fx.node.Target, + args: Tuple[torch.fx.node.Argument, + ...], kwargs: Dict[str, Any]) -> Any: + assert isinstance(target, str) + output = super().call_module(target, args, kwargs) + + if target in self.compile_submod_names: + index = self.compile_submod_names.index(target) + submod = self.fetch_attr(target) + sym_shape_indices = [ + i for i, x in enumerate(args) if isinstance(x, torch.SymInt) + ] + global compilation_start_time + compiled_graph_for_general_shape = self.vllm_backend.\ + compiler_manager.compile( + submod, + args, + self.compilation_config.inductor_compile_config, + self.compilation_config, + graph_index=index, + num_graphs=len(self.compile_submod_names), + runtime_shape=None) + + self.module.__dict__[target] = PiecewiseBackend( + submod, self.vllm_config, self.graph_pool, index, + len(self.compile_submod_names), sym_shape_indices, + compiled_graph_for_general_shape, self.vllm_backend) + + compilation_counter.num_piecewise_capturable_graphs_seen += 1 + + return output + + +class VllmBackend: + """The compilation backend for `torch.compile` with vLLM. + It is used for compilation level of `CompilationLevel.PIECEWISE`, + where we customize the compilation. + + The major work of this backend is to split the graph into + piecewise graphs, and pass them to the piecewise backend. + + This backend also adds the PostGradPassManager to Inductor config, + which handles the post-grad passes. + """ + + vllm_config: VllmConfig + compilation_config: CompilationConfig + graph_pool: Any + _called: bool = False + # the graph we compiled + graph: fx.GraphModule + # the stiching graph module for all the piecewise graphs + split_gm: fx.GraphModule + piecewise_graphs: List[SplitItem] + returned_callable: Callable + # Inductor passes to run on the graph pre-defunctionalization + post_grad_passes: Sequence[Callable] + sym_tensor_indices: List[int] + input_buffers: List[torch.Tensor] + compiler_manager: CompilerManager + + def __init__( + self, + vllm_config: VllmConfig, + ): + global global_graph_pool + if global_graph_pool is None: + global_graph_pool = torch.cuda.graph_pool_handle() + + # TODO: in the future, if we want to use multiple + # streams, it might not be safe to share a global pool. + # only investigate this when we use multiple streams + self.graph_pool = global_graph_pool + + # Passes to run on the graph post-grad. + self.post_grad_pass_manager = PostGradPassManager() + + self.sym_tensor_indices = [] + self.input_buffers = [] + + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config + + self.compiler_manager: CompilerManager = CompilerManager( + self.compilation_config.use_inductor) + + # `torch.compile` is JIT compiled, so we don't need to + # do anything here + + def configure_post_pass(self): + config = self.compilation_config + self.post_grad_pass_manager.configure(config.pass_config) + + # Post-grad custom passes are run using the post_grad_custom_post_pass + # hook. If a pass for that hook exists, add it to the pass manager. + inductor_config = config.inductor_compile_config + PASS_KEY = "post_grad_custom_post_pass" + if PASS_KEY in inductor_config: + # Config should automatically wrap all inductor passes + assert isinstance(inductor_config[PASS_KEY], InductorPass) + self.post_grad_pass_manager.add(inductor_config[PASS_KEY]) + inductor_config[PASS_KEY] = self.post_grad_pass_manager + + def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: + + vllm_config = self.vllm_config + if not self.compilation_config.cache_dir: + # no provided cache dir, generate one based on the known factors + # that affects the compilation. if none of the factors change, + # the cache dir will be the same so that we can reuse the compiled + # graph. + + factors = [] + # 0. factors come from the env, for example, The values of + # VLLM_PP_LAYER_PARTITION will affects the computation graph. + env_hash = envs.compute_hash() + factors.append(env_hash) + + # 1. factors come from the vllm_config (it mainly summarizes how the + # model is created) + config_hash = vllm_config.compute_hash() + factors.append(config_hash) + + # 2. factors come from the code files that are traced by Dynamo ( + # it mainly summarizes how the model is used in forward pass) + forward_code_files = list( + sorted(self.compilation_config.traced_files)) + self.compilation_config.traced_files.clear() + logger.debug( + "Traced files (to be considered for compilation cache):\n%s", + "\n".join(forward_code_files)) + hash_content = [] + for filepath in forward_code_files: + hash_content.append(filepath) + with open(filepath) as f: + hash_content.append(f.read()) + import hashlib + code_hash = hashlib.md5("\n".join(hash_content).encode(), + usedforsecurity=False).hexdigest() + factors.append(code_hash) + + # 3. compiler hash + compiler_hash = self.compiler_manager.compute_hash(vllm_config) + factors.append(compiler_hash) + + # combine all factors to generate the cache dir + hash_key = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest()[:10] + + cache_dir = os.path.join( + envs.VLLM_CACHE_ROOT, + "torch_compile_cache", + hash_key, + ) + self.compilation_config.cache_dir = cache_dir + + cache_dir = self.compilation_config.cache_dir + os.makedirs(cache_dir, exist_ok=True) + rank = vllm_config.parallel_config.rank + dp_rank = vllm_config.parallel_config.data_parallel_rank + local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}") + os.makedirs(local_cache_dir, exist_ok=True) + self.compilation_config.local_cache_dir = local_cache_dir + + disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE + + if disable_cache: + logger.info("vLLM's torch.compile cache is disabled.") + else: + logger.info("Using cache directory: %s for vLLM's torch.compile", + local_cache_dir) + + self.compiler_manager.initialize_cache(local_cache_dir, disable_cache) + + # when dynamo calls the backend, it means the bytecode + # transform and analysis are done + compilation_counter.num_graphs_seen += 1 + from .monitor import torch_compile_start_time + dynamo_time = time.time() - torch_compile_start_time + logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time) + self.compilation_config.compilation_time += dynamo_time + + # we control the compilation process, each instance can only be + # called once + assert not self._called, "VllmBackend can only be called once" + + self.graph = graph + self.configure_post_pass() + + self.split_gm, self.piecewise_graphs = split_graph( + graph, self.compilation_config.splitting_ops) + + from torch._dynamo.utils import lazy_format_graph_code + + # depyf will hook lazy_format_graph_code and dump the graph + # for debugging, no need to print the graph here + lazy_format_graph_code("before split", self.graph) + lazy_format_graph_code("after split", self.split_gm) + + compilation_counter.num_piecewise_graphs_seen += len( + self.piecewise_graphs) + submod_names_to_compile = [ + item.submod_name for item in self.piecewise_graphs + if not item.is_splitting_graph + ] + + # propagate the split graph to the piecewise backend, + # compile submodules with symbolic shapes + PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile, + self.vllm_config, self.graph_pool, + self).run(*example_inputs) + + graph_path = os.path.join(local_cache_dir, "computation_graph.py") + if not os.path.exists(graph_path): + # code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa + # use `print_readable` because it can include submodules + src = "from __future__ import annotations\nimport torch\n" + \ + self.split_gm.print_readable(print_output=False) + src = src.replace("", "GraphModule") + with open(graph_path, "w") as f: + f.write(src) + + logger.debug("Computation graph saved to %s", graph_path) + + self._called = True + + if not self.compilation_config.use_cudagraph or \ + not self.compilation_config.cudagraph_copy_inputs: + return self.split_gm + + # if we need to copy input buffers for cudagraph + from torch._guards import detect_fake_mode + fake_mode = detect_fake_mode() + fake_args = [ + fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t + for t in example_inputs + ] + + # index of tensors that have symbolic shapes (batch size) + # for weights and static buffers, they will have concrete shapes. + # symbolic shape only happens for input tensors. + from torch.fx.experimental.symbolic_shapes import is_symbolic + self.sym_tensor_indices = [ + i for i, x in enumerate(fake_args) + if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) and \ + any(is_symbolic(d) for d in x.size()) + ] + + # compiler managed cudagraph input buffers + # we assume the first run with symbolic shapes + # has the maximum size among all the tensors + self.input_buffers = [ + example_inputs[x].clone() for x in self.sym_tensor_indices + ] + + # this is the callable we return to Dynamo to run + def copy_and_call(*args): + list_args = list(args) + for i, index in enumerate(self.sym_tensor_indices): + runtime_tensor = list_args[index] + runtime_shape = runtime_tensor.shape[0] + static_tensor = self.input_buffers[i][:runtime_shape] + + # copy the tensor to the static buffer + static_tensor.copy_(runtime_tensor) + + # replace the tensor in the list_args to the static buffer + list_args[index] = static_tensor + return self.split_gm(*list_args) + + return copy_and_call + + +@dataclasses.dataclass +class ConcreteSizeEntry: + runtime_shape: int + need_to_compile: bool # the size is in compile_sizes + use_cudagraph: bool # the size is in cudagraph_capture_sizes + + compiled: bool = False + runnable: Callable = None # type: ignore + num_finished_warmup: int = 0 + cudagraph: Optional[torch.cuda.CUDAGraph] = None + output: Optional[Any] = None + + # for cudagraph debugging, track the input addresses + # during capture, and check if they are the same during replay + input_addresses: Optional[List[int]] = None + + +class PiecewiseBackend: + + def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, + graph_pool: Any, piecewise_compile_index: int, + total_piecewise_compiles: int, sym_shape_indices: List[int], + compiled_graph_for_general_shape: Callable, + vllm_backend: VllmBackend): + """ + The backend for piecewise compilation. + It mainly handles the compilation and cudagraph capturing. + + We will compile `self.graph` once for the general shape, + and then compile for different shapes specified in + `compilation_config.compile_sizes`. + + Independently, we will capture cudagraph for different shapes. + + If a shape needs both compilation and cudagraph, we will + compile it first, and then capture cudagraph. + """ + self.graph = graph + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config + self.graph_pool = graph_pool + self.piecewise_compile_index = piecewise_compile_index + self.total_piecewise_compiles = total_piecewise_compiles + self.vllm_backend = vllm_backend + + self.is_first_graph = piecewise_compile_index == 0 + self.is_last_graph = ( + piecewise_compile_index == total_piecewise_compiles - 1) + + self.compile_sizes: Set[int] = set( + self.compilation_config.compile_sizes) + self.cudagraph_capture_sizes: Set[int] = set( + self.compilation_config.cudagraph_capture_sizes + ) if self.compilation_config.use_cudagraph else set() + + self.first_run_finished = False + + self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa + + self.sym_shape_indices = sym_shape_indices + + self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" + + # the entries for different shapes that we need to either + # compile or capture cudagraph + self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {} + + # to_be_compiled_sizes tracks the remaining sizes to compile, + # and updates during the compilation process, so we need to copy it + self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy() + for shape in self.compile_sizes.union(self.cudagraph_capture_sizes): + self.concrete_size_entries[shape] = ConcreteSizeEntry( + runtime_shape=shape, + need_to_compile=shape in self.compile_sizes, + use_cudagraph=shape in self.cudagraph_capture_sizes, + ) + + def check_for_ending_compilation(self): + if self.is_last_graph and not self.to_be_compiled_sizes: + # no specific sizes to compile + # save the hash of the inductor graph for the next run + self.vllm_backend.compiler_manager.save_to_file() + end_monitoring_torch_compile(self.vllm_config) + + def __call__(self, *args) -> Any: + if not self.first_run_finished: + self.first_run_finished = True + self.check_for_ending_compilation() + return self.compiled_graph_for_general_shape(*args) + + runtime_shape = args[self.sym_shape_indices[0]] + if runtime_shape not in self.concrete_size_entries: + # we don't need to do anything for this shape + return self.compiled_graph_for_general_shape(*args) + + entry = self.concrete_size_entries[runtime_shape] + + if entry.runnable is None: + entry.runnable = self.compiled_graph_for_general_shape + + if entry.need_to_compile and not entry.compiled: + entry.compiled = True + self.to_be_compiled_sizes.remove(runtime_shape) + # args are real arguments + entry.runnable = self.vllm_backend.compiler_manager.compile( + self.graph, + args, + self.compilation_config.inductor_compile_config, + self.compilation_config, + graph_index=self.piecewise_compile_index, + num_graphs=self.total_piecewise_compiles, + runtime_shape=runtime_shape) + + # finished compilations for all required shapes + if self.is_last_graph and not self.to_be_compiled_sizes: + self.check_for_ending_compilation() + + if not entry.use_cudagraph: + return entry.runnable(*args) + + if entry.cudagraph is None: + if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa + entry.num_finished_warmup += 1 + if self.is_first_graph: + logger.debug( + "Warming up %s/%s for shape %s", + entry.num_finished_warmup, + self.compilation_config.cudagraph_num_of_warmups, + runtime_shape) + return entry.runnable(*args) + + if self.is_first_graph: + # Since we capture cudagraph for many different shapes and + # capturing is fast, we don't need to log it for every shape. + # We only log it in the debug mode. + logger.debug("Capturing a cudagraph for shape %s", + runtime_shape) + + input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + entry.input_addresses = input_addresses + cudagraph = torch.cuda.CUDAGraph() + + with ExitStack() as stack: + if not self.is_first_graph: + # during every model forward, we will capture + # many pieces of cudagraphs (roughly one per layer). + # running gc again and again across layers will + # make the cudagraph capture very slow. + # therefore, we only run gc for the first graph, + # and disable gc for the rest of the graphs. + stack.enter_context(patch("gc.collect", lambda: None)) + stack.enter_context( + patch("torch.cuda.empty_cache", lambda: None)) + + # mind-exploding: carefully manage the reference and memory. + with torch.cuda.graph(cudagraph, pool=self.graph_pool): + # `output` is managed by pytorch's cudagraph pool + output = entry.runnable(*args) + if self.is_last_graph: + # by converting it to weak ref, + # the original `output` will immediately be released + # to save memory. It is only safe to do this for + # the last graph, because the output of the last graph + # will not be used by any other cuda graph. + output = weak_ref_tensors(output) + + # here we always use weak ref for the output + # to save memory + entry.output = weak_ref_tensors(output) + entry.cudagraph = cudagraph + + compilation_counter.num_cudagraph_caputured += 1 + + # important: we need to return the output, rather than + # the weak ref of the output, so that pytorch can correctly + # manage the memory during cuda graph capture + return output + + if self.is_debugging_mode: + # check if the input addresses are the same + new_input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + assert new_input_addresses == entry.input_addresses, ( + "Input addresses for cudagraphs are different during replay." + f" Expected {entry.input_addresses}, got {new_input_addresses}" + ) + + entry.cudagraph.replay() + return entry.output diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py new file mode 100644 index 0000000..5a22cf7 --- /dev/null +++ b/vllm/compilation/compiler_interface.py @@ -0,0 +1,401 @@ +# SPDX-License-Identifier: Apache-2.0 +import contextlib +import copy +import hashlib +import importlib.metadata +import os +from contextlib import ExitStack +from typing import Any, Callable, Dict, List, Optional, Tuple +from unittest.mock import patch + +import torch +import torch._inductor.compile_fx +import torch.fx as fx +from packaging.version import Version + +from vllm.config import VllmConfig + + +class CompilerInterface: + """ + The interface for a compiler that can be used by vLLM. + """ + # The name of the compiler, e.g. inductor. + # This is a class-level attribute. + name: str + + def initialize_cache(self, cache_dir: str, disable_cache: bool = False): + """ + when the vLLM process uses `cache_dir` as the cache directory, + the compiler should initialize itself with the cache directory, + e.g. by re-directing its own cache directory to a sub-directory. + """ + pass + + def compute_hash(self, vllm_config: VllmConfig) -> str: + """ + Gather all the relevant information from the vLLM config, + to compute a hash so that we can cache the compiled model. + + See :meth:`VllmConfig.compute_hash` to check what information + is already considered by default. This function should only + consider the information that is specific to the compiler. + """ + return "" + + def compile( + self, + graph: fx.GraphModule, + example_inputs: List[Any], + compiler_config: Dict[str, Any], + runtime_shape: Optional[int] = None + ) -> Tuple[Optional[Callable], Optional[Any]]: + """ + Compile the graph with the given example inputs and compiler config, + with a runtime shape. If the `runtime_shape` is None, it means + the `example_inputs` have a dynamic shape. Otherwise, the + `runtime_shape` specifies the shape of the inputs. Right now we only + support one variable shape for all inputs, which is the batchsize + (number of tokens) during inference. + + Dynamo will make sure `graph(*example_inputs)` is valid. + + The function should return a compiled callable function, as well as + a handle that can be used to directly load the compiled function. + + The handle should be a plain Python object, preferably a string or a + file path for readability. + + If the compiler doesn't support caching, it should return None for the + handle. If the compiler fails to compile the graph, it should return + None for the compiled function as well. + """ + return None, None + + def load(self, + handle: Any, + graph: fx.GraphModule, + example_inputs: List[Any], + graph_index: int, + runtime_shape: Optional[int] = None) -> Callable: + """ + Load the compiled function from the handle. + Raises an error if the handle is invalid. + + The handle is the second return value of the `compile` function. + """ + raise NotImplementedError("caching is not supported") + + +class AlwaysHitShapeEnv: + """ + Why do we need this class: + + For normal `torch.compile` usage, every compilation will have + one Dynamo bytecode compilation and one Inductor compilation. + The Inductor compilation happens under the context of the + Dynamo bytecode compilation, and that context is used to + determine the dynamic shape information, etc. + + For our use case, we only run Dynamo bytecode compilation once, + and run Inductor compilation multiple times with different shapes + plus a general shape. The compilation for specific shapes happens + outside of the context of the Dynamo bytecode compilation. At that + time, we don't have shape environment to provide to Inductor, and + it will fail the Inductor code cache lookup. + + By providing a dummy shape environment that always hits, we can + make the Inductor code cache lookup always hit, and we can + compile the graph for different shapes as needed. + + The following dummy methods are obtained by trial-and-error + until it works. + """ + + def __init__(self) -> None: + self.guards: List[Any] = [] + + def evaluate_guards_expression(self, *args, **kwargs): + return True + + def get_pruned_guards(self, *args, **kwargs): + return [] + + def produce_guards_expression(self, *args, **kwargs): + return "" + + +class InductorAdaptor(CompilerInterface): + """ + The adaptor for the Inductor compiler, version 2.5 and 2.6. + """ + name = "inductor" + + def compute_hash(self, vllm_config: VllmConfig) -> str: + factors: List[Any] = [] + # summarize system state + from torch._inductor.codecache import CacheBase + system_factors = CacheBase.get_system() + factors.append(system_factors) + + # summarize pytorch state + from torch._inductor.codecache import torch_key + torch_factors = torch_key() + factors.append(torch_factors) + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest()[:10] + return hash_str + + def initialize_cache(self, cache_dir: str, disable_cache: bool = False): + self.cache_dir = cache_dir + if disable_cache: + return + # redirect the cache directory to a sub-directory + # set flags so that Inductor and Triton store their cache + # in the cache_dir, then users only need to copy the cache_dir + # to another machine to reuse the cache. + inductor_cache = os.path.join(cache_dir, "inductor_cache") + os.makedirs(inductor_cache, exist_ok=True) + os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache + triton_cache = os.path.join(cache_dir, "triton_cache") + os.makedirs(triton_cache, exist_ok=True) + os.environ["TRITON_CACHE_DIR"] = triton_cache + + def compile( + self, + graph: fx.GraphModule, + example_inputs: List[Any], + compiler_config: Dict[str, Any], + runtime_shape: Optional[int] = None + ) -> Tuple[Optional[Callable], Optional[Any]]: + from torch._inductor import config + current_config = config.get_config_copy() + from torch._inductor.compile_fx import compile_fx + + # disable remote cache + current_config["fx_graph_cache"] = True + current_config["fx_graph_remote_cache"] = False + + if compiler_config is not None: + current_config.update(compiler_config) + + if isinstance(runtime_shape, int): + # for a specific batchsize, tuning triton kernel parameters + # can be beneficial + current_config["max_autotune"] = True + current_config["coordinate_descent_tuning"] = True + + # inductor can inplace modify the graph, so we need to copy it + # see https://github.com/pytorch/pytorch/issues/138980 + graph = copy.deepcopy(graph) + + # it's the first time we compile this graph + # the assumption is that we don't have nested Inductor compilation. + # compiled_fx_graph_hash will only be called once, and we can hook + # it to get the hash of the compiled graph directly. + + hash_str, file_path = None, None + from torch._inductor.codecache import (FxGraphCache, + compiled_fx_graph_hash) + + if torch.__version__.startswith("2.5"): + original_load = FxGraphCache.load + original_load_name = "torch._inductor.codecache.FxGraphCache.load" + + def hijack_load(*args, **kwargs): + inductor_compiled_graph = original_load(*args, **kwargs) + nonlocal file_path + compiled_fn = inductor_compiled_graph.current_callable + file_path = compiled_fn.__code__.co_filename # noqa + if not file_path.startswith(self.cache_dir): + # hooked in the align_inputs_from_check_idxs function + # in torch/_inductor/utils.py + for cell in compiled_fn.__closure__: + if not callable(cell.cell_contents): + continue + if cell.cell_contents.__code__.co_filename.startswith( + self.cache_dir): + # this is the real file path compiled from Inductor + file_path = cell.cell_contents.__code__.co_filename + break + return inductor_compiled_graph + + hijacked_compile_fx_inner = torch._inductor.compile_fx.compile_fx_inner # noqa + elif torch.__version__ >= "2.6": + # function renamed in 2.6 + original_load_name = None + + def hijacked_compile_fx_inner(*args, **kwargs): + output = torch._inductor.compile_fx.compile_fx_inner( + *args, **kwargs) + nonlocal hash_str + inductor_compiled_graph = output + if inductor_compiled_graph is not None: + nonlocal file_path + compiled_fn = inductor_compiled_graph.current_callable + file_path = compiled_fn.__code__.co_filename # noqa + if not file_path.startswith(self.cache_dir): + # hooked in the align_inputs_from_check_idxs function + # in torch/_inductor/utils.py + for cell in compiled_fn.__closure__: + if not callable(cell.cell_contents): + continue + code = cell.cell_contents.__code__ + if code.co_filename.startswith(self.cache_dir): + # this is the real file path + # compiled from Inductor + file_path = code.co_filename + break + hash_str = inductor_compiled_graph._fx_graph_cache_key + return output + + def hijack_compiled_fx_graph_hash(*args, **kwargs): + out = compiled_fx_graph_hash(*args, **kwargs) + nonlocal hash_str + hash_str = out[0] + return out + + def _check_can_cache(*args, **kwargs): + # no error means it can be cached. + # Inductor refuses to cache the graph outside of Dynamo + # tracing context, and also disables caching for graphs + # with high-order ops. + # For vLLM, in either case, we want to cache the graph. + # see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa + return + + def _get_shape_env() -> AlwaysHitShapeEnv: + return AlwaysHitShapeEnv() + + with ExitStack() as stack: + # hijack to get the compiled graph itself + if original_load_name is not None: + stack.enter_context(patch(original_load_name, hijack_load)) + + # for hijacking the hash of the compiled graph + stack.enter_context( + patch("torch._inductor.codecache.compiled_fx_graph_hash", + hijack_compiled_fx_graph_hash)) + + # for providing a dummy shape environment + stack.enter_context( + patch("torch._inductor.codecache.FxGraphCache._get_shape_env", + _get_shape_env)) + + # for forcing the graph to be cached + stack.enter_context( + patch( + "torch._inductor.codecache.FxGraphCache._check_can_cache", + _check_can_cache)) + + # Dynamo metrics context, see method for more details. + stack.enter_context(self.metrics_context()) + + compiled_graph = compile_fx( + graph, + example_inputs, + inner_compile=hijacked_compile_fx_inner, + config_patches=current_config) + + assert hash_str is not None, ( + "failed to get the hash of the compiled graph") + assert file_path is not None, ( + "failed to get the file path of the compiled graph") + return compiled_graph, (hash_str, file_path) + + def load(self, + handle: Any, + graph: fx.GraphModule, + example_inputs: List[Any], + graph_index: int, + runtime_shape: Optional[int] = None) -> Callable: + assert isinstance(handle, tuple) + assert isinstance(handle[0], str) + assert isinstance(handle[1], str) + hash_str = handle[0] + + from torch._inductor.codecache import FxGraphCache + with ExitStack() as exit_stack: + exit_stack.enter_context( + patch("torch._inductor.codecache.FxGraphCache._get_shape_env", + lambda *args, **kwargs: AlwaysHitShapeEnv())) + + # Dynamo metrics context, see method for more details. + exit_stack.enter_context(self.metrics_context()) + + if torch.__version__.startswith("2.5"): + inductor_compiled_graph = FxGraphCache._lookup_graph( + hash_str, example_inputs, True, False) + assert inductor_compiled_graph is not None, ( + "Inductor cache lookup failed. Please remove" + f"the cache directory and try again." # noqa + ) + elif torch.__version__ >= "2.6": + from torch._inductor.output_code import ( + CompiledFxGraphConstantsWithGm) + constants = CompiledFxGraphConstantsWithGm(graph) + inductor_compiled_graph, _ = FxGraphCache._lookup_graph( + hash_str, example_inputs, True, None, constants) + assert inductor_compiled_graph is not None, ( + "Inductor cache lookup failed. Please remove" + f"the cache directory and try again." # noqa + ) + + # Inductor calling convention (function signature): + # f(list) -> tuple + # Dynamo calling convention (function signature): + # f(*args) -> Any + + # need to know if the graph returns a tuple + from torch._inductor.compile_fx import graph_returns_tuple + returns_tuple = graph_returns_tuple(graph) + + # this is the callable we return to Dynamo to run + def compiled_graph(*args): + # convert args to list + list_args = list(args) + graph_output = inductor_compiled_graph(list_args) + # unpack the tuple if needed + if returns_tuple: + return graph_output + else: + return graph_output[0] + + return compiled_graph + + def metrics_context(self) -> contextlib.AbstractContextManager: + """ + This method returns the Dynamo metrics context (if it exists, + otherwise a null context). It is used by various compile components. + Present in torch>=2.6, it's used inside FxGraphCache in + torch==2.6 (but not after). It might also be used in various other + torch.compile internal functions. + + Because it is re-entrant, we always set it (even if entering via Dynamo + and the context was already entered). We might want to revisit if it + should be set at a different level of compilation. + + This is likely a bug in PyTorch: public APIs should not rely on + manually setting up internal contexts. But we also rely on non-public + APIs which might not provide these guarantees. + """ + if Version(importlib.metadata.version('torch')) >= Version("2.6"): + import torch._dynamo.utils + return torch._dynamo.utils.get_metrics_context() + else: + return contextlib.nullcontext() + + +class EagerAdaptor(CompilerInterface): + name = "eager" + + def compile( + self, + graph: fx.GraphModule, + example_inputs: List[Any], + compiler_config: Dict[str, Any], + runtime_shape: Optional[int] = None + ) -> Tuple[Optional[Callable], Optional[Any]]: + # we don't need to compile the graph, just return the graph itself. + # It does not support caching, return None for the handle. + return graph, None diff --git a/vllm/compilation/counter.py b/vllm/compilation/counter.py new file mode 100644 index 0000000..5be4525 --- /dev/null +++ b/vllm/compilation/counter.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 + +import copy +import dataclasses +from contextlib import contextmanager + + +@dataclasses.dataclass +class CompilationCounter: + num_models_seen: int = 0 + num_graphs_seen: int = 0 + # including the splitting ops + num_piecewise_graphs_seen: int = 0 + # not including the splitting ops + num_piecewise_capturable_graphs_seen: int = 0 + num_backend_compilations: int = 0 + num_cudagraph_caputured: int = 0 + + def clone(self) -> "CompilationCounter": + return copy.deepcopy(self) + + @contextmanager + def expect(self, **kwargs): + old = self.clone() + yield + for k, v in kwargs.items(): + assert getattr(self, k) - getattr(old, k) == v, ( + f"{k} not as expected, before it is {getattr(old, k)}" + f", after it is {getattr(self, k)}, " + f"expected diff is {v}") + + +compilation_counter = CompilationCounter() diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py new file mode 100644 index 0000000..20afe69 --- /dev/null +++ b/vllm/compilation/decorators.py @@ -0,0 +1,249 @@ +# SPDX-License-Identifier: Apache-2.0 + +import inspect +from typing import Callable, Dict, List, Optional, TypeVar, Union, overload +from unittest.mock import patch + +import torch +import torch.nn as nn +from torch._dynamo.symbolic_convert import InliningInstructionTranslator + +from vllm.compilation.counter import compilation_counter +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher +from vllm.config import CompilationLevel, VllmConfig +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors +from vllm.utils import supports_dynamo + +from .monitor import start_monitoring_torch_compile + +logger = init_logger(__name__) + +_T = TypeVar("_T", bound=type[nn.Module]) + + +@overload +def support_torch_compile( + *, + dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]], +) -> Callable[[_T], _T]: + ... + + +@overload +def support_torch_compile(cls: _T) -> _T: + ... + + +def support_torch_compile( + cls: Optional[_T] = None, + *, + dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None, +) -> Union[Callable[[_T], _T], _T]: + """ + A decorator to add support for compiling the forward method of a class. + + Usage 1: use directly as a decorator without arguments: + + ```python + @support_torch_compile + class MyModel(nn.Module): + def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): + ... + ``` + + Usage 2: use as a decorator with arguments: + + ```python + @support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0}) + class MyModel(nn.Module): + def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): + ... + ``` + + `dynamic_arg_dims` is a dictionary that maps argument names to the dynamic + dimensions of the argument. The dynamic dimensions can be either a single + integer or a list of integers. + + if `dynamic_arg_dims` is `None`, it is inferred from the type annotation + of the `forward` method, based on the following default rules: + + - if the argument is annotated as `torch.Tensor` or + `Optional[torch.Tensor]`, the first dimension will be + marked as dynamic. + - if the argument is annotated as `IntermediateTensors`, the first + dimension of all the tensors in the intermediate tensors + will be marked as dynamic. + + During runtime, when we actually mark dimensions of tensors, + it depends on the value of arguments: + + - if it is a single integer (can be negative), the corresponding dimension + of the argument will be marked as dynamic. + - if it is `None`, ignored. + - if it is `IntermediateTensors`, all the tensors in the intermediate + tensors will be marked as dynamic. + - otherwise, it will raise an error. + + NOTE: if an argument is `None`, it should always be passed as `None` during + the lifetime of the model, otherwise, it cannot be captured as a single + computation graph. + """ + + def cls_decorator_helper(cls: _T) -> _T: + # helper to pass `dynamic_arg_dims`` to `_support_torch_compile`` + # to avoid too much indentation for `_support_torch_compile`` + if not hasattr(cls, 'forward'): + raise TypeError("decorated class should have a forward method.") + sig = inspect.signature(cls.forward) + inferred_dynamic_arg_dims = dynamic_arg_dims + if inferred_dynamic_arg_dims is None: + inferred_dynamic_arg_dims = {} + for k, v in sig.parameters.items(): + if v.annotation in [ + torch.Tensor, Optional[torch.Tensor], + IntermediateTensors, Optional[IntermediateTensors] + ]: + inferred_dynamic_arg_dims[k] = 0 + + logger.debug(("Inferred dynamic dimensions for " + "forward method of %s: %s"), cls, + list(inferred_dynamic_arg_dims.keys())) + + if len(inferred_dynamic_arg_dims) == 0: + raise ValueError( + "No dynamic dimensions found in the forward method of " + f"{cls}. Please provide dynamic_arg_dims explicitly.") + + for k in inferred_dynamic_arg_dims: + if k not in sig.parameters: + raise ValueError( + f"Argument {k} not found in the forward method of {cls}") + return _support_torch_compile(cls, inferred_dynamic_arg_dims) + + if cls is not None: + # use `support_torch_compile` as a decorator without arguments + assert isinstance(cls, type) + return cls_decorator_helper(cls) + + return cls_decorator_helper + + +def _support_torch_compile( + cls: _T, + dynamic_arg_dims: Dict[str, Union[int, List[int]]], +) -> _T: + """ + A decorator to add support for compiling the forward method of a class. + """ + if TorchCompileWrapperWithCustomDispatcher in cls.__bases__: + # support decorating multiple times + return cls + + # take care of method resolution order + # make sure super().__init__ is called on the base class + # other than TorchCompileWrapperWithCustomDispatcher + cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, ) + + old_init = cls.__init__ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): + old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) + self.vllm_config = vllm_config + # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner + # will handle the compilation, so we don't need to do anything here. + self.do_not_compile = \ + vllm_config.compilation_config.level in [ + CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS + ] or not supports_dynamo() + if self.do_not_compile: + return + compilation_counter.num_models_seen += 1 + TorchCompileWrapperWithCustomDispatcher.__init__( + self, compilation_level=vllm_config.compilation_config.level) + + cls.__init__ = __init__ + + def __call__(self, *args, **kwargs): + # torch.compiler.is_compiling() means we are inside the compilation + # e.g. TPU has the compilation logic in model runner, so we don't + # need to compile the model inside. + if self.do_not_compile or torch.compiler.is_compiling(): + return self.forward(*args, **kwargs) + + # the first compilation needs to have dynamic shapes marked + if len(self.compiled_codes) < 1: + sig = inspect.signature(self.__class__.forward) + bound_args = sig.bind(self, *args, **kwargs) + bound_args.apply_defaults() + for k, dims in dynamic_arg_dims.items(): + arg = bound_args.arguments.get(k) + if arg is not None: + dims = [dims] if isinstance(dims, int) else dims + if isinstance(arg, torch.Tensor): + # In case dims is specified with negative indexing + dims = [ + arg.ndim + dim if dim < 0 else dim for dim in dims + ] + torch._dynamo.mark_dynamic(arg, dims) + elif isinstance(arg, IntermediateTensors): + for tensor in arg.tensors.values(): + # In case dims is specified with negative indexing + dims = [ + tensor.ndim + dim if dim < 0 else dim + for dim in dims + ] + torch._dynamo.mark_dynamic(tensor, dims) + else: + raise ValueError( + "Unsupported dynamic dimensions" + f" {dims} for argument {k} with type {type(arg)}.") + # here, it is the starting point of the `torch.compile` process + start_monitoring_torch_compile(self.vllm_config) + logger.debug("Start compiling function %s", + self.original_code_object) + + # if we don't use custom dispatcher, we can directly call the + # compiled function and let torch.compile handle the dispatching, + # with the overhead of guard evaluation and recompilation. + if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher: + # it seems Dynamo reuse the compilation across instances, + # while we need to make sure the compiled code is not reused. + # we need to control all the compilation of the model. + torch._dynamo.eval_frame.remove_from_cache( + self.original_code_object) + + # collect all relevant files traced by Dynamo, + # so that the compilation cache can trigger re-compilation + # properly when any of these files change. + + # 1. the file containing the top-level forward function + self.vllm_config.compilation_config.traced_files.add( + self.original_code_object.co_filename) + + # 2. every time Dynamo sees a function call, it will inline + # the function by calling InliningInstructionTranslator.inline_call + # we hijack this function to know all the functions called + # during Dynamo tracing, and their corresponding files + inline_call = InliningInstructionTranslator.inline_call + + def patched_inline_call(parent, func, args, kwargs): + code = func.get_code() + self.vllm_config.compilation_config.traced_files.add( + code.co_filename) + return inline_call(parent, func, args, kwargs) + + with patch.object(InliningInstructionTranslator, 'inline_call', + patched_inline_call): + output = self.compiled_callable(*args, **kwargs) + return output + + # usually, capturing the model once is enough, and then we can + # dispatch to the compiled code directly, without going through + # the Dynamo guard mechanism. + with self.dispatch_to_code(0): + model_output = self.forward(*args, **kwargs) + return model_output + + cls.__call__ = __call__ + return cls diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py new file mode 100644 index 0000000..9b0e9c5 --- /dev/null +++ b/vllm/compilation/fix_functionalization.py @@ -0,0 +1,182 @@ +# SPDX-License-Identifier: Apache-2.0 + +import operator +from typing import Dict, Iterable, List, Optional, Tuple, Union + +import torch +from torch._higher_order_ops.auto_functionalize import auto_functionalized + +from vllm.logger import init_logger + +from .fx_utils import is_func +from .vllm_inductor_pass import VllmInductorPass + +logger = init_logger(__name__) + + +class FixFunctionalizationPass(VllmInductorPass): + """ + This pass defunctionalizes certain nodes to avoid redundant tensor copies. + After this pass, DCE (dead-code elimination) should never be run, + as de-functionalized nodes may appear as dead code. + + To add new nodes to defunctionalize, add to the if-elif chain in __call__. + """ + + def __call__(self, graph: torch.fx.Graph): + self.begin() + self.dump_graph(graph, "before_fix_functionalization") + + self.nodes_to_remove: List[torch.fx.Node] = [] + count = 0 + for node in graph.nodes: + if not is_func(node, auto_functionalized): + continue # Avoid deep if-elif nesting + + kwargs = node.kwargs + at_target = node.args[0] + + if at_target == torch.ops._C.rotary_embedding.default: + query = kwargs['query'] + mm_node = query.args[0].args[0] + + # rotary_embedding is a special case: the two mutating inputs + # are query and key, which are slices of mm_node. + # While functionalized, results at[1] and at[2] are scattered + # back into mm_node. After de-functionalization, we can just + # use mm_node directly. + for idx, user in self.getitem_users(node).items(): + for user_of_getitem in user.users: + if is_func(user_of_getitem, + torch.ops.aten.slice_scatter.default): + user_of_getitem.replace_all_uses_with(mm_node) + self._remove(user_of_getitem) + self._remove(user) + + self.insert_defunctionalized(graph, node) + self._remove(node) + + # rms_norm replacements avoid the most copies for LLaMa. + elif at_target == torch.ops._C.fused_add_rms_norm.default: + mutated_args = {1: 'input', 2: 'residual'} + self.defunctionalize(graph, node, mutated_args) + elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501 + mutated_args = {1: 'result', 2: 'residual'} + self.defunctionalize(graph, node, mutated_args) + elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501 + mutated_args = {1: 'result', 2: 'scale', 3: 'residual'} + self.defunctionalize(graph, node, mutated_args) + elif at_target in [ + torch.ops._C.rms_norm.default, + torch.ops._C.rms_norm_static_fp8_quant.default + ]: + mutated_args = {1: 'result'} + self.defunctionalize(graph, node, mutated_args) + + elif at_target == torch.ops._C.silu_and_mul.default: + mutated_args = {1: 'out'} + # Because we have an 'out', need to specify args directly + self.defunctionalize(graph, + node, + mutated_args, + args=('out', 'input')) + else: + continue # skip the count + + count += 1 + + self.dump_graph(graph, "before_fix_functionalization_cleanup") + + # Remove the nodes all at once + count_removed = len(self.nodes_to_remove) + for node in self.nodes_to_remove: + graph.erase_node(node) + + logger.debug("De-functionalized %s nodes, removed %s nodes", count, + count_removed) + self.dump_graph(graph, "after_fix_functionalization") + self.end_and_log() + + def _remove(self, node_or_nodes: Union[torch.fx.Node, + Iterable[torch.fx.Node]]): + """ + Stage a node (or nodes) for removal at the end of the pass. + """ + if isinstance(node_or_nodes, torch.fx.Node): + self.nodes_to_remove.append(node_or_nodes) + else: + self.nodes_to_remove.extend(node_or_nodes) + + def defunctionalize(self, + graph: torch.fx.Graph, + node: torch.fx.Node, + mutated_args: Dict[int, Union[torch.fx.Node, str]], + args: Optional[Tuple[Union[torch.fx.Node, str], + ...]] = None): + """ + De-functionalize a node by replacing it with a call to the original. + It also replaces the getitem users with the mutated arguments. + See replace_users_with_mutated_args and insert_defunctionalized. + """ + self.replace_users_with_mutated_args(node, mutated_args) + self.insert_defunctionalized(graph, node, args=args) + self._remove(node) + + def replace_users_with_mutated_args(self, node: torch.fx.Node, + mutated_args: Dict[int, + Union[torch.fx.Node, + str]]): + """ + Replace all getitem users of the auto-functionalized node with the + mutated arguments. + :param node: The auto-functionalized node + :param mutated_args: The mutated arguments, indexed by getitem index. + If the value of an arg is a string, `node.kwargs[arg]` is used. + """ + for idx, user in self.getitem_users(node).items(): + arg = mutated_args[idx] + arg = node.kwargs[arg] if isinstance(arg, str) else arg + user.replace_all_uses_with(arg) + self._remove(user) + + def getitem_users(self, node: torch.fx.Node) -> Dict[int, torch.fx.Node]: + """ + Returns the operator.getitem users of the auto-functionalized node, + indexed by the index they are getting. + """ + users = {} + for user in node.users: + if is_func(user, operator.getitem): + idx = user.args[1] + users[idx] = user + return users + + def insert_defunctionalized(self, + graph: torch.fx.Graph, + node: torch.fx.Node, + args: Optional[Tuple[Union[torch.fx.Node, str], + ...]] = None): + """ + Insert a new defunctionalized node into the graph before node. + If one of the kwargs is 'out', provide args directly, + as node.kwargs cannot be used. + See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 + + :param graph: Graph to insert the defunctionalized node into + :param node: The auto-functionalized node to defunctionalize + :param args: If we cannot use kwargs, specify args directly. + If an arg is a string, `node.kwargs[arg]` is used. + """ # noqa: E501 + assert is_func(node, auto_functionalized), \ + f"node must be auto-functionalized, is {node} instead" + + # Create a new call to the original function + with graph.inserting_before(node): + function = node.args[0] + if args is None: + graph.call_function(function, kwargs=node.kwargs) + else: + # Args passed as strings refer to items in node.kwargs + args = tuple(node.kwargs[arg] if isinstance(arg, str) else arg + for arg in args) + graph.call_function(function, args=args) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py new file mode 100644 index 0000000..b46f5f5 --- /dev/null +++ b/vllm/compilation/fusion.py @@ -0,0 +1,617 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable, Dict, List, NamedTuple, Optional, Tuple + +import torch +import torch._inductor.pattern_matcher as pm +from torch import fx +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._inductor.pattern_matcher import PatternMatcherPass +from torch._ops import OpOverload + +from vllm.config import CompilationConfig +from vllm.logger import init_logger +from vllm.platforms import current_platform + +from .fx_utils import find_getitem_maybe +from .multi_output_match import MultiOutputMatch +from .vllm_inductor_pass import VllmInductorPass + +logger = init_logger(__name__) +FP8_DTYPE = current_platform.fp8_dtype() + + +def empty_bf16(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") + + +def empty_fp32(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda") + + +RMS_OP = torch.ops._C.rms_norm.default +RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default + + +class QuantKey(NamedTuple): + """ + Named tuple for identifying the type of quantization. + dtype: quantized data type + static: static quantization if True, dynamic if False + per_tensor: per-tensor quantization if True, per-token if False + symmetric: symmetric if True, asymmetric if False + """ + dtype: torch.dtype + static: bool + per_tensor: bool = True + symmetric: bool = True + + def __str__(self): + return (f"QuantKey({'static' if self.static else 'dynamic'}," + f"{fx.graph.dtype_abbrs[self.dtype]}," + f"{'per_tensor' if self.per_tensor else 'per_token'}," + f"{'a' if not self.symmetric else ''}symmetric)") + + +kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, True, True) +kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, True, True) +kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, False, True) + +QUANT_OPS: Dict[QuantKey, OpOverload] = { + kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa + kFp8DynamicTensorSym: + torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa + kFp8DynamicTokenSym: + torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa +} + + +class FusedRMSQuantKey(NamedTuple): + """ + Named tuple for identifying the type of RMSNorm + quant fusion. + quant: type of quantization + fused_add: does the op also perform the residual add + """ + quant: QuantKey + fused_add: bool + + def __str__(self): + return (f"FusedQuantKey({self.quant}, with" + f"{'' if self.fused_add else 'out'} residual)") + + +FUSED_OPS: Dict[FusedRMSQuantKey, OpOverload] = { + FusedRMSQuantKey(kFp8StaticTensorSym, False): + torch.ops._C.rms_norm_static_fp8_quant.default, # noqa + FusedRMSQuantKey(kFp8StaticTensorSym, True): + torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa + FusedRMSQuantKey(kFp8DynamicTokenSym, False): + torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa + FusedRMSQuantKey(kFp8DynamicTokenSym, True): + torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa +} + + +class QuantMultiOutputMatch(MultiOutputMatch): + + def __init__(self, match: pm.Match, quant_op, fused_op): + super().__init__(match) + assert isinstance(quant_op, OpOverload) + assert isinstance(fused_op, OpOverload) + self.QUANT_OP = quant_op # in-place quant op + self.FUSED_OP = fused_op # in-place fused quant op + + def insert_fused_node(self, fused_return_mapping: Dict[int, Tuple[fx.Node, + int]], + **kwargs): + """ + This utility function inserts an auto-functionalized node for FUSED_OP. + It also correctly sets its meta value and rebinds the users of the + unfused nodes to use the fused node instead. + + :param fused_return_mapping: A dictionary, mapping from getitem indices + of the fused node result to a tuple of the old node and a getitem index. + :param kwargs: kwargs that get directly forwarded to the auto_fn node + + Example: + If we want to replace this graph: + _, x1, x2 = auto_fn(op1) + _, y1, y2 = auto_fn(op2) + + with + _, x1, y2, x2 = auto_fn(FUSED_OP) + + we would call: + insert_fused_node({1: (op1_node, 1), 2: (op2_node, 2), 3: (op1_node, 2)} + + Note that the 0th element is None for auto-functionalized in-place ops. + Hence, others appear 1-indexed. + """ + fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs) + indices = fused_return_mapping.keys() + getitem_nodes = self.insert_getitems(fused_node, indices) + + # Prepare the meta value, use a list so it's mutable + meta_val = [None] * (max(indices) + 1) + + # Iterate through elements of the tuple produced by fused_node + for idx, getitem_node in zip(indices, getitem_nodes): + old_node, old_idx = fused_return_mapping[idx] + + # If the old value was never used, the old_getitem might not exist + old_getitem = find_getitem_maybe(old_node, old_idx) + if old_getitem is not None: + # Rebind the users of match getitem nodes to use the new nodes. + # The old nodes will be removed by DCE at the end of the pass. + old_getitem.replace_all_uses_with(getitem_node) + getitem_node.meta["val"] = old_getitem.meta["val"] + + # Extract the appropriate meta value + # It is present even if the getitem node does not exist + meta_val[idx] = old_node.meta["val"][old_idx] + + # Fix the meta value on the new fused node + fused_node.meta["val"] = tuple(meta_val) + + +class RMSNormQuantPattern: + + def __init__(self, epsilon: float, key: FusedRMSQuantKey): + self.epsilon = epsilon + self.quant_dtype = key.quant.dtype + + assert key.quant in QUANT_OPS, \ + f"unsupported quantization scheme {key.quant}" + self.QUANT_OP = QUANT_OPS[key.quant] + + assert key in FUSED_OPS, \ + f"unsupported fused rmsnorm+quant op for {key}" + self.FUSED_OP = FUSED_OPS[key] + + +class RMSNormStaticQuantPattern(RMSNormQuantPattern): + + def __init__(self, + epsilon: float, + quant_dtype: torch.dtype, + symmetric=True): + fused_key = FusedRMSQuantKey(fused_add=False, + quant=QuantKey(dtype=quant_dtype, + static=True, + per_tensor=True, + symmetric=symmetric)) + super().__init__(epsilon, fused_key) + + def register(self, pm_pass: PatternMatcherPass): + # Cannot use methods, as the self argument affects tracing + def pattern(result: torch.Tensor, result_rms: torch.Tensor, + input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at1 = auto_functionalized(RMS_OP, + result=result_rms, + input=input, + weight=weight, + epsilon=self.epsilon) + at2 = auto_functionalized(self.QUANT_OP, + result=result, + input=at1[1], + scale=scale) + + # result + return at2[1] + + def replacement(result: torch.Tensor, result_rms: torch.Tensor, + input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon) + + # result + return at[1] + + inputs = [ + torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result + empty_bf16(5, 4), # result_rms + empty_bf16(5, 4), # input + empty_bf16(1, 5), # weight + empty_fp32(1, 1) # scale + ] + + pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, + pm_pass) + + +class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): + + def __init__(self, + epsilon: float, + quant_dtype: torch.dtype, + symmetric=True): + key = FusedRMSQuantKey(fused_add=True, + quant=QuantKey(dtype=quant_dtype, + static=True, + per_tensor=True, + symmetric=symmetric)) + super().__init__(epsilon, key) + + def register(self, pm_pass: PatternMatcherPass, + record_match: Callable[[MultiOutputMatch], bool]): + + def pattern(result: torch.Tensor, input: torch.Tensor, + residual: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(RMS_ADD_OP, + input=input, + residual=residual, + weight=weight, + epsilon=self.epsilon) + at1 = auto_functionalized(self.QUANT_OP, + result=result, + input=at[1], + scale=scale) + + # result, residual + return at1[1], at[2] + + def replacement(result: torch.Tensor, input: torch.Tensor, + residual: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(self.FUSED_OP, + result=result, + input=input, + residual=residual, + weight=weight, + scale=scale, + epsilon=self.epsilon) + + # result, residual + return at[1], at[2] + + inputs = [ + torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result + empty_bf16(5, 4), # input + empty_bf16(5, 4), # residual + empty_bf16(1, 5), # weight + empty_fp32(1, 1) # scale + ] + + pm.register_replacement( + pattern, + replacement, + inputs, + pm.fwd_only, + pm_pass, + extra_check=lambda m: record_match( + self.Match(m, self.QUANT_OP, self.FUSED_OP))) + + class Match(QuantMultiOutputMatch): + + def process(self): + # Find the nodes in the match that we need to rebind + rms_node = self.find_auto_fn(RMS_ADD_OP) + quant_node = self.find_auto_fn(self.QUANT_OP) + + assert len(rms_node.users) == 2 + assert len(quant_node.users) == 1 + + # First, insert a new auto_functionalized node for the fused op, + # as well as getitem nodes to extract the result and residual. + # The auto_fn node returns a tuple of (None, result, residual). + # + # The resulting graph looks like this: + # at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa + # result_node_new = at[1] + # residual_node_new = at[2] + with self.inserting_after_match(): + # Missing epsilon, scalars cannot be inputs to the pattern + kwargs = self.match.kwargs.copy() + + # 0 is always None + fused_return_mapping = {1: (quant_node, 1), 2: (rms_node, 2)} + self.insert_fused_node(fused_return_mapping, + epsilon=rms_node.kwargs["epsilon"], + **kwargs) + + +class RMSNormDynamicQuantPattern(RMSNormQuantPattern): + + def __init__(self, + epsilon: float, + quant_dtype: torch.dtype, + per_tensor: bool, + symmetric=True): + key = FusedRMSQuantKey(fused_add=False, + quant=QuantKey(dtype=quant_dtype, + static=False, + per_tensor=per_tensor, + symmetric=symmetric)) + super().__init__(epsilon, key) + + def register(self, pm_pass: PatternMatcherPass, + record_match: Callable[[MultiOutputMatch], bool]): + + def pattern(result: torch.Tensor, result_rms: torch.Tensor, + input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at1 = auto_functionalized(RMS_OP, + result=result_rms, + input=input, + weight=weight, + epsilon=self.epsilon) + at2 = auto_functionalized(self.QUANT_OP, + result=result, + input=at1[1], + scale=scale, + scale_ub=None) + + # result, scale + return at2[1], at2[2] + + def replacement(result: torch.Tensor, result_rms: torch.Tensor, + input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon, + scale_ub=None, + residual=None) + + # result, scale + return at[1], at[2] + + inputs = [ + torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result + empty_bf16(5, 4), # result_rms + empty_bf16(5, 4), # input + empty_bf16(1, 5), # weight + empty_fp32(1, 1) # scale + ] + + pm.register_replacement( + pattern, + replacement, + inputs, + pm.fwd_only, + pm_pass, + extra_check=lambda m: record_match( + self.Match(m, self.QUANT_OP, self.FUSED_OP))) + + class Match(QuantMultiOutputMatch): + + def process(self): + # Find the nodes in the match that we need to rebind + rms_node = self.find_auto_fn(RMS_OP) + quant_node = self.find_auto_fn(self.QUANT_OP) + + assert len(rms_node.users) == 1 + assert len(quant_node.users) == 2 + + # First, insert a new auto_functionalized node for the fused op, + # as well as getitem nodes to extract the result and scale. + # The auto_fn node returns a tuple of (None, result, scale). + # + # The resulting graph looks like this: + # at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa + # result_node_new = at[1] + # scale_node_new = at[2] + with self.inserting_after_match(): + # Missing epsilon, scalars cannot be inputs to the pattern + kwargs = self.match.kwargs.copy() + del kwargs["result_rms"] # not used in the fused op + + fused_return_mapping = {1: (quant_node, 1), 2: (quant_node, 2)} + self.insert_fused_node( + fused_return_mapping, + epsilon=rms_node.kwargs["epsilon"], + scale_ub=None, # not used but required + residual=None, # not used but required + **kwargs) + + +class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): + + def __init__(self, + epsilon: float, + quant_dtype: torch.dtype, + per_tensor: bool = True, + symmetric=True): + key = FusedRMSQuantKey(fused_add=True, + quant=QuantKey(dtype=quant_dtype, + static=False, + per_tensor=per_tensor, + symmetric=symmetric)) + super().__init__(epsilon, key) + + def register(self, pm_pass: PatternMatcherPass, + record_match: Callable[[MultiOutputMatch], bool]): + + def pattern(result: torch.Tensor, input: torch.Tensor, + residual: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(RMS_ADD_OP, + input=input, + residual=residual, + weight=weight, + epsilon=self.epsilon) + at1 = auto_functionalized(self.QUANT_OP, + result=result, + input=at[1], + scale=scale, + scale_ub=None) + + # result, residual, scale + return at1[1], at[2], at1[2] + + def replacement(result: torch.Tensor, input: torch.Tensor, + residual: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon, + scale_ub=None, + residual=residual) + + # result, residual, scale + return at[1], at[3], at[2] + + inputs = [ + torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result + empty_bf16(5, 4), # input + empty_bf16(5, 4), # residual + empty_bf16(1, 5), # weight + empty_fp32(1, 1) # scale + ] + + pm.register_replacement( + pattern, + replacement, + inputs, + pm.fwd_only, + pm_pass, + extra_check=lambda m: record_match( + self.Match(m, self.QUANT_OP, self.FUSED_OP))) + + class Match(QuantMultiOutputMatch): + + def process(self): + # Find the nodes in the match that we need to rebind + rms_node = self.find_auto_fn(RMS_ADD_OP) + quant_node = self.find_auto_fn(self.QUANT_OP) + + assert len(rms_node.users) == 2 + assert len(quant_node.users) == 2 + + # First, insert a new auto_functionalized node for the fused op, + # as well as getitem nodes to extract result, scale, and residual. + # The auto_fn node returns a tuple (None, result, scale, residual). + # + # The resulting graph looks like this: + # at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa + # result_node_new = at[1] + # scale_node_new = at[2] + # residual_node_new = at[3] + with self.inserting_after_match(): + # Missing epsilon, scalars cannot be inputs to the pattern + kwargs = self.match.kwargs.copy() + + fused_return_mapping = { + 1: (quant_node, 1), # result + 2: (quant_node, 2), # scale + 3: (rms_node, 2), # residual + } + self.insert_fused_node( + fused_return_mapping, + epsilon=rms_node.kwargs["epsilon"], + scale_ub=None, # not used but required + **kwargs) + + +class FusionPass(VllmInductorPass): + """ + This pass fuses a pre-defined set of custom ops into fused ops. + It uses the torch pattern matcher to find the patterns and replace them. + It also manually processes multi-output matches, as those are broken in + the torch pattern matcher. + + Because patterns can only be registered once, the pass is a singleton. + This will be addressed in a future version of PyTorch: + https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980 + """ + + _instance: 'Optional[FusionPass]' = None + + @classmethod + def instance(cls, config: CompilationConfig.PassConfig): + """ + Get the singleton instance of the FusionPass. + If the instance exists, the config is updated but + initialization is not repeated. + """ + if cls._instance is None: + cls._instance = FusionPass(config) + else: + cls._instance.config = config + return cls._instance + + def __init__(self, config: CompilationConfig.PassConfig): + assert self.__class__._instance is None, \ + "FusionPass singleton instance already exists" + super().__init__(config) + + self.matches: List[MultiOutputMatch] = [] + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="fusion_pass") + + for epsilon in [1e-5, 1e-6]: + # Fuse rms_norm + static fp8 quant + RMSNormStaticQuantPattern(epsilon, + FP8_DTYPE).register(self.patterns) + + # Matches for patterns below have 2 or more outputs, + # so we need to process them manually (see process_matches) + + # Fuse rms_norm + static fp8 quant + FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( + self.patterns, self.record_match) + + # Fuse rms_norm + dynamic per-token fp8 quant + RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE, + per_tensor=False).register( + self.patterns, self.record_match) + + # Fuse fused_add_rms_norm + dynamic per-token fp8 quant + FusedAddRMSNormDynamicQuantPattern(epsilon, + FP8_DTYPE, + per_tensor=False).register( + self.patterns, + self.record_match) + + # WARNING: This is a hack to clear the pattern matcher cache + # and allow multiple values of epsilon. + torch._inductor.pattern_matcher._seen_patterns.clear() + + def record_match(self, match: MultiOutputMatch) -> bool: + # Hijack the extra_check to record the match and + # save it for post-processing. + self.matches.append(match) + + # Return False to prevent automatic replacement. + return False + + def process_matches(self, graph: fx.Graph): + """ + Manually process multi-output matches and replace them with fused nodes. + See MultiOutputMatch for more details. + """ + for match in self.matches: + match.process() + + # Finally, remove matched nodes + graph.eliminate_dead_code() + assert all(node not in graph.nodes for match in self.matches + for node in match.match.nodes) + + def __call__(self, graph: fx.Graph): + self.begin() + self.dump_graph(graph, "before_fusion") + + count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", count) + self.dump_graph(graph, "after_pattern_match") + + # Manually process multi-output matches (and run DCE) + self.process_matches(graph) + logger.debug("Post-processed %s matches", len(self.matches)) + self.dump_graph(graph, "after_fusion") + self.matches.clear() + self.end_and_log() diff --git a/vllm/compilation/fx_utils.py b/vllm/compilation/fx_utils.py new file mode 100644 index 0000000..b9a8d31 --- /dev/null +++ b/vllm/compilation/fx_utils.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: Apache-2.0 + +import operator +from typing import Iterable, Optional + +from torch import fx +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._ops import OpOverload + + +def is_func(node: fx.Node, target) -> bool: + return node.op == "call_function" and node.target == target + + +# Returns the first auto_functionalized node with the given op (if it exists) +def find_auto_fn_maybe(nodes: Iterable[fx.Node], + op: OpOverload) -> Optional[fx.Node]: + for node in nodes: + if is_func(node, auto_functionalized) and node.args[0] == op: # noqa + return node + return None + + +# Returns the first auto_functionalized node with the given op +def find_auto_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node: + node = find_auto_fn_maybe(nodes, op) + assert node is not None, f"Could not find {op} in nodes {nodes}" + return node + + +# Returns the getitem node that extracts the idx-th element from node +# (if it exists) +def find_getitem_maybe(node: fx.Node, idx: int) -> Optional[fx.Node]: + for user in node.users: + if is_func(user, operator.getitem) and user.args[1] == idx: + return user + return None + + +# Returns the getitem node that extracts the idx-th element from node +def find_getitem(node: fx.Node, idx: int) -> fx.Node: + ret = find_getitem_maybe(node, idx) + assert ret is not None, f"Could not find getitem {idx} in node {node}" + return ret diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py new file mode 100644 index 0000000..08dd8c8 --- /dev/null +++ b/vllm/compilation/inductor_pass.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: Apache-2.0 + +import hashlib +import importlib.metadata +import inspect +import json +import types +from typing import Any, Callable, Dict, Optional, Union + +import torch +from packaging.version import Version +from torch import fx + +if Version(importlib.metadata.version('torch')) >= Version("2.6"): + from torch._inductor.custom_graph_pass import CustomGraphPass +else: + # CustomGraphPass is not present in 2.5 or lower, import our version + from .torch25_custom_graph_pass import ( # noqa: yapf + Torch25CustomGraphPass as CustomGraphPass) + + +class InductorPass(CustomGraphPass): + """ + A custom graph pass that uses a hash of its source as the UUID. + This is defined as a convenience and should work in most cases. + """ + + def uuid(self) -> Any: + """ + Provide a unique identifier for the pass, used in Inductor code cache. + This should depend on the pass implementation, so that changes to the + pass result in recompilation. + By default, the object source is hashed. + """ + return InductorPass.hash_source(self) + + @staticmethod + def hash_source(*srcs: Union[str, Any]): + """ + Utility method to hash the sources of functions or objects. + :param srcs: strings or objects to add to the hash. + Objects and functions have their source inspected. + :return: + """ + hasher = hashlib.sha256() + for src in srcs: + if isinstance(src, str): + src_str = src + elif isinstance(src, types.FunctionType): + src_str = inspect.getsource(src) + else: + src_str = inspect.getsource(src.__class__) + hasher.update(src_str.encode("utf-8")) + return hasher.hexdigest() + + @staticmethod + def hash_dict(dict_: Dict[Any, Any]): + """ + Utility method to hash a dictionary, can alternatively be used for uuid. + :return: A sha256 hash of the json rep of the dictionary. + """ + encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") + return hashlib.sha256(encoded).hexdigest() + + +class CallableInductorPass(InductorPass): + """ + This class is a wrapper for a callable that automatically provides an + implementation of the UUID. + """ + + def __init__(self, + callable: Callable[[fx.Graph], None], + uuid: Optional[Any] = None): + self.callable = callable + self._uuid = self.hash_source(callable) if uuid is None else uuid + + def __call__(self, graph: torch.fx.Graph): + self.callable(graph) + + def uuid(self) -> Any: + return self._uuid diff --git a/vllm/compilation/monitor.py b/vllm/compilation/monitor.py new file mode 100644 index 0000000..786c7c1 --- /dev/null +++ b/vllm/compilation/monitor.py @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +import time + +from vllm.config import CompilationConfig, CompilationLevel, VllmConfig +from vllm.logger import init_logger + +logger = init_logger(__name__) + +context_manager = None +torch_compile_start_time: float = 0.0 + + +def start_monitoring_torch_compile(vllm_config: VllmConfig): + global torch_compile_start_time + torch_compile_start_time = time.time() + + compilation_config: CompilationConfig = vllm_config.compilation_config + if compilation_config.level == CompilationLevel.PIECEWISE and \ + compilation_config.debug_dump_path: + import depyf + path = os.path.join(compilation_config.debug_dump_path, + f"rank_{vllm_config.parallel_config.rank}") + global context_manager + context_manager = depyf.prepare_debug(path) + context_manager.__enter__() + + +def end_monitoring_torch_compile(vllm_config: VllmConfig): + compilation_config: CompilationConfig = vllm_config.compilation_config + if compilation_config.level == CompilationLevel.PIECEWISE: + logger.info("torch.compile takes %.2f s in total", + compilation_config.compilation_time) + global context_manager + if context_manager is not None: + context_manager.__exit__(None, None, None) + context_manager = None diff --git a/vllm/compilation/multi_output_match.py b/vllm/compilation/multi_output_match.py new file mode 100644 index 0000000..e6f6a60 --- /dev/null +++ b/vllm/compilation/multi_output_match.py @@ -0,0 +1,108 @@ +# SPDX-License-Identifier: Apache-2.0 + +import abc +import operator +from abc import abstractmethod +from typing import Iterable, List, Tuple + +from torch import fx +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._inductor import pattern_matcher as pm +from torch._ops import OpOverload +from torch.fx import Node + +from vllm.compilation.fx_utils import find_auto_fn + + +class MultiOutputMatch(abc.ABC): + """ + This class provides utilities to process multi-output matches and + manually insert replacements. + + This is necessary because the automatic replacement for multi-output + matches is broken: https://github.com/pytorch/pytorch/issues/137280 + """ + + def __init__(self, match: pm.Match): + self.match = match + + @abstractmethod + def process(self): + """ + Process a multi-output match and manually insert the replacement. + + This method should: + 1. Insert the replacement nodes after the last node in the match. + 2. Rebind the users of nodes in the match to use the new nodes. + 3. Set meta["val"] for de-functionalization. + + The result of an auto-functionalized node is a tuple of tensors. + The first element is the return value of the function, usually None. + The remaining elements are the mutated args of the function. + + All auto-functionalized nodes must contain a proper meta["val"], + as it is used by de-functionalization. meta["val"] has to contain the + value of the node (tuple of tensors) that would be returned by the + functionalized node during tracing. + + Existing nodes in the graph all have this property set, but we have + to set it manually for new nodes we insert. + + Example: + # op schema: foo(a: Tensor!, b: Tensor, c: Tensor!) -> None + at = auto_functionalized(torch.ops._C.foo.default, a, b, c) + # at.meta["val"] = (None, a, c) + """ + raise NotImplementedError + + @property + def nodes(self) -> List[fx.Node]: + return self.match.nodes + + @property + def graph(self) -> fx.Graph: + return self.match.graph + + def find_auto_fn(self, op) -> fx.Node: + """ + Find the first auto_functionalized node with the given op in the match. + """ + return find_auto_fn(self.nodes, op) + + def inserting_after_match(self): + """ + Insert nodes after the last node in the match. + This is done to avoid use-before-definition errors after inserting + replacement nodes. + """ + + # match.nodes is not guaranteed to be sorted. + # Find the last node in the match. + for last_node_in_match in reversed(self.graph.nodes): + if last_node_in_match in self.match.nodes: + break + else: + raise ValueError("No nodes in graph") + + return self.graph.inserting_after(last_node_in_match) + + def insert_getitems(self, tuple_node: fx.Node, + indices: Iterable[int]) -> Tuple[fx.Node, ...]: + """ + Insert operator.getitem nodes to extract elements from a tuple node. + + :param tuple_node: The tuple node to extract elements from. + :param indices: The indices of the elements to extract. + :return: Tuple of the new getitem nodes, corresponding to the indices. + """ + with self.graph.inserting_after(tuple_node): + return tuple( + self.graph.call_function(operator.getitem, (tuple_node, idx)) + for idx in indices) + + def insert_auto_fn(self, op: OpOverload, kwargs) -> Node: + """ + Insert an auto_functionalized node with the given op and kwargs. + """ + return self.graph.call_function(auto_functionalized, (op, ), + kwargs=kwargs) diff --git a/vllm/compilation/noop_elimination.py b/vllm/compilation/noop_elimination.py new file mode 100644 index 0000000..19127e9 --- /dev/null +++ b/vllm/compilation/noop_elimination.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Iterable, Union + +import torch.fx +from torch import SymInt + +from vllm.logger import init_logger + +from .fx_utils import is_func +from .vllm_inductor_pass import VllmInductorPass + +logger = init_logger(__name__) + + +class NoOpEliminationPass(VllmInductorPass): + """ + This is an inductor pass that removes redundant reshape/slice operations. + It is required for RMSNorm-quant fusion to work properly. + That's because apply_fp8_linear adds a reshape, which is redundant + in the 2D-case. Additionally, torch internal no-op elimination pass does + not handle certain slice variants. + + Example graph 1: + getitem_1: "f16[s0, 4096]" = ... + view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096]) + at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...) + out: "f8e4m3fn[s0, 4096]" = at[1] + + Can be replaced with: + getitem_1: "f16[s0, 4096]" = ... + at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...) + out: "f8e4m3fn[s0, 4096]" = at[1] + + Example graph 2: + arg0: "s0" = SymInt(s0) + scaled_mm: "f16[s0, 4096]" = ... + slice_1: "f16[s0, 4096]" = torch.slice(scaled_mm, -1, 0, arg0) + at = auto_functionalized(fused_add_rms_norm, input = slice_1, ...) + out: "f16[s0, 4096]" = torch.slice_scatter(scaled_mm, at[1], 0, 0, arg0) + + Can be replaced with: + arg0: "s0" = SymInt(s0) + scaled_mm: "f16[s0, 4096]" = ... + at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...) + out: "f16[s0, 4096]" = at[1] + + TODO(luka): This is currently tested in test_fusion, + but separate tests could be good. + """ + + def __call__(self, graph: torch.fx.Graph): + self.begin() + self.dump_graph(graph, "before_noop_elimination") + count = 0 + # Remove no-op reshapes/views: + for node in graph.nodes: + if is_func(node, torch.ops.aten.reshape.default): + input, shape = node.args[:2] + input_shape = input.meta["val"].shape + if len(shape) != len(input_shape): + # Reshape changing rank, skip + continue + + if shape.count(-1) > 1: + # Invalid reshape args, skip + continue + + if self.all_dims_equivalent(shape, input_shape): + node.replace_all_uses_with(input) + graph.erase_node(node) + count += 1 + + elif is_func(node, torch.ops.aten.slice.Tensor): + input, dim_index, start, end = node.args[:4] + input_shape = input.meta["val"].shape + i_dim = input_shape[dim_index] + + if start == 0 and self.dims_equivalent(end, i_dim): + node.replace_all_uses_with(input) + graph.erase_node(node) + count += 1 + + elif is_func(node, torch.ops.aten.slice_scatter.default): + base, view, dim_index, start, end = node.args[:5] + base_shape = base.meta["val"].shape + view_shape = view.meta["val"].shape + + view_dim = view_shape[dim_index] + + # Check that view fully covers base and the full view is used + # (if the view fully covered the base after slicing but was not + # fully used, we could replace slice_scatter with a simple slice + # but that's a niche case). + if (base_shape == view_shape and start == 0 + and self.dims_equivalent(end, view_dim)): + node.replace_all_uses_with(view) + graph.erase_node(node) + count += 1 + + logger.debug("Removed %s no-op reshapes and slices", count) + self.dump_graph(graph, "after_noop_elimination") + self.end_and_log() + + def all_dims_equivalent(self, dims: Iterable[Union[int, torch.fx.Node]], + i_dims: Iterable[Union[int, SymInt]]): + return all( + self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims)) + + def dims_equivalent(self, dim: Union[int, torch.fx.Node], + i_dim: Union[int, SymInt]) -> bool: + """ + This function checks if two dimensions are equivalent. + :param dim: The dimension arg to reshape/slice + :param i_dim: The corresponding dimension in the input tensor + :return: Are the dimensions equivalent? + + There are three cases in which the dimensions are equivalent: + 1. The dimensions are equal (both integers) + 2. The reshape dimension is -1 (i.e. inferred) + 3. The dimensions both correspond to the same SymInt + + While case 2 does not guarantee the dimensions are equal, + they are equal if all other dimensions are equal. + + In case 3, the reshape dimension is a torch.fx.Node, + and its value is a SymInt. That value is equal to the + input dimension. + + """ + # Case 1 and 2 + if dim == i_dim or dim == -1: + return True + # Case 3 + return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py new file mode 100644 index 0000000..530a88b --- /dev/null +++ b/vllm/compilation/pass_manager.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List + +from torch import fx as fx + +from vllm.config import CompilationConfig +from vllm.logger import init_logger + +from .fix_functionalization import FixFunctionalizationPass +from .fusion import FusionPass +from .inductor_pass import CustomGraphPass, InductorPass +from .noop_elimination import NoOpEliminationPass + +logger = init_logger(__name__) + + +class PostGradPassManager(CustomGraphPass): + """ + The pass manager for post-grad passes. + It handles configuration, adding custom passes, and running passes. + It supports uuid for the Inductor code cache. That includes torch<2.6 + support using pickling (in .inductor_pass.CustomGraphPass). + + The order of the post-grad post-passes is: + 1. passes (constructor parameter) + 2. default passes (NoopEliminationPass, FusionPass) + 3. config["post_grad_custom_post_pass"] (if it exists) + 4. fix_functionalization + This way, all passes operate on a functionalized graph. + """ + + def __init__(self): + self.passes: List[InductorPass] = [] + + def __call__(self, graph: fx.Graph): + for pass_ in self.passes: + pass_(graph) + + # always run fix_functionalization last + self.fix_functionalization(graph) + + def configure(self, pass_config: CompilationConfig.PassConfig): + self.pass_config = pass_config + if pass_config.enable_noop: + self.passes += [NoOpEliminationPass(pass_config)] + + if pass_config.enable_fusion: + self.passes += [FusionPass.instance(pass_config)] + + self.fix_functionalization = FixFunctionalizationPass(pass_config) + + def add(self, pass_: InductorPass): + assert isinstance(pass_, InductorPass) + self.passes.append(pass_) + + def uuid(self): + """ + The PostGradPassManager is set as a custom pass in the Inductor and + affects compilation caching. Its uuid depends on the UUIDs of all + dependent passes and the pass config. See InductorPass for more info. + """ + state = {"pass_config": self.pass_config.uuid(), "passes": []} + for pass_ in self.passes: + state["passes"].append(pass_.uuid()) + state["passes"].append(self.fix_functionalization.uuid()) + return InductorPass.hash_dict(state) diff --git a/vllm/compilation/torch25_custom_graph_pass.py b/vllm/compilation/torch25_custom_graph_pass.py new file mode 100644 index 0000000..4b881d0 --- /dev/null +++ b/vllm/compilation/torch25_custom_graph_pass.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 +from abc import ABC, abstractmethod +from typing import Any, Optional + +import torch + + +class Torch25CustomGraphPass(ABC): # noqa (redefinition) + """ + This class replaces CustomGraphPass from torch==2.6 when using torch<2.6. + It conforms to the 2.6 interface but also supports pickling, as that's what + the inductor code cache uses to determine the cache key before 2.6. + (in 2.6 and above, uuid() is used.) + + Subclasses can just "pretend" that uuid is used. + """ + + @abstractmethod + def __call__(self, graph: torch.fx.graph.Graph) -> None: + """ + Implementation of the custom pass. + """ + + @abstractmethod + def uuid(self) -> Optional[Any]: + """ + Return an ID to uniquely identify your custom pass implementation. + Return None to skip inductor code caching entirely. + """ + + def __getstate__(self): + """ + Pickling is used instead of uuid() in torch<2.6. Just return uuid() + to enable subclasses to only have to implement uuid. + """ + return self.uuid() + + def __setstate__(self, state): + raise ValueError("Cannot unpickle CustomGraphPass because pickling" + " is used for cache key uuid. Use torch>=2.6 with" + " native uuid support for custom passes.") diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py new file mode 100644 index 0000000..98ed6f1 --- /dev/null +++ b/vllm/compilation/vllm_inductor_pass.py @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: Apache-2.0 + +import time + +import torch + +from vllm.config import CompilationConfig +# yapf: disable +from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank +from vllm.distributed import ( + get_tensor_model_parallel_world_size as get_tp_world_size) +from vllm.distributed import model_parallel_is_initialized as p_is_init +# yapf: enable +from vllm.logger import init_logger + +from .inductor_pass import InductorPass + +logger = init_logger(__name__) + + +class VllmInductorPass(InductorPass): + """ + An inductor pass with access to vLLM PassConfig. + It provides timing, logging, and dumping utilities. + """ + + def __init__(self, config: CompilationConfig.PassConfig): + self.config = config + self.pass_name = self.__class__.__name__ + + def dump_graph(self, graph: torch.fx.Graph, stage: str, always=False): + if stage in self.config.dump_graph_stages or always: + # Make sure filename includes rank in the distributed setting + parallel = p_is_init() and get_tp_world_size() > 1 + rank = f"-{get_tp_rank()}" if parallel else "" + filepath = self.config.dump_graph_dir / f"{stage}{rank}.py" + + logger.info("%s printing graph to %s", self.pass_name, filepath) + with open(filepath, "w") as f: + src = graph.python_code(root_module="self", verbose=True).src + # Add imports so it's not full of errors + print("import torch; from torch import device", file=f) + print(src, file=f) + + def begin(self): + self._start_time = time.perf_counter_ns() + + def end_and_log(self): + self._end_time = time.perf_counter_ns() + duration_ms = float(self._end_time - self._start_time) / 1.0e6 + logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms) + + +class PrinterInductorPass(VllmInductorPass): + + def __init__(self, + name: str, + config: CompilationConfig.PassConfig, + always=False): + super().__init__(config) + self.name = name + self.always = always + + def __call__(self, graph: torch.fx.Graph): + self.dump_graph(graph, self.name, always=self.always) diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py new file mode 100644 index 0000000..a8a283d --- /dev/null +++ b/vllm/compilation/wrapper.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +import sys +from abc import abstractmethod +from contextlib import contextmanager +from types import CodeType +from typing import Callable, List, Optional + +import torch + +import vllm.envs as envs +from vllm.config import CompilationLevel, get_current_vllm_config +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class TorchCompileWrapperWithCustomDispatcher: + """ + A wrapper class for torch.compile, with a custom dispatch logic. + Subclasses should: + 1. Implement the forward method + 2. Implement the dispatch logic in the __call__ method + It can use `self.compiled_codes` to access the compiled bytecode, + and `with self.dispatch_to_code(index):` to dispatch to + the compiled code. + 3. Implement the `__init__` method to determine how to call + `torch.compile` over the forward method. + """ + + def __init__(self, + compiled_callable: Optional[Callable] = None, + compilation_level: int = 0): + + vllm_config = get_current_vllm_config() + self.vllm_config = vllm_config + if compiled_callable is None: + # default compilation settings + # compiling the forward method + + backend = vllm_config.compilation_config.init_backend(vllm_config) + + compiled_callable = torch.compile( + self.forward, + fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + backend=backend) + + self.compiled_callable = compiled_callable + self.original_code_object = self.__class__.forward.__code__ + self.compiled_codes: List[CodeType] = [] + torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook) + + # read the env var to determine whether to use the custom dispatcher + # subclasses can use this to switch between the custom dispatcher + # and the default Dynamo guard mechanism. + self.use_custom_dispatcher: bool = \ + compilation_level >= CompilationLevel.DYNAMO_ONCE + + def __call__(self, *args, **kwargs): + """Implement the dispatch logic here, beyond the torch.compile level. + NOTE: this function can have additional arguments beyond the forward + method, for directly dispatching to the compiled code. + """ + return self.compiled_callable(*args, **kwargs) + + @abstractmethod + def forward(self, *args, **kwargs): + ... + + def bytecode_hook(self, old_code: CodeType, new_code: CodeType): + """Hook to save the compiled bytecode for direct execution.""" + if old_code is not self.original_code_object: + return + # code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25 + frame = sys._getframe() + while frame and frame.f_back: + frame = frame.f_back + code_name = frame.f_code.co_name + file_name = frame.f_code.co_filename.split(os.path.sep)[-1] + if code_name == "_compile" and file_name == "convert_frame.py": + break + frame = frame.f_locals["frame"] + assert frame.f_code == old_code + + if frame.f_locals["self"] is not self: + return + + self.compiled_codes.append(new_code) + local_cache_dir = self.vllm_config.compilation_config.local_cache_dir + if isinstance(local_cache_dir, str): + decompiled_file = os.path.join(local_cache_dir, + "transformed_code.py") + if not os.path.exists(decompiled_file): + try: + # usually the decompilation will succeed for most models, + # as we guarantee a full-graph compilation in Dynamo. + # but there's no 100% guarantee, since decompliation is + # not a reversible process. + import depyf + src = depyf.decompile(new_code) + with open(decompiled_file, "w") as f: + f.write(src) + + logger.debug("Dynamo transformed code saved to %s", + decompiled_file) + except Exception: + pass + + if self.vllm_config.compilation_config.use_cudagraph and \ + "update" in new_code.co_names: + import depyf + src = depyf.decompile(new_code) + msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src # noqa + raise RuntimeError(msg) + + @contextmanager + def dispatch_to_code(self, index: int): + """Context manager to dispatch to the compiled code. + Why does this work? Because Dynamo guarantees that the compiled + bytecode has exactly the same arguments, cell variables, and free + variables as the original code. Therefore we can directly switch + the code object in the function and call it. + + See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details. + """ # noqa + self.__class__.forward.__code__ = self.compiled_codes[index] + yield + self.__class__.forward.__code__ = self.original_code_object diff --git a/vllm/config.py b/vllm/config.py new file mode 100644 index 0000000..bab44a2 --- /dev/null +++ b/vllm/config.py @@ -0,0 +1,3862 @@ +# SPDX-License-Identifier: Apache-2.0 + +import ast +import copy +import enum +import hashlib +import importlib.metadata +import json +import sys +import warnings +from collections import Counter +from collections.abc import Mapping +from contextlib import contextmanager +from dataclasses import dataclass, field, replace +from importlib.util import find_spec +from pathlib import Path +from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal, + Optional, Protocol, Union) + +import torch +from packaging.version import Version +from pydantic import BaseModel, Field, PrivateAttr +from torch.distributed import ProcessGroup, ReduceOp +from transformers import PretrainedConfig + +import vllm.envs as envs +from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, + get_quantization_config) +from vllm.model_executor.models import ModelRegistry +from vllm.platforms import CpuArchEnum, current_platform +from vllm.sampling_params import GuidedDecodingParams +from vllm.tracing import is_otel_available, otel_import_error_traceback +from vllm.transformers_utils.config import ( + ConfigFormat, get_config, get_hf_image_processor_config, + get_hf_text_config, get_pooling_config, + get_sentence_transformer_tokenizer_config, is_encoder_decoder, + try_get_generation_config, uses_mrope) +from vllm.transformers_utils.s3_utils import S3Model +from vllm.transformers_utils.utils import is_s3, maybe_model_redirect +from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless, + get_cpu_memory, get_open_port, random_uuid, + resolve_obj_by_qualname) + +if TYPE_CHECKING: + from ray.util.placement_group import PlacementGroup + + from vllm.executor.executor_base import ExecutorBase + from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + from vllm.model_executor.model_loader.loader import BaseModelLoader + from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( + BaseTokenizerGroup) +else: + QuantizationConfig = None + +logger = init_logger(__name__) + +# This value is chosen to have a balance between ITL and TTFT. Note it is +# not optimized for throughput. +_DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048 +_POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 +_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 + +TaskOption = Literal["auto", "generate", "embedding", "embed", "classify", + "score", "reward", "transcription"] + +_ResolvedTask = Literal["generate", "embed", "classify", "score", "reward", + "draft", "transcription"] + +RunnerType = Literal["generate", "pooling", "draft", "transcription"] + +_RUNNER_TASKS: dict[RunnerType, list[_ResolvedTask]] = { + "generate": ["generate"], + "pooling": ["embed", "classify", "score", "reward"], + "draft": ["draft"], + "transcription": ["transcription"], +} + +_TASK_RUNNER: dict[_ResolvedTask, RunnerType] = { + task: runner + for runner, tasks in _RUNNER_TASKS.items() + for task in tasks +} + +HfOverrides = Union[dict[str, Any], Callable[[PretrainedConfig], + PretrainedConfig]] + + +class SupportsHash(Protocol): + + def compute_hash(self) -> str: + ... + + +class SupportsMetricsInfo(Protocol): + + def metrics_info(self) -> dict[str, str]: + ... + + +class ModelImpl(str, enum.Enum): + AUTO = "auto" + VLLM = "vllm" + TRANSFORMERS = "transformers" + + +class ModelConfig: + """Configuration for the model. + + Args: + model: Name or path of the huggingface model to use. + It is also used as the content for `model_name` tag in metrics + output when `served_model_name` is not specified. + task: The task to use the model for. Each vLLM instance only supports + one task, even if the same model can be used for multiple tasks. + When the model only supports one task, "auto" can be used to select + it; otherwise, you must specify explicitly which task to use. + tokenizer: Name or path of the huggingface tokenizer to use. + tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if + available, "slow" will always use the slow tokenizer, + "mistral" will always use the tokenizer from `mistral_common`, and + "custom" will use --tokenizer to select the preregistered tokenizer. + trust_remote_code: Trust remote code (e.g., from HuggingFace) when + downloading the model and tokenizer. + allowed_local_media_path: Allowing API requests to read local images or + videos from directories specified by the server file system. + This is a security risk. Should only be enabled in trusted + environments. + dtype: Data type for model weights and activations. The "auto" option + will use FP16 precision for FP32 and FP16 models, and BF16 precision + for BF16 models. + seed: Random seed for reproducibility. + revision: The specific model version to use. It can be a branch name, + a tag name, or a commit id. If unspecified, will use the default + version. + code_revision: The specific revision to use for the model code on + Hugging Face Hub. It can be a branch name, a tag name, or a + commit id. If unspecified, will use the default version. + tokenizer_revision: The specific tokenizer version to use. It can be a + branch name, a tag name, or a commit id. If unspecified, will use + the default version. + max_model_len: Maximum length of a sequence (including prompt and + output). If None, will be derived from the model. + spec_target_max_model_len: Specify the the maximum length for spec + decoding draft models. + quantization: Quantization method that was used to quantize the model + weights. If None, we assume the model weights are not quantized. + enforce_eager: Whether to enforce eager execution. If True, we will + disable CUDA graph and always execute the model in eager mode. + If False, we will use CUDA graph and eager execution in hybrid. + If None, the user did not specify, so default to False. + max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. + When a sequence has context length larger than this, we fall back + to eager mode. Additionally for encoder-decoder models, if the + sequence length of the encoder input is larger than this, we fall + back to the eager mode. + max_logprobs: Maximum number of log probabilities. Defaults to 20. + disable_sliding_window: Whether to disable sliding window. If True, + we will disable the sliding window functionality of the model. + If the model does not support sliding window, this argument is + ignored. + skip_tokenizer_init: If true, skip initialization of tokenizer and + detokenizer. + served_model_name: The model name used in metrics tag `model_name`, + matches the model name exposed via the APIs. If multiple model + names provided, the first name will be used. If not specified, + the model name will be the same as `model`. + limit_mm_per_prompt: Maximum number of data items per modality + per prompt. Only applicable for multimodal models. + use_async_output_proc: Whether to use async output processor. + Defaults to True. + config_format: The config format which shall be loaded. + Defaults to 'auto' which defaults to 'hf'. + hf_overrides: If a dictionary, contains arguments to be forwarded to the + HuggingFace config. If a callable, it is called to update the + HuggingFace config. + mm_processor_kwargs: Arguments to be forwarded to the model's processor + for multi-modal data, e.g., image processor. + disable_mm_preprocessor_cache: If true, then disables caching of the + multi-modal preprocessor/mapper. (not recommended) + override_neuron_config: Initialize non default neuron config or + override default neuron config that are specific to Neuron devices, + this argument will be used to configure the neuron config that + can not be gathered from the vllm arguments. + override_pooler_config: Initialize non default pooling config or + override default pooling config for the pooling model. + logits_processor_pattern: Optional regex pattern specifying valid + logits processor qualified names that can be passed with the + `logits_processors` extra completion argument. Defaults to None, + which allows no processors. + generation_config: Configuration parameter file for generation. + model_impl: Which implementation of the model to use: + "auto" will try to use the vLLM implementation if it exists and + fall back to the Transformers implementation if no vLLM + implementation is available. + "vllm" will use the vLLM model implementation. + "transformers" will use the Transformers model implementation. + override_generation_config: Override the generation config with the + given config. + """ + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [] + factors.append(self.model) + factors.append(self.dtype) + factors.append(self.quantization) + factors.append(self.revision) + factors.append(self.code_revision) + factors.append(self.trust_remote_code) + factors.append(self.rope_scaling) + factors.append(self.rope_theta) + # rope cos/sin cache depends on the max_position_embeddings + factors.append( + getattr(self.hf_config, "max_position_embeddings", "None")) + return hashlib.sha256(str(factors).encode()).hexdigest() + + def __init__( + self, + model: str, + task: Union[TaskOption, Literal["draft"]], + tokenizer: str, + tokenizer_mode: str, + trust_remote_code: bool, + dtype: Union[str, torch.dtype], + seed: int, + hf_config_path: Optional[str] = None, + allowed_local_media_path: str = "", + revision: Optional[str] = None, + code_revision: Optional[str] = None, + rope_scaling: Optional[dict[str, Any]] = None, + rope_theta: Optional[float] = None, + tokenizer_revision: Optional[str] = None, + max_model_len: Optional[int] = None, + spec_target_max_model_len: Optional[int] = None, + quantization: Optional[str] = None, + enforce_eager: Optional[bool] = None, + max_seq_len_to_capture: Optional[int] = None, + max_logprobs: int = 20, + disable_sliding_window: bool = False, + disable_cascade_attn: bool = False, + skip_tokenizer_init: bool = False, + served_model_name: Optional[Union[str, list[str]]] = None, + limit_mm_per_prompt: Optional[Mapping[str, int]] = None, + use_async_output_proc: bool = True, + config_format: ConfigFormat = ConfigFormat.AUTO, + hf_overrides: Optional[HfOverrides] = None, + mm_processor_kwargs: Optional[dict[str, Any]] = None, + disable_mm_preprocessor_cache: bool = False, + override_neuron_config: Optional[dict[str, Any]] = None, + override_pooler_config: Optional["PoolerConfig"] = None, + logits_processor_pattern: Optional[str] = None, + generation_config: str = "auto", + enable_sleep_mode: bool = False, + override_generation_config: Optional[dict[str, Any]] = None, + model_impl: Union[str, ModelImpl] = ModelImpl.AUTO, + ) -> None: + self.model = maybe_model_redirect(model) + self.tokenizer = maybe_model_redirect(tokenizer) + + self.hf_config_path = hf_config_path + if isinstance(hf_config_path, str): + self.hf_config_path = maybe_model_redirect(hf_config_path) + + self.tokenizer_mode = tokenizer_mode + self.trust_remote_code = trust_remote_code + self.allowed_local_media_path = allowed_local_media_path + self.seed = seed + self.revision = revision + self.code_revision = code_revision + self.rope_scaling = rope_scaling + self.rope_theta = rope_theta + self.model_impl = model_impl + + if hf_overrides is None: + hf_overrides = {} + + if callable(hf_overrides): + hf_overrides_kw = {} + hf_overrides_fn = hf_overrides + else: + hf_overrides_kw = hf_overrides + hf_overrides_fn = None + + if rope_scaling is not None: + hf_override: dict[str, Any] = {"rope_scaling": rope_scaling} + hf_overrides_kw.update(hf_override) + hf_overrides_str = json.dumps(hf_overrides) + msg = ( + "`--rope-scaling` will be removed in a future release. " + f"'Please instead use `--hf-overrides '{hf_overrides_str}'`") + warnings.warn(DeprecationWarning(msg), stacklevel=2) + if rope_theta is not None: + hf_override = {"rope_theta": rope_theta} + hf_overrides_kw.update(hf_override) + hf_overrides_str = json.dumps(hf_overrides) + msg = ( + "`--rope-theta` will be removed in a future release. " + f"'Please instead use `--hf-overrides '{hf_overrides_str}'`") + warnings.warn(DeprecationWarning(msg), stacklevel=2) + + self.maybe_pull_model_tokenizer_for_s3(model, tokenizer) + + if (backend := envs.VLLM_ATTENTION_BACKEND + ) and backend == "FLASHINFER" and find_spec("flashinfer") is None: + raise ValueError( + "VLLM_ATTENTION_BACKEND is set to FLASHINFER, but flashinfer " + "module was not found. See " + "https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile " # noqa: E501 + "for instructions on how to install it.") + + # The tokenizer version is consistent with the model version by default. + if tokenizer_revision is None: + self.tokenizer_revision = revision + else: + self.tokenizer_revision = tokenizer_revision + self.quantization = quantization + self.enforce_eager = enforce_eager + self.max_seq_len_to_capture = max_seq_len_to_capture + self.max_logprobs = max_logprobs + self.disable_sliding_window = disable_sliding_window + self.disable_cascade_attn = disable_cascade_attn + self.skip_tokenizer_init = skip_tokenizer_init + self.enable_sleep_mode = enable_sleep_mode + + from vllm.platforms import current_platform + + if self.enable_sleep_mode and not current_platform.is_cuda(): + raise ValueError("Sleep mode is only supported on CUDA devices.") + + hf_config = get_config(self.hf_config_path or self.model, + trust_remote_code, revision, code_revision, + config_format) + + if hf_overrides_kw: + logger.info("Overriding HF config with %s", hf_overrides_kw) + hf_config.update(hf_overrides_kw) + if hf_overrides_fn: + logger.info("Overriding HF config with %s", hf_overrides_fn) + hf_config = hf_overrides_fn(hf_config) + + self.hf_config = hf_config + + self.hf_text_config = get_hf_text_config(self.hf_config) + self.attention_chunk_size = getattr(self.hf_text_config, + "attention_chunk_size", None) + self.encoder_config = self._get_encoder_config() + self.hf_image_processor_config = get_hf_image_processor_config( + self.model, revision) + self.dtype = _get_and_verify_dtype(self.hf_config, dtype) + self.use_async_output_proc = use_async_output_proc + self.mm_processor_kwargs = mm_processor_kwargs + self.disable_mm_preprocessor_cache = disable_mm_preprocessor_cache + + # Set enforce_eager to False if the value is unset. + if self.enforce_eager is None: + self.enforce_eager = False + + interleaved_attn_models = ["gemma2", "gemma3_text", "cohere2"] + sliding_window = getattr(self.hf_text_config, "sliding_window", None) + has_interleaved_attention = (sliding_window is not None) and ( + isinstance(sliding_window, list) or + (self.hf_text_config.model_type in interleaved_attn_models)) + + if (not self.disable_sliding_window and has_interleaved_attention): + if (backend := + envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER"): + sliding_window_len_min = get_min_sliding_window( + self.hf_text_config.sliding_window) + + logger.warning_once( + f"{self.hf_text_config.model_type} has interleaved " + "attention, which is currently not supported by the " + f"{backend} backend. Disabling sliding window and capping " + "the max length to the sliding window size " + f"({sliding_window_len_min}).") + self.disable_sliding_window = True + else: + # for a model with interleaved attention, + # the scheduler and the model treat it as full attention + # (i.e., not dropping any tokens outside the window). + # only the attention layer itself is aware of the sliding + # window, and use the window size to compute the attention. + self.hf_text_config.interleaved_sliding_window = sliding_window + delattr(self.hf_text_config, "sliding_window") + sliding_window = None + + self.max_model_len = _get_and_verify_max_len( + hf_config=self.hf_text_config, + max_model_len=max_model_len, + disable_sliding_window=self.disable_sliding_window, + sliding_window_len=self.get_hf_config_sliding_window(), + spec_target_max_model_len=spec_target_max_model_len, + encoder_config=self.encoder_config) + self.served_model_name = get_served_model_name(model, + served_model_name) + self.multimodal_config = self._init_multimodal_config( + limit_mm_per_prompt) + if not self.skip_tokenizer_init: + self._verify_tokenizer_mode() + + self.is_attention_free = self._init_attention_free() + self.is_hybrid = self._init_is_hybrid() + self.has_noops = self._init_has_noops() + self.has_inner_state = self._init_has_inner_state() + + if current_platform.is_neuron(): + self.override_neuron_config = override_neuron_config + else: + self.override_neuron_config = None + + supported_tasks, task = self._resolve_task(task) + self.supported_tasks = supported_tasks + self.task: Final = task + if self.task in ("draft", "generate"): + self.truncation_side = "left" + else: + self.truncation_side = "right" + + self.pooler_config = self._init_pooler_config(override_pooler_config) + self.logits_processor_pattern = logits_processor_pattern + + self.generation_config = generation_config + self.override_generation_config = override_generation_config or {} + + self._verify_quantization() + self._verify_cuda_graph() + import os + enforce_cuda_graph = os.environ.get("VLLM_ENFORCE_CUDA_GRAPH",None) + if enforce_cuda_graph is not None and enforce_cuda_graph in ["1", "y", "Y"]: + self.enforce_eager = False + else: + self.enforce_eager = True + logger.warning_once( + "Please export VLLM_ENFORCE_CUDA_GRAPH=1 to enable cuda graph. " + "For now, cuda graph is not used and --enforce-eager is disabled ," + "we are trying to use cuda graph as the default mode") + self._verify_bnb_config() + + @property + def registry(self): + return ModelRegistry + + @property + def architectures(self) -> list[str]: + return getattr(self.hf_config, "architectures", []) + + def maybe_pull_model_tokenizer_for_s3(self, model: str, + tokenizer: str) -> None: + """ + Pull the model config or tokenizer to a temporary + directory in case of S3. + + Args: + model: The model name or path. + tokenizer: The tokenizer name or path. + + """ + if is_s3(model) or is_s3(tokenizer): + if is_s3(model): + s3_model = S3Model() + s3_model.pull_files( + model, allow_pattern=["*.model", "*.py", "*.json"]) + self.model_weights = self.model + self.model = s3_model.dir + + if is_s3(tokenizer): + s3_tokenizer = S3Model() + s3_tokenizer.pull_files( + model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) + self.tokenizer = s3_tokenizer.dir + + def _init_multimodal_config( + self, limit_mm_per_prompt: Optional[Mapping[str, int]] + ) -> Optional["MultiModalConfig"]: + if self.registry.is_multimodal_model(self.architectures): + return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {}) + + if limit_mm_per_prompt: + raise ValueError("`limit_mm_per_prompt` is only supported for " + "multimodal models.") + + return None + + def _get_encoder_config(self): + return get_sentence_transformer_tokenizer_config( + self.model, self.revision) + + def _init_pooler_config( + self, + override_pooler_config: Optional["PoolerConfig"], + ) -> Optional["PoolerConfig"]: + + if self.runner_type == "pooling": + user_config = override_pooler_config or PoolerConfig() + + base_config = get_pooling_config(self.model, self.revision) + if base_config is not None: + # Only set values that are not overridden by the user + for k, v in base_config.items(): + if getattr(user_config, k) is None: + setattr(user_config, k, v) + + return user_config + + return None + + def _init_attention_free(self) -> bool: + return self.registry.is_attention_free_model(self.architectures) + + def _init_is_hybrid(self) -> bool: + return self.registry.is_hybrid_model(self.architectures) + + def _init_has_noops(self) -> bool: + architectures = getattr(self.hf_config, "architectures", []) + return self.registry.is_noops_model(architectures) + + def _init_has_inner_state(self) -> bool: + return self.registry.model_has_inner_state(self.architectures) + + def _verify_tokenizer_mode(self) -> None: + tokenizer_mode = self.tokenizer_mode.lower() + if tokenizer_mode not in ["auto", "slow", "mistral", "custom"]: + raise ValueError( + f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be " + "either 'auto', 'slow', 'mistral' or 'custom'.") + self.tokenizer_mode = tokenizer_mode + + def _get_preferred_task( + self, + architectures: list[str], + supported_tasks: set[_ResolvedTask], + ) -> Optional[_ResolvedTask]: + model_id = self.model + if get_pooling_config(model_id, self.revision): + return "embed" + if self.registry.is_cross_encoder_model(architectures): + return "score" + if self.registry.is_transcription_model(architectures): + return "transcription" + + suffix_to_preferred_task: list[tuple[str, _ResolvedTask]] = [ + # Other models follow this pattern + ("ForCausalLM", "generate"), + ("ForConditionalGeneration", "generate"), + ("ForSequenceClassification", "classify"), + ("ChatModel", "generate"), + ("LMHeadModel", "generate"), + ("EmbeddingModel", "embed"), + ("RewardModel", "reward"), + ] + _, arch = self.registry.inspect_model_cls(architectures) + + for suffix, pref_task in suffix_to_preferred_task: + if arch.endswith(suffix) and pref_task in supported_tasks: + return pref_task + + return None + + def _resolve_task( + self, + task_option: Union[TaskOption, Literal["draft"]], + ) -> tuple[set[_ResolvedTask], _ResolvedTask]: + if task_option == "draft": + return {"draft"}, "draft" + + registry = self.registry + architectures = self.architectures + + runner_support: dict[RunnerType, bool] = { + # NOTE: Listed from highest to lowest priority, + # in case the model supports multiple of them + "transcription": registry.is_transcription_model(architectures), + "generate": registry.is_text_generation_model(architectures), + "pooling": registry.is_pooling_model(architectures), + } + supported_runner_types_lst: list[RunnerType] = [ + runner_type + for runner_type, is_supported in runner_support.items() + if is_supported + ] + + supported_tasks_lst: list[_ResolvedTask] = [ + task for runner_type in supported_runner_types_lst + for task in _RUNNER_TASKS[runner_type] + ] + supported_tasks = set(supported_tasks_lst) + + if task_option == "auto": + selected_task = next(iter(supported_tasks_lst)) + + if len(supported_tasks_lst) > 1: + preferred_task = self._get_preferred_task( + architectures, supported_tasks) + if preferred_task is not None: + selected_task = preferred_task + + logger.info( + "This model supports multiple tasks: %s. " + "Defaulting to '%s'.", supported_tasks, selected_task) + else: + # Aliases + if task_option == "embedding": + preferred_task = self._get_preferred_task( + architectures, supported_tasks) + if preferred_task != "embed": + msg = ("The 'embedding' task will be restricted to " + "embedding models in a future release. Please " + "pass `--task classify`, `--task score`, or " + "`--task reward` explicitly for other pooling " + "models.") + warnings.warn(msg, DeprecationWarning, stacklevel=2) + + task_option = preferred_task or "embed" + + if task_option not in supported_tasks: + msg = ( + f"This model does not support the '{task_option}' task. " + f"Supported tasks: {supported_tasks}") + raise ValueError(msg) + + selected_task = task_option + + return supported_tasks, selected_task + + def _parse_quant_hf_config(self): + quant_cfg = getattr(self.hf_config, "quantization_config", None) + if quant_cfg is None: + # compressed-tensors uses a "compression_config" key + quant_cfg = getattr(self.hf_config, "compression_config", None) + return quant_cfg + + def _verify_quantization(self) -> None: + supported_quantization = QUANTIZATION_METHODS + optimized_quantization_methods = [ + "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", + "awq_marlin", "fbgemm_fp8", "compressed_tensors", + "compressed-tensors", "experts_int8", "quark", "nvfp4" + ] + if self.quantization is not None: + self.quantization = self.quantization.lower() + + # Parse quantization method from the HF model config, if available. + quant_cfg = self._parse_quant_hf_config() + + if quant_cfg is not None: + quant_method = quant_cfg.get("quant_method", "").lower() + + # Detect which checkpoint is it + for name in QUANTIZATION_METHODS: + method = get_quantization_config(name) + quantization_override = method.override_quantization_method( + quant_cfg, self.quantization) + if quantization_override: + quant_method = quantization_override + self.quantization = quantization_override + break + + # Verify quantization configurations. + if self.quantization is None: + self.quantization = quant_method + elif self.quantization != quant_method: + raise ValueError( + "Quantization method specified in the model config " + f"({quant_method}) does not match the quantization " + f"method specified in the `quantization` argument " + f"({self.quantization}).") + + if self.quantization is not None: + if self.quantization not in supported_quantization: + raise ValueError( + f"Unknown quantization method: {self.quantization}. Must " + f"be one of {supported_quantization}.") + from vllm.platforms import current_platform + current_platform.verify_quantization(self.quantization) + if self.quantization not in optimized_quantization_methods: + logger.warning( + "%s quantization is not fully " + "optimized yet. The speed can be slower than " + "non-quantized models.", self.quantization) + + def _verify_cuda_graph(self) -> None: + if self.max_seq_len_to_capture is None: + self.max_seq_len_to_capture = self.max_model_len + self.max_seq_len_to_capture = min(self.max_seq_len_to_capture, + self.max_model_len) + ROCM_UNSUPPORTED_MODELS = ['mllama'] + if (self.hf_config.model_type in ROCM_UNSUPPORTED_MODELS + and not self.enforce_eager and current_platform.is_rocm()): + logger.warning( + "CUDA graph is not supported for %s on ROCm yet, fallback " + "to the eager mode.", self.hf_config.model_type) + self.enforce_eager = True + + def _verify_bnb_config(self) -> None: + """ + The current version of bitsandbytes (0.45.3) with 8-bit models does not + yet support CUDA graph. + # TODO Remove this when bitsandbytes supports. + """ + is_bitsandbytes = self.quantization == "bitsandbytes" + has_quantization_config = (getattr(self.hf_config, + "quantization_config", None) + is not None) + is_8bit = (self.hf_config.quantization_config.get( + "load_in_8bit", False) if has_quantization_config else False) + if all([ + is_bitsandbytes, + has_quantization_config, + is_8bit, + not self.enforce_eager, + ]): + logger.warning( + "CUDA graph is not supported on BitsAndBytes 8bit yet, " + "fallback to the eager mode.") + + self.enforce_eager = True + + def _verify_with_expert_parallelism(self) -> None: + num_expert_names = [ + "moe_num_experts", # Dbrx + "num_experts", # Jamba + "n_routed_experts", # DeepSeek + "num_local_experts", # Mixtral + ] + num_experts = 0 + for name in num_expert_names: + num_experts = getattr(self.hf_text_config, name, 0) + if num_experts > 0: + break + if num_experts < 1: + raise ValueError( + "Number of experts in the model must be greater than 0 " + "when expert parallelism is enabled.") + + def verify_async_output_proc(self, parallel_config, speculative_config, + device_config) -> None: + if not self.use_async_output_proc: + # Nothing to check + return + + if parallel_config.pipeline_parallel_size > 1: + self.use_async_output_proc = False + return + + # Reminder: Please update docs/source/features/compatibility_matrix.md + # If the feature combo become valid + from vllm.platforms import current_platform + if not current_platform.is_async_output_supported(self.enforce_eager): + self.use_async_output_proc = False + return + + if envs.VLLM_USE_RAY_SPMD_WORKER: + self.use_async_output_proc = False + return + + # Async postprocessor is not necessary for pooling models + # since there is no token generation + if self.runner_type == "pooling": + self.use_async_output_proc = False + + # Reminder: Please update docs/source/features/compatibility_matrix.md + # If the feature combo become valid + if speculative_config: + self.use_async_output_proc = False + + def verify_with_parallel_config( + self, + parallel_config: "ParallelConfig", + ) -> None: + + if parallel_config.distributed_executor_backend == "external_launcher": + assert self.seed is not None, ( + "Seed must be set when using external launcher backend to " + "make sure sampling results are the same across workers.") + + total_num_attention_heads = getattr(self.hf_text_config, + "num_attention_heads", 0) + tensor_parallel_size = parallel_config.tensor_parallel_size + if total_num_attention_heads % tensor_parallel_size != 0: + raise ValueError( + f"Total number of attention heads ({total_num_attention_heads})" + " must be divisible by tensor parallel size " + f"({tensor_parallel_size}).") + + if parallel_config.enable_expert_parallel: + self._verify_with_expert_parallelism() + + pipeline_parallel_size = parallel_config.pipeline_parallel_size + if pipeline_parallel_size > 1: + if not self.registry.is_pp_supported_model(self.architectures): + raise NotImplementedError( + "Pipeline parallelism is not supported for this model. " + "Supported models implement the `SupportsPP` interface.") + + if self.use_async_output_proc: + self.use_async_output_proc = False + + def get_hf_config_sliding_window( + self) -> Union[Optional[int], list[Optional[int]]]: + """Get the sliding window size, or None if disabled.""" + + # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in + # addition to sliding window size. We check if that field is present + # and if it's False, return None. + if (hasattr(self.hf_text_config, "use_sliding_window") + and not self.hf_text_config.use_sliding_window): + return None + return getattr(self.hf_text_config, "sliding_window", None) + + def get_sliding_window(self) -> Optional[Union[int, list[Optional[int]]]]: + """Get the sliding window size, or None if disabled. + """ + # If user disables sliding window, return None. + if self.disable_sliding_window: + return None + # Otherwise get the value from the hf config. + return self.get_hf_config_sliding_window() + + def get_vocab_size(self) -> int: + return self.hf_text_config.vocab_size + + def get_hidden_size(self) -> int: + return self.hf_text_config.hidden_size + + @property + def is_deepseek_mla(self) -> bool: + if not hasattr(self.hf_text_config, "model_type"): + return False + elif self.hf_text_config.model_type in \ + ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'): + return self.hf_text_config.kv_lora_rank is not None + elif self.hf_text_config.model_type == 'eagle': + # if the model is an EAGLE module, check for the + # underlying architecture + return self.hf_text_config.model.model_type in \ + ('deepseek_v2', 'deepseek_v3') \ + and self.hf_text_config.kv_lora_rank is not None + return False + + def get_head_size(self) -> int: + # TODO remove hard code + if self.is_deepseek_mla: + qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim", + 0) + if self.use_mla: + return self.hf_text_config.kv_lora_rank + qk_rope_head_dim + else: + qk_nope_head_dim = getattr(self.hf_text_config, + "qk_nope_head_dim", 0) + if qk_rope_head_dim and qk_nope_head_dim: + return qk_rope_head_dim + qk_nope_head_dim + + if hasattr(self.hf_text_config, + "model_type") and (self.hf_text_config.model_type + == "zamba2"): + return self.hf_text_config.attention_head_dim + + if self.is_attention_free: + return 0 + + # if hasattr(self.hf_text_config, "head_dim"): + # return self.hf_text_config.head_dim + if getattr(self.hf_text_config, "head_dim", None) is not None: + return self.hf_text_config.head_dim + # FIXME(woosuk): This may not be true for all models. + return (self.hf_text_config.hidden_size // + self.hf_text_config.num_attention_heads) + + def get_total_num_kv_heads(self) -> int: + """Returns the total number of KV heads.""" + # For GPTBigCode & Falcon: + # NOTE: for falcon, when new_decoder_architecture is True, the + # multi_query flag is ignored and we use n_head_kv for the number of + # KV heads. + falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] + new_decoder_arch_falcon = ( + self.hf_config.model_type in falcon_model_types + and getattr(self.hf_config, "new_decoder_architecture", False)) + if not new_decoder_arch_falcon and getattr(self.hf_text_config, + "multi_query", False): + # Multi-query attention, only one KV head. + # Currently, tensor parallelism is not supported in this case. + return 1 + + # For DBRX and MPT + if self.hf_config.model_type == "mpt": + if "kv_n_heads" in self.hf_config.attn_config: + return self.hf_config.attn_config["kv_n_heads"] + return self.hf_config.num_attention_heads + if self.hf_config.model_type == "dbrx": + return getattr(self.hf_config.attn_config, "kv_n_heads", + self.hf_config.num_attention_heads) + + if self.hf_config.model_type == "nemotron-nas": + for block in self.hf_config.block_configs: + if not block.attention.no_op: + return self.hf_config.num_attention_heads \ + // block.attention.n_heads_in_group + + raise RuntimeError("Couldn't determine number of kv heads") + + if self.is_attention_free: + return 0 + + attributes = [ + # For Falcon: + "n_head_kv", + "num_kv_heads", + # For LLaMA-2: + "num_key_value_heads", + # For ChatGLM: + "multi_query_group_num", + ] + for attr in attributes: + num_kv_heads = getattr(self.hf_text_config, attr, None) + if num_kv_heads is not None: + return num_kv_heads + + # For non-grouped-query attention models, the number of KV heads is + # equal to the number of attention heads. + return self.hf_text_config.num_attention_heads + + def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: + """Returns the number of KV heads per GPU.""" + if self.use_mla: + # When using MLA during decode it becomes MQA + return 1 + + total_num_kv_heads = self.get_total_num_kv_heads() + # If tensor parallelism is used, we divide the number of KV heads by + # the tensor parallel size. We will replicate the KV heads in the + # case where the number of KV heads is smaller than the tensor + # parallel size so each GPU has at least one KV head. + return max(1, + total_num_kv_heads // parallel_config.tensor_parallel_size) + + def get_num_attention_heads(self, + parallel_config: "ParallelConfig") -> int: + num_heads = getattr(self.hf_text_config, "num_attention_heads", 0) + return num_heads // parallel_config.tensor_parallel_size + + def get_layers_start_end_indices( + self, parallel_config: "ParallelConfig") -> tuple[int, int]: + from vllm.distributed.utils import get_pp_indices + if self.hf_text_config.model_type == "deepseek_mtp": + total_num_hidden_layers = getattr(self.hf_text_config, + "num_nextn_predict_layers", 0) + else: + total_num_hidden_layers = getattr(self.hf_text_config, + "num_hidden_layers", 0) + # the layout order is: DP x PP x TP + pp_rank = (parallel_config.rank // parallel_config.tensor_parallel_size + ) % parallel_config.pipeline_parallel_size + pp_size = parallel_config.pipeline_parallel_size + + if self.hf_text_config.model_type == "deepseek_mtp": + layers_per_partition = total_num_hidden_layers // pp_size + partitions = [layers_per_partition for _ in range(pp_size)] + + if remaining_layers := total_num_hidden_layers % pp_size: + for i in range(2, remaining_layers + 2): + partitions[-i] += 1 + logger.info("Hidden layers were unevenly partitioned: %s", + ",".join(str(p) for p in partitions)) + logger.info("This can be manually overridden using the " + "VLLM_PP_LAYER_PARTITION environment variable") + start = sum(partitions[:pp_rank]) + end = start + partitions[pp_rank] + return start, end + start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size) + return start, end + + def get_num_layers(self, parallel_config: "ParallelConfig") -> int: + start, end = self.get_layers_start_end_indices(parallel_config) + return end - start + + def get_num_layers_by_block_type( + self, + parallel_config: "ParallelConfig", + block_type: LayerBlockType = LayerBlockType.attention, + ) -> int: + # This function relies on 'layers_block_type' in hf_config, + # for w/o this attribute, we will need to have workarounds like so + attn_block_type = block_type == LayerBlockType.attention + is_transformer = not self.is_hybrid and \ + not self.has_noops and \ + not self.is_attention_free + start, end = self.get_layers_start_end_indices(parallel_config) + + if is_transformer: + # Handle the basic case first + return end - start if attn_block_type else 0 + elif self.is_attention_free: + # Attention free + # Note that this code assumes there + # is only one type of attention-free block type. + return 0 if attn_block_type else end - start + elif self.has_noops: + block_configs = self.hf_config.block_configs + return sum(not bc.attention.no_op + for bc in block_configs[start:end]) + else: + # Hybrid model Jamba + layers_block_type_value = getattr(self.hf_config, + "layers_block_type", None) + if layers_block_type_value is not None: + if hasattr(self.hf_text_config, + "model_type") and (self.hf_text_config.model_type + == "zamba2"): + if attn_block_type: + return sum(t == "hybrid" + for t in layers_block_type_value[start:end]) + else: + return self.get_num_layers(parallel_config) + return sum(t == block_type.value + for t in layers_block_type_value[start:end]) + + # Hybrid model Minimax + attn_type_list = getattr(self.hf_config, "attn_type_list", None) + if attn_type_list: + return sum(t == 1 for t in attn_type_list[start:end]) + + if layers_block_type_value is None and attn_type_list is None: + raise ValueError( + "The model is an hybrid without a" + "layers_block_type or an attn_type_list in the hf_config," + "cannot determine the num of " + f"{block_type.value} layers") + + return sum(t == 1 for t in attn_type_list[start:end]) + + def get_multimodal_config(self) -> "MultiModalConfig": + """ + Get the multimodal configuration of the model. + + Raises: + ValueError: If the model is not multimodal. + """ + if self.multimodal_config is None: + raise ValueError("The model is not multimodal.") + + return self.multimodal_config + + def try_get_generation_config(self) -> dict[str, Any]: + if self.generation_config in ("auto", "vllm"): + config = try_get_generation_config( + self.hf_config_path or self.model, + trust_remote_code=self.trust_remote_code, + revision=self.revision, + ) + else: + config = try_get_generation_config( + self.generation_config, + trust_remote_code=self.trust_remote_code, + ) + + if config is None: + return {} + + return config.to_diff_dict() + + def get_diff_sampling_param(self) -> dict[str, Any]: + """ + This method returns a dictionary containing the parameters + that differ from the default sampling parameters. If + `generation_config` is `"vllm"`, an empty dictionary is returned. + + Returns: + dict[str, Any]: A dictionary with the differing sampling + parameters, if `generation_config` is `"vllm"` an empty dictionary. + """ + if self.generation_config == "vllm": + config = {} + else: + config = self.try_get_generation_config() + + # Overriding with given generation config + config.update(self.override_generation_config) + + available_params = [ + "repetition_penalty", + "temperature", + "top_k", + "top_p", + "min_p", + "max_new_tokens", + ] + if any(p in config for p in available_params): + diff_sampling_param = { + p: config.get(p) + for p in available_params if config.get(p) is not None + } + # Huggingface definition of max_new_tokens is equivalent + # to vLLM's max_tokens + if "max_new_tokens" in diff_sampling_param: + diff_sampling_param["max_tokens"] = diff_sampling_param.pop( + "max_new_tokens") + else: + diff_sampling_param = {} + + if diff_sampling_param: + logger.warning_once( + "Default sampling parameters have been overridden by the " + "model's Hugging Face generation config recommended from the " + "model creator. If this is not intended, please relaunch " + "vLLM instance with `--generation-config vllm`.") + return diff_sampling_param + + @property + def is_encoder_decoder(self) -> bool: + """Extract the HF encoder/decoder model flag.""" + return is_encoder_decoder(self.hf_config) + + @property + def uses_mrope(self) -> bool: + return uses_mrope(self.hf_config) + + @property + def is_multimodal_model(self) -> bool: + return self.multimodal_config is not None + + @property + def is_cross_encoder(self) -> bool: + return self.registry.is_cross_encoder_model(self.architectures) + + @property + def use_mla(self) -> bool: + return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE + + @property + def supported_runner_types(self) -> set[RunnerType]: + return {_TASK_RUNNER[task] for task in self.supported_tasks} + + @property + def runner_type(self) -> RunnerType: + return _TASK_RUNNER[self.task] + + @property + def is_v1_compatible(self) -> bool: + architectures = getattr(self.hf_config, "architectures", []) + return ModelRegistry.is_v1_compatible(architectures) + + +class CacheConfig: + """Configuration for the KV cache. + + Args: + block_size: Size of a cache block in number of tokens. + gpu_memory_utilization: Fraction of GPU memory to use for the + vLLM execution. + swap_space: Size of the CPU swap space per GPU (in GiB). + cache_dtype: Data type for kv cache storage. + is_attention_free: Whether the model is attention-free. + num_gpu_blocks_override: Number of GPU blocks to use. This overrides the + profiled num_gpu_blocks if specified. Does nothing if None. + sliding_window: Sliding window size for the KV cache. + enable_prefix_caching: Whether to enable prefix caching. + cpu_offload_gb: Size of the CPU offload buffer in GiB. + """ + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [] + factors.append(self.cache_dtype) + # `cpu_offload_gb` does not use `torch.compile` yet. + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __init__( + self, + block_size: int, + gpu_memory_utilization: float, + swap_space: float, + cache_dtype: str, + is_attention_free: bool = False, + num_gpu_blocks_override: Optional[int] = None, + sliding_window: Optional[int] = None, + enable_prefix_caching: bool = False, + prefix_caching_hash_algo: str = "builtin", + cpu_offload_gb: float = 0, + calculate_kv_scales: Optional[bool] = None, + ) -> None: + self.block_size = block_size + self.gpu_memory_utilization = gpu_memory_utilization + self.swap_space_bytes = swap_space * GiB_bytes + self.num_gpu_blocks_override = num_gpu_blocks_override + self.cache_dtype = cache_dtype + self.is_attention_free = is_attention_free + self.sliding_window = sliding_window + self.enable_prefix_caching = enable_prefix_caching + self.prefix_caching_hash_algo = prefix_caching_hash_algo + self.cpu_offload_gb = cpu_offload_gb + self.calculate_kv_scales = calculate_kv_scales + self._verify_args() + self._verify_cache_dtype() + self._verify_prefix_caching() + + # Will be set after profiling. + self.num_gpu_blocks: Optional[int] = None + self.num_cpu_blocks: Optional[int] = None + + # Set calculate_kv_scales to False if the value is unset. + if self.calculate_kv_scales is None: + self.calculate_kv_scales = False + + def metrics_info(self): + # convert cache_config to dict(key: str, value: str) for prometheus + # metrics info + return {key: str(value) for key, value in self.__dict__.items()} + + def _verify_args(self) -> None: + if self.cpu_offload_gb < 0: + raise ValueError("CPU offload space must be non-negative" + f", but got {self.cpu_offload_gb}") + + if self.gpu_memory_utilization > 1.0: + raise ValueError( + "GPU memory utilization must be less than 1.0. Got " + f"{self.gpu_memory_utilization}.") + + def _verify_cache_dtype(self) -> None: + if self.cache_dtype == "auto": + pass + elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"): + logger.info( + "Using fp8 data type to store kv cache. It reduces the GPU " + "memory footprint and boosts the performance. " + "Meanwhile, it may cause accuracy drop without a proper " + "scaling factor") + else: + raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") + + def _verify_prefix_caching(self) -> None: + if not self.enable_prefix_caching: + return + + if self.sliding_window is not None and not envs.VLLM_USE_V1: + raise NotImplementedError( + "Prefix caching is not supported with sliding window. " + "Run with --disable-sliding-window to use prefix caching.") + + if self.enable_prefix_caching and self.prefix_caching_hash_algo not in ( + "builtin", "sha256"): + raise ValueError( + "Unknown prefix caching hash algorithm: " + f"{self.prefix_caching_hash_algo}. Must be either " + "'builtin' or 'sha256'.") + + def verify_with_parallel_config( + self, + parallel_config: "ParallelConfig", + ) -> None: + total_cpu_memory = get_cpu_memory() + # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel + # group are in the same node. However, the GPUs may span multiple nodes. + num_gpus_per_node = parallel_config.tensor_parallel_size + cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node + + msg = (f"{cpu_memory_usage / GiB_bytes:.2f} GiB out of the " + f"{total_cpu_memory / GiB_bytes:.2f} GiB total CPU memory " + "is allocated for the swap space.") + if cpu_memory_usage > 0.7 * total_cpu_memory: + raise ValueError("Too large swap space. " + msg) + elif cpu_memory_usage > 0.4 * total_cpu_memory: + logger.warning("Possibly too large swap space. %s", msg) + + +@dataclass +class TokenizerPoolConfig: + """Configuration for the tokenizer pool. + + Args: + pool_size: Number of tokenizer workers in the pool. + pool_type: Type of the pool. + extra_config: Additional config for the pool. + The way the config will be used depends on the + pool type. + """ + pool_size: int + pool_type: Union[str, type["BaseTokenizerGroup"]] + extra_config: dict + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self): + if self.pool_type not in ("ray", ) and not isinstance( + self.pool_type, type): + raise ValueError(f"Unknown pool type: {self.pool_type}") + if not isinstance(self.extra_config, dict): + raise ValueError("extra_config must be a dictionary.") + + @classmethod + def create_config( + cls, tokenizer_pool_size: int, + tokenizer_pool_type: Union[str, type["BaseTokenizerGroup"]], + tokenizer_pool_extra_config: Optional[Union[str, dict]] + ) -> Optional["TokenizerPoolConfig"]: + """Create a TokenizerPoolConfig from the given parameters. + + If tokenizer_pool_size is 0, return None. + + Args: + tokenizer_pool_size: Number of tokenizer workers in the pool. + tokenizer_pool_type: Type of the pool. + tokenizer_pool_extra_config: Additional config for the pool. + The way the config will be used depends on the + pool type. This can be a JSON string (will be parsed). + """ + if tokenizer_pool_size: + if isinstance(tokenizer_pool_extra_config, str): + tokenizer_pool_extra_config_parsed = json.loads( + tokenizer_pool_extra_config) + else: + tokenizer_pool_extra_config_parsed = ( + tokenizer_pool_extra_config or {}) + tokenizer_pool_config = cls(tokenizer_pool_size, + tokenizer_pool_type, + tokenizer_pool_extra_config_parsed) + else: + tokenizer_pool_config = None + return tokenizer_pool_config + + +class LoadFormat(str, enum.Enum): + AUTO = "auto" + PT = "pt" + SAFETENSORS = "safetensors" + NPCACHE = "npcache" + DUMMY = "dummy" + TENSORIZER = "tensorizer" + SHARDED_STATE = "sharded_state" + GGUF = "gguf" + BITSANDBYTES = "bitsandbytes" + MISTRAL = "mistral" + RUNAI_STREAMER = "runai_streamer" + FASTSAFETENSORS = "fastsafetensors" + + +@dataclass +class LoadConfig: + """ + download_dir: Directory to download and load the weights, default to the + default cache directory of huggingface. + load_format: The format of the model weights to load: + "auto" will try to load the weights in the safetensors format and + fall back to the pytorch bin format if safetensors format is + not available. + "pt" will load the weights in the pytorch bin format. + "safetensors" will load the weights in the safetensors format. + "npcache" will load the weights in pytorch format and store + a numpy cache to speed up the loading. + "dummy" will initialize the weights with random values, which is + mainly for profiling. + "tensorizer" will use CoreWeave's tensorizer library for + fast weight loading. + "bitsandbytes" will load nf4 type weights. + "sharded_state" will load weights from pre-sharded checkpoint files, + supporting efficient loading of tensor-parallel models. + "gguf" will load weights from GGUF format files. + "mistral" will load weights from consolidated safetensors files used + by Mistral models. + "runai_streamer" will load weights from RunAI streamer format files. + model_loader_extra_config: The extra config for the model loader. + ignore_patterns: The list of patterns to ignore when loading the model. + Default to "original/**/*" to avoid repeated loading of llama's + checkpoints. + use_tqdm_on_load: Whether to enable tqdm for showing progress bar during + loading. Default to True + """ + + load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO + download_dir: Optional[str] = None + model_loader_extra_config: Optional[Union[str, dict]] = field( + default_factory=dict) + ignore_patterns: Optional[Union[list[str], str]] = None + use_tqdm_on_load: bool = True + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self): + model_loader_extra_config = self.model_loader_extra_config or {} + if isinstance(model_loader_extra_config, str): + self.model_loader_extra_config = json.loads( + model_loader_extra_config) + if isinstance(self.load_format, str): + load_format = self.load_format.lower() + self.load_format = LoadFormat(load_format) + + if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: + logger.info( + "Ignoring the following patterns when downloading weights: %s", + self.ignore_patterns) + else: + self.ignore_patterns = ["original/**/*"] + + +@dataclass +class ParallelConfig: + """Configuration for the distributed execution.""" + + pipeline_parallel_size: int = 1 # Number of pipeline parallel groups. + tensor_parallel_size: int = 1 # Number of tensor parallel groups. + data_parallel_size: int = 1 # Number of data parallel groups. + data_parallel_rank: int = 0 # Rank of the data parallel group. + # Local rank of the data parallel group, defaults to global rank. + data_parallel_rank_local: Optional[int] = None + # IP of the data parallel master. + data_parallel_master_ip: str = "127.0.0.1" + data_parallel_master_port: int = 29500 # Port of the data parallel master. + enable_expert_parallel: bool = False # Use EP instead of TP for MoE layers. + num_virtual_engine: int = 1 # Number of virtual engines. + + # Maximum number of multiple batches + # when load model sequentially. To avoid RAM OOM when using tensor + # parallel and large models. + max_parallel_loading_workers: Optional[int] = None + + # Disable the custom all-reduce kernel and fall back to NCCL. + disable_custom_all_reduce: bool = False + + # Config for the tokenizer pool. If None, will use synchronous tokenization. + tokenizer_pool_config: Optional[TokenizerPoolConfig] = None + + # Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler. + ray_workers_use_nsight: bool = False + + # ray distributed model workers placement group. + placement_group: Optional["PlacementGroup"] = None + + # Backend to use for distributed model + # workers, either "ray" or "mp" (multiprocessing). If the product + # of pipeline_parallel_size and tensor_parallel_size is less than + # or equal to the number of GPUs available, "mp" will be used to + # keep processing on a single host. Otherwise, this will default + # to "ray" if Ray is installed and fail otherwise. Note that tpu + # and hpu only support Ray for distributed inference. + distributed_executor_backend: Optional[Union[str, + type["ExecutorBase"]]] = None + + # the full name of the worker class to use. If "auto", the worker class + # will be determined based on the platform. + worker_cls: str = "auto" + sd_worker_cls: str = "auto" + worker_extension_cls: str = "" + + # world_size is TPxPP, it affects the number of workers we create. + world_size: int = field(init=False) + # world_size_across_dp is TPxPPxDP, it is the size of the world + # including data parallelism. + world_size_across_dp: int = field(init=False) + + rank: int = 0 + + def get_next_dp_init_port(self) -> int: + """ + We might need to initialize process groups in multiple + processes that is related to data parallelism, + e.g. both in the worker and in the engine, which + can live in different processes. To avoid port conflicts, we + increment the port number each time we need to initialize a + new process group related to data parallelism. + """ + answer = self.data_parallel_master_port + self.data_parallel_master_port += 1 + return answer + + def stateless_init_dp_group(self) -> "ProcessGroup": + from vllm.distributed.utils import ( + stateless_init_torch_distributed_process_group) + + # use gloo since the engine process might not have cuda device + dp_group = stateless_init_torch_distributed_process_group( + self.data_parallel_master_ip, + self.get_next_dp_init_port(), + self.data_parallel_rank, + self.data_parallel_size, + backend="gloo") + + return dp_group + + @staticmethod + def has_unfinished_dp(dp_group: "ProcessGroup", + has_unfinished: bool) -> bool: + tensor = torch.tensor([has_unfinished], + dtype=torch.int32, + device="cpu") + # dp rank 0: has_unfinished_seqs=True + # dp rank 1: has_unfinished_seqs=False + # aggregated: has_unfinished_seqs=True + # so this is an OR operation, i.e. MAX in integers + torch.distributed.all_reduce(tensor, op=ReduceOp.MAX, group=dp_group) + aggregated_has_unfinished = bool(tensor.item()) + return aggregated_has_unfinished + + def compute_hash(self): + """ + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [] + factors.append(self.pipeline_parallel_size) + factors.append(self.tensor_parallel_size) + return hashlib.sha256(str(factors).encode()).hexdigest() + + def __post_init__(self) -> None: + self.world_size = self.pipeline_parallel_size * \ + self.tensor_parallel_size + + if self.data_parallel_size > 1: + # Data parallel was specified in the engine args. + self.data_parallel_master_port = get_open_port() + # TODO multi-node + else: + # Otherwise fall back to env vars (e.g. for offline SPMD case). + self.data_parallel_size = envs.VLLM_DP_SIZE + self.data_parallel_rank = envs.VLLM_DP_RANK + self.data_parallel_rank_local = envs.VLLM_DP_RANK_LOCAL + self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP + self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT + + self.world_size_across_dp = self.world_size * self.data_parallel_size + + if self.distributed_executor_backend == "external_launcher": + import os + os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" + logger.info("Disabling V1 multiprocessing for external launcher.") + + ray_only_devices: list[str] = [] + from vllm.platforms import current_platform + if (current_platform.device_type in ray_only_devices + and self.world_size > 1): + if self.distributed_executor_backend is None: + self.distributed_executor_backend = "ray" + if self.distributed_executor_backend != "ray": + raise ValueError( + f"{current_platform.device_type.upper()} backend only " + "supports Ray for distributed inference.") + + if self.distributed_executor_backend is None and self.world_size > 1: + # We use multiprocessing by default if world_size fits on the + # current node and we aren't in a ray placement group. + + from vllm.executor import ray_utils + backend = "mp" + ray_found = ray_utils.ray_is_available() + if current_platform.is_neuron(): + # neuron uses single process to control multiple devices + backend = "uni" + elif (current_platform.is_cuda() + and cuda_device_count_stateless() < self.world_size): + if not ray_found: + raise ValueError("Unable to load Ray which is " + "required for multi-node inference, " + "please install Ray with `pip install " + "ray`.") from ray_utils.ray_import_err + backend = "ray" + elif ray_found: + if self.placement_group: + backend = "ray" + else: + from ray import is_initialized as ray_is_initialized + if ray_is_initialized(): + from ray.util import get_current_placement_group + if get_current_placement_group(): + backend = "ray" + self.distributed_executor_backend = backend + logger.info("Defaulting to use %s for distributed inference", + backend) + + if self.distributed_executor_backend is None and self.world_size == 1: + self.distributed_executor_backend = "uni" + + self._verify_args() + + @property + def use_ray(self) -> bool: + return self.distributed_executor_backend == "ray" or ( + isinstance(self.distributed_executor_backend, type) + and self.distributed_executor_backend.uses_ray) + + def _verify_args(self) -> None: + # Lazy import to avoid circular import + from vllm.executor.executor_base import ExecutorBase + from vllm.platforms import current_platform + if self.distributed_executor_backend not in ( + "ray", "mp", "uni", + "external_launcher", None) and not (isinstance( + self.distributed_executor_backend, type) and issubclass( + self.distributed_executor_backend, ExecutorBase)): + raise ValueError( + "Unrecognized distributed executor backend " + f"{self.distributed_executor_backend}. Supported " + "values are 'ray', 'mp' 'uni', 'external_launcher' or" + " custom ExecutorBase subclass.") + if self.use_ray: + from vllm.executor import ray_utils + ray_utils.assert_ray_available() + + if not current_platform.use_custom_allreduce(): + self.disable_custom_all_reduce = True + logger.info( + "Disabled the custom all-reduce kernel because it is not " + "supported on current platform.") + if self.ray_workers_use_nsight and not self.use_ray: + raise ValueError("Unable to use nsight profiling unless workers " + "run with Ray.") + + assert isinstance(self.worker_extension_cls, str), ( + "worker_extension_cls must be a string (qualified class name).") + + +@dataclass +class SchedulerConfig: + """Scheduler configuration.""" + + runner_type: str = "generate" # The runner type to launch for the model. + + # Maximum number of tokens to be processed in a single iteration. + max_num_batched_tokens: int = field(default=None) # type: ignore + + # Maximum number of sequences to be processed in a single iteration. + max_num_seqs: int = 128 + + # Maximum length of a sequence (including prompt and generated text). + max_model_len: int = 8192 + + # Maximum number of sequences that can be partially prefilled concurrently + max_num_partial_prefills: int = 1 + + # Maximum number of "very long prompt" sequences that can be prefilled + # concurrently (long is defined by long_prefill_threshold) + max_long_partial_prefills: int = 1 + + # calculate context length that determines which sequences are + # considered "long" + long_prefill_token_threshold: int = 0 + + # The number of slots to allocate per sequence per + # step, beyond the known token ids. This is used in speculative + # decoding to store KV activations of tokens which may or may not be + # accepted. + num_lookahead_slots: int = 0 + + # Apply a delay (of delay factor multiplied by previous + # prompt latency) before scheduling next prompt. + delay_factor: float = 0.0 + + # If True, prefill requests can be chunked based + # on the remaining max_num_batched_tokens. + enable_chunked_prefill: bool = False + + is_multimodal_model: bool = False + + # NOTE: The following multimodal encoder budget will be initialized to + # max_num_batched_tokens and overridden in case max multimodal embedding + # size is larger. + # TODO (ywang96): Make these configurable. + # Multimodal encoder compute budget, only used in V1 + max_num_encoder_input_tokens: int = field(default=None) # type: ignore + + # Multimodal encoder cache size, only used in V1 + encoder_cache_size: int = field(default=None) # type: ignore + + # Whether to perform preemption by swapping or + # recomputation. If not specified, we determine the mode as follows: + # We use recomputation by default since it incurs lower overhead than + # swapping. However, when the sequence group has multiple sequences + # (e.g., beam search), recomputation is not currently supported. In + # such a case, we use swapping instead. + preemption_mode: Optional[str] = None + + num_scheduler_steps: int = 1 + + multi_step_stream_outputs: bool = False + + # Private API. If used, scheduler sends delta data to + # workers instead of an entire data. It should be enabled only + # when SPMD worker architecture is enabled. I.e., + # VLLM_USE_RAY_SPMD_WORKER=1 + send_delta_data: bool = False + + # The scheduling policy to use. "fcfs" (default) or "priority". + policy: str = "fcfs" + + chunked_prefill_enabled: bool = field(init=False) + + # scheduler class or path. "vllm.core.scheduler.Scheduler" (default) + # or "mod.custom_class". + scheduler_cls: Union[str, type[object]] = "vllm.core.scheduler.Scheduler" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self) -> None: + if self.max_num_batched_tokens is None: + if self.enable_chunked_prefill: + if self.num_scheduler_steps > 1: + # Multi-step Chunked-Prefill doesn't allow prompt-chunking + # for now. Have max_num_batched_tokens set to max_model_len + # so we don't reject sequences on account of a short + # max_num_batched_tokens. + self.max_num_batched_tokens = max( + self.max_model_len, _DEFAULT_MAX_NUM_BATCHED_TOKENS) + else: + self.max_num_batched_tokens = ( + _DEFAULT_MAX_NUM_BATCHED_TOKENS) + else: + # If max_model_len is too short, use + # _DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value + # for higher throughput. + self.max_num_batched_tokens = max( + self.max_model_len, _DEFAULT_MAX_NUM_BATCHED_TOKENS) + + if self.runner_type == "pooling": + # Choose specific value for higher throughput + self.max_num_batched_tokens = max( + self.max_num_batched_tokens, + _POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, + ) + if self.is_multimodal_model: + # The value needs to be at least the number of multimodal tokens + self.max_num_batched_tokens = max( + self.max_num_batched_tokens, + _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, + ) + + self.max_num_encoder_input_tokens = self.max_num_batched_tokens + self.encoder_cache_size = self.max_num_batched_tokens + + if self.enable_chunked_prefill: + logger.info( + "Chunked prefill is enabled with max_num_batched_tokens=%d.", + self.max_num_batched_tokens) + + self.chunked_prefill_enabled = self.enable_chunked_prefill + if self.max_num_partial_prefills > 1: + if self.long_prefill_token_threshold == 0: + self.long_prefill_token_threshold = int(self.max_model_len * + 0.04) + + logger.info( + "Concurrent partial prefills enabled with " + "max_num_partial_prefills=%d, max_long_partial_prefills=%d, " + "long_prefill_token_threshold=%d", + self.max_num_partial_prefills, self.max_long_partial_prefills, + self.long_prefill_token_threshold) + + self._verify_args() + + def _verify_args(self) -> None: + if (self.max_num_batched_tokens < self.max_model_len + and not self.chunked_prefill_enabled): + raise ValueError( + f"max_num_batched_tokens ({self.max_num_batched_tokens}) is " + f"smaller than max_model_len ({self.max_model_len}). " + "This effectively limits the maximum sequence length to " + "max_num_batched_tokens and makes vLLM reject longer " + "sequences. Please increase max_num_batched_tokens or " + "decrease max_model_len.") + + if self.max_num_batched_tokens < self.max_num_seqs: + raise ValueError( + f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " + "be greater than or equal to max_num_seqs " + f"({self.max_num_seqs}).") + + if self.num_lookahead_slots < 0: + raise ValueError( + "num_lookahead_slots " + f"({self.num_lookahead_slots}) must be greater than or " + "equal to 0.") + + if self.num_scheduler_steps < 1: + raise ValueError( + "num_scheduler_steps " + f"({self.num_scheduler_steps}) must be greater than or " + "equal to 1.") + + if self.max_num_partial_prefills < 1: + raise ValueError( + f"max_num_partial_prefills ({self.max_num_partial_prefills}) " + "must be greater than or equal to 1.") + elif self.max_num_partial_prefills > 1: + if not self.chunked_prefill_enabled: + raise ValueError("Chunked prefill must be enabled to set " + "max_num_partial_prefills > 1.") + + if self.long_prefill_token_threshold > self.max_model_len: + raise ValueError( + "long_prefill_token_threshold " + f"({self.long_prefill_token_threshold}) cannot be greater " + f"than the max_model_len ({self.max_model_len}).") + + if (self.max_long_partial_prefills + < 1) or (self.max_long_partial_prefills + > self.max_num_partial_prefills): + raise ValueError( + f"max_long_partial_prefills ({self.max_long_partial_prefills}) " + "must be greater than or equal to 1 and less than or equal to " + f"max_num_partial_prefills ({self.max_num_partial_prefills}).") + + @property + def is_multi_step(self) -> bool: + return self.num_scheduler_steps > 1 + + +class DeviceConfig: + device: Optional[torch.device] + device_type: str + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # the device/platform information will be summarized + # by torch/vllm automatically. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __init__(self, device: str = "auto") -> None: + if device == "auto": + # Automated device type detection + from vllm.platforms import current_platform + self.device_type = current_platform.device_type + if not self.device_type: + raise RuntimeError( + "Failed to infer device type, please set " + "the environment variable `VLLM_LOGGING_LEVEL=DEBUG` " + "to turn on verbose logging to help debug the issue.") + else: + # Device type is assigned explicitly + self.device_type = device + + # Some device types require processing inputs on CPU + if self.device_type in ["neuron"]: + self.device = torch.device("cpu") + elif self.device_type in ["tpu"]: + self.device = None + else: + # Set device with device type + self.device = torch.device(self.device_type) + + +@dataclass +class SpeculativeConfig: + """ + Configuration for speculative decoding. + Configurable parameters include: + - General Speculative Decoding Control: + - num_speculative_tokens (int): The number of speculative + tokens, if provided. It will default to the number in the draft + model config if present, otherwise, it is required. + - model (Optional[str]): The name of the draft model, eagle head, + or additional weights, if provided. + - method (Optional[str]): The name of the speculative method to use. + If users provide and set the `model` param, the speculative method + type will be detected automatically if possible, if `model` param + is not provided, the method name must be provided. + - Possible values: + - ngram + Related additional configuration: + - prompt_lookup_max (Optional[int]): + Maximum size of ngram token window when using Ngram + proposer, required when method is set to ngram. + - prompt_lookup_min (Optional[int]): + Minimum size of ngram token window when using Ngram + proposer, if provided. Defaults to 1. + - eagle + - medusa + - mlp_speculator + - draft_model + - acceptance_method (str): The method to use for accepting draft + tokens. This can take two possible values: 'rejection_sampler' and + 'typical_acceptance_sampler' for RejectionSampler and + TypicalAcceptanceSampler respectively. If not specified, it + defaults to 'rejection_sampler'. + - Possible values: + - rejection_sampler + - typical_acceptance_sampler + Related additional configuration: + - posterior_threshold (Optional[float]): + A threshold value that sets a lower bound on the + posterior probability of a token in the target model + for it to be accepted. This threshold is used only + when we use the TypicalAcceptanceSampler for token + acceptance. + - posterior_alpha (Optional[float]): + Scaling factor for entropy-based threshold, applied + when using TypicalAcceptanceSampler. + - draft_tensor_parallel_size (Optional[int]): The degree of the tensor + parallelism for the draft model. Can only be 1 or the same as the + target model's tensor parallel size. + - draft_pipeline_parallel_size (Optional[int]): The degree of the + pipeline parallelism for the draft model. Can only be 1 or the + same as the target model's pipeline parallel size. + - disable_logprobs (bool): If set to True, token log probabilities are + not returned during speculative decoding. If set to False, token + log probabilities are returned according to the log probability + settings in SamplingParams. If not specified, it defaults to True. + + - Draft Model Configuration: + - quantization (Optional[str]): Quantization method that was used to + quantize the draft model weights. If None, we assume the + model weights are not quantized. Note that it only takes effect + when using the draft model-based speculative method. + - max_model_len (Optional[int]): The maximum model length of the + draft model. Used when testing the ability to skip + speculation for some sequences. + - revision: The specific model version to use for the draft model. It + can be a branch name, a tag name, or a commit id. If unspecified, + will use the default version. + - code_revision: The specific revision to use for the draft model code + on Hugging Face Hub. It can be a branch name, a tag name, or a + commit id. If unspecified, will use the default version. + + - Advanced Control: + - disable_mqa_scorer (bool): Disable the MQA scorer and fall back to + batch expansion for scoring proposals. If not specified, it + defaults to False. + - disable_by_batch_size (Optional[int]): Disable speculative decoding + for new incoming requests when the number of enqueued requests is + larger than this value, if provided. + + Although the parameters above are structured hierarchically, there is no + need to nest them during configuration. + + Non-configurable internal parameters include: + - Model Configuration: + - target_model_config (ModelConfig): The configuration of the target + model. + - draft_model_config (ModelConfig): The configuration of the draft + model initialized internal. + - Parallelism Configuration: + - target_parallel_config (ParallelConfig): The parallel configuration + for the target model. + - draft_parallel_config (ParallelConfig): The parallel configuration + for the draft model initialized internal. + - Execution Control: + - enable_chunked_prefill (bool): Whether vLLM is configured to use + chunked prefill or not. Used for raising an error since it's not + yet compatible with speculative decode. + - disable_log_stats (bool): Whether to disable the periodic printing of + stage times in speculative decoding. + """ + # speculative configs from cli args + num_speculative_tokens: int = field(default=None, + init=True) # type: ignore + method: Optional[str] = None + acceptance_method: str = "rejection_sampler" + draft_tensor_parallel_size: Optional[int] = None + draft_pipeline_parallel_size: Optional[int] = None + disable_logprobs: bool = True + + model: Optional[str] = None + quantization: Optional[str] = None + max_model_len: Optional[int] = None + revision: Optional[str] = None + code_revision: Optional[str] = None + + disable_mqa_scorer: bool = False + disable_by_batch_size: Optional[int] = None + prompt_lookup_max: Optional[int] = None + prompt_lookup_min: Optional[int] = None + posterior_threshold: Optional[float] = None + posterior_alpha: Optional[float] = None + + # required configuration params passed from engine + target_model_config: ModelConfig = field(default=None, + init=True) # type: ignore + target_parallel_config: ParallelConfig = field(default=None, + init=True) # type: ignore + enable_chunked_prefill: bool = field(default=None, + init=True) # type: ignore + disable_log_stats: bool = field(default=None, init=True) # type: ignore + + # params generated in the post-init stage + draft_model_config: ModelConfig = field(default=None, + init=True) # type: ignore + draft_parallel_config: ParallelConfig = field(default=None, + init=True) # type: ignore + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # spec decode does not use `torch.compile` yet. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + @classmethod + def from_dict(cls, dict_value: dict) -> "SpeculativeConfig": + """Parse the CLI value for the speculative config.""" + return cls(**dict_value) + + @staticmethod + def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: + if hf_config.model_type == "deepseek_v3": + hf_config.model_type = "deepseek_mtp" + if hf_config.model_type == "deepseek_mtp": + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "n_predict": n_predict, + "architectures": ["DeepSeekMTPModel"] + }) + return hf_config + + def __post_init__(self): + + # Note: "method" is a new parameter that helps to extend the + # configuration of non-model-based proposers, and the "model" parameter + # will be used to set the draft model, eagle head, or additional weight + # when needed. If users do not specify "method", the speculative method + # will be detected automatically if possible. If the speculative method + # can not be detected, it will be considered as the "draft_model" by + # default. + + if self.model is None and self.num_speculative_tokens is not None: + # TODO(Shangming): Refactor mtp configuration logic when supporting + # mtp acceleration for more models besides deepseek_v3 + if self.target_model_config.hf_text_config.model_type \ + == "deepseek_v3": + # use the draft model from the same model: + self.model = self.target_model_config.model + elif self.method in ("ngram", "[ngram]"): + self.model = "ngram" + else: + raise ValueError("num_speculative_tokens was provided without " + "speculative model.") + + # Automatically configure the method for ngram when "model" is used + # instead of "method" + if self.method is None and (self.model is not None + and self.model in ("ngram", "[ngram]")): + self.method = "ngram" + + if self.method in ("ngram", "[ngram]"): + # Unified to "ngram" internally + self.method = "ngram" + # Set default values if not provided + if (self.prompt_lookup_min is None + and self.prompt_lookup_max is None): + # TODO(woosuk): Tune these values. They are arbitrarily chosen. + self.prompt_lookup_min = 5 + self.prompt_lookup_max = 5 + elif self.prompt_lookup_min is None: + assert self.prompt_lookup_max is not None + self.prompt_lookup_min = self.prompt_lookup_max + elif self.prompt_lookup_max is None: + assert self.prompt_lookup_min is not None + self.prompt_lookup_max = self.prompt_lookup_min + + # Validate values + if self.prompt_lookup_min < 1: + raise ValueError( + f"prompt_lookup_min={self.prompt_lookup_min} must be > 0") + if self.prompt_lookup_max < 1: + raise ValueError( + f"prompt_lookup_max={self.prompt_lookup_max} must be > 0") + if self.prompt_lookup_min > self.prompt_lookup_max: + raise ValueError( + f"prompt_lookup_min={self.prompt_lookup_min} must " + f"be <= prompt_lookup_max={self.prompt_lookup_max}") + + # TODO: current we still need extract vocab_size from target model + # config, in future, we may try refactor it out, and set + # draft related config as None here. + self.draft_model_config = self.target_model_config + self.draft_parallel_config = self.target_parallel_config + else: + self.prompt_lookup_max = 0 + self.prompt_lookup_min = 0 + + if self.model is not None: + self.draft_model_config = ModelConfig( + model=self.model, + task="draft", + tokenizer=self.target_model_config.tokenizer, + tokenizer_mode=self.target_model_config.tokenizer_mode, + trust_remote_code=self.target_model_config. + trust_remote_code, + allowed_local_media_path=self.target_model_config. + allowed_local_media_path, + dtype=self.target_model_config.dtype, + seed=self.target_model_config.seed, + revision=self.revision, + code_revision=self.code_revision, + tokenizer_revision=self.target_model_config. + tokenizer_revision, + max_model_len=None, + spec_target_max_model_len=self.target_model_config. + max_model_len, + quantization=self.quantization, + enforce_eager=self.target_model_config.enforce_eager, + max_seq_len_to_capture=self.target_model_config. + max_seq_len_to_capture, + max_logprobs=self.target_model_config.max_logprobs, + hf_overrides=SpeculativeConfig.hf_config_override, + ) + + # Automatically detect the method + if "eagle-" in self.draft_model_config.model.lower(): + self.method = "eagle" + elif self.draft_model_config.hf_config.model_type == "medusa": + self.method = "medusa" + elif (self.draft_model_config.hf_config.model_type == + "mlp_speculator"): + self.method = "mlp_speculator" + else: + self.method = "draft_model" + + # Replace hf_config for EAGLE draft_model + if self.method == "eagle": + if self.enable_chunked_prefill and not envs.VLLM_USE_V1: + raise ValueError( + "Chunked prefill and EAGLE are not compatible " + "when using V0.") + + from vllm.transformers_utils.configs.eagle import ( + EAGLEConfig) + if isinstance(self.draft_model_config.hf_config, + EAGLEConfig): + pass + else: + eagle_config = EAGLEConfig( + self.draft_model_config.hf_config) + self.draft_model_config.hf_config = eagle_config + + if (self.num_speculative_tokens is not None + and hasattr(self.draft_model_config.hf_config, + "num_lookahead_tokens")): + self.draft_model_config.hf_config.num_lookahead_tokens = \ + self.num_speculative_tokens + + n_predict = getattr(self.draft_model_config.hf_config, + "n_predict", None) + if n_predict is not None: + if self.num_speculative_tokens is None: + # Default to max value defined in draft model config. + self.num_speculative_tokens = n_predict + elif self.num_speculative_tokens > n_predict and \ + self.num_speculative_tokens % n_predict != 0: + # Ensure divisibility for MTP module reuse. + raise ValueError( + f"num_speculative_tokens:{self.num_speculative_tokens}" + f" must be divisible by {n_predict=}") + + self.draft_tensor_parallel_size = \ + SpeculativeConfig._verify_and_get_draft_tp( + self.target_parallel_config, + self.draft_tensor_parallel_size, + self.draft_model_config.hf_config + ) + self.draft_pipeline_parallel_size = \ + SpeculativeConfig._verify_and_get_draft_pp( + self.target_parallel_config, + self.draft_pipeline_parallel_size, + self.draft_model_config.hf_config + ) + + self.draft_model_config.max_model_len = ( + SpeculativeConfig._maybe_override_draft_max_model_len( + self.max_model_len, + self.draft_model_config.max_model_len, + self.target_model_config.max_model_len, + )) + + self.draft_parallel_config = ( + SpeculativeConfig.create_draft_parallel_config( + self.target_parallel_config, + self.draft_tensor_parallel_size, + self.draft_pipeline_parallel_size)) + + if self.acceptance_method == "typical_acceptance_sampler": + if self.posterior_threshold is None: + self.posterior_threshold = 0.09 + if self.posterior_alpha is None: + self.posterior_alpha = 0.3 + + self._verify_args() + + @staticmethod + def _maybe_override_draft_max_model_len( + speculative_max_model_len: Optional[int], + draft_max_model_len: int, + target_max_model_len: int, + ) -> int: + """Determine the max sequence len for the draft model. This is usually + the draft_max_model_len, but may be the target_max_model_len if it is + less than the draft_max_model_len, or may be speculative_max_model_len + if it is specified. + + This is necessary so that sequences do not exceed the capacity of the + draft model or the target model. + + speculative_max_model_len is mainly used for testing that sequences can + skip speculation. + """ + + if speculative_max_model_len is not None: + + if speculative_max_model_len > draft_max_model_len: + raise ValueError(f"{speculative_max_model_len=} cannot be " + f"larger than {draft_max_model_len=}") + + if speculative_max_model_len > target_max_model_len: + raise ValueError(f"{speculative_max_model_len=} cannot be " + f"larger than {target_max_model_len=}") + + return speculative_max_model_len + + return min( + draft_max_model_len, + target_max_model_len, + ) + + @staticmethod + def _verify_and_get_draft_tp( + target_parallel_config: ParallelConfig, + speculative_draft_tensor_parallel_size: Optional[int], + draft_hf_config: PretrainedConfig) -> int: + """ + Verifies and adjusts the tensor parallel size for a draft model + specified using speculative_draft_tensor_parallel_size. + """ + # If speculative_draft_tensor_parallel_size is unset then set it + # appropriately else verify that it is set correctly. + if speculative_draft_tensor_parallel_size is None: + if draft_hf_config.model_type == "mlp_speculator": + speculative_draft_tensor_parallel_size = 1 + if target_parallel_config.tensor_parallel_size > 1: + logger.warning( + "%s cannot currently be run with tp>1; " + "setting speculative_draft_tensor_parallel_size=1", + draft_hf_config.model_type) + else: + speculative_draft_tensor_parallel_size = \ + target_parallel_config.tensor_parallel_size + elif speculative_draft_tensor_parallel_size not in ( + 1, target_parallel_config.tensor_parallel_size): + raise ValueError( + f"{speculative_draft_tensor_parallel_size=} cannot be " + f"other value than 1 or target model tensor_parallel_size") + return speculative_draft_tensor_parallel_size + + @staticmethod + def _verify_and_get_draft_pp( + target_parallel_config: ParallelConfig, + speculative_draft_pipeline_parallel_size: Optional[int], + draft_hf_config: PretrainedConfig) -> int: + """ + Verifies and adjusts the tensor parallel size for a draft model + specified using speculative_draft_pipeline_parallel_size. + """ + # If speculative_draft_pipeline_parallel_size is unset then set it + # appropriately else verify that it is set correctly. + if speculative_draft_pipeline_parallel_size is None: + speculative_draft_pipeline_parallel_size = \ + target_parallel_config.pipeline_parallel_size + elif speculative_draft_pipeline_parallel_size not in ( + 1, target_parallel_config.pipeline_parallel_size): + raise ValueError( + f"{speculative_draft_pipeline_parallel_size=} cannot be " + f"other value than 1 or target model pipeline_parallel_size") + return speculative_draft_pipeline_parallel_size + + @staticmethod + def create_draft_parallel_config( + target_parallel_config: ParallelConfig, + speculative_draft_tensor_parallel_size: int, + speculative_draft_pipeline_parallel_size: int, + ) -> ParallelConfig: + """Create a parallel config for use by the draft worker. + + This is mostly a copy of the target parallel config, except the tp_size. + """ + draft_parallel_config = ParallelConfig( + pipeline_parallel_size=speculative_draft_pipeline_parallel_size, + tensor_parallel_size=speculative_draft_tensor_parallel_size, + num_virtual_engine=target_parallel_config.num_virtual_engine, + distributed_executor_backend=target_parallel_config. + distributed_executor_backend, + max_parallel_loading_workers=target_parallel_config. + max_parallel_loading_workers, + disable_custom_all_reduce=target_parallel_config. + disable_custom_all_reduce, + tokenizer_pool_config=target_parallel_config.tokenizer_pool_config, + ray_workers_use_nsight=target_parallel_config. + ray_workers_use_nsight, + placement_group=target_parallel_config.placement_group, + ) + + return draft_parallel_config + + def _verify_args(self) -> None: + if self.num_speculative_tokens is None: + raise ValueError( + "num_speculative_tokens must be provided with " + "speculative model unless the draft model config contains an " + "n_predict parameter.") + + if self.num_speculative_tokens <= 0: + raise ValueError("Expected num_speculative_tokens to be greater " + f"than zero ({self.num_speculative_tokens}).") + + if self.draft_model_config: + self.draft_model_config.verify_with_parallel_config( + self.draft_parallel_config) + # Validate and set draft token acceptance related settings. + + if self.acceptance_method is None: + raise ValueError("acceptance_method is not set. " + "Expected values are rejection_sampler or " + "typical_acceptance_sampler.") + + if (self.acceptance_method != 'rejection_sampler' + and self.acceptance_method != 'typical_acceptance_sampler'): + raise ValueError( + "Expected acceptance_method to be either " + "rejection_sampler or typical_acceptance_sampler. Instead it " + f"is {self.acceptance_method}") + + if self.acceptance_method == "typical_acceptance_sampler" and ( + (self.posterior_threshold is not None + and self.posterior_threshold < 0) or + (self.posterior_alpha is not None and self.posterior_alpha < 0)): + raise ValueError( + "Expected the posterior_threshold and posterior_alpha of " + "typical_acceptance_sampler to be > 0. " + "Instead found posterior_threshold = " + f"{self.posterior_threshold} and posterior_alpha = " + f"{self.posterior_alpha}") + + if (self.disable_by_batch_size is not None + and self.disable_by_batch_size < 2): + raise ValueError("Expect the batch size threshold of disabling " + "speculative decoding is > 1, but got " + f"{self.disable_by_batch_size=}") + + @property + def num_lookahead_slots(self) -> int: + """The number of additional slots the scheduler should allocate per + step, in addition to the slots allocated for each known token. + + This is equal to the number of speculative tokens, as each speculative + token must be scored. + """ + return self.num_speculative_tokens + + def __repr__(self) -> str: + method = self.method + model = None if method == "ngram" else self.draft_model_config.model + num_spec_tokens = self.num_speculative_tokens + return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})" + + +@dataclass +class LoRAConfig: + max_lora_rank: int + max_loras: int + fully_sharded_loras: bool = False + max_cpu_loras: Optional[int] = None + lora_dtype: Optional[Union[torch.dtype, str]] = None + lora_extra_vocab_size: int = 256 + # This is a constant. + lora_vocab_padding_size: ClassVar[int] = 256 + long_lora_scaling_factors: Optional[tuple[float]] = None + bias_enabled: bool = False + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [] + factors.append(self.max_lora_rank) + factors.append(self.max_loras) + factors.append(self.fully_sharded_loras) + factors.append(self.lora_dtype) + factors.append(self.lora_extra_vocab_size) + factors.append(self.long_lora_scaling_factors) + factors.append(self.bias_enabled) + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self): + # Setting the maximum rank to 512 should be able to satisfy the vast + # majority of applications. + possible_max_ranks = (8, 16, 32, 64, 128, 256, 320, 512) + possible_lora_extra_vocab_size = (256, 512) + if self.max_lora_rank not in possible_max_ranks: + raise ValueError( + f"max_lora_rank ({self.max_lora_rank}) must be one of " + f"{possible_max_ranks}.") + if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size: + raise ValueError( + f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) " + f"must be one of {possible_lora_extra_vocab_size}.") + if self.max_loras < 1: + raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.") + if self.max_cpu_loras is None: + self.max_cpu_loras = self.max_loras + elif self.max_cpu_loras < self.max_loras: + raise ValueError( + f"max_cpu_loras ({self.max_cpu_loras}) must be >= " + f"max_loras ({self.max_loras})") + + def verify_with_cache_config(self, cache_config: CacheConfig): + if cache_config.cpu_offload_gb > 0 and not envs.VLLM_USE_V1: + raise ValueError( + "V0 LoRA does not support CPU offload, please use V1.") + + def verify_with_model_config(self, model_config: ModelConfig): + if self.lora_dtype in (None, "auto"): + self.lora_dtype = model_config.dtype + elif isinstance(self.lora_dtype, str): + self.lora_dtype = getattr(torch, self.lora_dtype) + + def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): + # Reminder: Please update docs/source/features/compatibility_matrix.md + # If the feature combo become valid + if scheduler_config.chunked_prefill_enabled: + logger.warning("LoRA with chunked prefill is still experimental " + "and may be unstable.") + + +@dataclass +class PromptAdapterConfig: + max_prompt_adapters: int + max_prompt_adapter_token: int + max_cpu_prompt_adapters: Optional[int] = None + prompt_adapter_dtype: Optional[torch.dtype] = None + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self): + + if self.max_prompt_adapters < 1: + raise ValueError(f"max_prompt_adapters " + f"({self.max_prompt_adapters}) must be >= 1.") + if self.max_prompt_adapter_token == 0: + raise ValueError("max_prompt_adapter_token must be set.") + if self.max_cpu_prompt_adapters is None: + self.max_cpu_prompt_adapters = self.max_prompt_adapters + + def verify_with_model_config(self, model_config: ModelConfig): + if self.prompt_adapter_dtype in (None, "auto"): + self.prompt_adapter_dtype = model_config.dtype + elif isinstance(self.prompt_adapter_dtype, str): + self.prompt_adapter_dtype = getattr(torch, + self.prompt_adapter_dtype) + + +@dataclass +class MultiModalConfig: + """Controls the behavior of multimodal models.""" + + limit_per_prompt: Mapping[str, int] = field(default_factory=dict) + """ + The maximum number of input items allowed per prompt for each modality. + """ + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def get_limit_per_prompt(self, modality: str) -> int: + """ + Get the maximum number of input items allowed per prompt + for the given modality. + + If not set by the user, this defaults to `1`. + """ + return self.limit_per_prompt.get(modality, 1) + + # TODO: Add configs to init vision tower or not. + + +@dataclass +class PoolerConfig: + """Controls the behavior of output pooling in pooling models.""" + + pooling_type: Optional[str] = None + """ + The pooling method of the pooling model. This should be a key in + :class:`vllm.model_executor.layers.pooler.PoolingType`. + """ + + normalize: Optional[bool] = None + """ + Whether to normalize the pooled outputs. Usually, this should be set to + ``True`` for embedding outputs. + """ + + softmax: Optional[bool] = None + """ + Whether to apply softmax to the pooled outputs. Usually, this should be set + to ``True`` for classification outputs. + """ + + step_tag_id: Optional[int] = None + """ + If set, only the score corresponding to the ``step_tag_id`` in the + generated sentence should be returned. Otherwise, the scores for all tokens + are returned. + """ + + returned_token_ids: Optional[list[int]] = None + """ + A list of indices for the vocabulary dimensions to be extracted, + such as the token IDs of ``good_token`` and ``bad_token`` in the + ``math-shepherd-mistral-7b-prm`` model. + """ + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + @staticmethod + def from_json(json_str: str) -> "PoolerConfig": + return PoolerConfig(**json.loads(json_str)) + + +_STR_DTYPE_TO_TORCH_DTYPE = { + "half": torch.float16, + "float16": torch.float16, + "float": torch.float32, + "float32": torch.float32, + "bfloat16": torch.bfloat16, +} + +_ROCM_NOT_SUPPORTED_DTYPE: list[str] = [] # + + +def _get_and_verify_dtype( + config: PretrainedConfig, + dtype: Union[str, torch.dtype], +) -> torch.dtype: + # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct + # because config.torch_dtype can be None. + config_dtype = getattr(config, "torch_dtype", None) + + # Fallbacks for multi-modal models if the root config + # does not define torch_dtype + if config_dtype is None and hasattr(config, "text_config"): + config_dtype = getattr(config.text_config, "torch_dtype", None) + if config_dtype is None and hasattr(config, "vision_config"): + config_dtype = getattr(config.vision_config, "torch_dtype", None) + + if config_dtype is None: + config_dtype = torch.float32 + + if isinstance(dtype, str): + dtype = dtype.lower() + if dtype == "auto": + if config_dtype == torch.float32: + # Following common practice, we use float16 for float32 models + torch_dtype = torch.float16 + else: + torch_dtype = config_dtype + + from vllm.platforms import current_platform + if (current_platform.is_cpu() + and current_platform.get_cpu_architecture() + == CpuArchEnum.POWERPC + and (config_dtype == torch.float16 + or config_dtype == torch.float32)): + logger.info( + "For POWERPC, we cast models to bfloat16 instead of " + "using float16 by default. Float16 is not currently " + "supported for POWERPC.") + torch_dtype = torch.bfloat16 + + # TODO: change this condition to check if the platform support bf16 + # instead of checking the OS. For instance M2 shall supports bf16 + # already. But we need to modify `cpu_extension.cmake` to activate + # the feature in the build. + if (current_platform.is_cpu() and sys.platform.startswith("darwin") + and current_platform.get_cpu_architecture() + == CpuArchEnum.ARM and config_dtype == torch.bfloat16): + logger.info("For macOS with Apple Silicon, currently bfloat16 " + "is not supported. Setting dtype to float16.") + torch_dtype = torch.float16 + + if current_platform.is_hpu() and config_dtype == torch.float16: + logger.info( + "For HPU, we cast models to bfloat16 instead of " + "using float16 by default. Please specify `dtype` if you " + "want to use float16.") + torch_dtype = torch.bfloat16 + else: + if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: + raise ValueError(f"Unknown dtype: {dtype}") + torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] + elif isinstance(dtype, torch.dtype): + torch_dtype = dtype + else: + raise ValueError(f"Unknown dtype: {dtype}") + + # Verify the dtype. + if torch_dtype != config_dtype: + if torch_dtype == torch.float32: + # Upcasting to float32 is allowed. + logger.info("Upcasting %s to %s.", config_dtype, torch_dtype) + pass + elif config_dtype == torch.float32: + # Downcasting from float32 to float16 or bfloat16 is allowed. + logger.info("Downcasting %s to %s.", config_dtype, torch_dtype) + pass + else: + # Casting between float16 and bfloat16 is allowed with a warning. + logger.warning("Casting %s to %s.", config_dtype, torch_dtype) + + return torch_dtype + + +def _get_and_verify_max_len( + hf_config: PretrainedConfig, + max_model_len: Optional[int], + disable_sliding_window: bool, + sliding_window_len: Optional[Union[int, list[Optional[int]]]], + spec_target_max_model_len: Optional[int] = None, + encoder_config: Optional[Any] = None, +) -> int: + """Get and verify the model's maximum length.""" + derived_max_model_len = float("inf") + possible_keys = [ + # OPT + "max_position_embeddings", + # GPT-2 + "n_positions", + # MPT + "max_seq_len", + # ChatGLM2 + "seq_length", + # Command-R + "model_max_length", + # Whisper + "max_target_positions", + # Others + "max_sequence_length", + "max_seq_length", + "seq_len", + ] + # Choose the smallest "max_length" from the possible keys. + max_len_key = None + for key in possible_keys: + max_len = getattr(hf_config, key, None) + if max_len is not None: + max_len_key = key if max_len < derived_max_model_len \ + else max_len_key + derived_max_model_len = min(derived_max_model_len, max_len) + # For Command-R / Cohere, Cohere2 / Aya Vision models + if tmp_max_len := getattr(hf_config, "model_max_length", None): + max_len_key = "model_max_length" + derived_max_model_len = tmp_max_len + + # If sliding window is manually disabled, max_length should be less + # than the sliding window length in the model config. + if disable_sliding_window and sliding_window_len is not None: + + sliding_window_len_min = get_min_sliding_window(sliding_window_len) + max_len_key = "sliding_window" \ + if sliding_window_len_min < derived_max_model_len else max_len_key + derived_max_model_len = min(derived_max_model_len, + sliding_window_len_min) + + # If none of the keys were found in the config, use a default and + # log a warning. + if derived_max_model_len == float("inf"): + if max_model_len is not None: + # If max_model_len is specified, we use it. + return max_model_len + + if spec_target_max_model_len is not None: + # If this is a speculative draft model, we use the max model len + # from the target model. + return spec_target_max_model_len + + default_max_len = 2048 + logger.warning( + "The model's config.json does not contain any of the following " + "keys to determine the original maximum length of the model: " + "%s. Assuming the model's maximum length is %d.", possible_keys, + default_max_len) + derived_max_model_len = default_max_len + + rope_scaling = getattr(hf_config, "rope_scaling", None) + # NOTE(woosuk): Gemma3's max_model_len (128K) is already scaled by RoPE + # scaling, so we skip applying the scaling factor again. + if rope_scaling is not None and "gemma3" not in hf_config.model_type: + # No need to consider "type" key because of patch_rope_scaling when + # loading HF config + rope_type = rope_scaling["rope_type"] + + if rope_type not in ("su", "longrope", "llama3"): + if disable_sliding_window: + # TODO(robertgshaw): Find a model that supports rope_scaling + # with sliding window to see if this case should be allowed. + raise NotImplementedError( + "Disabling sliding window is not supported for models " + "with rope_scaling. Please raise an issue so we can " + "investigate.") + + # NOTE: rope_type == "default" does not define factor + # https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/modeling_rope_utils.py + scaling_factor = rope_scaling.get("factor", 1.0) + + if rope_type == "yarn": + derived_max_model_len = rope_scaling[ + "original_max_position_embeddings"] + derived_max_model_len *= scaling_factor + + if encoder_config and "max_seq_length" in encoder_config: + derived_max_model_len = encoder_config["max_seq_length"] + + # If the user specified a max length, make sure it is smaller than the + # derived length from the HF model config. + if max_model_len is None: + max_model_len = int(derived_max_model_len) + elif max_model_len > derived_max_model_len: + # Some models might have a separate key for specifying model_max_length + # that will be bigger than derived_max_model_len. We compare user input + # with model_max_length and allow this override when it's smaller. + model_max_length = getattr(hf_config, "model_max_length", None) + if model_max_length is not None and max_model_len <= model_max_length: + if disable_sliding_window: + # TODO(robertgshaw): Find a model that has model_max_length + # with sliding window to see if this case should be allowed. + raise NotImplementedError( + "Disabling sliding window is not supported for models " + "model_max_length in the config. Please raise an issue " + "so we can investigate.") + else: + msg = ( + f"User-specified max_model_len ({max_model_len}) is greater " + f"than the derived max_model_len ({max_len_key}=" + f"{derived_max_model_len} or model_max_length=" + f"{model_max_length} in model's config.json). This may lead " + "to incorrect model outputs or CUDA errors.") + if envs.VLLM_ALLOW_LONG_MAX_MODEL_LEN: + logger.warning( + "%s Make sure the value is correct and within the " + "model context size.", msg) + else: + raise ValueError( + f"{msg} To allow overriding this maximum, set " + "the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN=1") + return int(max_model_len) + + +def get_min_sliding_window( + sliding_window: Union[int, list[Optional[int]]]) -> int: + if isinstance(sliding_window, list): + return min(s for s in sliding_window if s is not None) + + return sliding_window + + +def get_served_model_name(model: str, + served_model_name: Optional[Union[str, list[str]]]): + """ + If the input is a non-empty list, the first model_name in + `served_model_name` is taken. + If the input is a non-empty string, it is used directly. + For cases where the input is either an empty string or an + empty list, the fallback is to use `self.model`. + """ + if not served_model_name: + return model + if isinstance(served_model_name, list): + return served_model_name[0] + return served_model_name + + +@dataclass +class DecodingConfig: + """Dataclass which contains the decoding strategy of the engine""" + + # Which guided decoding algo to use. + # 'outlines' / 'lm-format-enforcer' / 'xgrammar' + guided_decoding_backend: str = 'xgrammar' + + reasoning_backend: Optional[str] = None + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self): + v0_valid_guided_backends = [ + 'outlines', 'lm-format-enforcer', 'xgrammar' + ] + v1_valid_guided_backends = ['xgrammar', 'guidance', 'auto'] + + backend = GuidedDecodingParams( + backend=self.guided_decoding_backend).backend_name + if envs.VLLM_USE_V1: + valid_guided_backends = v1_valid_guided_backends + else: + valid_guided_backends = v0_valid_guided_backends + if backend not in valid_guided_backends: + raise ValueError(f"Invalid guided_decoding_backend '{backend}'," + f" must be one of {valid_guided_backends}") + + +@dataclass +class ObservabilityConfig: + """Configuration for observability - metrics and tracing.""" + show_hidden_metrics: bool = False + + otlp_traces_endpoint: Optional[str] = None + + # Collecting detailed timing information for each request can be expensive. + + # If set, collects the model forward time for the request. + collect_model_forward_time: bool = False + + # If set, collects the model execute time for the request. + collect_model_execute_time: bool = False + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self): + if not is_otel_available() and self.otlp_traces_endpoint is not None: + raise ValueError( + "OpenTelemetry is not available. Unable to configure " + "'otlp_traces_endpoint'. Ensure OpenTelemetry packages are " + f"installed. Original error:\n{otel_import_error_traceback}") + + +class KVTransferConfig(BaseModel): + """Configuration for distributed KV cache transfer.""" + + # The KV connector for vLLM to transmit KV caches between vLLM instances. + kv_connector: Optional[str] = None + + # The device used by kv connector to buffer the KV cache. + # Currently only support 'cuda'. + kv_buffer_device: Optional[str] = "cuda" + + # The buffer size for TorchDistributedConnector. Measured in number of + # bytes. Recommended value: 1e9 (about 1GB). + kv_buffer_size: float = 1e9 + + # Whether this vLLM instance produces, consumes KV cache, or both. Choices + # are 'kv_producer', 'kv_consumer', and 'both'. + kv_role: Optional[str] = None + + # The rank of this vLLM instance in the KV cache transfer. Typical value: + # 0 for prefill instance, 1 for decode instance. + # Currently only 1P1D is supported. + kv_rank: Optional[int] = None + + # The number of parallel instances for KV cache transfer. For + # PyNcclConnector, this should be 2. + kv_parallel_size: int = 1 + + # The KV connector ip, used to build distributed connection + kv_ip: str = "127.0.0.1" + + # The KV connector port, used to build distributed connection + kv_port: int = 14579 + + # any extra config that the connector may need + kv_connector_extra_config: dict[str, Any] = {} + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + @classmethod + def from_cli(cls, cli_value: str) -> "KVTransferConfig": + """Parse the CLI value for the kv cache transfer config.""" + return KVTransferConfig.model_validate_json(cli_value) + + def model_post_init(self, __context: Any) -> None: + + if self.kv_role is not None and self.kv_role not in [ + "kv_producer", "kv_consumer", "kv_both" + ]: + raise ValueError( + f"Unsupported kv_role: {self.kv_role}. " + f"Supported roles are `kv_producer`, `kv_consumer`, " + f"and `kv_both`") + + if self.kv_connector is not None and self.kv_role is None: + raise ValueError("Please specify kv_disagg_role when kv_connector " + "is set, supported roles are `kv_producer`, " + "`kv_consumer`, and `kv_both`") + + @property + def is_kv_transfer_instance(self) -> bool: + return self.kv_connector is not None and \ + self.kv_role in ["kv_producer", "kv_consumer", "kv_both"] + + @property + def is_kv_producer(self) -> bool: + return self.kv_connector is not None and \ + self.kv_role in ["kv_producer", "kv_both"] + + @property + def is_kv_consumer(self) -> bool: + return self.kv_connector is not None and \ + self.kv_role in ["kv_consumer", "kv_both"] + + def get_from_extra_config(self, key, default) -> Any: + return self.kv_connector_extra_config.get(key, default) + + +class CompilationLevel: + # constants for the levels of the compilation process + NO_COMPILATION = 0 + DYNAMO_AS_IS = 1 + DYNAMO_ONCE = 2 + PIECEWISE = 3 + + +class CompilationConfig(BaseModel): + """ + Configuration for compilation. + It has three parts: + - Top-level Compilation control: + - level: the level of compilation. + - 0: no compilation. + - 1: dynamo as is. + - 2: dynamo once. + - 3: piecewise compilation. + - debug_dump_path: the path to dump the debug information. + - cache_dir: the directory to store the compiled graph, to + accelerate Inductor compilation. By default, it will use + model-related information to generate a cache directory. + - backend: the backend for compilation. It needs to be a string. + - "" (empty string): use the default backend. + - "eager"/"openxla"/...: use the specified backend registered in PyTorch. + - "full.module.name": a qualified name which can be used to import the backend function. + We use string to avoid serialization issues when using compilation in a distributed setting. + When the compilation level is 1 or 2, the backend is used for the compilation directly (it sees the whole graph). + When the compilation level is 3, the backend is used for the piecewise compilation (it sees a part of the graph). + - custom_ops: fine-grained control over which custom ops to enable/disable. + Use 'all' to enable all, 'none' to disable all. + Also specify a list of custom op names to enable (prefixed with a '+'), + or disable (prefixed with a '-'). + Examples: + - 'all,-op1' to enable all except op1 + - 'none,+op1,+op2' to enable only op1 and op2 + By default, all custom ops are enabled when running without Inductor + and disabled when running with Inductor (compile_level >= Inductor). + - splitting_ops: a list of ops to split the full graph into subgraphs, used in piecewise compilation. + - CudaGraph capture: + - use_cudagraph: whether to use cudagraph inside compilation. + - False: cudagraph inside compilation is not used. + - True: cudagraph inside compilation is used. It requires + that all input buffers have fixed addresses, and all + splitting ops write their outputs to input buffers. + Note that this is orthogonal to the cudagraph capture logic + outside of compilation. + TODO: move outside cudagraph logic into compilation. + torch.compile will handle cudagraph capture logic in the future. + - cudagraph_capture_sizes: sizes to capture cudagraph. + - None (default): capture sizes are inferred from vllm config. + - list[int]: capture sizes are specified as given. + - cudagraph_num_of_warmups: number of warmup runs for cudagraph. + It means the first several runs will be treated as warmup runs. + Only after that, the execution will be recorded, and the recorded + cudagraph will be used for subsequent runs. + - cudagraph_copy_inputs: whether to copy input tensors for + cudagraph. If the caller can guarantee that the same input buffers + are always used, it can set this to False. Otherwise, it should + set this to True, and the compiler will copy the input to an + internally managed buffer. Default is False. + - Inductor compilation: + - use_inductor: whether to use inductor compilation. + - False: inductor compilation is not used. graph runs in eager. + - True: inductor compilation is used. one graph for symbolic shape + is compiled. In addition, compile for compile_sizes, + using configurations in inductor_compile_config. + - compile_sizes: sizes to compile for inductor. In addition + to integers, it also supports "cudagraph_capture_sizes" to + specify the sizes for cudagraph capture. + - inductor_compile_config: additional configurations for inductor. + - None: use default configurations. + - inductor_passes: additional passes for inductor. It is a dictionary + from pass name to pass function qualified name. We use function + name because the config uses json format. If we pass the config + from Python, functions can also be passed directly via Python object + constructor, e.g. `CompilationConfig(inductor_passes={"a": func})` + - custom inductor passes: see PassConfig for more details + + Why we have different sizes for cudagraph and inductor: + - cudagraph: a cudagraph captured for a specific size can only be used + for the same size. We need to capture all the sizes we want to use. + - inductor: a graph compiled by inductor for a general shape can be used + for different sizes. Inductor can also compile for specific sizes, + where it can have more information to optimize the graph with fully + static shapes. However, we find the general shape compilation is + sufficient for most cases. It might be beneficial to compile for + certain small batchsizes, where inductor is good at optimizing. + """ # noqa + level: int = 0 + debug_dump_path: str = "" + cache_dir: str = "" + backend: str = "" + custom_ops: list[str] = Field(default_factory=list) + splitting_ops: list[str] = Field(default=None) # type: ignore + + use_inductor: bool = True + compile_sizes: Optional[list[Union[int, str]]] = Field(default=None) + inductor_compile_config: dict = Field(default_factory=dict) + inductor_passes: dict[str, str] = Field(default_factory=dict) + + use_cudagraph: bool = False + cudagraph_num_of_warmups: int = 0 + cudagraph_capture_sizes: Optional[list[int]] = None + cudagraph_copy_inputs: bool = False + + class PassConfig(BaseModel): + """ + Configuration for custom Inductor passes. + This is separate from general CompilationConfig so that inductor passes + don't all have access to full configuration - that would create a cycle + as the PassManager is set as a property of config. + - dump_graph_stages: list of stages for which we want to dump the graph. + Each pass defines its own stages (before, after, maybe in-between). + - dump_graph_dir: directory to dump the graphs. Default is . + - enable_fusion: whether to enable the custom fusion pass. + - enable_noop: whether to enable the custom no-op elimination pass. + TODO(luka) better pass enabling system. + """ + dump_graph_stages: list[str] = Field(default_factory=list) + dump_graph_dir: Path = Field(default=Path(".")) + enable_fusion: bool = True + enable_noop: bool = True + + def uuid(self): + """ + Produces a hash unique to the pass configuration. + Any new fields that affect compilation should be added to the hash. + Do not include dump_graph_* in the hash - they don't affect + compilation. + """ + dict_ = self.model_dump(include={"enable_fusion", "enable_noop"}) + return InductorPass.hash_dict(dict_) + + def model_post_init(self, __context: Any) -> None: + if not self.enable_noop and self.enable_fusion: + logger.warning_once( + "Fusion enabled but reshape elimination disabled. " + "RMSNorm + quant (fp8) fusion might not work") + + pass_config: PassConfig = Field(default_factory=PassConfig) + + # not configurable, computed after init + max_capture_size: int = PrivateAttr + local_cache_dir: str = PrivateAttr # local cache dir for each rank + # optimization: + # Intuitively, bs_to_padded_graph_size should be dict[int, int]. + # since we know all keys are in a range [0, max_capture_size], + # we can optimize it to list[int] for better lookup performance. + bs_to_padded_graph_size: list[int] = PrivateAttr + + # keep track of enabled and disabled custom ops + enabled_custom_ops: Counter[str] = PrivateAttr + disabled_custom_ops: Counter[str] = PrivateAttr + traced_files: set[str] = PrivateAttr + compilation_time: float = PrivateAttr + + # Per-model forward context + # Map from layer name to the attention cls + static_forward_context: dict[str, Any] = PrivateAttr + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [] + factors.append(self.level) + factors.append(self.backend) + factors.append(self.custom_ops) + factors.append(self.splitting_ops) + factors.append(self.use_inductor) + factors.append(self.inductor_compile_config) + factors.append(self.inductor_passes) + factors.append(self.pass_config.uuid()) + return hashlib.sha256(str(factors).encode()).hexdigest() + + def __repr__(self) -> str: + exclude = { + "static_forward_context", + "enabled_custom_ops", + "disabled_custom_ops", + "compilation_time", + "bs_to_padded_graph_size", + "pass_config", + "traced_files", + } + return self.model_dump_json(exclude=exclude, exclude_unset=True) + + __str__ = __repr__ + + @classmethod + def from_cli(cls, cli_value: str) -> "CompilationConfig": + """Parse the CLI value for the compilation config.""" + if cli_value in ["0", "1", "2", "3"]: + return cls(level=int(cli_value)) + # do not use `eval`, it is dangerous and can execute arbitrary code + dict_value = ast.literal_eval(cli_value) + return CompilationConfig.model_validate(dict_value) + + def model_post_init(self, __context: Any) -> None: + + count_none = self.custom_ops.count("none") + count_all = self.custom_ops.count("all") + assert count_none + count_all <= 1, "Can only specify 'none' or 'all'" + + # TODO(zou3519/luka): There are 2 issues with auto-functionalization V2: + # 1. A bug in PyTorch, fixed in 2.7: + # https://github.com/pytorch/pytorch/issues/147924 + # 2. Custom passes (fusion) rely on auto-functionalization V1 and don't + # work with V2. Addressing this will take extra engineering effort + # and it is not yet a priority. RFC here: + # https://github.com/vllm-project/vllm/issues/14703 + + if Version(importlib.metadata.version('torch')) >= Version("2.6"): + KEY = 'enable_auto_functionalized_v2' + if KEY not in self.inductor_compile_config: + self.inductor_compile_config[KEY] = False + + if self.splitting_ops is None: + self.splitting_ops = [] + + for k, v in self.inductor_passes.items(): + if not isinstance(v, str): + assert callable(v), ( + f"pass {k} should be callable or a qualified name") + self.inductor_compile_config[k] = v if isinstance( + v, InductorPass) else CallableInductorPass(v) + continue + + # resolve function from qualified name + names = v.split(".") + module = ".".join(names[:-1]) + func_name = names[-1] + func = __import__(module).__dict__[func_name] + self.inductor_compile_config[k] = func if isinstance( + func, InductorPass) else CallableInductorPass(func) + + self.enabled_custom_ops = Counter() + self.disabled_custom_ops = Counter() + self.traced_files = set() + self.static_forward_context = {} + self.compilation_time = 0.0 + + def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: + if self.level == CompilationLevel.NO_COMPILATION: + raise ValueError("No compilation level is set.") + + from torch._dynamo.backends.registry import list_backends + torch_backends = list_backends(exclude_tags=tuple()) + if self.level in [ + CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE + ]: + if self.backend == "": + return "eager" + if self.backend in torch_backends: + return self.backend + return resolve_obj_by_qualname(self.backend) + + # TODO: pass user-specified backend to piecewise compilation + # merge with the config use_inductor + assert self.level == CompilationLevel.PIECEWISE + + from vllm.compilation.backends import VllmBackend + return VllmBackend(vllm_config) + + def init_with_cudagraph_sizes(self, + cudagraph_capture_sizes: list[int]) -> None: + """To complete the initialization of config, + we need to know the cudagraph sizes.""" + + if self.cudagraph_capture_sizes is None: + self.cudagraph_capture_sizes = cudagraph_capture_sizes + else: + # de-duplicate the sizes provided by the config + self.cudagraph_capture_sizes = list( + set(self.cudagraph_capture_sizes)) + logger.info(("cudagraph sizes specified by model runner" + " %s is overridden by config %s"), + cudagraph_capture_sizes, self.cudagraph_capture_sizes) + + computed_compile_sizes = [] + if self.compile_sizes is not None: + # de-duplicate the sizes provided by the config + self.compile_sizes = list(set(self.compile_sizes)) + for x in self.compile_sizes: + if isinstance(x, str): + assert x == "cudagraph_capture_sizes", \ + "Unrecognized size type in compile_sizes, " \ + f"expect 'cudagraph_capture_sizes', got {x}" + computed_compile_sizes.extend(self.cudagraph_capture_sizes) + else: + assert isinstance(x, int) + computed_compile_sizes.append(x) + self.compile_sizes = computed_compile_sizes # type: ignore + + # sort to make sure cudagraph capture sizes are in descending order + self.cudagraph_capture_sizes.sort(reverse=True) + self.max_capture_size = self.cudagraph_capture_sizes[ + 0] if self.cudagraph_capture_sizes else 0 + + # pre-compute the mapping from batch size to padded graph size + self.bs_to_padded_graph_size = [ + 0 for i in range(self.max_capture_size + 1) + ] + for end, start in zip(self.cudagraph_capture_sizes, + self.cudagraph_capture_sizes[1:] + [0]): + for bs in range(start, end): + if bs == start: + self.bs_to_padded_graph_size[bs] = start + else: + self.bs_to_padded_graph_size[bs] = end + self.bs_to_padded_graph_size[ + self.max_capture_size] = self.max_capture_size + + def set_splitting_ops_for_v1(self): + # If default, override splitting ops for piecewise cudagraph on V1. + # NOTE: this function needs to be called + if not self.splitting_ops: + self.splitting_ops = [ + "vllm.unified_attention", + "vllm.unified_attention_with_output", + ] + + +@dataclass +class VllmConfig: + """Dataclass which contains all vllm-related configuration. This + simplifies passing around the distinct configurations in the codebase. + """ + + model_config: ModelConfig = field(default=None, init=True) # type: ignore + cache_config: CacheConfig = field(default=None, init=True) # type: ignore + parallel_config: ParallelConfig = field(default_factory=ParallelConfig, + init=True) + scheduler_config: SchedulerConfig = field(default_factory=SchedulerConfig, + init=True) + device_config: DeviceConfig = field(default=None, + init=True) # type: ignore + load_config: LoadConfig = field(default=None, init=True) # type: ignore + lora_config: Optional[LoRAConfig] = None + speculative_config: SpeculativeConfig = field(default=None, + init=True) # type: ignore + decoding_config: Optional[DecodingConfig] = None + observability_config: Optional[ObservabilityConfig] = None + prompt_adapter_config: Optional[PromptAdapterConfig] = None + quant_config: Optional[QuantizationConfig] = None + compilation_config: CompilationConfig = field(default=None, + init=True) # type: ignore + kv_transfer_config: KVTransferConfig = field(default=None, + init=True) # type: ignore + # some opaque config, only used to provide additional information + # for the hash computation, mainly used for testing, debugging or out of + # tree config registration. + additional_config: SupportsHash = field(default=None, + init=True) # type: ignore + instance_id: str = "" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [] + + # summarize vllm config + vllm_factors: list[Any] = [] + from vllm import __version__ + vllm_factors.append(__version__) + vllm_factors.append(envs.VLLM_USE_V1) + if self.model_config: + vllm_factors.append(self.model_config.compute_hash()) + else: + vllm_factors.append("None") + if self.cache_config: + vllm_factors.append(self.cache_config.compute_hash()) + else: + vllm_factors.append("None") + if self.parallel_config: + vllm_factors.append(self.parallel_config.compute_hash()) + else: + vllm_factors.append("None") + if self.scheduler_config: + vllm_factors.append(self.scheduler_config.compute_hash()) + else: + vllm_factors.append("None") + if self.device_config: + vllm_factors.append(self.device_config.compute_hash()) + else: + vllm_factors.append("None") + if self.load_config: + vllm_factors.append(self.load_config.compute_hash()) + else: + vllm_factors.append("None") + if self.lora_config: + vllm_factors.append(self.lora_config.compute_hash()) + # LoRA creates static buffers based on max_num_batched_tokens. + # The tensor sizes and strides get captured in the torch.compile + # graph explicitly. + vllm_factors.append( + str(self.scheduler_config.max_num_batched_tokens)) + else: + vllm_factors.append("None") + if self.speculative_config: + vllm_factors.append(self.speculative_config.compute_hash()) + else: + vllm_factors.append("None") + if self.decoding_config: + vllm_factors.append(self.decoding_config.compute_hash()) + else: + vllm_factors.append("None") + if self.observability_config: + vllm_factors.append(self.observability_config.compute_hash()) + else: + vllm_factors.append("None") + if self.prompt_adapter_config: + vllm_factors.append(self.prompt_adapter_config.compute_hash()) + else: + vllm_factors.append("None") + if self.quant_config: + pass # should be captured by model_config.quantization + if self.compilation_config: + vllm_factors.append(self.compilation_config.compute_hash()) + else: + vllm_factors.append("None") + if self.kv_transfer_config: + vllm_factors.append(self.kv_transfer_config.compute_hash()) + else: + vllm_factors.append("None") + if self.additional_config: + vllm_factors.append(self.additional_config.compute_hash()) + else: + vllm_factors.append("None") + factors.append(vllm_factors) + + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest()[:10] + return hash_str + + def pad_for_cudagraph(self, batch_size: int) -> int: + # if batch_size > self.compilation_config.max_capture_size, + # it should raise an IndexError. + # the caller should make sure the batch_size is within the range, + # i.e., batch_size <= self.compilation_config.max_capture_size + return self.compilation_config.bs_to_padded_graph_size[batch_size] + + @staticmethod + def _get_quantization_config( + model_config: ModelConfig, + load_config: LoadConfig) -> Optional[QuantizationConfig]: + """Get the quantization config.""" + from vllm.platforms import current_platform + if model_config.quantization is not None: + from vllm.model_executor.model_loader.weight_utils import ( + get_quant_config) + quant_config = get_quant_config(model_config, load_config) + capability_tuple = current_platform.get_device_capability() + + if capability_tuple is not None: + capability = capability_tuple.to_int() + if capability < quant_config.get_min_capability(): + raise ValueError( + f"The quantization method {model_config.quantization} " + "is not supported for the current GPU. Minimum " + f"capability: {quant_config.get_min_capability()}. " + f"Current capability: {capability}.") + supported_dtypes = quant_config.get_supported_act_dtypes() + if model_config.dtype not in supported_dtypes: + raise ValueError( + f"{model_config.dtype} is not supported for quantization " + f"method {model_config.quantization}. Supported dtypes: " + f"{supported_dtypes}") + return quant_config + return None + + def with_hf_config( + self, + hf_config: PretrainedConfig, + architectures: Optional[list[str]] = None, + ) -> "VllmConfig": + if architectures is not None: + hf_config = copy.deepcopy(hf_config) + hf_config.architectures = architectures + + model_config = copy.deepcopy(self.model_config) + model_config.hf_config = hf_config + + return replace(self, model_config=model_config) + + def __post_init__(self): + """Verify configs are valid & consistent with each other. + """ + if self.model_config is not None: + self.model_config.verify_async_output_proc(self.parallel_config, + self.speculative_config, + self.device_config) + self.model_config.verify_with_parallel_config(self.parallel_config) + + if self.cache_config is not None: + self.cache_config.verify_with_parallel_config(self.parallel_config) + + if self.lora_config: + self.lora_config.verify_with_cache_config(self.cache_config) + self.lora_config.verify_with_model_config(self.model_config) + self.lora_config.verify_with_scheduler_config( + self.scheduler_config) + if self.prompt_adapter_config: + self.prompt_adapter_config.verify_with_model_config( + self.model_config) + + if self.quant_config is None and \ + self.model_config is not None and self.load_config is not None: + self.quant_config = VllmConfig._get_quantization_config( + self.model_config, self.load_config) + + from vllm.platforms import current_platform + if self.scheduler_config is not None and \ + self.model_config is not None and \ + self.scheduler_config.chunked_prefill_enabled and \ + self.model_config.dtype == torch.float32 and \ + current_platform.get_device_capability() == (7, 5): + logger.warning_once( + "Turing devices tensor cores do not support float32 matmul. " + "To workaround this limitation, vLLM will set 'ieee' input " + "precision for chunked prefill triton kernels.") + + if self.compilation_config is None: + self.compilation_config = CompilationConfig() + if envs.VLLM_USE_V1 and self.model_config is not None and \ + not self.model_config.enforce_eager: + # NOTE(woosuk): Currently, we use inductor because the piecewise + # CUDA graphs do not work properly with the custom CUDA kernels. + # FIXME(woosuk): Disable inductor to reduce the compilation time + # and avoid any potential issues with the inductor. + # FIXME(rob): Add function to set all of these. + self.compilation_config.custom_ops = ["none"] + self.compilation_config.use_cudagraph = True + self.compilation_config.use_inductor = True + self.compilation_config.cudagraph_num_of_warmups = 1 + self.compilation_config.pass_config.enable_fusion = False + self.compilation_config.pass_config.enable_noop = False + self.compilation_config.level = CompilationLevel.PIECEWISE + self.compilation_config.set_splitting_ops_for_v1() + + self._set_cudagraph_sizes() + + if self.cache_config is not None and \ + self.cache_config.cpu_offload_gb > 0 and \ + self.compilation_config.level != CompilationLevel.NO_COMPILATION \ + and not envs.VLLM_USE_V1: + logger.warning( + "CPU offload is not supported with `torch.compile` in v0 yet." + " Disabling `torch.compile`.") + self.compilation_config.level = CompilationLevel.NO_COMPILATION + + if ((not envs.VLLM_USE_V1) and self.lora_config is not None + and self.compilation_config.level + != CompilationLevel.NO_COMPILATION): + logger.warning( + "LoRA for V0 is not supported with `torch.compile` yet. " + "Disabling `torch.compile`.") + self.compilation_config.level = CompilationLevel.NO_COMPILATION + + + if self.model_config and self.model_config.use_mla and \ + not (current_platform.is_cuda() or current_platform.is_rocm()): + logger.info( + "MLA is enabled on a non-GPU platform; forcing chunked " + "prefill and prefix caching to be disabled.") + self.scheduler_config.enable_chunked_prefill = False + self.scheduler_config.chunked_prefill_enabled = False + self.scheduler_config.max_num_batched_tokens = max( + self.scheduler_config.max_model_len, + _DEFAULT_MAX_NUM_BATCHED_TOKENS) + + if self.cache_config is not None: + self.cache_config.enable_prefix_caching = False + + current_platform.check_and_update_config(self) + + if not self.instance_id: + self.instance_id = random_uuid()[:5] + + def _set_cudagraph_sizes(self): + """ + cudagraph batchsize padding logic: + + `[1, 2, 4] + [8 * i for i in range(1, 1025)]` is a list of all possible + batch sizes that cudagraph will capture. + + Depending on the engine's configuration of `max_num_seqs`, the + candidate batch sizes to capture cudagraph will shrink to the subset + which just cover the range of `[1, max_num_seqs]`. In the common case, + `max_num_seqs` is 256, and the cudagraph batch sizes will be + `[1, 2, 4, 8, 16, 24, 32, 40, ..., 256]`. + + However, if users specify the cudagraph capture sizes through + compilation config, we will use the specified sizes instead. + + In the end, `vllm_config.compilation_config.cudagraph_capture_sizes` + will be the final sizes to capture cudagraph (in descending order). + + During runtime, if batchsize is larger than + `vllm_config.compilation_config.cudagraph_capture_sizes`, + no cudagraph will be used. + If the batch size is no larger than + `vllm_config.compilation_config.cudagraph_capture_sizes`, + we can quickly find the padded graph size for a given batch size by + looking up `vllm_config.compilation_config.bs_to_padded_graph_size`. + """ + + # calculate the default `batch_size_capture_list` + if not envs.VLLM_USE_V1: + batch_size_capture_list = [] + max_batchsize_to_capture = 0 + if self.scheduler_config is not None and \ + self.model_config is not None and \ + not self.model_config.enforce_eager: + + possible_sizes = [1, 2, 4] + [8 * i for i in range(1, 1025)] + # find the minimum size that is larger than max_num_seqs, + # which then becomes the max_batchsize_to_capture + larger_sizes = [ + x for x in possible_sizes + if x >= self.scheduler_config.max_num_seqs + ] + if larger_sizes: + max_batchsize_to_capture = larger_sizes[0] + else: + max_batchsize_to_capture = possible_sizes[-1] + + # filter out the sizes that are + # larger than max_batchsize_to_capture + batch_size_capture_list = [ + size for size in possible_sizes + if size <= max_batchsize_to_capture + ] + else: + batch_size_capture_list = [] + if self.model_config is not None and \ + not self.model_config.enforce_eager: + batch_size_capture_list = [1, 2, 4 + ] + [i for i in range(8, 513, 8)] + max_num_tokens = self.scheduler_config.max_num_batched_tokens + batch_size_capture_list = [ + size for size in batch_size_capture_list + if size <= max_num_tokens + ] + + self.compilation_config.init_with_cudagraph_sizes( + batch_size_capture_list) + + def __str__(self): + return ( + f"model={self.model_config.model!r}," + f" speculative_config={self.speculative_config!r}," + f" tokenizer={self.model_config.tokenizer!r}, " + f"skip_tokenizer_init={self.model_config.skip_tokenizer_init}," + f" tokenizer_mode={self.model_config.tokenizer_mode}, " + f"revision={self.model_config.revision}, " + f"override_neuron_config={self.model_config.override_neuron_config}," + f" tokenizer_revision={self.model_config.tokenizer_revision}, " + f"trust_remote_code={self.model_config.trust_remote_code}, " + f"dtype={self.model_config.dtype}, " + f"max_seq_len={self.model_config.max_model_len}," + f" download_dir={self.load_config.download_dir!r}, " + f"load_format={self.load_config.load_format}, " + f"tensor_parallel_size={self.parallel_config.tensor_parallel_size}," + f" pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, " # noqa + f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa + f"quantization={self.model_config.quantization}, " + f"enforce_eager={self.model_config.enforce_eager}, " + f"kv_cache_dtype={self.cache_config.cache_dtype}, " + f" device_config={self.device_config.device}, " + f"decoding_config={self.decoding_config!r}, " + f"observability_config={self.observability_config!r}, " + f"seed={self.model_config.seed}, " + f"served_model_name={self.model_config.served_model_name}, " + f"num_scheduler_steps={self.scheduler_config.num_scheduler_steps}, " + f"multi_step_stream_outputs={self.scheduler_config.multi_step_stream_outputs}, " # noqa + f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, " + f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa + f"use_async_output_proc={self.model_config.use_async_output_proc}, " + f"disable_mm_preprocessor_cache={self.model_config.disable_mm_preprocessor_cache!r}, " # noqa + f"mm_processor_kwargs={self.model_config.mm_processor_kwargs}, " + f"pooler_config={self.model_config.pooler_config!r}, " + f"compilation_config={self.compilation_config!r}") + + +_current_vllm_config: Optional[VllmConfig] = None + + +@contextmanager +def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False): + """ + Temporarily set the current vLLM config. + Used during model initialization. + We save the current vLLM config in a global variable, + so that all modules can access it, e.g. custom ops + can access the vLLM config to determine how to dispatch. + """ + global _current_vllm_config + old_vllm_config = _current_vllm_config + from vllm.compilation.counter import compilation_counter + num_models_seen = compilation_counter.num_models_seen + try: + _current_vllm_config = vllm_config + yield + finally: + logger.debug("enabled custom ops: %s", + vllm_config.compilation_config.enabled_custom_ops) + logger.debug("disabled custom ops: %s", + vllm_config.compilation_config.disabled_custom_ops) + if check_compile and \ + vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \ + and compilation_counter.num_models_seen == num_models_seen: + # If the model supports compilation, + # compilation_counter.num_models_seen should be increased + # by at least 1. + # If it is not increased, it means the model does not support + # compilation (does not have @support_torch_compile decorator). + logger.warning( + "`torch.compile` is turned on, but the model %s" + " does not support it. Please open an issue on GitHub" + " if you want it to be supported.", + vllm_config.model_config.model) + _current_vllm_config = old_vllm_config + + +def get_current_vllm_config() -> VllmConfig: + if _current_vllm_config is None: + # in ci, usually when we test custom ops/modules directly, + # we don't set the vllm config. In that case, we set a default + # config. + logger.warning("Current vLLM config is not set.") + from vllm.config import VllmConfig + return VllmConfig() + return _current_vllm_config diff --git a/vllm/connections.py b/vllm/connections.py new file mode 100644 index 0000000..2c259bb --- /dev/null +++ b/vllm/connections.py @@ -0,0 +1,170 @@ +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Mapping, MutableMapping +from pathlib import Path +from typing import Optional +from urllib.parse import urlparse + +import aiohttp +import requests + +from vllm.version import __version__ as VLLM_VERSION + + +class HTTPConnection: + """Helper class to send HTTP requests.""" + + def __init__(self, *, reuse_client: bool = True) -> None: + super().__init__() + + self.reuse_client = reuse_client + + self._sync_client: Optional[requests.Session] = None + self._async_client: Optional[aiohttp.ClientSession] = None + + def get_sync_client(self) -> requests.Session: + if self._sync_client is None or not self.reuse_client: + self._sync_client = requests.Session() + + return self._sync_client + + # NOTE: We intentionally use an async function even though it is not + # required, so that the client is only accessible inside async event loop + async def get_async_client(self) -> aiohttp.ClientSession: + if self._async_client is None or not self.reuse_client: + self._async_client = aiohttp.ClientSession(trust_env=True) + + return self._async_client + + def _validate_http_url(self, url: str): + parsed_url = urlparse(url) + + if parsed_url.scheme not in ("http", "https"): + raise ValueError("Invalid HTTP URL: A valid HTTP URL " + "must have scheme 'http' or 'https'.") + + def _headers(self, **extras: str) -> MutableMapping[str, str]: + return {"User-Agent": f"vLLM/{VLLM_VERSION}", **extras} + + def get_response( + self, + url: str, + *, + stream: bool = False, + timeout: Optional[float] = None, + extra_headers: Optional[Mapping[str, str]] = None, + ): + self._validate_http_url(url) + + client = self.get_sync_client() + extra_headers = extra_headers or {} + + return client.get(url, + headers=self._headers(**extra_headers), + stream=stream, + timeout=timeout) + + async def get_async_response( + self, + url: str, + *, + timeout: Optional[float] = None, + extra_headers: Optional[Mapping[str, str]] = None, + ): + self._validate_http_url(url) + + client = await self.get_async_client() + extra_headers = extra_headers or {} + + return client.get(url, + headers=self._headers(**extra_headers), + timeout=timeout) + + def get_bytes(self, url: str, *, timeout: Optional[float] = None) -> bytes: + with self.get_response(url, timeout=timeout) as r: + r.raise_for_status() + + return r.content + + async def async_get_bytes( + self, + url: str, + *, + timeout: Optional[float] = None, + ) -> bytes: + async with await self.get_async_response(url, timeout=timeout) as r: + r.raise_for_status() + + return await r.read() + + def get_text(self, url: str, *, timeout: Optional[float] = None) -> str: + with self.get_response(url, timeout=timeout) as r: + r.raise_for_status() + + return r.text + + async def async_get_text( + self, + url: str, + *, + timeout: Optional[float] = None, + ) -> str: + async with await self.get_async_response(url, timeout=timeout) as r: + r.raise_for_status() + + return await r.text() + + def get_json(self, url: str, *, timeout: Optional[float] = None) -> str: + with self.get_response(url, timeout=timeout) as r: + r.raise_for_status() + + return r.json() + + async def async_get_json( + self, + url: str, + *, + timeout: Optional[float] = None, + ) -> str: + async with await self.get_async_response(url, timeout=timeout) as r: + r.raise_for_status() + + return await r.json() + + def download_file( + self, + url: str, + save_path: Path, + *, + timeout: Optional[float] = None, + chunk_size: int = 128, + ) -> Path: + with self.get_response(url, timeout=timeout) as r: + r.raise_for_status() + + with save_path.open("wb") as f: + for chunk in r.iter_content(chunk_size): + f.write(chunk) + + return save_path + + async def async_download_file( + self, + url: str, + save_path: Path, + *, + timeout: Optional[float] = None, + chunk_size: int = 128, + ) -> Path: + async with await self.get_async_response(url, timeout=timeout) as r: + r.raise_for_status() + + with save_path.open("wb") as f: + async for chunk in r.content.iter_chunked(chunk_size): + f.write(chunk) + + return save_path + + +global_http_connection = HTTPConnection() +"""The global :class:`HTTPConnection` instance used by vLLM.""" diff --git a/vllm/core/__init__.py b/vllm/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/core/block/__init__.py b/vllm/core/block/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py new file mode 100644 index 0000000..d4d31c5 --- /dev/null +++ b/vllm/core/block/block_table.py @@ -0,0 +1,398 @@ +# SPDX-License-Identifier: Apache-2.0 + +import math +from typing import List, Optional + +from vllm.core.block.common import BlockList +from vllm.core.block.interfaces import Block, DeviceAwareBlockAllocator +from vllm.utils import Device, cdiv, chunk_list + + +class BlockTable: + """A class to manage blocks for a specific sequence. + + The BlockTable maps a sequence of tokens to a list of blocks, where each + block represents a contiguous memory allocation for a portion of the + sequence. The blocks are managed by a DeviceAwareBlockAllocator, which is + responsible for allocating and freeing memory for the blocks. + + Args: + block_size (int): The maximum number of tokens that can be stored in a + single block. + block_allocator (DeviceAwareBlockAllocator): The block allocator used to + manage memory for the blocks. + _blocks (Optional[List[Block]], optional): An optional list of existing + blocks to initialize the BlockTable with. If not provided, an empty + BlockTable is created. + max_block_sliding_window (Optional[int], optional): The number of + blocks to keep around for each sequence. If None, all blocks + are kept (eg., when sliding window is not used). + It should at least fit the sliding window size of the model. + + Attributes: + _block_size (int): The maximum number of tokens that can be stored in a + single block. + _allocator (DeviceAwareBlockAllocator): The block allocator used to + manage memory for the blocks. + _blocks (Optional[List[Block]]): The list of blocks managed by this + BlockTable. + _num_full_slots (int): The number of tokens currently stored in the + blocks. + """ + + def __init__( + self, + block_size: int, + block_allocator: DeviceAwareBlockAllocator, + _blocks: Optional[List[Block]] = None, + max_block_sliding_window: Optional[int] = None, + ): + self._block_size = block_size + self._allocator = block_allocator + if _blocks is None: + _blocks = [] + self._blocks: BlockList = BlockList(_blocks) + + self._max_block_sliding_window = max_block_sliding_window + self._num_full_slots = self._get_num_token_ids() + + @staticmethod + def get_num_required_blocks(token_ids: List[int], + block_size: int, + num_lookahead_slots: int = 0) -> int: + """Calculates the minimum number of blocks required to store a given + sequence of token IDs along with any look-ahead slots that may be + required (like in multi-step + chunked-prefill). + + This assumes worst-case scenario, where every block requires a new + allocation (e.g. ignoring prefix caching). + + Args: + token_ids (List[int]): The sequence of token IDs to be stored. + block_size (int): The maximum number of tokens that can be stored in + a single block. + num_lookahead_slots (int): look-ahead slots that the sequence may + require. + + Returns: + int: The minimum number of blocks required to store the given + sequence of token IDs along with any required look-ahead slots. + """ + return cdiv(len(token_ids) + num_lookahead_slots, block_size) + + def allocate(self, + token_ids: List[int], + device: Device = Device.GPU, + extra_hash: Optional[int] = None) -> None: + """Allocates memory blocks for storing the given sequence of token IDs. + + This method allocates the required number of blocks to store the given + sequence of token IDs. + + Args: + token_ids (List[int]): The sequence of token IDs to be stored. + device (Device, optional): The device on which the blocks should be + allocated. Defaults to Device.GPU. + extra_hash (Optional[int]): The hash value of additional + factors, such as adapters, that influence the block hash + in the prefixcaching block. + """ + assert not self._is_allocated + assert token_ids + blocks = self._allocate_blocks_for_token_ids(prev_block=None, + token_ids=token_ids, + device=device, + extra_hash=extra_hash) + self.update(blocks) + self._num_full_slots = len(token_ids) + + def update(self, blocks: List[Block]) -> None: + """Resets the table to the newly provided blocks + (with their corresponding block ids) + """ + self._blocks.update(blocks) + + def append_token_ids(self, + token_ids: List[int], + num_lookahead_slots: int = 0, + num_computed_slots: Optional[int] = None, + extra_hash: Optional[int] = None) -> None: + """Appends a sequence of token IDs to the existing blocks in the + BlockTable. + + This method appends the given sequence of token IDs to the existing + blocks in the BlockTable. If there is not enough space in the existing + blocks, new blocks are allocated using the `ensure_num_empty_slots` + method to accommodate the additional tokens. + + The token IDs are divided into chunks of size `block_size` (except for + the first chunk, which may be smaller), and each chunk is appended to a + separate block. + + Args: + token_ids (List[int]): The sequence of token IDs to be appended. + num_computed_slots (Optional[int]): The number of KV cache slots + that are already filled (computed). + When sliding window is enabled, this is used to compute how many + blocks to drop at the front of the sequence. + Without sliding window, None can be passed. + Without chunked prefill, it should be the same as + _num_full_slots. + extra_hash (Optional[int]): The hash value of additional + factors such as adapters that influence the block, apart + from the token_ids. + """ + assert self._is_allocated, "no blocks have been allocated" + assert len(self._blocks) > 0 + + # Drop blocks that are no longer needed due to sliding window + if self._max_block_sliding_window is not None: + null_block = self._allocator.allocate_or_get_null_block() + assert num_computed_slots is not None + end_block_idx = (num_computed_slots // + self._block_size) - self._max_block_sliding_window + for idx in range(0, end_block_idx): + b = self._blocks[idx] + if b is not null_block: + self._allocator.free(b) + self._blocks[idx] = null_block + + # Ensure there are enough empty slots for the new tokens plus + # lookahead slots + self.ensure_num_empty_slots(num_empty_slots=len(token_ids) + + num_lookahead_slots, + extra_hash=extra_hash) + + # Update the blocks with the new tokens + first_block_idx = self._num_full_slots // self._block_size + token_blocks = self._chunk_token_blocks_for_append(token_ids) + + for i, token_block in enumerate(token_blocks): + self._blocks.append_token_ids(first_block_idx + i, token_block) + + self._num_full_slots += len(token_ids) + + def ensure_num_empty_slots(self, + num_empty_slots: int, + extra_hash: Optional[int] = None) -> None: + """Ensures that the BlockTable has at least the specified number of + empty slots available. + + This method checks if the BlockTable has enough empty slots (i.e., + available space) to accommodate the requested number of tokens. If not, + it allocates additional blocks on the GPU to ensure that the required + number of empty slots is available. + + Args: + num_empty_slots (int): The minimum number of empty slots required. + extra_hash (Optional[int]): The hash value of additional + factors such as adapters that influence the block, apart + from the token_ids. + """ + # Currently the block table only supports + # appending tokens to GPU blocks. + device = Device.GPU + assert self._is_allocated + + if self._num_empty_slots >= num_empty_slots: + return + + slots_to_allocate = num_empty_slots - self._num_empty_slots + blocks_to_allocate = cdiv(slots_to_allocate, self._block_size) + + for _ in range(blocks_to_allocate): + assert len(self._blocks) > 0 + self._blocks.append( + self._allocator.allocate_mutable_block( + prev_block=self._blocks[-1], + device=device, + extra_hash=extra_hash)) + + def fork(self) -> "BlockTable": + """Creates a new BlockTable instance with a copy of the blocks from the + current instance. + + This method creates a new BlockTable instance with the same block size, + block allocator, and a copy of the blocks from the current instance. The + new BlockTable has its own independent set of blocks, but shares the + same underlying memory allocation with the original BlockTable. + + Returns: + BlockTable: A new BlockTable instance with a copy of the blocks from + the current instance. + """ + assert self._is_allocated + assert len(self._blocks) > 0 + forked_blocks = self._allocator.fork(self._blocks[-1]) + return BlockTable( + block_size=self._block_size, + block_allocator=self._allocator, + _blocks=forked_blocks, + max_block_sliding_window=self._max_block_sliding_window, + ) + + def free(self) -> None: + """Frees the memory occupied by the blocks in the BlockTable. + + This method iterates over all the blocks in the `_blocks` list and calls + the `free` method of the `_allocator` object to release the memory + occupied by each block. After freeing all the blocks, the `_blocks` list + is set to `None`. + """ + for block in self.blocks: + self._allocator.free(block) + self._blocks.reset() + + @property + def physical_block_ids(self) -> List[int]: + """Returns a list of physical block indices for the blocks in the + BlockTable. + + This property returns a list of integers, where each integer represents + the physical block index of a corresponding block in the `_blocks` list. + The physical block index is a unique identifier for the memory location + occupied by the block. + + Returns: + List[int]: A list of physical block indices for the blocks in the + BlockTable. + """ + return self._blocks.ids() + + def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]: + """Get the number of "unseen" tokens in the sequence. + + Unseen tokens are tokens in the sequence corresponding to this block + table, but are not yet appended to this block table. + + Args: + sequence_token_ids (List[int]): The list of token ids in the + sequence. + + Returns: + List[int]: The postfix of sequence_token_ids that has not yet been + appended to the block table. + """ + + # Since the block table is append-only, the unseen token ids are the + # ones after the appended ones. + return sequence_token_ids[self.num_full_slots:] + + def _allocate_blocks_for_token_ids( + self, + prev_block: Optional[Block], + token_ids: List[int], + device: Device, + extra_hash: Optional[int] = None) -> List[Block]: + blocks: List[Block] = [] + + block_token_ids = [] + tail_token_ids = [] + for cur_token_ids in chunk_list(token_ids, self._block_size): + if len(cur_token_ids) == self._block_size: + block_token_ids.append(cur_token_ids) + else: + tail_token_ids.append(cur_token_ids) + + if block_token_ids: + blocks.extend( + self._allocator.allocate_immutable_blocks( + prev_block, + block_token_ids=block_token_ids, + device=device, + extra_hash=extra_hash)) + prev_block = blocks[-1] + + if tail_token_ids: + assert len(tail_token_ids) == 1 + cur_token_ids = tail_token_ids[0] + + block = self._allocator.allocate_mutable_block( + prev_block=prev_block, device=device, extra_hash=extra_hash) + block.append_token_ids(cur_token_ids) + + blocks.append(block) + + return blocks + + def _get_all_token_ids(self) -> List[int]: + # NOTE: This function is O(seq_len); use sparingly. + token_ids: List[int] = [] + + if not self._is_allocated: + return token_ids + + for block in self.blocks: + token_ids.extend(block.token_ids) + + return token_ids + + def _get_num_token_ids(self) -> int: + res = 0 + for block in self.blocks: + res += len(block.token_ids) + + return res + + @property + def _is_allocated(self) -> bool: + return len(self._blocks) > 0 + + @property + def blocks(self) -> List[Block]: + return self._blocks.list() + + @property + def _num_empty_slots(self) -> int: + assert self._is_allocated + return len(self._blocks) * self._block_size - self._num_full_slots + + @property + def num_full_slots(self) -> int: + """Returns the total number of tokens currently stored in the + BlockTable. + + Returns: + int: The total number of tokens currently stored in the BlockTable. + """ + return self._num_full_slots + + def get_num_blocks_touched_by_append_slots( + self, token_ids: List[int], num_lookahead_slots: int) -> int: + """Determine how many blocks will be "touched" by appending the token + ids. + + This is required for the scheduler to determine whether a sequence can + continue generation, or if it must be preempted. + """ + # Math below is equivalent to: + # all_token_ids = token_ids + [-1] * num_lookahead_slots + # token_blocks = self._chunk_token_blocks_for_append(all_token_ids) + # return len(token_blocks) + + num_token_ids = len(token_ids) + num_lookahead_slots + first_chunk_size = self._block_size - (self._num_full_slots % + self._block_size) + num_token_blocks = (1 + math.ceil( + (num_token_ids - first_chunk_size) / self._block_size)) + return num_token_blocks + + def _chunk_token_blocks_for_append( + self, token_ids: List[int]) -> List[List[int]]: + """Split the token ids into block-sized chunks so they can be easily + appended to blocks. The first such "token block" may have less token ids + than the block size, since the last allocated block may be partially + full. + + If no token ids are provided, then no chunks are returned. + """ + + if not token_ids: + return [] + + first_chunk_size = self._block_size - (self._num_full_slots % + self._block_size) + token_blocks = [token_ids[:first_chunk_size]] + token_blocks.extend( + chunk_list(token_ids[first_chunk_size:], self._block_size)) + return token_blocks diff --git a/vllm/core/block/common.py b/vllm/core/block/common.py new file mode 100644 index 0000000..1966eac --- /dev/null +++ b/vllm/core/block/common.py @@ -0,0 +1,370 @@ +# SPDX-License-Identifier: Apache-2.0 + +from collections import deque +from dataclasses import dataclass +from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple + +from vllm.core.block.interfaces import Block, BlockAllocator + +BlockId = int +RefCount = int + + +class RefCounterProtocol(Protocol): + + def incr(self, block_id: BlockId) -> RefCount: + raise NotImplementedError + + def decr(self, block_id: BlockId) -> RefCount: + raise NotImplementedError + + def get(self, block_id: BlockId) -> RefCount: + raise NotImplementedError + + +class RefCounter(RefCounterProtocol): + """A class for managing reference counts for a set of block indices. + + The RefCounter class maintains a dictionary that maps block indices to their + corresponding reference counts. It provides methods to increment, decrement, + and retrieve the reference count for a given block index. + + Args: + all_block_indices (Iterable[BlockId]): An iterable of block indices + to initialize the reference counter with. + """ + + def __init__(self, all_block_indices: Iterable[BlockId]): + deduped = set(all_block_indices) + self._refcounts: Dict[BlockId, RefCount] = { + index: 0 + for index in deduped + } + + def incr(self, block_id: BlockId) -> RefCount: + assert block_id in self._refcounts + pre_incr_refcount = self._refcounts[block_id] + + assert pre_incr_refcount >= 0 + + post_incr_refcount = pre_incr_refcount + 1 + self._refcounts[block_id] = post_incr_refcount + return post_incr_refcount + + def decr(self, block_id: BlockId) -> RefCount: + assert block_id in self._refcounts + refcount = self._refcounts[block_id] + + assert refcount > 0 + refcount -= 1 + + self._refcounts[block_id] = refcount + + return refcount + + def get(self, block_id: BlockId) -> RefCount: + assert block_id in self._refcounts + return self._refcounts[block_id] + + def as_readonly(self) -> "ReadOnlyRefCounter": + return ReadOnlyRefCounter(self) + + +class ReadOnlyRefCounter(RefCounterProtocol): + """A read-only view of the RefCounter class. + + The ReadOnlyRefCounter class provides a read-only interface to access the + reference counts maintained by a RefCounter instance. It does not allow + modifications to the reference counts. + + Args: + refcounter (RefCounter): The RefCounter instance to create a read-only + view for. + """ + + def __init__(self, refcounter: RefCounter): + self._refcounter = refcounter + + def incr(self, block_id: BlockId) -> RefCount: + raise ValueError("Incr not allowed") + + def decr(self, block_id: BlockId) -> RefCount: + raise ValueError("Decr not allowed") + + def get(self, block_id: BlockId) -> RefCount: + return self._refcounter.get(block_id) + + +class CopyOnWriteTracker: + """A class for tracking and managing copy-on-write operations for blocks. + + The CopyOnWriteTracker class maintains a mapping of source block indices to + their corresponding copy-on-write destination block indices. It works in + conjunction with a RefCounter. + + Args: + refcounter (RefCounter): The reference counter used to track block + reference counts. + """ + + def __init__(self, refcounter: RefCounterProtocol): + self._copy_on_writes: List[Tuple[BlockId, BlockId]] = [] + self._refcounter = refcounter + + def is_appendable(self, block: Block) -> bool: + """Checks if the block is shared or not. If shared, then it cannot + be appended and needs to be duplicated via copy-on-write + """ + block_id = block.block_id + if block_id is None: + return True + + refcount = self._refcounter.get(block_id) + return refcount <= 1 + + def record_cow(self, src_block_id: Optional[BlockId], + trg_block_id: Optional[BlockId]) -> None: + """Records a copy-on-write operation from source to target block id + Args: + src_block_id (BlockId): The source block id from which to copy + the data + trg_block_id (BlockId): The target block id to which the data + is copied + """ + assert src_block_id is not None + assert trg_block_id is not None + self._copy_on_writes.append((src_block_id, trg_block_id)) + + def clear_cows(self) -> List[Tuple[BlockId, BlockId]]: + """Clears the copy-on-write tracking information and returns the current + state. + + This method returns a list mapping source block indices to + destination block indices for the current copy-on-write operations. + It then clears the internal tracking information. + + Returns: + List[Tuple[BlockId, BlockId]]: A list mapping source + block indices to destination block indices for the + current copy-on-write operations. + """ + cows = self._copy_on_writes + self._copy_on_writes = [] + return cows + + +class BlockPool: + """Used to pre-allocate block objects, in order to avoid excessive python + object allocations/deallocations. + The pool starts from "pool_size" objects and will increase to more objects + if necessary + + Note that multiple block objects may point to the same physical block id, + which is why this pool is needed, so that it will be easier to support + prefix caching and more complicated sharing of physical blocks. + """ + + def __init__(self, block_size: int, create_block: Block.Factory, + allocator: BlockAllocator, pool_size: int): + self._block_size = block_size + self._create_block = create_block + self._allocator = allocator + self._pool_size = pool_size + assert self._pool_size >= 0 + + self._free_ids: Deque[int] = deque(range(self._pool_size)) + self._pool = [] + for i in range(self._pool_size): + self._pool.append( + self._create_block(prev_block=None, + token_ids=[], + block_size=self._block_size, + allocator=self._allocator, + block_id=None, + extra_hash=None)) + + def increase_pool(self): + """Doubles the internal pool size + """ + cur_pool_size = self._pool_size + new_pool_size = cur_pool_size * 2 + self._pool_size = new_pool_size + + self._free_ids += deque(range(cur_pool_size, new_pool_size)) + + for i in range(cur_pool_size, new_pool_size): + self._pool.append( + self._create_block(prev_block=None, + token_ids=[], + block_size=self._block_size, + allocator=self._allocator, + block_id=None, + extra_hash=None)) + + def init_block(self, + prev_block: Optional[Block], + token_ids: List[int], + block_size: int, + physical_block_id: Optional[int], + extra_hash: Optional[int] = None) -> Block: + if len(self._free_ids) == 0: + self.increase_pool() + assert len(self._free_ids) > 0 + + pool_id = self._free_ids.popleft() + + block = self._pool[pool_id] + block.__init__( # type: ignore[misc] + prev_block=prev_block, + token_ids=token_ids, + block_size=block_size, + allocator=block._allocator, # type: ignore[attr-defined] + block_id=physical_block_id, + extra_hash=extra_hash) + block.pool_id = pool_id # type: ignore[attr-defined] + return block + + def free_block(self, block: Block) -> None: + self._free_ids.appendleft(block.pool_id) # type: ignore[attr-defined] + + +class BlockList: + """This class is an optimization to allow fast-access to physical + block ids. It maintains a block id list that is updated with the + block list and this avoids the need to reconstruct the block id + list on every iteration of the block manager + """ + + def __init__(self, blocks: List[Block]): + self._blocks: List[Block] = [] + self._block_ids: List[int] = [] + + self.update(blocks) + + def _add_block_id(self, block_id: Optional[BlockId]) -> None: + assert block_id is not None + self._block_ids.append(block_id) + + def _update_block_id(self, block_index: int, + new_block_id: Optional[BlockId]) -> None: + assert new_block_id is not None + self._block_ids[block_index] = new_block_id + + def update(self, blocks: List[Block]): + self._blocks = blocks + + # Cache block ids for fast query + self._block_ids = [] + for block in self._blocks: + self._add_block_id(block.block_id) + + def append_token_ids(self, block_index: int, token_ids: List[int]) -> None: + block = self._blocks[block_index] + prev_block_id = block.block_id + + block.append_token_ids(token_ids) + + # CoW or promotion may update the internal block_id + if prev_block_id != block.block_id: + self._update_block_id(block_index, block.block_id) + + def append(self, new_block: Block): + self._blocks.append(new_block) + self._add_block_id(new_block.block_id) + + def __len__(self) -> int: + return len(self._blocks) + + def __getitem__(self, block_index: int) -> Block: + return self._blocks[block_index] + + def __setitem__(self, block_index: int, new_block: Block) -> None: + self._blocks[block_index] = new_block + self._update_block_id(block_index, new_block.block_id) + + def reset(self): + self._blocks = [] + self._block_ids = [] + + def list(self) -> List[Block]: + return self._blocks + + def ids(self) -> List[int]: + return self._block_ids + + +@dataclass +class CacheMetricData: + """A utility dataclass to maintain cache metric. + To avoid overflow, we maintain the hit rate in block granularity, so that + we can maintain a single hit rate for n_completed_block x block_size, + and calculate the real time hit rate by the following: + BS = The number of queries per block. + nB = The number of completed blocks. + HR = hit rate of (nB x BS) queries. + Q = current number of queries (< BS). + H = current number of hits (< BS). + hit rate = ((HR x nB) + (H / Q) x (Q / BS)) / (nB + Q / BS) + """ + num_completed_blocks: int = 0 + completed_block_cache_hit_rate: float = 0.0 + num_incompleted_block_queries: int = 0 + num_incompleted_block_hit: int = 0 + block_size: int = 1000 + + def query(self, hit: bool): + self.num_incompleted_block_queries += 1 + self.num_incompleted_block_hit += 1 if hit else 0 + + # When a block is completed, update the cache hit rate + # and reset the incomplete numbers. + if self.num_incompleted_block_queries == self.block_size: + hit_rate = (self.num_incompleted_block_hit / + self.num_incompleted_block_queries) + self.completed_block_cache_hit_rate = ( + self.completed_block_cache_hit_rate * self.num_completed_blocks + + hit_rate) / (self.num_completed_blocks + 1) + self.num_incompleted_block_queries = 0 + self.num_incompleted_block_hit = 0 + self.num_completed_blocks += 1 + + def get_hit_rate(self): + incomplete_ratio = self.num_incompleted_block_queries / self.block_size + total_blocks = self.num_completed_blocks + incomplete_ratio + if total_blocks == 0: + return 0.0 + + completed_block_hit, incompleted_block_hit = 0.0, 0.0 + if self.num_completed_blocks > 0: + completed_block_hit = (self.completed_block_cache_hit_rate * + self.num_completed_blocks) + if self.num_incompleted_block_queries > 0: + incompleted_hit_rate = (self.num_incompleted_block_hit / + self.num_incompleted_block_queries) + incompleted_block_hit = (incompleted_hit_rate * incomplete_ratio) + return (completed_block_hit + incompleted_block_hit) / total_blocks + + +def get_all_blocks_recursively(last_block: Block) -> List[Block]: + """Retrieves all the blocks in a sequence starting from the last block. + + This function recursively traverses the sequence of blocks in reverse order, + starting from the given last block, and returns a list of all the blocks in + the sequence. + + Args: + last_block (Block): The last block in the sequence. + + Returns: + List[Block]: A list of all the blocks in the sequence, in the order they + appear. + """ + + def recurse(block: Block, lst: List[Block]) -> None: + if block.prev_block is not None: + recurse(block.prev_block, lst) + lst.append(block) + + all_blocks: List[Block] = [] + recurse(last_block, all_blocks) + return all_blocks diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py new file mode 100644 index 0000000..d64142e --- /dev/null +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -0,0 +1,440 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, FrozenSet, List, Optional, Tuple + +from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, + DeviceAwareBlockAllocator) +from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator +from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator +from vllm.platforms import current_platform +from vllm.utils import Device + + +class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): + """A block allocator that can allocate blocks on both CPU and GPU memory. + + This class implements the `DeviceAwareBlockAllocator` interface and provides + functionality for allocating and managing blocks of memory on both CPU and + GPU devices. + + The `CpuGpuBlockAllocator` maintains separate memory pools for CPU and GPU + blocks, and allows for allocation, deallocation, forking, and swapping of + blocks across these memory pools. + """ + + @staticmethod + def create( + allocator_type: str, + num_gpu_blocks: int, + num_cpu_blocks: int, + block_size: int, + ) -> DeviceAwareBlockAllocator: + """Creates a CpuGpuBlockAllocator instance with the specified + configuration. + + This static method creates and returns a CpuGpuBlockAllocator instance + based on the provided parameters. It initializes the CPU and GPU block + allocators with the specified number of blocks, block size, and + allocator type. + + Args: + allocator_type (str): The type of block allocator to use for CPU + and GPU blocks. Currently supported values are "naive" and + "prefix_caching". + num_gpu_blocks (int): The number of blocks to allocate for GPU + memory. + num_cpu_blocks (int): The number of blocks to allocate for CPU + memory. + block_size (int): The size of each block in number of tokens. + + Returns: + DeviceAwareBlockAllocator: A CpuGpuBlockAllocator instance with the + specified configuration. + + Notes: + - The block IDs are assigned contiguously, with GPU block IDs coming + before CPU block IDs. + """ + # For HPU, block id 0 is used only for padding + reserved_blocks = 1 if current_platform.is_hpu() else 0 + block_ids = list( + range(reserved_blocks, num_gpu_blocks + num_cpu_blocks)) + num_gpu_blocks -= reserved_blocks + gpu_block_ids = block_ids[:num_gpu_blocks] + cpu_block_ids = block_ids[num_gpu_blocks:] + + if allocator_type == "naive": + gpu_allocator: BlockAllocator = NaiveBlockAllocator( + create_block=NaiveBlock, # type: ignore + num_blocks=num_gpu_blocks, + block_size=block_size, + block_ids=gpu_block_ids, + ) + + cpu_allocator: BlockAllocator = NaiveBlockAllocator( + create_block=NaiveBlock, # type: ignore + num_blocks=num_cpu_blocks, + block_size=block_size, + block_ids=cpu_block_ids, + ) + elif allocator_type == "prefix_caching": + gpu_allocator = PrefixCachingBlockAllocator( + num_blocks=num_gpu_blocks, + block_size=block_size, + block_ids=gpu_block_ids, + ) + + cpu_allocator = PrefixCachingBlockAllocator( + num_blocks=num_cpu_blocks, + block_size=block_size, + block_ids=cpu_block_ids, + ) + else: + raise ValueError(f"Unknown allocator type {allocator_type=}") + + return CpuGpuBlockAllocator( + cpu_block_allocator=cpu_allocator, + gpu_block_allocator=gpu_allocator, + ) + + def __init__(self, cpu_block_allocator: BlockAllocator, + gpu_block_allocator: BlockAllocator): + assert not ( + cpu_block_allocator.all_block_ids + & gpu_block_allocator.all_block_ids + ), "cpu and gpu block allocators can't have intersection of block ids" + + self._allocators = { + Device.CPU: cpu_block_allocator, + Device.GPU: gpu_block_allocator, + } + + self._swap_mapping: Dict[int, int] = {} + self._null_block: Optional[Block] = None + + self._block_ids_to_allocator: Dict[int, BlockAllocator] = {} + for _, allocator in self._allocators.items(): + for block_id in allocator.all_block_ids: + self._block_ids_to_allocator[block_id] = allocator + + def allocate_or_get_null_block(self) -> Block: + if self._null_block is None: + self._null_block = NullBlock( + self.allocate_mutable_block(None, Device.GPU)) + return self._null_block + + def allocate_mutable_block(self, + prev_block: Optional[Block], + device: Device, + extra_hash: Optional[int] = None) -> Block: + """Allocates a new mutable block on the specified device. + + Args: + prev_block (Optional[Block]): The previous block to in the sequence. + Used for prefix hashing. + device (Device): The device on which to allocate the new block. + extra_hash (Optional[int]): The hash value of additional + factors, such as adapters, that influence the block hash + in the prefix caching block. + + Returns: + Block: The newly allocated mutable block. + """ + return self._allocators[device].allocate_mutable_block( + prev_block, extra_hash=extra_hash) + + def allocate_immutable_blocks( + self, + prev_block: Optional[Block], + block_token_ids: List[List[int]], + device: Device, + extra_hash: Optional[int] = None) -> List[Block]: + """Allocates a new group of immutable blocks with the provided block + token IDs on the specified device. + + Args: + prev_block (Optional[Block]): The previous block in the sequence. + Used for prefix hashing. + block_token_ids (List[int]): The list of block token IDs to be + stored in the new blocks. + device (Device): The device on which to allocate the new block. + extra_hash (Optional[int]): The hash value of additional + factors, such as adapters, that influence the block hash + in the prefix caching block. + + Returns: + List[Block]: The newly allocated list of immutable blocks + containing the provided block token IDs. + """ + return self._allocators[device].allocate_immutable_blocks( + prev_block, block_token_ids, extra_hash=extra_hash) + + def allocate_immutable_block(self, + prev_block: Optional[Block], + token_ids: List[int], + device: Device, + extra_hash: Optional[int] = None) -> Block: + """Allocates a new immutable block with the provided token IDs on the + specified device. + + Args: + prev_block (Optional[Block]): The previous block in the sequence. + Used for prefix hashing. + token_ids (List[int]): The list of token IDs to be stored in the new + block. + device (Device): The device on which to allocate the new block. + extra_hash (Optional[int]): The hash value of additional + factors, such as adapters, that influence the block hash + in the prefix caching block. + + Returns: + Block: The newly allocated immutable block containing the provided + token IDs. + """ + return self._allocators[device].allocate_immutable_block( + prev_block, token_ids, extra_hash=extra_hash) + + def free(self, block: Block) -> None: + """Frees the memory occupied by the given block. + + Args: + block (Block): The block to be freed. + """ + # Null block should never be freed + if isinstance(block, NullBlock): + return + block_id = block.block_id + assert block_id is not None + allocator = self._block_ids_to_allocator[block_id] + allocator.free(block) + + def fork(self, last_block: Block) -> List[Block]: + """Creates a new sequence of blocks that shares the same underlying + memory as the original sequence. + + Args: + last_block (Block): The last block in the original sequence. + + Returns: + List[Block]: A new list of blocks that shares the same memory as the + original sequence. + """ + # do not attempt to fork the null block + assert not isinstance(last_block, NullBlock) + block_id = last_block.block_id + assert block_id is not None + allocator = self._block_ids_to_allocator[block_id] + return allocator.fork(last_block) + + def get_num_free_blocks(self, device: Device) -> int: + """Returns the number of free blocks available on the specified device. + + Args: + device (Device): The device for which to query the number of free + blocks. AssertionError is raised if None is passed. + + Returns: + int: The number of free blocks available on the specified device. + """ + return self._allocators[device].get_num_free_blocks() + + def get_num_total_blocks(self, device: Device) -> int: + return self._allocators[device].get_num_total_blocks() + + def get_physical_block_id(self, device: Device, absolute_id: int) -> int: + """Returns the zero-offset block id on certain device given the + absolute block id. + + Args: + device (Device): The device for which to query relative block id. + absolute_id (int): The absolute block id for the block in + whole allocator. + + Returns: + int: The zero-offset block id on certain device. + """ + return self._allocators[device].get_physical_block_id(absolute_id) + + def swap(self, blocks: List[Block], src_device: Device, + dst_device: Device) -> Dict[int, int]: + """Execute the swap for the given blocks from source_device + on to dest_device, save the current swap mapping and append + them to the accumulated `self._swap_mapping` for each + scheduling move. + + Args: + blocks: List of blocks to be swapped. + src_device (Device): Device to swap the 'blocks' from. + dst_device (Device): Device to swap the 'blocks' to. + + Returns: + Dict[int, int]: Swap mapping from source_device + on to dest_device. + """ + src_block_ids = [block.block_id for block in blocks] + self._allocators[src_device].swap_out(blocks) + self._allocators[dst_device].swap_in(blocks) + dst_block_ids = [block.block_id for block in blocks] + + current_swap_mapping: Dict[int, int] = {} + for src_block_id, dst_block_id in zip(src_block_ids, dst_block_ids): + if src_block_id is not None and dst_block_id is not None: + self._swap_mapping[src_block_id] = dst_block_id + current_swap_mapping[src_block_id] = dst_block_id + return current_swap_mapping + + def get_num_full_blocks_touched(self, blocks: List[Block], + device: Device) -> int: + """Returns the number of full blocks that will be touched by + swapping in/out the given blocks on to the 'device'. + + Args: + blocks: List of blocks to be swapped. + device (Device): Device to swap the 'blocks' on. + + Returns: + int: the number of full blocks that will be touched by + swapping in/out the given blocks on to the 'device'. + Non full blocks are ignored when deciding the number + of blocks to touch. + """ + return self._allocators[device].get_num_full_blocks_touched(blocks) + + def clear_copy_on_writes(self) -> List[Tuple[int, int]]: + """Clears the copy-on-write (CoW) state and returns the mapping of + source to destination block IDs. + + Returns: + List[Tuple[int, int]]: A list mapping source block IDs to + destination block IDs. + """ + # CoW only supported on GPU + device = Device.GPU + return self._allocators[device].clear_copy_on_writes() + + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + """Mark blocks as accessed, only use for prefix caching.""" + # Prefix caching only supported on GPU. + device = Device.GPU + return self._allocators[device].mark_blocks_as_accessed(block_ids, now) + + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: + """Mark blocks as accessed, only use for prefix caching.""" + # Prefix caching only supported on GPU. + device = Device.GPU + return self._allocators[device].mark_blocks_as_computed(block_ids) + + def get_common_computed_block_ids( + self, computed_seq_block_ids: List[List[int]]) -> List[int]: + # Prefix caching only supported on GPU. + device = Device.GPU + return self._allocators[device].get_common_computed_block_ids( + computed_seq_block_ids) + + @property + def all_block_ids(self) -> FrozenSet[int]: + return frozenset(self._block_ids_to_allocator.keys()) + + def get_prefix_cache_hit_rate(self, device: Device) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + assert device in self._allocators + return self._allocators[device].get_prefix_cache_hit_rate() + + def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: + """Reset prefix cache for specified or all devices.""" + if device: + return self._allocators[device].reset_prefix_cache() + success = True + for allocator in self._allocators.values(): + success = success and allocator.reset_prefix_cache() + return success + + def get_and_reset_swaps(self) -> List[Tuple[int, int]]: + """Returns and clears the mapping of source to destination block IDs. + Will be called after every swapping operations for now, and after every + schedule when BlockManagerV2 become default. Currently not useful. + + Returns: + List[Tuple[int, int]]: A mapping of source to destination block IDs. + """ + mapping = self._swap_mapping.copy() + self._swap_mapping.clear() + return list(mapping.items()) + + def find_cached_blocks_prefix( + self, + block_hashes: List[int], + device: Device = Device.GPU, + ) -> List[int]: + return self._allocators[device].find_cached_blocks_prefix(block_hashes) + + +class NullBlock(Block): + """ + Null blocks are used as a placeholders for KV cache blocks that have + been dropped due to sliding window. + This implementation just wraps an ordinary block and prevents it from + being modified. It also allows for testing if a block is NullBlock + via isinstance(). + """ + + def __init__(self, proxy: Block): + super().__init__() + self._proxy = proxy + + def append_token_ids(self, token_ids: List[BlockId]): + raise ValueError("null block should not be modified") + + @property + def block_id(self): + return self._proxy.block_id + + @block_id.setter + def block_id(self, value: Optional[BlockId]): + raise ValueError("null block should not be modified") + + @property + def token_ids(self) -> List[BlockId]: + return self._proxy.token_ids + + @property + def num_tokens_total(self) -> int: + raise NotImplementedError( + "num_tokens_total is not used for null block") + + @property + def num_empty_slots(self) -> BlockId: + return self._proxy.num_empty_slots + + @property + def is_full(self): + return self._proxy.is_full + + @property + def prev_block(self): + return self._proxy.prev_block + + @property + def extra_hash(self): + return None + + @property + def computed(self): + return self._proxy.computed + + @computed.setter + def computed(self, value): + self._proxy.computed = value + + @property + def last_accessed(self) -> float: + return self._proxy.last_accessed + + @last_accessed.setter + def last_accessed(self, last_accessed_ts: float): + self._proxy.last_accessed = last_accessed_ts + + @property + def content_hash(self): + return self._proxy.content_hash diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py new file mode 100644 index 0000000..3016569 --- /dev/null +++ b/vllm/core/block/interfaces.py @@ -0,0 +1,318 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from typing import Dict, FrozenSet, List, Optional, Protocol, Tuple + +from vllm.utils import Device + +BlockId = int + + +class Block(ABC): + + @abstractmethod + def append_token_ids(self, token_ids: List[int]) -> None: + pass + + @property + @abstractmethod + def block_id(self) -> Optional[int]: + pass + + @block_id.setter + @abstractmethod + def block_id(self, value: Optional[int]) -> None: + """NOTE: Do not use this API outside Block.""" + self._block_id = value + + @property + @abstractmethod + def token_ids(self) -> List[int]: + pass + + @property + @abstractmethod + def num_tokens_total(self) -> int: + """The number of tokens till the current block (inclusive) + """ + pass + + @property + @abstractmethod + def num_empty_slots(self) -> int: + pass + + @property + @abstractmethod + def is_full(self) -> bool: + pass + + @property + @abstractmethod + def prev_block(self) -> Optional["Block"]: + pass + + @property + @abstractmethod + def extra_hash(self) -> Optional[int]: + return None + + @property + @abstractmethod + def computed(self) -> bool: + raise NotImplementedError + + @computed.setter + @abstractmethod + def computed(self, value) -> bool: + """Should be only used by PrefixCacingAllocator""" + raise NotImplementedError + + @property + @abstractmethod + def last_accessed(self) -> float: + raise NotImplementedError + + @last_accessed.setter + @abstractmethod + def last_accessed(self, last_accessed_ts: float): + raise NotImplementedError + + class Factory(Protocol): + + @abstractmethod + def __call__( + self, + prev_block: Optional["Block"], + token_ids: List[int], + block_size: int, + allocator: "BlockAllocator", + block_id: Optional[int] = None, + computed: bool = False, + extra_hash: Optional[int] = None, + ) -> "Block": + pass + + @property + @abstractmethod + def content_hash(self) -> Optional[int]: + """Return the content-based hash of the current block, or None if it is + not yet defined or not supported. + + For the content-based hash to be defined, the current block must be + full. + """ + return None + + +class BlockAllocator(ABC): + + @abstractmethod + def allocate_mutable_block(self, prev_block: Optional[Block], + extra_hash: Optional[int]) -> Block: + pass + + @abstractmethod + def allocate_immutable_block(self, prev_block: Optional[Block], + token_ids: List[int], + extra_hash: Optional[int]) -> Block: + pass + + @abstractmethod + def allocate_immutable_blocks(self, prev_block: Optional[Block], + block_token_ids: List[List[int]], + extra_hash: Optional[int]) -> List[Block]: + pass + + @abstractmethod + def free(self, block: Block) -> None: + pass + + @abstractmethod + def fork(self, last_block: Block) -> List[Block]: + pass + + @abstractmethod + def get_num_total_blocks(self) -> int: + pass + + @abstractmethod + def get_num_free_blocks(self) -> int: + pass + + @abstractmethod + def get_physical_block_id(self, absolute_id: int) -> int: + pass + + @abstractmethod + def swap_out(self, blocks: List[Block]) -> None: + pass + + @abstractmethod + def swap_in(self, blocks: List[Block]) -> None: + pass + + @property + @abstractmethod + def all_block_ids(self) -> FrozenSet[int]: + pass + + @abstractmethod + def clear_copy_on_writes(self) -> List[Tuple[int, int]]: + pass + + @abstractmethod + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + pass + + @abstractmethod + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: + pass + + @abstractmethod + def get_common_computed_block_ids( + self, computed_seq_block_ids: List[List[int]]) -> List[int]: + pass + + @abstractmethod + def cow_block_if_not_appendable(self, block: Block) -> BlockId: + """NOTE: This should not be used besides Block""" + pass + + @abstractmethod + def promote_to_immutable_block(self, block: Block) -> BlockId: + """NOTE: This should not be used besides Block""" + pass + + @abstractmethod + def get_num_full_blocks_touched(self, blocks: List[Block]) -> int: + pass + + @abstractmethod + def get_prefix_cache_hit_rate(self) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + pass + + @abstractmethod + def reset_prefix_cache(self) -> bool: + """Reset prefix cache.""" + pass + + class NoFreeBlocksError(ValueError): + pass + + @abstractmethod + def find_cached_blocks_prefix( + self, + block_hashes: List[int], + ) -> List[int]: + pass + + +class DeviceAwareBlockAllocator(ABC): + + @abstractmethod + def allocate_mutable_block(self, + prev_block: Optional[Block], + device: Device, + extra_hash: Optional[int] = None) -> Block: + pass + + @abstractmethod + def allocate_immutable_block(self, + prev_block: Optional[Block], + token_ids: List[int], + device: Device, + extra_hash: Optional[int] = None) -> Block: + pass + + @abstractmethod + def allocate_immutable_blocks( + self, + prev_block: Optional[Block], + block_token_ids: List[List[int]], + device: Device, + extra_hash: Optional[int] = None, + ) -> List[Block]: + pass + + @abstractmethod + def get_num_free_blocks(self, device: Device) -> int: + pass + + @abstractmethod + def get_num_total_blocks(self, device: Device) -> int: + pass + + @abstractmethod + def free(self, block: Block) -> None: + pass + + @abstractmethod + def fork(self, last_block: Block) -> List[Block]: + pass + + @property + @abstractmethod + def all_block_ids(self) -> FrozenSet[int]: + pass + + @abstractmethod + def clear_copy_on_writes(self) -> List[Tuple[int, int]]: + pass + + @abstractmethod + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + pass + + @abstractmethod + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: + pass + + @abstractmethod + def get_common_computed_block_ids( + self, computed_seq_block_ids: List[List[int]]) -> List[int]: + pass + + @abstractmethod + def get_num_full_blocks_touched(self, blocks: List[Block], + device: Device) -> int: + pass + + @abstractmethod + def swap(self, blocks: List[Block], src_device: Device, + dst_device: Device) -> Dict[int, int]: + pass + + @abstractmethod + def get_physical_block_id(self, device: Device, absolute_id: int) -> int: + pass + + @abstractmethod + def allocate_or_get_null_block(self) -> Block: + """ + Null blocks are used as a placeholders for KV cache blocks that have + been dropped due to sliding window. + There is at most one null block per allocator. + """ + pass + + @abstractmethod + def get_prefix_cache_hit_rate(self, device: Device) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + pass + + @abstractmethod + def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: + """Reset prefix cache.""" + pass + + @abstractmethod + def find_cached_blocks_prefix( + self, + block_hashes: List[int], + device: Device = Device.GPU, + ) -> List[int]: + pass diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py new file mode 100644 index 0000000..c388366 --- /dev/null +++ b/vllm/core/block/naive_block.py @@ -0,0 +1,465 @@ +# SPDX-License-Identifier: Apache-2.0 + +from collections import deque +from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple, Union + +from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter, + get_all_blocks_recursively) +from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device + +Refcount = int + + +class NaiveBlockAllocator(BlockAllocator): + """A simple block allocator that manages blocks of memory without prefix + caching. + + Args: + create_block (Block.Factory): A factory function for creating new + blocks. This is used when a NaiveBlockAllocator is composed within + a prefix caching allocator -- the naive block allocator must + construct prefix caching blocks (but shouldn't know anything else + about them). + num_blocks (int): The total number of blocks to manage. + block_size (int): The size of each block in tokens. + block_ids (Optional[Iterable[int]], optional): An optional iterable of + block IDs. If not provided, block IDs will be assigned sequentially + from 0 to num_blocks - 1. + """ + + def __init__( + self, + create_block: Block.Factory, + num_blocks: int, + block_size: int, + block_ids: Optional[Iterable[int]] = None, + block_pool: Optional[BlockPool] = None, + ): + if block_ids is None: + block_ids = range(num_blocks) + + self._free_block_indices: Deque[BlockId] = deque(block_ids) + self._all_block_indices = frozenset(block_ids) + assert len(self._all_block_indices) == num_blocks + + self._refcounter = RefCounter( + all_block_indices=self._free_block_indices) + self._block_size = block_size + + self._cow_tracker = CopyOnWriteTracker( + refcounter=self._refcounter.as_readonly()) + + if block_pool is None: + extra_factor = 4 + # Pre-allocate "num_blocks * extra_factor" block objects. + # The "* extra_factor" is a buffer to allow more block objects + # than physical blocks + self._block_pool = BlockPool(self._block_size, create_block, self, + num_blocks * extra_factor) + else: + # In this case, the block pool is provided by the caller, + # which means that there is most likely a need to share + # a block pool between allocators + self._block_pool = block_pool + + def allocate_immutable_block(self, + prev_block: Optional[Block], + token_ids: List[int], + extra_hash: Optional[int] = None, + device: Optional[Device] = None) -> Block: + """Allocates a new immutable block with the given token IDs, linked to + the previous block. + + Args: + prev_block (Optional[Block]): The previous block in the sequence. If + None, then the block to be allocated is the first block in the + sequence. + token_ids (List[int]): The token IDs to be stored in the new block. + + Returns: + Block: The newly allocated immutable block. + """ + assert device is None + block = self.allocate_mutable_block(prev_block=prev_block) + block.append_token_ids(token_ids) + return block + + def allocate_immutable_blocks( + self, + prev_block: Optional[Block], + block_token_ids: List[List[int]], + extra_hash: Optional[int] = None, + device: Optional[Device] = None) -> List[Block]: + assert device is None + num_blocks = len(block_token_ids) + + block_ids = [] + for i in range(num_blocks): + block_ids.append(self._allocate_block_id()) + + blocks = [] + for i in range(num_blocks): + prev_block = self._block_pool.init_block( + prev_block=prev_block, + token_ids=block_token_ids[i], + block_size=self._block_size, + physical_block_id=block_ids[i]) + blocks.append(prev_block) + + return blocks + + def allocate_mutable_block(self, + prev_block: Optional[Block], + extra_hash: Optional[int] = None, + device: Optional[Device] = None) -> Block: + """Allocates a new mutable block, linked to the previous block. + + Args: + prev_block (Optional[Block]): The previous block in the sequence. If + None, then the block to be allocated is the first block in the + sequence. + + Returns: + Block: The newly allocated mutable block. + """ + assert device is None + block_id = self._allocate_block_id() + block = self._block_pool.init_block(prev_block=prev_block, + token_ids=[], + block_size=self._block_size, + physical_block_id=block_id) + return block + + def _allocate_block_id(self) -> BlockId: + if not self._free_block_indices: + raise BlockAllocator.NoFreeBlocksError() + + block_id = self._free_block_indices.popleft() + self._refcounter.incr(block_id) + return block_id + + def _free_block_id(self, block: Union[Block, BlockId]) -> None: + if isinstance(block, Block): + block_id = block.block_id + block.block_id = None + else: + block_id = block + assert block_id is not None + + refcount = self._refcounter.decr(block_id) + if refcount == 0: + self._free_block_indices.appendleft(block_id) + + def free(self, block: Block, keep_block_object: bool = False) -> None: + # Release the physical block id + self._free_block_id(block) + + # Release the block object + if not keep_block_object: + self._block_pool.free_block(block) + + def free_block_id(self, block_id: BlockId) -> None: + self._free_block_id(block_id) + + def fork(self, last_block: Block) -> List[Block]: + """Creates a new sequence of blocks that shares the same underlying + memory as the original sequence. + + Args: + last_block (Block): The last block in the original sequence. + + Returns: + List[Block]: The new sequence of blocks that shares the same memory + as the original sequence. + """ + source_blocks = get_all_blocks_recursively(last_block) + + forked_blocks: List[Block] = [] + prev_block = None + for block in source_blocks: + + # Increment refcount for each block. + assert block.block_id is not None + refcount = self._refcounter.incr(block.block_id) + assert refcount != 1, "can't fork free'd block" + + forked_block = self._block_pool.init_block( + prev_block=prev_block, + token_ids=block.token_ids, + block_size=self._block_size, + physical_block_id=block.block_id) + + forked_blocks.append(forked_block) + prev_block = forked_blocks[-1] + + return forked_blocks + + def get_num_free_blocks(self) -> int: + return len(self._free_block_indices) + + def get_num_total_blocks(self) -> int: + return len(self._all_block_indices) + + def get_physical_block_id(self, absolute_id: int) -> int: + """Returns the zero-offset block id on certain block allocator + given the absolute block id. + + Args: + absolute_id (int): The absolute block id for the block + in whole allocator. + + Returns: + int: The zero-offset block id on certain device. + """ + return sorted(self._all_block_indices).index(absolute_id) + + @property + def refcounter(self): + return self._refcounter + + @property + def all_block_ids(self) -> FrozenSet[int]: + return self._all_block_indices + + def cow_block_if_not_appendable(self, block: Block) -> BlockId: + """Performs a copy-on-write operation on the given block if it is not + appendable. + + Args: + block (Block): The block to check for copy-on-write. + + Returns: + BlockId: The block index of the new block if a copy-on-write + operation was performed, or the original block index if + no copy-on-write was necessary. + """ + src_block_id = block.block_id + assert src_block_id is not None + + if self._cow_tracker.is_appendable(block): + return src_block_id + + self._free_block_id(block) + trg_block_id = self._allocate_block_id() + + self._cow_tracker.record_cow(src_block_id, trg_block_id) + + return trg_block_id + + def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]: + """Returns the copy-on-write source->destination mapping and clears it. + + Returns: + List[Tuple[BlockId, BlockId]]: A list mapping source + block indices to destination block indices. + """ + return self._cow_tracker.clear_cows() + + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + """Mark blocks as accessed, used in prefix caching. + + Since the naive allocator does not implement prefix caching, we do + nothing. + """ + pass + + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: + """Mark blocks as computed, used in prefix caching. + + Since the naive allocator does not implement prefix caching, we do + nothing. + """ + pass + + def get_common_computed_block_ids( + self, computed_seq_block_ids: List[List[int]]) -> List[int]: + """Determine blocks that can be skipped in prefill. + + Since the naive allocator does not support prefix caching, always return + an empty list. + """ + return [] + + def promote_to_immutable_block(self, block: Block) -> BlockId: + raise NotImplementedError("There is no promotion for naive blocks") + + def get_num_full_blocks_touched(self, blocks: List[Block]) -> int: + """Returns the number of full blocks that will be touched by + swapping in/out. + + Args: + blocks: List of blocks to be swapped. + Returns: + int: the number of full blocks that will be touched by + swapping in/out the given blocks. Non full blocks are ignored + when deciding the number of blocks to touch. + """ + # NOTE: for naive block, we use set to eliminate common blocks among + # seqs, also we compare the empty slots in the mutable blocks with + # lookahead slots to get the number of unique new block that are + # needed. + old_block_set = set() + for block in blocks: + if block.is_full: + old_block_set.add(block) + return len(old_block_set) + + def swap_out(self, blocks: List[Block]) -> None: + for block in blocks: + self._free_block_id(block) + + def swap_in(self, blocks: List[Block]) -> None: + for block in blocks: + # Here we allocate either immutable or mutable block and then + # extract its block_id. Note that the block object is released + # and the block_id is assigned to "block" to allow reusing the + # existing "block" object + if block.is_full: + tmp_block = self.allocate_immutable_block( + prev_block=block.prev_block, token_ids=block.token_ids) + else: + tmp_block = self.allocate_mutable_block( + prev_block=block.prev_block) + tmp_block.append_token_ids(block.token_ids) + + block_id = tmp_block.block_id + tmp_block.block_id = None + self._block_pool.free_block(tmp_block) + + block.block_id = block_id # Assign block_id + + def get_prefix_cache_hit_rate(self) -> float: + return -1 + + def reset_prefix_cache(self) -> bool: + """No prefix cache for naive block allocator.""" + return True + + def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]: + # Not applicable for naive block allocator. + return [] + + +class NaiveBlock(Block): + """An implementation of the Block class that does not support prefix + caching. + + The NaiveBlock class represents a block of token IDs with a fixed size. It + provides methods for appending token IDs to the block and manages copy-on + -write operations when necessary. + + Args: + prev_block (Block): The previous block in the sequence. + token_ids (List[int]): The initial token IDs to be stored in the block. + block_size (int): The maximum number of token IDs that can be stored in + the block. + allocator (BlockAllocator): The block allocator associated with this + block. + block_id (Optional[int], optional): The physical block index + of this block. Defaults to None, which means no allocation has been + made. + _cow_target (Optional[Block], optional): The copy-on-write target block. + If not provided, it defaults to self. + """ + + def __init__(self, + prev_block: Optional[Block], + token_ids: List[int], + block_size: int, + allocator: BlockAllocator, + block_id: Optional[int] = None, + _cow_target: Optional[Block] = None, + extra_hash: Optional[int] = None): + self._token_ids: List[int] = [] + self._block_size = block_size + self._prev_block = prev_block + self._block_id = block_id + self._allocator = allocator + self._cow_target = _cow_target if _cow_target is not None else self + + self._append_token_ids_no_cow(token_ids) + + def append_token_ids(self, token_ids: List[int]) -> None: + """Appends the given token IDs to the block and performs a + copy-on-write if necessary. + + Args: + token_ids (Optional[List[int]]): The token IDs to be appended + to the block. + """ + self._append_token_ids_no_cow(token_ids) + + if self._block_id is not None: + self._block_id = (self._allocator.cow_block_if_not_appendable( + self._cow_target)) + + def _append_token_ids_no_cow(self, token_ids: List[int]) -> None: + """Appends the given token IDs to the block + + Args: + token_ids (List[int]): The token IDs to be appended to the block. + """ + if len(token_ids) == 0: + return + + assert len(token_ids) <= self.num_empty_slots + + self._token_ids.extend(token_ids) + + @property + def computed(self) -> bool: + raise NotImplementedError + + @computed.setter + def computed(self, value) -> None: + raise NotImplementedError + + @property + def last_accessed(self) -> float: + raise NotImplementedError + + @last_accessed.setter + def last_accessed(self, last_accessed_ts: float): + raise NotImplementedError + + @property + def block_id(self) -> Optional[int]: + return self._block_id + + @block_id.setter + def block_id(self, value: Optional[int]) -> None: + self._block_id = value + + @property + def is_full(self) -> bool: + return self.num_empty_slots == 0 + + @property + def num_empty_slots(self) -> int: + return self._block_size - len(self.token_ids) + + @property + def token_ids(self) -> List[int]: + return self._token_ids + + @property + def num_tokens_total(self) -> int: + raise NotImplementedError( + "num_tokens_total is not used for naive block") + + @property + def block_size(self) -> int: + return self._block_size + + @property + def prev_block(self) -> Optional["Block"]: + return self._prev_block + + @property + def extra_hash(self): + return None + + @property + def content_hash(self) -> Optional[int]: + return None diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py new file mode 100644 index 0000000..1ca9e49 --- /dev/null +++ b/vllm/core/block/prefix_caching_block.py @@ -0,0 +1,1134 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Token blocks.""" +import sys +from bisect import bisect_left +from os.path import commonprefix +from typing import (Callable, Dict, FrozenSet, Iterable, List, Optional, Set, + Tuple) + +from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker, + get_all_blocks_recursively) +from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, Device, + DeviceAwareBlockAllocator) +from vllm.core.block.naive_block import (BlockPool, NaiveBlock, + NaiveBlockAllocator) +from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor +from vllm.logger import init_logger +from vllm.sequence import Sequence + +PrefixHash = int + +# By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME +# so that if we find one block is still hold _DEFAULT_LAST_ACCESSED_TIME, +# then we know this block hasn't been accessed yet. +_DEFAULT_LAST_ACCESSED_TIME = -1 + +logger = init_logger(__name__) + + +class BlockTracker: + """Used to track the status of a block inside the prefix caching allocator + """ + __slots__ = ("active", "last_accessed", "computed") + + def reset(self): + self.last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME + self.computed: bool = False + + def __init__(self): + self.active: bool = False + self.reset() + + def enable(self): + assert not self.active + self.active = True + self.reset() + + def disable(self): + assert self.active + self.active = False + self.reset() + + +class PrefixCachingBlockAllocator(BlockAllocator): + """A block allocator that implements prefix caching. + + The PrefixCachingBlockAllocator maintains a cache of blocks based on their + content hash. It reuses blocks with the same content hash to avoid redundant + memory allocation. The allocator also supports copy-on-write operations. + + Args: + num_blocks (int): The total number of blocks to manage. + block_size (int): The size of each block in tokens. + block_ids(Optional[Iterable[int]], optional): An optional iterable of + block IDs. If not provided, block IDs will be assigned sequentially + from 0 to num_blocks - 1. + """ + + # Note that we use 'None' as a string here instead of None because + # as of Python 3.12, hash(None) returns a constant predictable value. + # This could possibly make it easier to find and exploit hash + # collisions. 'None' as a string will be hashed differently per process, + # but consistently within the same process. This is the same as the + # behavior of None prior to Python 3.12. + _none_hash: int = hash('None') + + # Implements Block.Factory. + def __init__( + self, + num_blocks: int, + block_size: int, + block_ids: Optional[Iterable[int]] = None, + eviction_policy: EvictionPolicy = EvictionPolicy.LRU, + ): + if block_ids is None: + block_ids = range(num_blocks) + + self._block_size = block_size + + # A mapping of prefix hash to block index. All blocks which have a + # prefix hash will be in this dict, even if they have refcount 0. + self._cached_blocks: Dict[PrefixHash, BlockId] = {} + + # A list of immutable block IDs that have been touched by scheduler + # and should be marked as computed after an entire batch of sequences + # are scheduled. + self._touched_blocks: Set[BlockId] = set() + + # Used to track status of each physical block id + self._block_tracker: Dict[BlockId, BlockTracker] = {} + for block_id in block_ids: + self._block_tracker[block_id] = BlockTracker() + + # Pre-allocate "num_blocks * extra_factor" block objects. + # The "* extra_factor" is a buffer to allow more block objects + # than physical blocks + extra_factor = 4 + self._block_pool = BlockPool(self._block_size, self._create_block, + self, num_blocks * extra_factor) + + # An allocator for blocks that do not have prefix hashes. + self._hashless_allocator = NaiveBlockAllocator( + create_block=self._create_block, # type: ignore + num_blocks=num_blocks, + block_size=block_size, + block_ids=block_ids, + block_pool=self._block_pool, # Share block pool here + ) + + # Evitor used to maintain how we want to handle those computed blocks + # if we find memory pressure is high. + self.eviction_policy = eviction_policy + self.evictor: Evictor = make_evictor(self.eviction_policy) + + # We share the refcounter between allocators. This allows us to promote + # blocks originally allocated in the hashless allocator to immutable + # blocks. + self._refcounter = self._hashless_allocator.refcounter + + self._cow_tracker = CopyOnWriteTracker( + refcounter=self._refcounter.as_readonly()) + + self.metric_data = CacheMetricData() + + def _create_block( + self, + prev_block: Optional[Block], + token_ids: List[int], + block_size: int, + allocator: BlockAllocator, + block_id: Optional[int] = None, + computed: bool = False, + extra_hash: Optional[int] = None, + ) -> Block: + # Bind block to self. + allocator = self + + return PrefixCachingBlock( + prev_block=prev_block, + token_ids=token_ids, + block_size=block_size, + block_id=block_id, + allocator=allocator, + computed=computed, + extra_hash=extra_hash, + ) + + def allocate_immutable_block(self, + prev_block: Optional[Block], + token_ids: List[int], + extra_hash: Optional[int] = None, + device: Optional[Device] = None) -> Block: + """Allocates an immutable block with the given token IDs, reusing cached + blocks if possible. + + Args: + prev_block (Optional[Block]): The previous block in the sequence. + token_ids (List[int]): The token IDs to be stored in the block. + + Returns: + Block: The allocated immutable block. + """ + assert device is None + assert_prefix_caching_block_or_none(prev_block) + + # First, try to create a block that points to cached data + block = self._block_pool.init_block(prev_block=prev_block, + token_ids=token_ids, + block_size=self._block_size, + physical_block_id=None, + extra_hash=extra_hash) + assert block.content_hash is not None + + cached_block_id = self._cached_blocks.get(block.content_hash, None) + if cached_block_id is not None: + self.metric_data.query(hit=True) + block.block_id = cached_block_id + self._incr_refcount_cached_block(block) + return block + self.metric_data.query(hit=False) + self._block_pool.free_block(block) + + # No cached block => Allocate a new block + block = self.allocate_mutable_block(prev_block, extra_hash=extra_hash) + block.append_token_ids(token_ids) + return block + + def allocate_immutable_blocks( + self, + prev_block: Optional[Block], + block_token_ids: List[List[int]], + extra_hash: Optional[int] = None, + device: Optional[Device] = None) -> List[Block]: + blocks = [] + for token_ids in block_token_ids: + prev_block = self.allocate_immutable_block(prev_block=prev_block, + token_ids=token_ids, + device=device, + extra_hash=extra_hash) + blocks.append(prev_block) + return blocks + + def allocate_mutable_block(self, + prev_block: Optional[Block], + extra_hash: Optional[int] = None, + device: Optional[Device] = None) -> Block: + """Allocates a mutable block. If there are no free blocks, this will + evict unused cached blocks. + + Args: + prev_block (Block): The previous block in the sequence. + None is not allowed unlike it is super class. + + Returns: + Block: The allocated mutable block. + """ + assert device is None + assert_prefix_caching_block_or_none(prev_block) + + block_id = self._allocate_block_id() + block = self._block_pool.init_block(prev_block=prev_block, + token_ids=[], + block_size=self._block_size, + physical_block_id=block_id, + extra_hash=extra_hash) + assert not block.computed + assert block.content_hash is None + return block + + def _incr_refcount_cached_block(self, block: Block) -> None: + # Set this block to be "computed" since it is pointing to a + # cached block id (which was already computed) + block.computed = True + + block_id = block.block_id + assert block_id is not None + + refcount = self._refcounter.incr(block_id) + if refcount == 1: + # In case a cached block was evicted, restore its tracking + if block_id in self.evictor: + self.evictor.remove(block_id) + + self._track_block_id(block_id, computed=True) + + def _decr_refcount_cached_block(self, block: Block) -> None: + # Ensure this is immutable/cached block + assert block.content_hash is not None + + block_id = block.block_id + assert block_id is not None + + refcount = self._refcounter.decr(block_id) + if refcount > 0: + block.block_id = None + return + else: + assert refcount == 0 + + # No longer used + assert block.content_hash in self._cached_blocks + + # Add the cached block to the evictor + # (This keeps the cached block around so it can be reused) + self.evictor.add(block_id, block.content_hash, block.num_tokens_total, + self._block_tracker[block_id].last_accessed) + + # Stop tracking the block + self._untrack_block_id(block_id) + + block.block_id = None + + def _decr_refcount_hashless_block(self, block: Block) -> None: + block_id = block.block_id + assert block_id is not None + + # We may have a fork case where block is shared, + # in which case, we cannot remove it from tracking + refcount = self._refcounter.get(block_id) + if refcount == 1: + self._untrack_block_id(block_id) + + # Decrement refcount of the block_id, but do not free the block object + # itself (will be handled by the caller) + self._hashless_allocator.free(block, keep_block_object=True) + + def _allocate_block_id(self) -> BlockId: + """First tries to allocate a block id from the hashless allocator, + and if there are no blocks, then tries to evict an unused cached block. + """ + hashless_block_id = self._maybe_allocate_hashless_block_id() + if hashless_block_id is not None: + return hashless_block_id + + evicted_block_id = self._maybe_allocate_evicted_block_id() + if evicted_block_id is not None: + return evicted_block_id + + # No block available in hashless allocator, nor in unused cache blocks. + raise BlockAllocator.NoFreeBlocksError() + + def _maybe_allocate_hashless_block_id(self) -> Optional[BlockId]: + try: + # Allocate mutable block and extract its block_id + block = self._hashless_allocator.allocate_mutable_block( + prev_block=None) + block_id = block.block_id + self._block_pool.free_block(block) + + self._track_block_id(block_id, computed=False) + return block_id + except BlockAllocator.NoFreeBlocksError: + return None + + def _maybe_allocate_evicted_block_id(self) -> Optional[BlockId]: + if self.evictor.num_blocks == 0: + return None + + # Here we get an evicted block, which is only added + # into evictor if its ref counter is 0 + # and since its content would be changed, we need + # to remove it from _cached_blocks's tracking list + block_id, content_hash_to_evict = self.evictor.evict() + + # Sanity checks + assert content_hash_to_evict in self._cached_blocks + _block_id = self._cached_blocks[content_hash_to_evict] + assert self._refcounter.get(_block_id) == 0 + assert _block_id == block_id + + self._cached_blocks.pop(content_hash_to_evict) + + self._refcounter.incr(block_id) + self._track_block_id(block_id, computed=False) + + return block_id + + def _free_block_id(self, block: Block) -> None: + """Decrements the refcount of the block. The block may be in two + possible states: (1) immutable/cached or (2) mutable/hashless. + In the first case, the refcount is decremented directly and the block + may be possibly added to the evictor. In other case, hashless + allocator free(..) with keep_block_object=True is called to only free + the block id (since the block object may be reused by the caller) + """ + block_id = block.block_id + assert block_id is not None, "Freeing unallocated block is undefined" + + if block.content_hash is not None: + # Immutable: This type of block is always cached, and we want to + # keep it in the evictor for future reuse + self._decr_refcount_cached_block(block) + else: + # Mutable: This type of block is not cached, so we release it + # directly to the hashless allocator + self._decr_refcount_hashless_block(block) + + assert block.block_id is None + + def free(self, block: Block, keep_block_object: bool = False) -> None: + """Release the block (look at free_block_id(..) docs) + """ + # Release the physical block index + self._free_block_id(block) + + # Release the block object to the pool + if not keep_block_object: + self._block_pool.free_block(block) + + def fork(self, last_block: Block) -> List[Block]: + """Creates a new sequence of blocks that shares the same underlying + memory as the original sequence. + + Args: + last_block (Block): The last block in the original sequence. + + Returns: + List[Block]: The new sequence of blocks that shares the same memory + as the original sequence. + """ + source_blocks = get_all_blocks_recursively(last_block) + + forked_blocks: List[Block] = [] + prev_block = None + for block in source_blocks: + block_id = block.block_id + assert block_id is not None + + refcount = self._refcounter.incr(block_id) + assert refcount != 1, "can't fork free'd block_id = {}".format( + block_id) + + forked_block = self._block_pool.init_block( + prev_block=prev_block, + token_ids=block.token_ids, + block_size=self._block_size, + physical_block_id=block_id, + extra_hash=block.extra_hash) + + forked_blocks.append(forked_block) + prev_block = forked_blocks[-1] + + return forked_blocks + + def get_num_free_blocks(self, device: Optional[Device] = None) -> int: + assert device is None + # The number of free blocks is the number of hashless free blocks + # plus the number of blocks evictor could free from its list. + return self._hashless_allocator.get_num_free_blocks( + ) + self.evictor.num_blocks + + def get_num_total_blocks(self) -> int: + return self._hashless_allocator.get_num_total_blocks() + + def get_physical_block_id(self, absolute_id: int) -> int: + """Returns the zero-offset block id on certain block allocator + given the absolute block id. + + Args: + absolute_id (int): The absolute block id for the block + in whole allocator. + + Returns: + int: The rzero-offset block id on certain device. + """ + return sorted(self.all_block_ids).index(absolute_id) + + @property + def all_block_ids(self) -> FrozenSet[int]: + return self._hashless_allocator.all_block_ids + + def get_prefix_cache_hit_rate(self) -> float: + return self.metric_data.get_hit_rate() + + def reset_prefix_cache(self) -> bool: + """Reset prefix cache. This function may be used in RLHF + flows to invalid prefix caching after the weights are updated, + or used for resetting prefix caching status for benchmarking. + + Returns: + bool: True if the prefix cache is successfully reset, + False otherwise. + """ + num_used_blocks = (self.get_num_total_blocks() - + self.get_num_free_blocks()) + if num_used_blocks > 0: + logger.warning( + "Failed to reset prefix cache because some " + "blocks (%d) are not freed yet", num_used_blocks) + return False + + # Free all blocks in the evictor. + while (block_id := + self._maybe_allocate_evicted_block_id()) is not None: + self._hashless_allocator.free_block_id(block_id) + + # Should not have any cached blocks because all blocks are evicted. + assert not self._cached_blocks + + # Reset the evictor. + self.evictor = make_evictor(self.eviction_policy) + + # Reset the block tracker. + for block_id in self._block_tracker: + self._block_tracker[block_id] = BlockTracker() + + # Reset the metrics. + self.metric_data = CacheMetricData() + + logger.info("Successfully reset prefix cache") + return True + + def is_block_cached(self, block: Block) -> bool: + assert block.content_hash is not None + return block.content_hash in self._cached_blocks + + def promote_to_immutable_block(self, block: Block) -> BlockId: + """Once a mutable block is full, it can be promoted to an immutable + block. This means that its content can be referenced by future blocks + having the same prefix. + + Note that if we already have a cached block with the same content, we + will replace the newly-promoted block's mapping with the existing cached + block id. + + Args: + block: The mutable block to be promoted. + + Returns: + BlockId: Either the original block index, or the block index of + the previously cached block matching the same content. + """ + # Ensure block can be promoted + assert block.content_hash is not None + assert block.block_id is not None + assert self._refcounter.get(block.block_id) > 0 + + if block.content_hash not in self._cached_blocks: + # No cached content hash => Set this block as cached. + # Note that this block cannot be marked as computed yet + # because other sequences in the same batch cannot reuse + # this block. + self._cached_blocks[block.content_hash] = block.block_id + # Mark this block as touched so that it can be marked as + # computed after the entire batch of sequences are scheduled. + self._touched_blocks.add(block.block_id) + return block.block_id + + # Reuse the cached content hash + self._decr_refcount_hashless_block(block) + block.block_id = self._cached_blocks[block.content_hash] + + # Increment refcount of the cached block and (possibly) restore + # it from the evictor. + # Note that in this case, the block is marked as computed + self._incr_refcount_cached_block(block) + + return block.block_id + + def cow_block_if_not_appendable(self, block: Block) -> BlockId: + """Performs a copy-on-write operation on the given block if it is not + appendable. + + Args: + block (Block): The block to check for copy-on-write. + + Returns: + BlockId: The block index of the new block if a copy-on-write + operation was performed, or the original block index if + no copy-on-write was necessary. + """ + src_block_id = block.block_id + assert src_block_id is not None + + if self._cow_tracker.is_appendable(block): + return src_block_id + + self._free_block_id(block) + trg_block_id = self._allocate_block_id() + + self._cow_tracker.record_cow(src_block_id, trg_block_id) + + return trg_block_id + + def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]: + """Returns the copy-on-write source->destination mapping and clears it. + + Returns: + List[Tuple[BlockId, BlockId]]: A list mapping source + block indices to destination block indices. + """ + return self._cow_tracker.clear_cows() + + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + """Mark blocks as accessed, used in prefix caching. + + If the block is added into evictor, we need to update corresponding + info in evictor's metadata. + """ + + for block_id in block_ids: + if self._block_tracker[block_id].active: + self._block_tracker[block_id].last_accessed = now + elif block_id in self.evictor: + self.evictor.update(block_id, now) + else: + raise ValueError( + "Mark block as accessed which is not belonged to GPU") + + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: + # Mark all touched blocks as computed. + for block_id in self._touched_blocks: + self._block_tracker[block_id].computed = True + self._touched_blocks.clear() + + def _track_block_id(self, block_id: Optional[BlockId], + computed: bool) -> None: + assert block_id is not None + self._block_tracker[block_id].enable() + self._block_tracker[block_id].computed = computed + + def _untrack_block_id(self, block_id: Optional[BlockId]) -> None: + assert block_id is not None + self._block_tracker[block_id].disable() + + def block_is_computed(self, block_id: int) -> bool: + if self._block_tracker[block_id].active: + return self._block_tracker[block_id].computed + else: + return block_id in self.evictor + + def get_common_computed_block_ids( + self, computed_seq_block_ids: List[List[int]]) -> List[int]: + """Return the block ids that are common for a given sequence group. + + Only those blocks that are immutable and already be marked + compyted would be taken consideration. + """ + + # NOTE We exclude the last block to avoid the case where the entire + # prompt is cached. This would cause erroneous behavior in model + # runner. + + # It returns a list of int although type annotation says list of string. + if len(computed_seq_block_ids) == 1: + return computed_seq_block_ids[0] + + return commonprefix([ + ids for ids in computed_seq_block_ids # type: ignore + if ids + ]) + + def get_num_full_blocks_touched(self, blocks: List[Block]) -> int: + """Returns the number of full blocks that will be touched by + swapping in/out. + + Args: + blocks: List of blocks to be swapped. + Returns: + int: the number of full blocks that will be touched by + swapping in/out the given blocks. Non full blocks are ignored + when deciding the number of blocks to touch. + """ + num_touched_blocks: int = 0 + for block in blocks: + # If the block has a match in the cache and the cached + # block is not referenced, then we still count it as a + # touched block + if block.is_full and (not self.is_block_cached(block) or \ + (block.content_hash is not None and \ + self._cached_blocks[block.content_hash] in \ + self.evictor)): + num_touched_blocks += 1 + return num_touched_blocks + + def swap_out(self, blocks: List[Block]) -> None: + """Execute the swap out actions. Basically just free the + given blocks. + + Args: + blocks: List of blocks to be swapped out. + """ + for block in blocks: + self._free_block_id(block) + + def swap_in(self, blocks: List[Block]) -> None: + """Execute the swap in actions. Change the block id from + old allocator to current allocator for each block to finish + the block table update. + + Args: + blocks: List of blocks to be swapped in. + """ + for block in blocks: + # Here we allocate either immutable or mutable block and then + # extract its block_id. Note that the block object is released + # and the block_id is assigned to "block" to allow reusing the + # existing "block" object + if block.is_full: + tmp_block = self.allocate_immutable_block( + prev_block=block.prev_block, + token_ids=block.token_ids, + extra_hash=block.extra_hash) + else: + tmp_block = self.allocate_mutable_block( + prev_block=block.prev_block, extra_hash=block.extra_hash) + tmp_block.append_token_ids(block.token_ids) + + block_id = tmp_block.block_id + self._block_pool.free_block(tmp_block) + + block.block_id = block_id # Assign block_id + + def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]: + """ + Given a list of block hashes, return the prefix of the block hashes that + are all cached. + + Since a block's block hash includes the hashes of all previous blocks, + and we only allocate/deallocate blocks in the entire sequence, so if a + block is cached, then all previous blocks are also cached. With this + property, we can use binary search to find the prefix of cached blocks. + + Args: + block_hashes (List[int]): The list of block hashes. + + Returns: + List[int]: The prefix of the `block_hashes` that are cached. + """ + + def _block_is_cached(block_hash: PrefixHash) -> bool: + if block_hash not in self._cached_blocks: + return False + + cached_block_id = self._cached_blocks[block_hash] + # We only consider the blocks that are marked as computed. + return self.block_is_computed(cached_block_id) + + def _bisect_left(a, x, key: Callable[[PrefixHash], bool]) -> int: + + # python <= 3.10 don't have the key argument + if sys.version_info < (3, 10): + a = [key(e) for e in a] + return bisect_left(a, x) + else: + return bisect_left(a, x, key=key) + + # Look for the first block that's not cached, and returns the prefix + # i.e. blocks that are cached. + idx = _bisect_left(block_hashes, + True, + key=lambda x: not _block_is_cached(x)) + return block_hashes[:idx] + + +class PrefixCachingBlock(Block): + """A block implementation that supports prefix caching. + + The PrefixCachingBlock class represents a block of token IDs with prefix + caching capabilities. It wraps a NaiveBlock internally and provides + additional functionality for content hashing and promoting immutable blocks + with the prefix caching allocator. + + Args: + prev_block (Optional[PrefixCachingBlock]): The previous block in the + sequence. + token_ids (List[int]): The initial token IDs to be stored in the block. + block_size (int): The maximum number of token IDs that can be stored in + the block. + allocator (BlockAllocator): The prefix + caching block allocator associated with this block. + block_id (Optional[int], optional): The physical block index + of this block. Defaults to None. + extra_hash (Optional[int]): The hash value of additional factors + such as adapters that influence the block, apart from the token_ids. + """ + + # Note that we use 'None' as a string here instead of None because + # as of Python 3.12, hash(None) returns a constant predictable value. + # This could possibly make it easier to find and exploit hash + # collisions. 'None' as a string will be hashed differently per process, + # but consistently within the same process. This is the same as the + # behavior of None prior to Python 3.12. + _none_hash: int = hash('None') + + def __init__( + self, + prev_block: Optional[Block], + token_ids: List[int], + block_size: int, + allocator: BlockAllocator, + block_id: Optional[int] = None, + computed: bool = False, + extra_hash: Optional[int] = None, + ): + assert isinstance(allocator, PrefixCachingBlockAllocator), ( + "Currently this class is only tested with " + "PrefixCachingBlockAllocator. Got instead allocator = {}".format( + allocator)) + assert_prefix_caching_block_or_none(prev_block) + + self._prev_block = prev_block + self._cached_content_hash: Optional[int] = None + self._cached_num_tokens_total: int = 0 + self._allocator = allocator + self._last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME + self._computed = computed + self._extra_hash = extra_hash + + # On the first time, we create the block object, and next we only + # reinitialize it + if hasattr(self, "_block"): + self._block.__init__( # type: ignore[has-type] + prev_block=prev_block, + token_ids=token_ids, + block_size=block_size, + block_id=block_id, + allocator=self._allocator) + else: + self._block = NaiveBlock(prev_block=prev_block, + token_ids=token_ids, + block_size=block_size, + block_id=block_id, + allocator=self._allocator) + + self._update_num_tokens_total() + + def _update_num_tokens_total(self): + """Incrementally computes the number of tokens that there is + till the current block (included) + """ + res = 0 + + # Add all previous blocks + if self._prev_block is not None: + res += self._prev_block.num_tokens_total + + # Add current block + res += len(self.token_ids) + + self._cached_num_tokens_total = res + + @property + def computed(self) -> bool: + return self._computed + + @computed.setter + def computed(self, value) -> None: + self._computed = value + + @property + def last_accessed(self) -> float: + return self._last_accessed + + @last_accessed.setter + def last_accessed(self, last_accessed_ts: float): + self._last_accessed = last_accessed_ts + + def append_token_ids(self, token_ids: List[int]) -> None: + """Appends the given token IDs to the block and registers the block as + immutable if the block becomes full. + + Args: + token_ids (List[int]): The token IDs to be appended to the block. + """ + # Ensure this is mutable block (not promoted) + assert self.content_hash is None + assert not self.computed + + if len(token_ids) == 0: + return + + # Ensure there are input tokens + assert token_ids, "Got token_ids = {}".format(token_ids) + + # Naive block handles CoW. + self._block.append_token_ids(token_ids) + self._update_num_tokens_total() + + # If the content hash is present, then the block can be made immutable. + # Register ourselves with the allocator, potentially replacing the + # physical block index. + if self.content_hash is not None: + self.block_id = self._allocator.promote_to_immutable_block(self) + + @property + def block_id(self) -> Optional[int]: + return self._block.block_id + + @block_id.setter + def block_id(self, value) -> None: + self._block.block_id = value + + @property + def is_full(self) -> bool: + return self._block.is_full + + @property + def num_empty_slots(self) -> int: + return self._block.num_empty_slots + + @property + def num_tokens_total(self) -> int: + return self._cached_num_tokens_total + + @property + def block_size(self) -> int: + return self._block.block_size + + @property + def token_ids(self) -> List[int]: + return self._block.token_ids + + @property + def prev_block(self) -> Optional[Block]: + return self._prev_block + + @property + def extra_hash(self) -> Optional[int]: + return self._extra_hash + + @property + def content_hash(self) -> Optional[int]: + """Return the content-based hash of the current block, or None if it is + not yet defined. + + For the content-based hash to be defined, the current block must be + full. + """ + # If the hash is already computed, return it. + if self._cached_content_hash is not None: + return self._cached_content_hash + + # We cannot compute a hash for the current block because it is not full. + if not self.is_full: + return None + + is_first_block = self._prev_block is None + prev_block_hash = ( + self._none_hash if is_first_block else + self._prev_block.content_hash # type: ignore + ) + + # Previous block exists but does not yet have a hash. + # Return no hash in this case. + if prev_block_hash == self._none_hash and not is_first_block: + return None + + self._cached_content_hash = PrefixCachingBlock.hash_block_tokens( + is_first_block, + prev_block_hash, + cur_block_token_ids=self.token_ids, + extra_hash=self._extra_hash) + return self._cached_content_hash + + @classmethod + def hash_block_tokens(cls, + is_first_block: bool, + prev_block_hash: Optional[int], + cur_block_token_ids: List[int], + extra_hash: Optional[int] = None) -> int: + """Computes a hash value corresponding to the contents of a block and + the contents of the preceding block(s). The hash value is used for + prefix caching. + + Parameters: + - is_first_block (bool): A flag indicating if the block is the first in + the sequence. + - prev_block_hash (Optional[int]): The hash of the previous block. None + if this is the first block. + - cur_block_token_ids (List[int]): A list of token ids in the current + block. The current block is assumed to be full. + - extra_hash (Optional[int]): The hash value of additional factors + such as adapters that influence the block, apart from the token_ids. + + Returns: + - int: The computed hash value for the block. + """ + if is_first_block and prev_block_hash is None: + prev_block_hash = cls._none_hash + return hash((is_first_block, prev_block_hash, *cur_block_token_ids, + extra_hash)) + + +class ComputedBlocksTracker: + """ + Tracks the computed blocks for each sequence. + + Internally, it maintains a map from sequence id to the list of block hashes + for the sequence. We cache the hashes of the full blocks for each sequence, + and make sure the hash is calculated in the same way as the allocator. + When a sequence is being decoded, we also update the sequence's hash + accordingly and incrementally. + + From the sequence hash, with prefix caching enabled, we could also calculate + the number of cached tokens for the sequence by looking up the number of + cached block hashes in the allocator. + """ + + # Note that we use 'None' as a string here instead of None because + # as of Python 3.12, hash(None) returns a constant predictable value. + # This could possibly make it easier to find and exploit hash + # collisions. 'None' as a string will be hashed differently per process, + # but consistently within the same process. This is the same as the + # behavior of None prior to Python 3.12. + _none_hash: int = hash('None') + + def __init__( + self, + allocator: DeviceAwareBlockAllocator, + block_size: int, + enable_caching: bool, + ): + self._allocator = allocator + self._block_size = block_size + self._enable_caching = enable_caching + + # A map from seq_id to the list of block hashes for the + # sequence. This is so that we don't have to recompute the block hashes + # for the sequence when we need to check if the sequence is cached. + # Note a block that's not full will not have its hash calculated and + # recorded. + self._seq_id_to_blocks_hashes: Dict[int, List[int]] = {} + + # A map from seq_id to the number of tokens that are cached for the + # sequence. + # We need this so that a sequence in continuous prefill doesn't + # accidentally see its cached token count change. See comments in + # `get_num_cached_tokens` for more details. + self._seq_id_to_num_tokens_computed: Dict[int, int] = {} + + def _update_seq_hashes(self, seq: Sequence) -> None: + """Incrementally update the sequence's block hashes and record them.""" + assert self._enable_caching + + block_hashes_recorded = self._seq_id_to_blocks_hashes.get( + seq.seq_id, []) + cur_num_blocks_recorded = len(block_hashes_recorded) + token_ids = seq.get_token_ids() + assert len(token_ids) >= cur_num_blocks_recorded * self._block_size, ( + f"The sequence has {len(token_ids)} tokens, but" + f" already recorded {cur_num_blocks_recorded} blocks. " + "This should not happen since we assume blocks are " + "only appended other than recomputation. When the sequence is " + "recomputed, we should have removed the info of the old blocks.") + # Update the computed block hashes for the sequence. Since only full + # blocks are considered as "computed", we take floor here. + num_computed_blocks = len(token_ids) // self._block_size + + # We need to know the hash of the previous block to compute the hash of + # the current block so that blocks could be uniquely identified across + # sequences of prefixes. + prev_block_hash = (self._none_hash if cur_num_blocks_recorded == 0 else + block_hashes_recorded[-1]) + # Only update the computed block hashes for the new blocks + for i in range(cur_num_blocks_recorded, num_computed_blocks): + assert len(token_ids) >= (i + 1) * self._block_size + block_token_ids = token_ids[i * self._block_size:(i + 1) * + self._block_size] + + # NOTE: If there are any factors affecting the block besides + # token_ids, they should be added as input to extra_hash. + extra_hash = seq.extra_hash() + + # This has to be kept in sync with the allocator's hash + # calculation. + block_hash = PrefixCachingBlock.hash_block_tokens( + is_first_block=prev_block_hash == self._none_hash, + prev_block_hash=prev_block_hash, + cur_block_token_ids=block_token_ids, + extra_hash=extra_hash, + ) + block_hashes_recorded.append(block_hash) + prev_block_hash = block_hash + + self._seq_id_to_blocks_hashes[seq.seq_id] = block_hashes_recorded + + def get_num_cached_tokens(self, seq: Sequence) -> int: + if not self._enable_caching: + return 0 + + # We always try to update the sequence hashes on the fly. + # This is to ensure that we don't miss any cached tokens for the + # sequence during decode. + # This routine should only update hash for any new blocks too. + self._update_seq_hashes(seq) + + num_computed_tokens_prev = self._seq_id_to_num_tokens_computed.get( + seq.seq_id, None) + + # TODO(rickyx): This hack could be removed once we mark blocks as + # computed correctly with chunked prefills. + if num_computed_tokens_prev is not None and seq.is_prefill(): + # For a sequence that is still in prefill, we don't + # recompute the number of cached tokens. + # This also handles correctly chunked prefill since currently + # we mark blocks as computed even if the sequence is still partially + # prefilled. So a continuously prefilled sequence should not + # see its cached token count change while running. + return num_computed_tokens_prev + + block_hashes = self._seq_id_to_blocks_hashes[seq.seq_id] + + # This is O(logN), where N is the number of blocks. + num_cached_blocks = len( + self._allocator.find_cached_blocks_prefix(block_hashes)) + num_cached_tokens = num_cached_blocks * self._block_size + self._seq_id_to_num_tokens_computed[seq.seq_id] = num_cached_tokens + return num_cached_tokens + + def remove_seq(self, seq_id: int) -> None: + """Stop tracking the sequence.""" + if not self._enable_caching: + return + assert seq_id in self._seq_id_to_blocks_hashes + del self._seq_id_to_blocks_hashes[seq_id] + + assert seq_id in self._seq_id_to_num_tokens_computed + del self._seq_id_to_num_tokens_computed[seq_id] + + +class LastAccessBlocksTracker: + """Manages the last access time of the tracked sequences, in order to allow + an efficient update of allocator's block last access times + """ + + def __init__(self, allocator): + self._allocator = allocator + self._seq_last_access: Dict[int, Optional[float]] = {} + + def add_seq(self, seq_id: int) -> None: + """Start tracking seq_id + """ + assert seq_id not in self._seq_last_access + self._seq_last_access[seq_id] = None + + def remove_seq(self, seq_id: int) -> None: + """Stop tracking seq_id + """ + assert seq_id in self._seq_last_access + del self._seq_last_access[seq_id] + + def update_last_access(self, seq_id: int, time: float) -> None: + assert seq_id in self._seq_last_access + self._seq_last_access[seq_id] = time + + def update_seq_blocks_last_access(self, seq_id: int, + block_ids: List[int]) -> None: + assert seq_id in self._seq_last_access + + ts = self._seq_last_access[seq_id] + + if ts is None: + # No last access was recorded, no need to update. + return + + self._allocator.mark_blocks_as_accessed(block_ids, ts) + + +def assert_prefix_caching_block_or_none(block: Optional[Block]): + if block is None: + return + assert isinstance(block, + PrefixCachingBlock), "Got block = {}".format(block) diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py new file mode 100644 index 0000000..910afdd --- /dev/null +++ b/vllm/core/block/utils.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Block manager utils.""" +from vllm.sequence import SequenceGroup +from vllm.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, + STR_NOT_IMPL_ENC_DEC_SWA) + + +def check_no_caching_or_swa_for_blockmgr_encdec( + block_mgr, seq_group: SequenceGroup) -> None: + ''' + Enforce that prefix caching & sliding-window attention (SWA) + are currently unsupported *specifically* for encoder/decoder models. + + Raises NotImplementedError if unsupported scenario is detected. + + Arguments: + + * block_mgr: BlockSpaceManager instance + * seq_group: SequenceGroup passed to block_mgr + ''' + + if seq_group.is_encoder_decoder(): + if block_mgr.max_block_sliding_window is not None: + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) + + if block_mgr.enable_caching: + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE) diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py new file mode 100644 index 0000000..c6bf6d1 --- /dev/null +++ b/vllm/core/block_manager.py @@ -0,0 +1,520 @@ +# SPDX-License-Identifier: Apache-2.0 +"""A block manager that manages token blocks.""" +from typing import Dict, List, Optional +from typing import Sequence as GenericSequence +from typing import Tuple + +from vllm.core.block.block_table import BlockTable +from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator +from vllm.core.block.interfaces import Block +from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker, + LastAccessBlocksTracker) +from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec +from vllm.core.interfaces import AllocStatus, BlockSpaceManager +from vllm.sequence import Sequence, SequenceGroup, SequenceStatus +from vllm.utils import Device + +SeqId = int +EncoderSeqId = str + + +class SelfAttnBlockSpaceManager(BlockSpaceManager): + """BlockSpaceManager which manages the allocation of KV cache. + + It owns responsibility for allocation, swapping, allocating memory for + autoregressively-generated tokens, and other advanced features such as + prefix caching, forking/copy-on-write, and sliding-window memory allocation. + + This class implements the design described in + https://github.com/vllm-project/vllm/pull/3492. + + Lookahead slots + The block manager has the notion of a "lookahead slot". These are slots + in the KV cache that are allocated for a sequence. Unlike the other + allocated slots, the content of these slots is undefined -- the worker + may use the memory allocations in any way. + + In practice, a worker could use these lookahead slots to run multiple + forward passes for a single scheduler invocation. Each successive + forward pass would write KV activations to the corresponding lookahead + slot. This allows low inter-token latency use-cases, where the overhead + of continuous batching scheduling is amortized over >1 generated tokens. + + Speculative decoding uses lookahead slots to store KV activations of + proposal tokens. + + See https://github.com/vllm-project/vllm/pull/3250 for more information + on lookahead scheduling. + + Args: + block_size (int): The size of each memory block. + num_gpu_blocks (int): The number of memory blocks allocated on GPU. + num_cpu_blocks (int): The number of memory blocks allocated on CPU. + watermark (float, optional): The threshold used for memory swapping. + Defaults to 0.01. + sliding_window (Optional[int], optional): The size of the sliding + window. Defaults to None. + enable_caching (bool, optional): Flag indicating whether caching is + enabled. Defaults to False. + """ + + def __init__( + self, + block_size: int, + num_gpu_blocks: int, + num_cpu_blocks: int, + watermark: float = 0.01, + sliding_window: Optional[int] = None, + enable_caching: bool = False, + ) -> None: + self.block_size = block_size + self.num_total_gpu_blocks = num_gpu_blocks + self.num_total_cpu_blocks = num_cpu_blocks + + self.sliding_window = sliding_window + # max_block_sliding_window is the max number of blocks that need to be + # allocated + self.max_block_sliding_window = None + if sliding_window is not None: + # +1 here because // rounds down + num_blocks = sliding_window // block_size + 1 + # +1 here because the last block may not be full, + # and so the sequence stretches one more block at the beginning + # For example, if sliding_window is 3 and block_size is 4, + # we may need 2 blocks when the second block only holds 1 token. + self.max_block_sliding_window = num_blocks + 1 + + self.watermark = watermark + assert watermark >= 0.0 + + self.enable_caching = enable_caching + + self.watermark_blocks = int(watermark * num_gpu_blocks) + + self.block_allocator = CpuGpuBlockAllocator.create( + allocator_type="prefix_caching" if enable_caching else "naive", + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks, + block_size=block_size, + ) + + self.block_tables: Dict[SeqId, BlockTable] = {} + self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {} + + self._computed_blocks_tracker = ComputedBlocksTracker( + self.block_allocator, self.block_size, self.enable_caching) + self._last_access_blocks_tracker = LastAccessBlocksTracker( + self.block_allocator) + + def can_allocate(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> AllocStatus: + # FIXME(woosuk): Here we assume that all sequences in the group share + # the same prompt. This may not be true for preempted sequences. + + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) + + seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] + num_required_blocks = BlockTable.get_num_required_blocks( + seq.get_token_ids(), + block_size=self.block_size, + num_lookahead_slots=num_lookahead_slots, + ) + + if seq_group.is_encoder_decoder(): + encoder_seq = seq_group.get_encoder_seq() + assert encoder_seq is not None + num_required_blocks += BlockTable.get_num_required_blocks( + encoder_seq.get_token_ids(), + block_size=self.block_size, + ) + + if self.max_block_sliding_window is not None: + num_required_blocks = min(num_required_blocks, + self.max_block_sliding_window) + + num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( + device=Device.GPU) + + # Use watermark to avoid frequent cache eviction. + if (self.num_total_gpu_blocks - num_required_blocks + < self.watermark_blocks): + return AllocStatus.NEVER + if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks: + return AllocStatus.OK + else: + return AllocStatus.LATER + + def _allocate_sequence(self, seq: Sequence) -> BlockTable: + block_table = BlockTable( + block_size=self.block_size, + block_allocator=self.block_allocator, + max_block_sliding_window=self.max_block_sliding_window, + ) + if seq.get_token_ids(): + # NOTE: If there are any factors affecting the block besides + # token_ids, they should be added as input to extra_hash. + extra_hash = seq.extra_hash() + + # Add blocks to the block table only if the sequence is non empty. + block_table.allocate(token_ids=seq.get_token_ids(), + extra_hash=extra_hash) + + return block_table + + def allocate(self, seq_group: SequenceGroup) -> None: + + # Allocate self-attention block tables for decoder sequences + waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) + assert not (set(seq.seq_id for seq in waiting_seqs) + & self.block_tables.keys()), "block table already exists" + + # NOTE: Here we assume that all sequences in the group have the same + # prompt. + seq = waiting_seqs[0] + block_table: BlockTable = self._allocate_sequence(seq) + self.block_tables[seq.seq_id] = block_table + + # Track seq + self._last_access_blocks_tracker.add_seq(seq.seq_id) + + # Assign the block table for each sequence. + for seq in waiting_seqs[1:]: + self.block_tables[seq.seq_id] = block_table.fork() + + # Track seq + self._last_access_blocks_tracker.add_seq(seq.seq_id) + + # Allocate cross-attention block table for encoder sequence + # + # NOTE: Here we assume that all sequences in the group have the same + # encoder prompt. + request_id = seq_group.request_id + + assert (request_id + not in self.cross_block_tables), \ + "block table already exists" + + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) + + if seq_group.is_encoder_decoder(): + encoder_seq = seq_group.get_encoder_seq() + assert encoder_seq is not None + block_table = self._allocate_sequence(encoder_seq) + self.cross_block_tables[request_id] = block_table + + def can_append_slots(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> bool: + """Determine if there is enough space in the GPU KV cache to continue + generation of the specified sequence group. + + We use a worst-case heuristic: assume each touched block will require a + new allocation (either via CoW or new block). We can append slots if the + number of touched blocks is less than the number of free blocks. + + "Lookahead slots" are slots that are allocated in addition to the slots + for known tokens. The contents of the lookahead slots are not defined. + This is used by speculative decoding when speculating future tokens. + """ + + num_touched_blocks = 0 + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + block_table = self.block_tables[seq.seq_id] + + num_touched_blocks += ( + block_table.get_num_blocks_touched_by_append_slots( + token_ids=block_table.get_unseen_token_ids( + seq.get_token_ids()), + num_lookahead_slots=num_lookahead_slots, + )) + + num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( + Device.GPU) + return num_touched_blocks <= num_free_gpu_blocks + + def append_slots( + self, + seq: Sequence, + num_lookahead_slots: int, + ) -> List[Tuple[int, int]]: + + block_table = self.block_tables[seq.seq_id] + + block_table.append_token_ids( + token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()), + num_lookahead_slots=num_lookahead_slots, + num_computed_slots=seq.data.get_num_computed_tokens(), + extra_hash=seq.extra_hash(), + ) + # Return any new copy-on-writes. + new_cows = self.block_allocator.clear_copy_on_writes() + return new_cows + + def free(self, seq: Sequence) -> None: + seq_id = seq.seq_id + + if seq_id not in self.block_tables: + # Already freed or haven't been scheduled yet. + return + + # Update seq block ids with the latest access time + self._last_access_blocks_tracker.update_seq_blocks_last_access( + seq_id, self.block_tables[seq.seq_id].physical_block_ids) + + # Untrack seq + self._last_access_blocks_tracker.remove_seq(seq_id) + self._computed_blocks_tracker.remove_seq(seq_id) + + # Free table/blocks + self.block_tables[seq_id].free() + del self.block_tables[seq_id] + + def free_cross(self, seq_group: SequenceGroup) -> None: + request_id = seq_group.request_id + if request_id not in self.cross_block_tables: + # Already freed or hasn't been scheduled yet. + return + self.cross_block_tables[request_id].free() + del self.cross_block_tables[request_id] + + def get_block_table(self, seq: Sequence) -> List[int]: + block_ids = self.block_tables[seq.seq_id].physical_block_ids + return block_ids # type: ignore + + def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]: + request_id = seq_group.request_id + assert request_id in self.cross_block_tables + block_ids = self.cross_block_tables[request_id].physical_block_ids + assert all(b is not None for b in block_ids) + return block_ids # type: ignore + + def access_all_blocks_in_seq(self, seq: Sequence, now: float): + if self.enable_caching: + # Record the latest access time for the sequence. The actual update + # of the block ids is deferred to the sequence free(..) call, since + # only during freeing of block ids, the blocks are actually added to + # the evictor (which is when the most updated time is required) + # (This avoids expensive calls to mark_blocks_as_accessed(..)) + self._last_access_blocks_tracker.update_last_access( + seq.seq_id, now) + + def mark_blocks_as_computed(self, seq_group: SequenceGroup, + token_chunk_size: int): + # If prefix caching is enabled, mark immutable blocks as computed + # right after they have been scheduled (for prefill). This assumes + # the scheduler is synchronous so blocks are actually computed when + # scheduling the next batch. + self.block_allocator.mark_blocks_as_computed([]) + + def get_common_computed_block_ids( + self, seqs: List[Sequence]) -> GenericSequence[int]: + """Determine which blocks for which we skip prefill. + + With prefix caching we can skip prefill for previously-generated blocks. + Currently, the attention implementation only supports skipping cached + blocks if they are a contiguous prefix of cached blocks. + + This method determines which blocks can be safely skipped for all + sequences in the sequence group. + """ + computed_seq_block_ids = [] + for seq in seqs: + all_blocks = self.block_tables[seq.seq_id].physical_block_ids + num_cached_tokens = ( + self._computed_blocks_tracker.get_num_cached_tokens(seq)) + assert num_cached_tokens % self.block_size == 0 + num_cached_blocks = num_cached_tokens // self.block_size + computed_block_ids = all_blocks[:num_cached_blocks] + computed_seq_block_ids.append(computed_block_ids) + + # NOTE(sang): This assumes seq_block_ids doesn't contain any None. + return self.block_allocator.get_common_computed_block_ids( + computed_seq_block_ids) # type: ignore + + def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: + if parent_seq.seq_id not in self.block_tables: + # Parent sequence has either been freed or never existed. + return + src_block_table = self.block_tables[parent_seq.seq_id] + self.block_tables[child_seq.seq_id] = src_block_table.fork() + + # Track child seq + self._last_access_blocks_tracker.add_seq(child_seq.seq_id) + + def can_swap_in(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> AllocStatus: + """Returns the AllocStatus for the given sequence_group + with num_lookahead_slots. + + Args: + sequence_group (SequenceGroup): The sequence group to swap in. + num_lookahead_slots (int): Number of lookahead slots used in + speculative decoding, default to 0. + + Returns: + AllocStatus: The AllocStatus for the given sequence group. + """ + return self._can_swap(seq_group, Device.GPU, SequenceStatus.SWAPPED, + num_lookahead_slots) + + def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + """Returns the block id mapping (from CPU to GPU) generated by + swapping in the given seq_group with num_lookahead_slots. + + Args: + seq_group (SequenceGroup): The sequence group to swap in. + + Returns: + List[Tuple[int, int]]: The mapping of swapping block from CPU + to GPU. + """ + physical_block_id_mapping = [] + for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): + blocks = self.block_tables[seq.seq_id].blocks + if len(blocks) == 0: + continue + + seq_swap_mapping = self.block_allocator.swap(blocks=blocks, + src_device=Device.CPU, + dst_device=Device.GPU) + + # Refresh the block ids of the table (post-swap) + self.block_tables[seq.seq_id].update(blocks) + + seq_physical_block_id_mapping = { + self.block_allocator.get_physical_block_id( + Device.CPU, cpu_block_id): + self.block_allocator.get_physical_block_id( + Device.GPU, gpu_block_id) + for cpu_block_id, gpu_block_id in seq_swap_mapping.items() + } + + physical_block_id_mapping.extend( + list(seq_physical_block_id_mapping.items())) + + return physical_block_id_mapping + + def can_swap_out(self, seq_group: SequenceGroup) -> bool: + """Returns whether we can swap out the given sequence_group + with num_lookahead_slots. + + Args: + seq_group (SequenceGroup): The sequence group to swap out. + num_lookahead_slots (int): Number of lookahead slots used in + speculative decoding, default to 0. + + Returns: + bool: Whether it's possible to swap out current sequence group. + """ + alloc_status = self._can_swap(seq_group, Device.CPU, + SequenceStatus.RUNNING) + return alloc_status == AllocStatus.OK + + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + """Returns the block id mapping (from GPU to CPU) generated by + swapping out the given sequence_group with num_lookahead_slots. + + Args: + sequence_group (SequenceGroup): The sequence group to swap out. + + Returns: + List[Tuple[int, int]]: The mapping of swapping block from + GPU to CPU. + """ + physical_block_id_mapping = [] + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + blocks = self.block_tables[seq.seq_id].blocks + if len(blocks) == 0: + continue + + seq_swap_mapping = self.block_allocator.swap(blocks=blocks, + src_device=Device.GPU, + dst_device=Device.CPU) + + # Refresh the block ids of the table (post-swap) + self.block_tables[seq.seq_id].update(blocks) + + seq_physical_block_id_mapping = { + self.block_allocator.get_physical_block_id( + Device.GPU, gpu_block_id): + self.block_allocator.get_physical_block_id( + Device.CPU, cpu_block_id) + for gpu_block_id, cpu_block_id in seq_swap_mapping.items() + } + + physical_block_id_mapping.extend( + list(seq_physical_block_id_mapping.items())) + + return physical_block_id_mapping + + def get_num_free_gpu_blocks(self) -> int: + return self.block_allocator.get_num_free_blocks(Device.GPU) + + def get_num_free_cpu_blocks(self) -> int: + return self.block_allocator.get_num_free_blocks(Device.CPU) + + def get_prefix_cache_hit_rate(self, device: Device) -> float: + return self.block_allocator.get_prefix_cache_hit_rate(device) + + def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: + return self.block_allocator.reset_prefix_cache(device) + + def _can_swap(self, + seq_group: SequenceGroup, + device: Device, + status: SequenceStatus, + num_lookahead_slots: int = 0) -> AllocStatus: + """Returns the AllocStatus for swapping in/out the given sequence_group + on to the 'device'. + + Args: + sequence_group (SequenceGroup): The sequence group to swap in/out. + device (Device): device to swap the 'seq_group' on. + status (SequenceStatus): The status of sequence which is needed + for action. RUNNING for swap out and SWAPPED for swap in + num_lookahead_slots (int): Number of lookahead slots used in + speculative decoding, default to 0. + + Returns: + AllocStatus: The AllocStatus for swapping in/out the given + sequence_group on to the 'device'. + """ + # First determine the number of blocks that will be touched by this + # swap. Then verify if there are available blocks in the device + # to perform the swap. + num_blocks_touched = 0 + blocks: List[Block] = [] + for seq in seq_group.get_seqs(status=status): + block_table = self.block_tables[seq.seq_id] + if block_table.blocks is not None: + # Compute the number blocks to touch for the tokens to be + # appended. This does NOT include the full blocks that need + # to be touched for the swap. + num_blocks_touched += \ + block_table.get_num_blocks_touched_by_append_slots( + block_table.get_unseen_token_ids(seq.get_token_ids()), + num_lookahead_slots=num_lookahead_slots) + blocks.extend(block_table.blocks) + # Compute the number of full blocks to touch and add it to the + # existing count of blocks to touch. + num_blocks_touched += self.block_allocator.get_num_full_blocks_touched( + blocks, device=device) + + watermark_blocks = 0 + if device == Device.GPU: + watermark_blocks = self.watermark_blocks + + if self.block_allocator.get_num_total_blocks( + device) < num_blocks_touched: + return AllocStatus.NEVER + elif self.block_allocator.get_num_free_blocks( + device) - num_blocks_touched >= watermark_blocks: + return AllocStatus.OK + else: + return AllocStatus.LATER + + def get_num_cached_tokens(self, seq: Sequence) -> int: + """Get the number of tokens in blocks that are already computed and + cached in the block manager for the sequence. + """ + return self._computed_blocks_tracker.get_num_cached_tokens(seq) diff --git a/vllm/core/evictor.py b/vllm/core/evictor.py new file mode 100644 index 0000000..0e363ed --- /dev/null +++ b/vllm/core/evictor.py @@ -0,0 +1,156 @@ +# SPDX-License-Identifier: Apache-2.0 + +import enum +import heapq +from abc import ABC, abstractmethod +from typing import Dict, List, Tuple + + +class EvictionPolicy(enum.Enum): + """Enum for eviction policy used by make_evictor to instantiate the correct + Evictor subclass. + """ + LRU = enum.auto() + + +class Evictor(ABC): + """The Evictor subclasses should be used by the BlockAllocator class to + handle eviction of freed Blocks. + """ + + @abstractmethod + def __init__(self): + pass + + @abstractmethod + def __contains__(self, block_id: int) -> bool: + pass + + @abstractmethod + def evict(self) -> Tuple[int, int]: + """Runs the eviction algorithm and returns the evicted block's + content hash along with physical block id along with physical block id + """ + pass + + @abstractmethod + def add(self, block_id: int, content_hash: int, num_hashed_tokens: int, + last_accessed: float): + """Adds block to the evictor, making it a candidate for eviction""" + pass + + @abstractmethod + def update(self, block_id: int, last_accessed: float): + """Update corresponding block's access time in metadata""" + pass + + @abstractmethod + def remove(self, block_id: int): + """Remove a given block id from the cache.""" + pass + + @property + @abstractmethod + def num_blocks(self) -> int: + pass + + +class BlockMetaData: + """Data structure for storing key data describe cached block, so that + evitor could use to make its decision which one to choose for eviction + + Here we use physical block id as the dict key, as there maybe several + blocks with the same content hash, but their physical id is unique. + """ + + def __init__(self, content_hash: int, num_hashed_tokens: int, + last_accessed: float): + self.content_hash = content_hash + self.num_hashed_tokens = num_hashed_tokens + self.last_accessed = last_accessed + + +class LRUEvictor(Evictor): + """Evicts in a least-recently-used order using the last_accessed timestamp + that's recorded in the Block. If there are multiple blocks with + the same last_accessed time, then the one with the largest num_hashed_tokens + will be evicted. If two blocks each have the lowest last_accessed time and + highest num_hashed_tokens value, then one will be chose arbitrarily + """ + + # CLEANUP_THRESHOLD determines the maximum allowable size of the priority + # queue relative to the free table size. When this threshold is exceeded, + # a cleanup operation is triggered to reduce memory usage. + CLEANUP_THRESHOLD = 50 + + def __init__(self): + self.free_table: Dict[int, BlockMetaData] = {} + self.priority_queue = [] + + def __contains__(self, block_id: int) -> bool: + return block_id in self.free_table + + def evict(self) -> Tuple[int, int]: + if len(self.free_table) == 0: + raise ValueError("No usable cache memory left") + + while self.priority_queue: + # We do not remove outdated entries from the priority queue at the + # time of updating the last_accessed timestamp. Instead, outdated + # entries are filtered out here during eviction. Outdated entries + # would either not in the free table, or have older last accessed + # time. + last_accessed, _, block_id, content_hash = heapq.heappop( + self.priority_queue) + if (block_id in self.free_table and + self.free_table[block_id].last_accessed == last_accessed): + self.free_table.pop(block_id) + return block_id, content_hash + + raise ValueError("No usable cache memory left") + + def add(self, block_id: int, content_hash: int, num_hashed_tokens: int, + last_accessed: float): + self.free_table[block_id] = BlockMetaData(content_hash, + num_hashed_tokens, + last_accessed) + heapq.heappush( + self.priority_queue, + (last_accessed, -num_hashed_tokens, block_id, content_hash)) + self._cleanup_if_necessary() + + def update(self, block_id: int, last_accessed: float): + self.free_table[block_id].last_accessed = last_accessed + + def _cleanup_if_necessary(self): + if len(self.priority_queue) > LRUEvictor.CLEANUP_THRESHOLD * len( + self.free_table): + self._cleanup() + + def _cleanup(self): + new_priority_queue: List[Tuple[float, int, int, int]] = [] + + for block_id, block in self.free_table.items(): + new_priority_queue.append( + (block.last_accessed, -block.num_hashed_tokens, block_id, + block.content_hash)) + heapq.heapify(new_priority_queue) + + self.priority_queue = new_priority_queue + + def remove(self, block_id: int): + if block_id not in self.free_table: + raise ValueError( + "Attempting to remove block that's not in the evictor") + self.free_table.pop(block_id) + + @property + def num_blocks(self) -> int: + return len(self.free_table) + + +def make_evictor(eviction_policy: EvictionPolicy) -> Evictor: + if eviction_policy == EvictionPolicy.LRU: + return LRUEvictor() + else: + raise ValueError(f"Unknown cache eviction policy: {eviction_policy}") diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py new file mode 100644 index 0000000..4c1182d --- /dev/null +++ b/vllm/core/interfaces.py @@ -0,0 +1,134 @@ +# SPDX-License-Identifier: Apache-2.0 + +import enum +from abc import ABC, abstractmethod +from typing import List, Optional +from typing import Sequence as GenericSequence +from typing import Tuple + +from vllm.sequence import Sequence, SequenceGroup +from vllm.utils import Device + + +class AllocStatus(enum.Enum): + """Result for BlockSpaceManager.can_allocate + + 1. Ok: seq_group can be allocated now. + 2. Later: seq_group cannot be allocated. + The capacity of allocator is larger than seq_group required. + 3. Never: seq_group can never be allocated. + The seq_group is too large to allocated in GPU. + """ + OK = enum.auto() + LATER = enum.auto() + NEVER = enum.auto() + + +class BlockSpaceManager(ABC): + + @staticmethod + def get_block_space_manager_class(version: str): + version = version.lower() + + if version == "selfattn": + from vllm.core.block_manager import SelfAttnBlockSpaceManager + return SelfAttnBlockSpaceManager + + if version == "placeholder": + from vllm.core.placeholder_block_space_manager import ( + PlaceholderBlockSpaceManager) + return PlaceholderBlockSpaceManager + + raise ValueError(f"Unknown version {version=}") + + @abstractmethod + def can_allocate(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> AllocStatus: + pass + + @abstractmethod + def allocate(self, seq_group: SequenceGroup) -> None: + pass + + @abstractmethod + def can_append_slots(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> bool: + pass + + @abstractmethod + def append_slots( + self, + seq: Sequence, + num_lookahead_slots: int, + ) -> List[Tuple[int, int]]: + pass + + @abstractmethod + def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: + pass + + @abstractmethod + def can_swap_in(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> AllocStatus: + pass + + @abstractmethod + def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + pass + + @abstractmethod + def can_swap_out(self, seq_group: SequenceGroup) -> bool: + pass + + @abstractmethod + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + pass + + @abstractmethod + def free(self, seq: Sequence) -> None: + pass + + @abstractmethod + def get_block_table(self, seq: Sequence) -> List[int]: + pass + + @abstractmethod + def get_num_free_gpu_blocks(self) -> int: + pass + + @abstractmethod + def get_num_free_cpu_blocks(self) -> int: + pass + + @abstractmethod + def access_all_blocks_in_seq( + self, + seq: Sequence, + access_time: float, + ) -> None: + pass + + @abstractmethod + def get_common_computed_block_ids( + self, seqs: List[Sequence]) -> GenericSequence[int]: + pass + + @abstractmethod + def mark_blocks_as_computed(self, seq_group: SequenceGroup, + token_chunk_size: int): + pass + + @abstractmethod + def get_prefix_cache_hit_rate(self, device: Device) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + pass + + @abstractmethod + def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: + """Reset prefix cache for specified or all devices.""" + pass + + @abstractmethod + def get_num_cached_tokens(self, seq: Sequence) -> int: + pass diff --git a/vllm/core/placeholder_block_space_manager.py b/vllm/core/placeholder_block_space_manager.py new file mode 100644 index 0000000..0f5d8ca --- /dev/null +++ b/vllm/core/placeholder_block_space_manager.py @@ -0,0 +1,99 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional, Tuple + +from vllm.core.interfaces import AllocStatus, BlockSpaceManager +from vllm.sequence import Sequence, SequenceGroup +from vllm.utils import Device + + +class PlaceholderBlockSpaceManager(BlockSpaceManager): + """A version of BlockSpaceManager for use in environments + where block management is not required. + For example: pooling models or attention-free models like Mamba. + + This class provides the same interface as BlockSpaceManager, but its + methods perform no actions or return simple values like True in specific + actions. It's designed to be used in scenarios where the overhead of + block management is unnecessary, such as in an embedding environment. + """ + + def __init__( + self, + **kwargs, + ) -> None: + pass + + def can_allocate(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> AllocStatus: + # Always return OK for dummy purposes + return AllocStatus.OK + + def allocate(self, seq_group: SequenceGroup) -> None: + # No actual allocation logic needed + pass + + def can_append_slots(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> bool: + return True + + def append_slots( + self, + seq: Sequence, + num_lookahead_slots: int, + ) -> List[Tuple[int, int]]: + return [] + + def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: + pass + + def can_swap_in(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> AllocStatus: + return AllocStatus.OK + + def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + return None # type: ignore + + def can_swap_out(self, seq_group: SequenceGroup) -> bool: + return True + + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + return None # type: ignore + + def free(self, seq: Sequence) -> None: + # No operation on free + return + + def get_block_table(self, seq: Sequence) -> List[int]: + return None # type: ignore + + def get_num_free_gpu_blocks(self) -> int: + return 1 + + def get_num_free_cpu_blocks(self) -> int: + return 1 + + def access_all_blocks_in_seq( + self, + seq: Sequence, + access_time: float, + ) -> None: + pass + + def get_common_computed_block_ids(self, + seq_group: List[Sequence]) -> List[int]: + return [] + + def mark_blocks_as_computed(self, seq_group: SequenceGroup, + token_chunk_size: int): + pass + + def get_prefix_cache_hit_rate(self, device: Device) -> float: + return -1 + + def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: + return True + + def get_num_cached_tokens(self, seq: Sequence) -> int: + return 0 diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py new file mode 100644 index 0000000..cf85a21 --- /dev/null +++ b/vllm/core/scheduler.py @@ -0,0 +1,2061 @@ +# SPDX-License-Identifier: Apache-2.0 + +import enum +import os +import random +import time +from collections import deque +from dataclasses import dataclass, field +from typing import Callable, Deque, Dict, Iterable, List, Optional +from typing import Sequence as GenericSequence +from typing import Set, Tuple, Union + +from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig +from vllm.core.interfaces import AllocStatus, BlockSpaceManager +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sequence import (Sequence, SequenceData, SequenceGroup, + SequenceGroupBase, SequenceGroupMetadata, + SequenceGroupMetadataDelta, SequenceStage, + SequenceStatus) +from vllm.utils import Device, PyObjectCache + +logger = init_logger(__name__) + +# Test-only. If configured, decode is preempted with +# ARTIFICIAL_PREEMPTION_PROB% probability. +ENABLE_ARTIFICIAL_PREEMPT = bool( + os.getenv("VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT", False)) # noqa +ARTIFICIAL_PREEMPTION_PROB = 0.5 +ARTIFICIAL_PREEMPTION_MAX_CNT = 500 + + +class PreemptionMode(enum.Enum): + """Preemption modes. + + 1. Swapping: Swap out the blocks of the preempted sequences to CPU memory + and swap them back in when the sequences are resumed. + 2. Recomputation: Discard the blocks of the preempted sequences and + recompute them when the sequences are resumed, treating the sequences as + new prompts. + """ + + SWAP = enum.auto() + RECOMPUTE = enum.auto() + + +@dataclass +class SchedulingBudget: + """The available slots for scheduling. + + TODO(sang): Right now, the budget is request_id-aware meaning it can ignore + budget update from the same request_id. It is because in normal scheduling + path, we update RUNNING num_seqs ahead of time, meaning it could be + updated more than once when scheduling RUNNING requests. Since this won't + happen if we only have chunked prefill scheduling, we can remove this + feature from the API when chunked prefill is enabled by default. + """ + + token_budget: int + max_num_seqs: int + _request_ids_num_batched_tokens: Set[str] = field(default_factory=set) + _request_ids_num_curr_seqs: Set[str] = field(default_factory=set) + # Number of cached tokens in the batch. + _num_cached_tokens: int = 0 + # Number of actual non-cached tokens in the batch. + _num_batched_tokens: int = 0 + _num_curr_seqs: int = 0 + + def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int): + # We allow num_new_tokens to be 0 when the entire sequence has + # been cached. + assert num_new_tokens >= 0 + assert num_new_seqs != 0 + return (self.num_batched_tokens + num_new_tokens <= self.token_budget + and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs) + + def remaining_token_budget(self): + return self.token_budget - self.num_batched_tokens + + def add_num_batched_tokens(self, + req_id: str, + num_batched_tokens: int, + num_cached_tokens: int = 0): + if req_id in self._request_ids_num_batched_tokens: + return + assert num_cached_tokens >= 0 + assert num_batched_tokens >= 0 + + self._request_ids_num_batched_tokens.add(req_id) + self._num_batched_tokens += num_batched_tokens + self._num_cached_tokens += num_cached_tokens + + def subtract_num_batched_tokens(self, req_id: str, + num_batched_tokens: int): + if req_id in self._request_ids_num_batched_tokens: + self._request_ids_num_batched_tokens.remove(req_id) + self._num_batched_tokens -= num_batched_tokens + + def add_num_seqs(self, req_id: str, num_curr_seqs: int): + if req_id in self._request_ids_num_curr_seqs: + return + + self._request_ids_num_curr_seqs.add(req_id) + self._num_curr_seqs += num_curr_seqs + + def subtract_num_seqs(self, req_id: str, num_curr_seqs: int): + if req_id in self._request_ids_num_curr_seqs: + self._request_ids_num_curr_seqs.remove(req_id) + self._num_curr_seqs -= num_curr_seqs + + @property + def num_batched_tokens(self): + return self._num_batched_tokens + + @property + def num_curr_seqs(self): + return self._num_curr_seqs + + @property + def num_cached_tokens(self): + return self._num_cached_tokens + + +@dataclass +class ScheduledSequenceGroup: + # A sequence group that's scheduled. + seq_group: SequenceGroup + # The total chunk size (number of tokens) to process for next iteration. + # 1 for decoding. Same as prompt tokens for prefill, but if prefill is + # chunked, it can be smaller than that. + token_chunk_size: int + + +@dataclass +class SchedulerOutputs: + """The scheduling decision made from a scheduler.""" + + # Scheduled sequence groups. + scheduled_seq_groups: GenericSequence[ScheduledSequenceGroup] + # Number of prefill groups scheduled. + num_prefill_groups: int + # Total number of batched tokens. + num_batched_tokens: int + # Blocks to swap in. List of CPU -> GPU block number. + blocks_to_swap_in: List[Tuple[int, int]] + # Blocks to swap out. List of GPU -> CPU block number. + blocks_to_swap_out: List[Tuple[int, int]] + # Blocks to copy. Source to dest block. + blocks_to_copy: List[Tuple[int, int]] + # Sequence groups that are going to be ignored. + ignored_seq_groups: List[SequenceGroup] + # The number of slots for lookahead decoding. + num_lookahead_slots: int + # The number of requests in the running queue + running_queue_size: int + preempted: int + + def __post_init__(self): + # Swap in and swap out should never happen at the same time. + assert not (self.blocks_to_swap_in and self.blocks_to_swap_out) + + self.num_loras: int = len(self.lora_requests) + if self.num_loras > 0: + self._sort_by_lora_ids() + + self.num_prompt_adapters: int = len(self.prompt_adapter_requests) + + def is_empty(self) -> bool: + # NOTE: We do not consider the ignored sequence groups. + return (not self.scheduled_seq_groups and not self.blocks_to_swap_in + and not self.blocks_to_swap_out and not self.blocks_to_copy) + + def _sort_by_lora_ids(self): + assert 0 <= self.num_prefill_groups <= len(self.scheduled_seq_groups) + + def key_fn(group: ScheduledSequenceGroup): + key = (group.seq_group.lora_int_id, group.seq_group.request_id) + if 0 < self.num_prefill_groups < len(self.scheduled_seq_groups): + # Sort sequence groups so that all prefills come before all + # decodes as required by chunked prefill. + return (not group.seq_group.is_prefill(), *key) + return key + + self.scheduled_seq_groups = sorted(self.scheduled_seq_groups, + key=key_fn) + + @property + def lora_requests(self) -> Set[LoRARequest]: + return { + g.seq_group.lora_request + for g in self.scheduled_seq_groups + if g.seq_group.lora_request is not None + } + + @property + def prompt_adapter_requests(self) -> Set[PromptAdapterRequest]: + return { + g.seq_group.prompt_adapter_request + for g in self.scheduled_seq_groups + if g.seq_group.prompt_adapter_request is not None + } + + +@dataclass +class SchedulerRunningOutputs: + """The requests that are scheduled from a running queue. + + Could contain prefill (prefill that's chunked) or decodes. If there's not + enough memory, it can be preempted (for recompute) or swapped out. + """ + + # Selected sequences that are running and in a decoding phase. + decode_seq_groups: List[ScheduledSequenceGroup] + # Selected sequences that are running and in a prefill phase. + # I.e., it means the prefill has been chunked. + prefill_seq_groups: List[ScheduledSequenceGroup] + # The preempted sequences. + preempted: List[SequenceGroup] + # Sequences that are swapped out. + swapped_out: List[SequenceGroup] + # The blocks to swap out. + blocks_to_swap_out: List[Tuple[int, int]] + # The blocks to copy. + blocks_to_copy: List[Tuple[int, int]] + # The number of slots for lookahead decoding. + num_lookahead_slots: int + + # Optimization for fast-access to seq_group lists + decode_seq_groups_list: List[SequenceGroup] + prefill_seq_groups_list: List[SequenceGroup] + + @classmethod + def create_empty(cls) -> "SchedulerRunningOutputs": + return SchedulerRunningOutputs( + decode_seq_groups=[], + prefill_seq_groups=[], + preempted=[], + swapped_out=[], + blocks_to_swap_out=[], + blocks_to_copy=[], + num_lookahead_slots=0, + decode_seq_groups_list=[], + prefill_seq_groups_list=[], + ) + + +@dataclass +class SchedulerSwappedInOutputs: + """The requests that are scheduled from a swap queue. + + Could contain prefill (prefill that's chunked) or decodes. + """ + + # Selected sequences that are going to be swapped in and is in a + # decoding phase. + decode_seq_groups: List[ScheduledSequenceGroup] + # Selected sequences that are going to be swapped in and in a prefill + # phase. I.e., it means the prefill has been chunked. + prefill_seq_groups: List[ScheduledSequenceGroup] + # The blocks to swap in. + blocks_to_swap_in: List[Tuple[int, int]] + # The blocks to copy. + blocks_to_copy: List[Tuple[int, int]] + # The number of slots for lookahead decoding. + num_lookahead_slots: int + # Infeasible sequence groups. + infeasible_seq_groups: List[SequenceGroup] + + @classmethod + def create_empty(cls) -> "SchedulerSwappedInOutputs": + return SchedulerSwappedInOutputs( + decode_seq_groups=[], + prefill_seq_groups=[], + blocks_to_swap_in=[], + blocks_to_copy=[], + num_lookahead_slots=0, + infeasible_seq_groups=[], + ) + + +@dataclass +class SchedulerPrefillOutputs: + """The requests that are scheduled from a waiting queue. + + Could contain a fresh prefill requests or preempted requests that need + to be recomputed from scratch. + """ + + # Selected sequences for prefill. + seq_groups: List[ScheduledSequenceGroup] + # Ignored sequence groups. + ignored_seq_groups: List[SequenceGroup] + num_lookahead_slots: int + + @classmethod + def create_empty(cls) -> "SchedulerPrefillOutputs": + return SchedulerPrefillOutputs( + seq_groups=[], + ignored_seq_groups=[], + num_lookahead_slots=0, + ) + + +def seq_group_metadata_builder(): + return SequenceGroupMetadata(request_id="", + is_prompt=False, + seq_data={}, + sampling_params=None, + block_tables={}) + + +def scheduler_running_outputs_builder(): + return SchedulerRunningOutputs(decode_seq_groups=[], + prefill_seq_groups=[], + preempted=[], + swapped_out=[], + blocks_to_swap_out=[], + blocks_to_copy=[], + num_lookahead_slots=0, + prefill_seq_groups_list=[], + decode_seq_groups_list=[]) + + +def scheduled_seq_group_builder(): + return ScheduledSequenceGroup(SequenceGroup.__new__(SequenceGroup), + token_chunk_size=0) + # return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0) + + +@dataclass +class PartialPrefillMetadata: + """Holds information about the partial prefills that are currently running + during a single iteration of the Scheduler. + When chunked prefill is enabled, we allow a certain number of seqs to be + partially prefilled during each iteration. Having multiple partial prefills + in flight allows us to minimize TTFT and avoid decode starvation in cases + where a single sequence group with a very large prompt blocks the queue for + too many iterations. + The number of long prefill requests is limited so that smaller + requests may jump the queue in front of them and get to the decode + phase faster. + """ + + # A minimum bound on the total number of prefills to be scheduled during + # this iteration + schedulable_prefills: int + + # The number of long prefill requests currently running + long_prefills: int + + scheduler_config: SchedulerConfig + + def can_schedule(self, seq_group: SequenceGroup) -> bool: + """When concurrent partial prefills are enabled, + we limit the number of long requests and only accept + shorter requests from the queue while running them + concurrently""" + return not (seq_group.first_seq.get_num_new_tokens() + > self.scheduler_config.long_prefill_token_threshold + and self.long_prefills + >= self.scheduler_config.max_long_partial_prefills + and self.scheduler_config.max_num_partial_prefills > 1) + + def maybe_increment_partial_prefills(self, + seq_group: SequenceGroup) -> None: + # When a new prefill is scheduled, we need to know if it is a + # long request + if (seq_group.first_seq.get_num_new_tokens() + > self.scheduler_config.long_prefill_token_threshold): + self.long_prefills += 1 + + @classmethod + def from_queues( + cls, + running: Deque[SequenceGroup], + waiting: Deque[SequenceGroup], + scheduler_config: SchedulerConfig, + ) -> "PartialPrefillMetadata": + """Create a PartialPrefillMetadata object from the current state of + the scheduler's queues. + This accounts for the currently running prefill requests, and peeks into + the waiting queue to see if there are more prefills to potentially be + scheduled during this iteration.""" + prefills = 0 + long_prefills = 0 + + waiting_long_prefills = 0 + + for sg in running: + if sg.first_seq.data.stage == SequenceStage.PREFILL: + prefills += 1 + if (sg.first_seq.get_num_new_tokens() + > scheduler_config.long_prefill_token_threshold): + long_prefills += 1 + + for sg in waiting: + # Don't bother looping through the rest of the queue if we know + # there are already at + # least max_partial_prefills requests to fill + if prefills >= scheduler_config.max_num_partial_prefills: + break + + # Don't count long requests from the waiting queue if we aren't + # going to schedule them anyway + if (sg.first_seq.get_num_new_tokens() + > scheduler_config.long_prefill_token_threshold): + if (long_prefills + waiting_long_prefills + >= scheduler_config.max_long_partial_prefills): + continue + waiting_long_prefills += 1 + prefills += 1 + + # NB: long_prefills and waiting_long_prefills are tracked separately. + # We don't account for the waiting requests here because we need to use + # this metadata to track how many have actually been scheduled. + return PartialPrefillMetadata( + schedulable_prefills=min( + prefills, scheduler_config.max_num_partial_prefills), + long_prefills=long_prefills, + scheduler_config=scheduler_config, + ) + + +class Scheduler: + + def __init__( + self, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + lora_config: Optional[LoRAConfig], + pipeline_parallel_size: int = 1, + output_proc_callback: Optional[Callable] = None, + ) -> None: + self.scheduler_config = scheduler_config + self.cache_config = cache_config + # Note for LoRA scheduling: the current policy is extremely + # simple and NOT fair. It can lead to starvation of some + # LoRAs. This should be improved in the future. + self.lora_config = lora_config + + version = "selfattn" + if (self.scheduler_config.runner_type == "pooling" + or self.cache_config.is_attention_free): + version = "placeholder" + + BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( + version) + + num_gpu_blocks = cache_config.num_gpu_blocks + if num_gpu_blocks: + num_gpu_blocks //= pipeline_parallel_size + + num_cpu_blocks = cache_config.num_cpu_blocks + if num_cpu_blocks: + num_cpu_blocks //= pipeline_parallel_size + + # Create the block space manager. + self.block_manager = BlockSpaceManagerImpl( + block_size=self.cache_config.block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks, + sliding_window=self.cache_config.sliding_window, + enable_caching=self.cache_config.enable_prefix_caching, + ) + + # Sequence groups in the WAITING state. + # Contain new prefill or preempted requests. + self.waiting: Deque[SequenceGroup] = deque() + # Sequence groups in the RUNNING state. + # Contain decode requests. + self.running: Deque[SequenceGroup] = deque() + # Sequence groups in the SWAPPED state. + # Contain decode requests that are swapped out. + self.swapped: Deque[SequenceGroup] = deque() + # Sequence groups finished requests ids since last step iteration. + # It lets the model know that any state associated with these requests + # can and must be released after the current step. + # This is used to evict the finished requests from the Mamba cache. + self._finished_requests_ids: List[str] = list() + # Time at previous scheduling step + self.prev_time = 0.0 + # Did we schedule a prompt at previous step? + self.prev_prompt = False + # Latency of the last prompt step + self.last_prompt_latency = 0.0 + # preemption mode, RECOMPUTE or SWAP + self.user_specified_preemption_mode = scheduler_config.preemption_mode + + # The following field is test-only. It is used to inject artificial + # preemption. + self.enable_artificial_preemption = ENABLE_ARTIFICIAL_PREEMPT + self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT + if self.enable_artificial_preemption + else 0) + self.num_cumulative_preemption: int = 0 + + # Used to cache python objects + self._seq_group_metadata_cache: List[PyObjectCache] = [] + self._scheduler_running_outputs_cache: List[PyObjectCache] = [] + self._scheduled_seq_group_cache: List[PyObjectCache] = [] + + # For async output processing, we need to swap cache buffers between + # iterations. I.e. since the output processing is lagged one step, + # we cannot reuse the cached objects immediately when the schedule() + # is called again, but only when schedule() is called the second time. + self.output_proc_callback = output_proc_callback + self.use_async_output_proc = self.output_proc_callback is not None + self.num_cache_iters = 2 if self.use_async_output_proc else 1 + + self.cache_id = 0 + for i in range(self.num_cache_iters): + self._seq_group_metadata_cache.append( + PyObjectCache(seq_group_metadata_builder)) + self._scheduler_running_outputs_cache.append( + PyObjectCache(scheduler_running_outputs_builder)) + self._scheduled_seq_group_cache.append( + PyObjectCache(scheduled_seq_group_builder)) + + # For async postprocessor, the extra decode run cannot be done + # when the request reaches max_model_len. In this case, the request + # will be stopped during schedule() call and added to this stop list + # for processing and deallocation by the free_finished_seq_groups() + self._async_stopped: List[SequenceGroup] = [] + + # List with the chunk sizes to hand out to each sequence depending + # on how many partial prefills are running. This is slightly faster than + # running an integer division every time a prefill is scheduled. + # This splits the budget evenly among all prefills. + self.partial_prefill_budget_lookup_list = [0] * ( + self.scheduler_config.max_num_partial_prefills + 1) + self.partial_prefill_budget_lookup_list[0] = ( + scheduler_config.max_num_batched_tokens) + for i in range(1, self.scheduler_config.max_num_partial_prefills + 1): + self.partial_prefill_budget_lookup_list[i] = ( + scheduler_config.max_num_batched_tokens // i) + + @property + def next_cache_id(self): + return (self.cache_id + 1) % self.num_cache_iters + + @property + def lora_enabled(self) -> bool: + return bool(self.lora_config) + + @property + def num_decoding_tokens_per_seq(self) -> int: + """The number of new tokens.""" + return 1 + + def add_seq_group(self, seq_group: SequenceGroup) -> None: + # Add sequence groups to the waiting queue. + self.waiting.append(seq_group) + + def _add_seq_group_to_running(self, seq_group: SequenceGroup) -> None: + # Add sequence groups to the running queue. + # Only for testing purposes. + self.running.append(seq_group) + + def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None: + # Add sequence groups to the swapped queue. + # Only for testing purposes. + self.swapped.append(seq_group) + + def abort_seq_group( + self, + request_id: Union[str, Iterable[str]], + seq_id_to_seq_group: Optional[Dict[str, SequenceGroupBase]] = None, + ) -> None: + """Aborts a sequence group with the given ID. + + Check if the sequence group with the given ID + is present in any of the state queue. + If present, remove the sequence group from the state queue. + Also, if any of the sequences in the sequence group is not finished, + free the sequence with status `FINISHED_ABORTED`. + Otherwise, do nothing. + + Args: + request_id: The ID(s) of the sequence group to abort. + seq_id_to_seq_group: helper for groups with n>1 + """ + if isinstance(request_id, str): + request_id = (request_id, ) + request_ids = set(request_id) + seq_id_to_seq_group = seq_id_to_seq_group or {} + for state_queue in [self.waiting, self.running, self.swapped]: + aborted_groups: List[SequenceGroup] = [] + for seq_group in state_queue: + # When n>1, seq_group.request_id looks like + # foo_parallel_sample_0, while request_ids is just foo, and we + # should resolve it as real_request_id to match. + if seq_group.request_id in seq_id_to_seq_group: + real_request_id = seq_id_to_seq_group[ + seq_group.request_id].group_id + else: + real_request_id = seq_group.request_id + if real_request_id in request_ids: + # Appending aborted group into pending list. + aborted_groups.append(seq_group) + # We can't remove real_request_id in request_ids here, + # because there may be other seq groups sharing the same + # real_request_id + for aborted_group in aborted_groups: + # Remove the sequence group from the state queue. + state_queue.remove(aborted_group) + # Remove the aborted request from the Mamba cache. + self._finished_requests_ids.append(aborted_group.request_id) + for seq in aborted_group.get_seqs(): + if seq.is_finished(): + continue + seq.status = SequenceStatus.FINISHED_ABORTED + self.free_seq(seq) + if aborted_group.request_id in seq_id_to_seq_group: + del seq_id_to_seq_group[aborted_group.request_id] + + self._free_seq_group_cross_attn_blocks(aborted_group) + + def _free_seq_group_cross_attn_blocks( + self, + seq_group: SequenceGroup, + ) -> None: + """ + Free a sequence group from a cross-attention block table. + Has no effect on decoder-only models. + """ + if seq_group.is_encoder_decoder(): + self.block_manager.free_cross(seq_group) + + def has_unfinished_seqs(self) -> bool: + return (len(self.waiting) != 0 or len(self.running) != 0 + or len(self.swapped) != 0) + + def get_prefix_cache_hit_rate(self, device: Device) -> float: + return self.block_manager.get_prefix_cache_hit_rate(device) + + def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: + return self.block_manager.reset_prefix_cache(device) + + def get_num_unfinished_seq_groups(self) -> int: + return len(self.waiting) + len(self.running) + len(self.swapped) + + def get_and_reset_finished_requests_ids(self) -> List[str]: + """Flushes the list of request ids of previously finished seq_groups.""" + finished_requests_ids = self._finished_requests_ids + self._finished_requests_ids = list() + return finished_requests_ids + + def _schedule_running( + self, + budget: SchedulingBudget, + curr_loras: Optional[Set[int]], + enable_chunking: bool = False, + partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, + ) -> SchedulerRunningOutputs: + """Schedule sequence groups that are running. + + Running queue should include decode and chunked prefill requests. + + Args: + budget: The scheduling budget. The argument is in-place updated + when any decodes are preempted. + curr_loras: Currently batched lora request ids. The argument is + in-place updated when any decodes are preempted. + enable_chunking: If True, seq group can be chunked and only a + chunked number of tokens are scheduled if + `budget.num_batched_tokens` has not enough capacity to schedule + all tokens. + partial_prefill_metadata: information about the partial prefills + that are currently running + + Returns: + SchedulerRunningOutputs. + """ + ret: SchedulerRunningOutputs = self._scheduler_running_outputs_cache[ + self.cache_id].get_object() + ret.blocks_to_swap_out.clear() + ret.blocks_to_copy.clear() + ret.decode_seq_groups.clear() + ret.prefill_seq_groups.clear() + ret.preempted.clear() + ret.swapped_out.clear() + + ret.num_lookahead_slots = self._get_num_lookahead_slots( + is_prefill=False, enable_chunking=enable_chunking) + + ret.decode_seq_groups_list.clear() + ret.prefill_seq_groups_list.clear() + + # Blocks that need to be swapped or copied before model execution. + blocks_to_swap_out: List[Tuple[int, int]] = ret.blocks_to_swap_out + blocks_to_copy: List[Tuple[int, int]] = ret.blocks_to_copy + + decode_seq_groups: List[ScheduledSequenceGroup] = ret.decode_seq_groups + prefill_seq_groups: List[ + ScheduledSequenceGroup] = ret.prefill_seq_groups + preempted: List[SequenceGroup] = ret.preempted + swapped_out: List[SequenceGroup] = ret.swapped_out + + running_queue = self.running + assert len(self._async_stopped) == 0 + while running_queue: + seq_group = running_queue[0] + # We discard the cached tokens info here because we don't need it + # for running sequence: + # 1. If a sequence is running with chunked prefill, the cached + # tokens info was already used for the first prefill. + # 2. If a sequence is running with non-chunked prefill, then + # there it's a decoding sequence, and the cached tokens info is + # irrelevant. + num_uncached_new_tokens, _ = \ + self._get_num_new_uncached_and_cached_tokens( + seq_group, + SequenceStatus.RUNNING, + enable_chunking, + budget, + partial_prefill_metadata, + ) + + num_running_tokens = num_uncached_new_tokens + if num_running_tokens == 0: + # No budget => Stop + break + + running_queue.popleft() + + # With async postprocessor, an extra decode run is done + # to process the final tokens. The check below avoids this extra + # decode run when the model max len is reached, in order to avoid + # a memory overflow. + if (self.use_async_output_proc and seq_group.seqs[0].get_len() + > self.scheduler_config.max_model_len): + self._async_stopped.append(seq_group) + continue + + # NOTE(woosuk): Preemption happens only when there is no available + # slot to keep all the sequence groups in the RUNNING state. + while not self._can_append_slots(seq_group, enable_chunking): + budget.subtract_num_batched_tokens(seq_group.request_id, + num_running_tokens) + num_running_seqs = seq_group.get_max_num_running_seqs() + budget.subtract_num_seqs(seq_group.request_id, + num_running_seqs) + + if (curr_loras is not None and seq_group.lora_int_id > 0 + and seq_group.lora_int_id in curr_loras): + curr_loras.remove(seq_group.lora_int_id) + + # Determine victim sequence + cont_loop = True + if running_queue: + # Preempt the lowest-priority sequence group. + victim_seq_group = running_queue.pop() + else: + # No other sequence group can be preempted. + # Preempt the current sequence group. + # Note: This is also where we stop this loop + # (since there is nothing else to preempt) + victim_seq_group = seq_group + cont_loop = False + + # With async postprocessor, before preempting a sequence + # we need to ensure it has no pending async postprocessor + do_preempt = True + if self.use_async_output_proc: + assert self.output_proc_callback is not None + self.output_proc_callback( + request_id=victim_seq_group.request_id) + + # It may be that the async pending "victim_seq_group" + # becomes finished, in which case we simply free it. + if victim_seq_group.is_finished(): + self._free_finished_seq_group(victim_seq_group) + do_preempt = False + + # Do preemption + if do_preempt: + preempted_mode = self._preempt(victim_seq_group, + blocks_to_swap_out) + if preempted_mode == PreemptionMode.RECOMPUTE: + preempted.append(victim_seq_group) + else: + swapped_out.append(victim_seq_group) + + if not cont_loop: + break + else: + self._append_slots(seq_group, blocks_to_copy, enable_chunking) + is_prefill = seq_group.is_prefill() + + scheduled_seq_group: ScheduledSequenceGroup = ( + self._scheduled_seq_group_cache[ + self.cache_id].get_object()) + scheduled_seq_group.seq_group = seq_group + if is_prefill: + scheduled_seq_group.token_chunk_size = num_running_tokens + prefill_seq_groups.append(scheduled_seq_group) + ret.prefill_seq_groups_list.append(seq_group) + else: + scheduled_seq_group.token_chunk_size = 1 + decode_seq_groups.append(scheduled_seq_group) + ret.decode_seq_groups_list.append(seq_group) + + budget.add_num_batched_tokens(seq_group.request_id, + num_running_tokens) + # OPTIMIZATION: Note that get_max_num_running_seqs is + # expensive. For the default scheduling chase where + # enable_chunking is False, num_seqs are updated before running + # this method, so we don't have to update it again here. + if enable_chunking: + num_running_seqs = seq_group.get_max_num_running_seqs() + budget.add_num_seqs(seq_group.request_id, num_running_seqs) + if curr_loras is not None and seq_group.lora_int_id > 0: + curr_loras.add(seq_group.lora_int_id) + + self._scheduler_running_outputs_cache[self.next_cache_id].reset() + self._scheduled_seq_group_cache[self.next_cache_id].reset() + + return ret + + def _schedule_swapped( + self, + budget: SchedulingBudget, + curr_loras: Optional[Set[int]], + enable_chunking: bool = False, + ) -> SchedulerSwappedInOutputs: + """Schedule sequence groups that are swapped out. + + It schedules swapped requests as long as it fits `budget` and + curr_loras <= max_lora from the scheduling config. The input arguments + `budget` and `curr_loras` are updated based on scheduled seq_groups. + + Args: + budget: The scheduling budget. The argument is in-place updated + when any requests are swapped in. + curr_loras: Currently batched lora request ids. The argument is + in-place updated when any requests are swapped in. + enable_chunking: If True, seq group can be chunked and only a + chunked number of tokens are scheduled if + `budget.num_batched_tokens` has not enough capacity to schedule + all tokens. + + Returns: + SchedulerSwappedInOutputs. + """ + # Blocks that need to be swapped or copied before model execution. + blocks_to_swap_in: List[Tuple[int, int]] = [] + blocks_to_copy: List[Tuple[int, int]] = [] + decode_seq_groups: List[ScheduledSequenceGroup] = [] + prefill_seq_groups: List[ScheduledSequenceGroup] = [] + infeasible_seq_groups: List[SequenceGroup] = [] + + swapped_queue = self.swapped + + leftover_swapped: Deque[SequenceGroup] = deque() + while swapped_queue: + seq_group = swapped_queue[0] + + # If the sequence group cannot be swapped in, stop. + is_prefill = seq_group.is_prefill() + alloc_status = self.block_manager.can_swap_in( + seq_group, + self._get_num_lookahead_slots(is_prefill, enable_chunking)) + if alloc_status == AllocStatus.LATER: + break + elif alloc_status == AllocStatus.NEVER: + logger.warning( + "Failing the request %s because there's not enough kv " + "cache blocks to run the entire sequence.", + seq_group.request_id, + ) + for seq in seq_group.get_seqs(): + seq.status = SequenceStatus.FINISHED_IGNORED + infeasible_seq_groups.append(seq_group) + swapped_queue.popleft() + continue + + lora_int_id = 0 + if self.lora_enabled: + lora_int_id = seq_group.lora_int_id + assert curr_loras is not None + assert self.lora_config is not None + if (lora_int_id > 0 and (lora_int_id not in curr_loras) + and len(curr_loras) >= self.lora_config.max_loras): + # We don't have a space for another LoRA, so + # we ignore this request for now. + leftover_swapped.appendleft(seq_group) + swapped_queue.popleft() + continue + + # The total number of sequences in the RUNNING state should not + # exceed the maximum number of sequences. + num_new_seqs = seq_group.get_max_num_running_seqs() + num_new_tokens_uncached, num_new_tokens_cached = ( + self._get_num_new_uncached_and_cached_tokens( + seq_group, SequenceStatus.SWAPPED, enable_chunking, + budget)) + + if num_new_tokens_uncached == 0 or not budget.can_schedule( + num_new_tokens=num_new_tokens_uncached, + num_new_seqs=num_new_seqs, + ): + break + + if lora_int_id > 0 and curr_loras is not None: + curr_loras.add(lora_int_id) + swapped_queue.popleft() + self._swap_in(seq_group, blocks_to_swap_in) + self._append_slots(seq_group, blocks_to_copy, enable_chunking) + if is_prefill: + prefill_seq_groups.append( + ScheduledSequenceGroup( + seq_group, + token_chunk_size=num_new_tokens_uncached + + num_new_tokens_cached, + )) + else: + decode_seq_groups.append( + ScheduledSequenceGroup(seq_group, token_chunk_size=1)) + budget.add_num_batched_tokens( + seq_group.request_id, + num_batched_tokens=num_new_tokens_uncached, + num_cached_tokens=num_new_tokens_cached, + ) + budget.add_num_seqs(seq_group.request_id, num_new_seqs) + + swapped_queue.extendleft(leftover_swapped) + + return SchedulerSwappedInOutputs( + decode_seq_groups=decode_seq_groups, + prefill_seq_groups=prefill_seq_groups, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_copy=blocks_to_copy, + num_lookahead_slots=self._get_num_lookahead_slots( + is_prefill=False, enable_chunking=enable_chunking), + infeasible_seq_groups=infeasible_seq_groups, + ) + + def _get_prompt_limit(self, seq_group: SequenceGroup) -> int: + if (self.scheduler_config.chunked_prefill_enabled + and not self.scheduler_config.is_multi_step): + prompt_limit = self.scheduler_config.max_model_len + else: + prompt_limit = min( + self.scheduler_config.max_model_len, + self.scheduler_config.max_num_batched_tokens, + ) + + # Model is fine tuned with long context. Return the fine tuned max_len. + if seq_group.lora_request and seq_group.lora_request.long_lora_max_len: + assert prompt_limit <= seq_group.lora_request.long_lora_max_len + return seq_group.lora_request.long_lora_max_len + else: + return prompt_limit + + def _get_priority(self, + seq_group: SequenceGroup) -> Tuple[Optional[int], float]: + """Get the priority of the sequence group. + Highest preference to user-defined priority, followed by arrival time. + Args: + seq_group: The sequence group input. + Returns: + The priority of the sequence group. + """ + return seq_group.priority, seq_group.arrival_time + + def _schedule_priority_preemption( + self, + budget: SchedulingBudget, + ) -> int: + """Sorts waiting and running queue. Also, force preempt requests + from the running queue if their priority is lower. + Priority-based preemption is used with the priority policy. + Args: + budget: The scheduling budget. The argument is in-place updated + when any requests are scheduled. + Returns: + A count of priority-based preemptions. + """ + + waiting_queue = self.waiting + + running_queue = deque(sorted(self.running, key=self._get_priority)) + + blocks_to_swap_out: List[Tuple[int, int]] = [] + force_preemption_count = 0 + + if waiting_queue: + seq_group = waiting_queue.popleft() + num_new_seqs = seq_group.get_max_num_running_seqs() + num_new_tokens_uncached, _ = \ + self._get_num_new_uncached_and_cached_tokens( + seq_group, SequenceStatus.WAITING, False, budget) + + # Only preempt if priority inversion exists + while running_queue and self._get_priority( + running_queue[-1]) > self._get_priority(seq_group): + # Only preempt if waiting sequence cannot be allocated + can_allocate = self.block_manager.can_allocate(seq_group) + if (num_new_tokens_uncached > 0 + and can_allocate == AllocStatus.OK + and budget.can_schedule( + num_new_tokens=num_new_tokens_uncached, + num_new_seqs=num_new_seqs, + )): + break + + # Adjust budget to remove the victim sequence group + vseq_group = running_queue.pop() + num_running_tokens_uncached, _ = ( + self._get_num_new_uncached_and_cached_tokens( + vseq_group, SequenceStatus.RUNNING, False, budget)) + budget.subtract_num_batched_tokens( + vseq_group.request_id, num_running_tokens_uncached) + num_running_seqs = vseq_group.get_max_num_running_seqs() + budget.subtract_num_seqs(vseq_group.request_id, + num_running_seqs) + + # Preempt out the victim sequence group + self._preempt(vseq_group, blocks_to_swap_out) + waiting_queue.appendleft(vseq_group) + force_preemption_count += 1 + # Put the sequence back into the waiting queue + waiting_queue.appendleft(seq_group) + + waiting_queue = deque(sorted(waiting_queue, key=self._get_priority)) + + self.waiting = waiting_queue + self.running = running_queue + return force_preemption_count + + def _schedule_prefills( + self, + budget: SchedulingBudget, + curr_loras: Optional[Set[int]], + enable_chunking: bool = False, + partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, + ) -> SchedulerPrefillOutputs: + """Schedule sequence groups that are in prefill stage. + + Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE + as a new prefill (that starts from beginning -> most recently generated + tokens). + + It schedules waiting requests as long as it fits `budget` and + curr_loras <= max_lora from the scheduling config. The input arguments + `budget` and `curr_loras` are updated based on scheduled seq_groups. + + Args: + budget: The scheduling budget. The argument is in-place updated + when any requests are scheduled. + curr_loras: Currently batched lora request ids. The argument is + in-place updated when any requests are scheduled. + enable_chunking: If True, seq group can be chunked and only a + chunked number of tokens are scheduled if + `budget.num_batched_tokens` has not enough capacity to schedule + all tokens. + partial_prefill_metadata: information about the partial prefills + that are currently running + + Returns: + SchedulerPrefillOutputs. + """ + if budget.remaining_token_budget() == 0: + # Do nothing: Can't add any more prefill anyway + return SchedulerPrefillOutputs( + seq_groups=[], + ignored_seq_groups=[], + num_lookahead_slots=self._get_num_lookahead_slots( + is_prefill=True, enable_chunking=enable_chunking), + ) + ignored_seq_groups: List[SequenceGroup] = [] + seq_groups: List[ScheduledSequenceGroup] = [] + + waiting_queue = self.waiting + + leftover_waiting_sequences: Deque[SequenceGroup] = deque() + while self._passed_delay(time.time()) and waiting_queue: + seq_group = waiting_queue[0] + + waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) + assert len(waiting_seqs) == 1, ( + "Waiting sequence group should have only one prompt " + "sequence.") + if (partial_prefill_metadata is not None + and not partial_prefill_metadata.can_schedule(seq_group)): + leftover_waiting_sequences.appendleft(seq_group) + waiting_queue.popleft() + continue + num_new_tokens_uncached, num_new_tokens_cached = ( + self._get_num_new_uncached_and_cached_tokens( + seq_group, + SequenceStatus.WAITING, + enable_chunking, + budget, + partial_prefill_metadata=partial_prefill_metadata, + )) + num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached + + if not enable_chunking: + num_prompt_tokens = waiting_seqs[0].get_len() + assert num_new_tokens == num_prompt_tokens + + prompt_limit = self._get_prompt_limit(seq_group) + if num_new_tokens > prompt_limit: + logger.warning( + "Input prompt (%d tokens) is too long" + " and exceeds limit of %d", + num_new_tokens, + prompt_limit, + ) + for seq in waiting_seqs: + seq.status = SequenceStatus.FINISHED_IGNORED + ignored_seq_groups.append(seq_group) + waiting_queue.popleft() + continue + + num_lookahead_slots: int = 0 + if self.scheduler_config.is_multi_step and enable_chunking: + num_lookahead_slots = self._get_num_lookahead_slots( + True, enable_chunking) + + # If the sequence group cannot be allocated, stop. + can_allocate = self.block_manager.can_allocate( + seq_group, num_lookahead_slots=num_lookahead_slots) + if can_allocate == AllocStatus.LATER: + break + elif can_allocate == AllocStatus.NEVER: + logger.warning( + "Input prompt (%d tokens) + lookahead slots (%d) is " + "too long and exceeds the capacity of block_manager", + num_new_tokens, + num_lookahead_slots, + ) + for seq in waiting_seqs: + seq.status = SequenceStatus.FINISHED_IGNORED + ignored_seq_groups.append(seq_group) + waiting_queue.popleft() + continue + + lora_int_id = 0 + if self.lora_enabled: + lora_int_id = seq_group.lora_int_id + assert curr_loras is not None + assert self.lora_config is not None + if (self.lora_enabled and lora_int_id > 0 + and lora_int_id not in curr_loras + and len(curr_loras) >= self.lora_config.max_loras): + # We don't have a space for another LoRA, so + # we ignore this request for now. + leftover_waiting_sequences.appendleft(seq_group) + waiting_queue.popleft() + continue + + if (budget.num_batched_tokens + >= self.scheduler_config.max_num_batched_tokens): + # We've reached the budget limit - since there might be + # continuous prefills in the running queue, we should break + # to avoid scheduling any new prefills. + break + + num_new_seqs = seq_group.get_max_num_running_seqs() + if num_new_tokens_uncached == 0 or not budget.can_schedule( + num_new_tokens=num_new_tokens_uncached, + num_new_seqs=num_new_seqs, + ): + break + + # Can schedule this request. + if curr_loras is not None and lora_int_id > 0: + curr_loras.add(lora_int_id) + waiting_queue.popleft() + self._allocate_and_set_running(seq_group) + + if partial_prefill_metadata is not None: + partial_prefill_metadata.maybe_increment_partial_prefills( + seq_group) + + if enable_chunking and self.scheduler_config.is_multi_step: + blocks_to_copy: List[Tuple[int, int]] = [] + # init_multi_step_from_lookahead_slots happens in append_slots + self._append_slots(seq_group, blocks_to_copy, enable_chunking) + # This assert will trip when a copy-on-write happens. This is + # not a concern as the very first sequence-group block + # allocation happens above. Still, we have the assert to + # catch any edge-cases. + assert not blocks_to_copy + else: + seq_group.init_multi_step_from_lookahead_slots( + num_lookahead_slots, + num_scheduler_steps=self.scheduler_config. + num_scheduler_steps, + is_multi_step=self.scheduler_config.is_multi_step, + enable_chunking=enable_chunking, + ) + + seq_groups.append( + ScheduledSequenceGroup(seq_group=seq_group, + token_chunk_size=num_new_tokens)) + budget.add_num_batched_tokens( + seq_group.request_id, + num_batched_tokens=num_new_tokens_uncached, + num_cached_tokens=num_new_tokens_cached, + ) + budget.add_num_seqs(seq_group.request_id, num_new_seqs) + + # Queue requests that couldn't be scheduled. + waiting_queue.extendleft(leftover_waiting_sequences) + if len(seq_groups) > 0: + self.prev_prompt = True + + return SchedulerPrefillOutputs( + seq_groups=seq_groups, + ignored_seq_groups=ignored_seq_groups, + num_lookahead_slots=self._get_num_lookahead_slots( + is_prefill=True, enable_chunking=enable_chunking), + ) + + def _schedule_default(self) -> SchedulerOutputs: + """Schedule queued requests. + + The current policy is designed to optimize the throughput. First, + it batches as many prefill requests as possible. And it schedules + decodes. If there's a pressure on GPU memory, decode requests can + be swapped or preempted. + """ + # Include running requests to the budget. + budget = SchedulingBudget( + token_budget=self.scheduler_config.max_num_batched_tokens, + max_num_seqs=self.scheduler_config.max_num_seqs, + ) + # Make sure we include num running seqs before scheduling prefill, + # so that we don't schedule beyond max_num_seqs for prefill. + for seq_group in self.running: + budget.add_num_seqs(seq_group.request_id, + seq_group.get_max_num_running_seqs()) + curr_loras = (set( + seq_group.lora_int_id for seq_group in self.running + if seq_group.lora_int_id > 0) if self.lora_enabled else None) + + prefills = SchedulerPrefillOutputs.create_empty() + running_scheduled = SchedulerRunningOutputs.create_empty() + swapped_in = SchedulerSwappedInOutputs.create_empty() + + # If any requests are swapped, prioritized swapped requests. + if not self.swapped: + prefills = self._schedule_prefills(budget, + curr_loras, + enable_chunking=False) + + if len(prefills.seq_groups + ) == 0 and self.scheduler_config.policy == "priority": + self._schedule_priority_preemption(budget) + + # Don't schedule decodes if prefills are scheduled. + # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running + # only contains decode requests, not chunked prefills. + if len(prefills.seq_groups) == 0: + running_scheduled = self._schedule_running(budget, + curr_loras, + enable_chunking=False) + + # If any sequence group is preempted, do not swap in any sequence + # group. because it means there's no slot for new running requests. + if (len(running_scheduled.preempted) + + len(running_scheduled.swapped_out) == 0): + swapped_in = \ + self._schedule_swapped(budget, curr_loras) + + assert (budget.num_batched_tokens + <= self.scheduler_config.max_num_batched_tokens) + assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs + + # Update waiting requests. + self.waiting.extendleft(running_scheduled.preempted) + # Update new running requests. + if len(prefills.seq_groups) > 0: + self.running.extend([s.seq_group for s in prefills.seq_groups]) + + self.running.extend(running_scheduled.decode_seq_groups_list) + + if len(swapped_in.decode_seq_groups) > 0: + self.running.extend( + [s.seq_group for s in swapped_in.decode_seq_groups]) + + # Update swapped requests. + self.swapped.extend(running_scheduled.swapped_out) + preempted = len(running_scheduled.preempted) + len( + running_scheduled.swapped_out) + + # There should be no prefill from running queue because this policy + # doesn't allow chunked prefills. + assert len(running_scheduled.prefill_seq_groups) == 0 + assert len(swapped_in.prefill_seq_groups) == 0 + + # Merge lists + num_prefill_groups = len(prefills.seq_groups) + if num_prefill_groups > 0: + scheduled_seq_groups = prefills.seq_groups + scheduled_seq_groups.extend(running_scheduled.decode_seq_groups) + else: + scheduled_seq_groups = running_scheduled.decode_seq_groups + scheduled_seq_groups.extend(swapped_in.decode_seq_groups) + + blocks_to_copy = running_scheduled.blocks_to_copy + blocks_to_copy.extend(swapped_in.blocks_to_copy) + + ignored_seq_groups = prefills.ignored_seq_groups + ignored_seq_groups.extend(swapped_in.infeasible_seq_groups) + + return SchedulerOutputs( + scheduled_seq_groups=scheduled_seq_groups, + num_prefill_groups=num_prefill_groups, + num_batched_tokens=budget.num_batched_tokens + + budget.num_cached_tokens, + blocks_to_swap_in=swapped_in.blocks_to_swap_in, + blocks_to_swap_out=running_scheduled.blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + ignored_seq_groups=ignored_seq_groups, + num_lookahead_slots=running_scheduled.num_lookahead_slots, + running_queue_size=len(self.running), + preempted=preempted, + ) + + def _schedule_chunked_prefill(self) -> SchedulerOutputs: + """Schedule queued requests. + + Chunked prefill allows to chunk prefill requests, batch them together + with decode requests. This policy 1. schedule as many decoding requests + as possible. 2. schedule chunked prefill requests that are not + finished. 3. schedule swapped request. 4. schedule new prefill + requests. + + The policy can sustain the high GPU utilization because it can put + prefill and decodes requests to the same batch, while it improves + inter token latency because decodes requests don't need to be blocked + by prefill requests. + """ + budget = SchedulingBudget( + token_budget=self.scheduler_config.max_num_batched_tokens, + max_num_seqs=self.scheduler_config.max_num_seqs, + ) + curr_loras: Set[int] = set() + + prefills = SchedulerPrefillOutputs.create_empty() + swapped_in = SchedulerSwappedInOutputs.create_empty() + + # Create partial prefill metadata + partial_prefill_metadata = PartialPrefillMetadata.from_queues( + running=self.running, + waiting=self.waiting, + scheduler_config=self.scheduler_config, + ) + + # Decoding should be always scheduled first by fcfs. + running_scheduled = self._schedule_running( + budget, + curr_loras, + enable_chunking=True, + partial_prefill_metadata=partial_prefill_metadata, + ) + + # Schedule swapped out requests. + # If preemption happens, it means we don't have space for swap-in. + if len(running_scheduled.preempted) + len( + running_scheduled.swapped_out) == 0: + swapped_in = self._schedule_swapped(budget, curr_loras) + + prefills = self._schedule_prefills( + budget, + curr_loras, + enable_chunking=True, + partial_prefill_metadata=partial_prefill_metadata, + ) + + assert (budget.num_batched_tokens + <= self.scheduler_config.max_num_batched_tokens) + assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs + + # Update waiting requests. + self.waiting.extendleft(running_scheduled.preempted) + + # Update new running requests. + # By default, vLLM scheduler prioritizes prefills. + # Once chunked prefill is enabled, + # the policy is changed to prioritize decode requests. + self.running.extend( + [s.seq_group for s in swapped_in.decode_seq_groups]) + self.running.extend( + [s.seq_group for s in swapped_in.prefill_seq_groups]) + self.running.extend( + [s.seq_group for s in running_scheduled.decode_seq_groups]) + # Because multiple prefills may be running concurrently, we need to + # make sure that prefills which are scheduled to finish are listed + # before those that won't. This is so that on the next scheduling + # iteration when they have transitioned to the decode stage, they are + # properly prioritized over sequences that are still in the prefill + # stage. + self.running.extend( + self._order_finishing_prefills_first( + running_scheduled.prefill_seq_groups)) + self.running.extend([s.seq_group for s in prefills.seq_groups]) + + # Update swapped requests. + self.swapped.extend(running_scheduled.swapped_out) + # Put prefills first due to Attention backend ordering assumption. + scheduled_seq_groups = (prefills.seq_groups + + running_scheduled.prefill_seq_groups + + swapped_in.prefill_seq_groups + + running_scheduled.decode_seq_groups + + swapped_in.decode_seq_groups) + num_prefill_groups = (len(prefills.seq_groups) + + len(swapped_in.prefill_seq_groups) + + len(running_scheduled.prefill_seq_groups)) + # If all prompts, then we set num_lookahead_slots to 0 + # this allows us to go through the `no_spec` path in + # `spec_decode_worker.py` + all_prefills = len(scheduled_seq_groups) == num_prefill_groups + num_lookahead_slots = (0 if + (all_prefills + and not self.scheduler_config.is_multi_step) + else running_scheduled.num_lookahead_slots) + return SchedulerOutputs( + scheduled_seq_groups=scheduled_seq_groups, + num_prefill_groups=num_prefill_groups, + num_batched_tokens=budget.num_batched_tokens + + budget.num_cached_tokens, + blocks_to_swap_in=swapped_in.blocks_to_swap_in, + blocks_to_swap_out=running_scheduled.blocks_to_swap_out, + blocks_to_copy=running_scheduled.blocks_to_copy + + swapped_in.blocks_to_copy, + ignored_seq_groups=prefills.ignored_seq_groups + + swapped_in.infeasible_seq_groups, + num_lookahead_slots=num_lookahead_slots, + running_queue_size=len(self.running), + preempted=(len(running_scheduled.preempted) + + len(running_scheduled.swapped_out)), + ) + + def _order_finishing_prefills_first( + self, scheduled_prefill_seqs: List[ScheduledSequenceGroup] + ) -> List[SequenceGroup]: + """Returns a list of prefilling SequenceGroups where sequences that are + scheduled to finish prefilling are listed first""" + finishing = [ + s.seq_group for s in scheduled_prefill_seqs + if s.seq_group.get_num_uncomputed_tokens() == s.token_chunk_size + ] + not_finishing = [ + s.seq_group for s in scheduled_prefill_seqs + if s.seq_group.get_num_uncomputed_tokens() != s.token_chunk_size + ] + return finishing + not_finishing + + def _schedule(self) -> SchedulerOutputs: + """Schedule queued requests.""" + if self.scheduler_config.chunked_prefill_enabled: + return self._schedule_chunked_prefill() + else: + return self._schedule_default() + + def _can_append_slots(self, seq_group: SequenceGroup, + enable_chunking: bool) -> bool: + """Determine whether or not we have enough space in the KV cache to + continue generation of the sequence group. + """ + # It is True only for testing case to trigger artificial preemption. + if (self.enable_artificial_preemption + and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB + and self.artificial_preempt_cnt > 0): + self.artificial_preempt_cnt -= 1 + return False + + is_prefill = seq_group.is_prefill() + num_lookahead_slots = self._get_num_lookahead_slots( + is_prefill, enable_chunking) + + if is_prefill and num_lookahead_slots > 0: + # Appending prefill slots only happens multi-step and + # chunked-prefill are enabled together. + assert self.scheduler_config.is_multi_step and enable_chunking + + return self.block_manager.can_append_slots( + seq_group=seq_group, num_lookahead_slots=num_lookahead_slots) + + def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool: + # async_output_proc is allowed only when we have a single sequence + # in the sequence group + no_single_seq = seq_group.sampling_params is None or ( + seq_group.sampling_params.n == 1) + return no_single_seq + + def schedule( + self + ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]: + # Schedule sequence groups. + # This function call changes the internal states of the scheduler + # such as self.running, self.swapped, and self.waiting. + scheduler_start_time = time.perf_counter() + + scheduler_outputs: SchedulerOutputs = self._schedule() + now = time.time() + + if not self.cache_config.enable_prefix_caching: + common_computed_block_nums = [] + + allow_async_output_proc: bool = self.use_async_output_proc + + # Create input data structures. + seq_group_metadata_list: List[SequenceGroupMetadata] = [] + for i, scheduled_seq_group in enumerate( + scheduler_outputs.scheduled_seq_groups): + seq_group = scheduled_seq_group.seq_group + token_chunk_size = scheduled_seq_group.token_chunk_size + seq_group.maybe_set_first_scheduled_time(now) + + seq_group_metadata = self._seq_group_metadata_cache[ + self.cache_id].get_object() + seq_group_metadata.seq_data.clear() + seq_group_metadata.block_tables.clear() + + # seq_id -> SequenceData + seq_data: Dict[int, SequenceData] = {} + # seq_id -> physical block numbers + block_tables: Dict[int, List[int]] = {} + + if seq_group.is_encoder_decoder(): + # Encoder associated with SequenceGroup + encoder_seq = seq_group.get_encoder_seq() + assert encoder_seq is not None + encoder_seq_data = encoder_seq.data + # Block table for cross-attention + # Also managed at SequenceGroup level + cross_block_table = self.block_manager.get_cross_block_table( + seq_group) + else: + encoder_seq_data = None + cross_block_table = None + + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + seq_id = seq.seq_id + seq_data[seq_id] = seq.data + block_tables[seq_id] = self.block_manager.get_block_table(seq) + self.block_manager.access_all_blocks_in_seq(seq, now) + + if self.cache_config.enable_prefix_caching: + common_computed_block_nums = ( + self.block_manager.get_common_computed_block_ids( + seq_group.get_seqs(status=SequenceStatus.RUNNING))) + + do_sample = True + is_prompt = seq_group.is_prefill() + # We should send the metadata to workers when the first prefill + # is sent. Subsequent requests could be chunked prefill or decode. + is_first_prefill = False + if is_prompt: + seqs = seq_group.get_seqs() + # Prefill has only 1 sequence. + assert len(seqs) == 1 + num_computed_tokens = seqs[0].data.get_num_computed_tokens() + is_first_prefill = num_computed_tokens == 0 + # In the next iteration, all prompt tokens are not computed. + # It means the prefill is chunked, and we don't need sampling. + # NOTE: We use get_len instead of get_prompt_len because when + # a sequence is preempted, prefill includes previous generated + # output tokens. + if (token_chunk_size + num_computed_tokens + < seqs[0].data.get_len()): + do_sample = False + + # It assumes the scheduled_seq_groups is ordered by + # prefill < decoding. + if is_first_prefill or not self.scheduler_config.send_delta_data: + seq_group_metadata = SequenceGroupMetadata( + request_id=seq_group.request_id, + is_prompt=is_prompt, + seq_data=seq_data, + sampling_params=seq_group.sampling_params, + block_tables=block_tables, + do_sample=do_sample, + pooling_params=seq_group.pooling_params, + token_chunk_size=token_chunk_size, + lora_request=seq_group.lora_request, + computed_block_nums=common_computed_block_nums, + encoder_seq_data=encoder_seq_data, + cross_block_table=cross_block_table, + state=seq_group.state, + token_type_ids=seq_group.token_type_ids, + # `multi_modal_data` will only be present for the 1st comm + # between engine and worker. + # the subsequent comms can still use delta, but + # `multi_modal_data` will be None. + multi_modal_data=(seq_group.multi_modal_data + if scheduler_outputs.num_prefill_groups + > 0 else None), + multi_modal_placeholders=( + seq_group.multi_modal_placeholders + if scheduler_outputs.num_prefill_groups > 0 else None), + mm_processor_kwargs=seq_group.mm_processor_kwargs, + prompt_adapter_request=seq_group.prompt_adapter_request, + ) + else: + # When SPMD mode is enabled, we only send delta data except for + # the first request to reduce serialization cost. + seq_data_delta = {} + for id, data in seq_data.items(): + seq_data_delta[id] = data.get_delta_and_reset() + seq_group_metadata = SequenceGroupMetadataDelta( + seq_data_delta, + seq_group.request_id, + block_tables, + is_prompt, + do_sample=do_sample, + token_chunk_size=token_chunk_size, + computed_block_nums=common_computed_block_nums, + ) + seq_group_metadata_list.append(seq_group_metadata) + + if allow_async_output_proc: + allow_async_output_proc = self._allow_async_output_proc( + seq_group) + + # Now that the batch has been created, we can assume all blocks in the + # batch will have been computed before the next scheduling invocation. + # This is because the engine assumes that a failure in model execution + # will crash the vLLM instance / will not retry. + for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: + self.block_manager.mark_blocks_as_computed( + scheduled_seq_group.seq_group, + scheduled_seq_group.token_chunk_size) + + self._seq_group_metadata_cache[self.next_cache_id].reset() + + scheduler_time = time.perf_counter() - scheduler_start_time + # Add this to scheduler time to all the sequences that are currently + # running. This will help estimate if the scheduler is a significant + # component in the e2e latency. + for seq_group in self.running: + if seq_group is not None and seq_group.metrics is not None: + if seq_group.metrics.scheduler_time is not None: + seq_group.metrics.scheduler_time += scheduler_time + else: + seq_group.metrics.scheduler_time = scheduler_time + + # Move to next cache (if exists) + self.cache_id = self.next_cache_id + + # Return results + return (seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc) + + def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: + self.block_manager.fork(parent_seq, child_seq) + + def free_seq(self, seq: Sequence) -> None: + """Free a sequence from a block table.""" + self.block_manager.free(seq) + + def _free_finished_seqs(self, seq_group: SequenceGroup) -> None: + """Free finished seqs in a sequence group.""" + for seq in seq_group.get_seqs(): + if seq.is_finished(): + self.free_seq(seq) + + def _free_finished_seq_group(self, seq_group: SequenceGroup) -> None: + if seq_group.is_finished(): + # Free cross-attention block table, if it exists + self._free_seq_group_cross_attn_blocks(seq_group) + + # Add the finished requests to the finished requests list. + # This list will be used to update the Mamba cache in the + # next step. + self._finished_requests_ids.append(seq_group.request_id) + + # Free finished seqs + self._free_finished_seqs(seq_group) + + def free_finished_seq_groups(self) -> None: + remaining: Deque[SequenceGroup] = deque() + for seq_group in self.running: + self._free_finished_seq_group(seq_group) + if not seq_group.is_finished(): + remaining.append(seq_group) + + self.running = remaining + + # Handle async stopped sequence groups + # (ones that reached max model len) + if self._async_stopped: + for seq_group in self._async_stopped: + self._free_seq_group_cross_attn_blocks(seq_group) + self._finished_requests_ids.append(seq_group.request_id) + + # Free finished seqs + self._free_finished_seqs(seq_group) + + self._async_stopped.clear() + + def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: + self.block_manager.allocate(seq_group) + for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): + seq.status = SequenceStatus.RUNNING + + def _append_slots( + self, + seq_group: SequenceGroup, + blocks_to_copy: List[Tuple[int, int]], + enable_chunking: bool = False, + ) -> None: + """Appends new slots to the sequences in the given sequence group. + + Args: + seq_group (SequenceGroup): The sequence group containing the + sequences to append slots to. + blocks_to_copy (List[Tuple[int, int]]): A list of tuple of two + ints, the first int is the source block index, and the second + int is the destination block index. This list is updated with + the new source and destination block indices for the appended + slots. + enable_chunking (bool): True if chunked prefill is enabled. + """ + is_prefill: bool = seq_group.is_prefill() + num_lookahead_slots: int = self._get_num_lookahead_slots( + is_prefill, enable_chunking) + + seq_group.init_multi_step_from_lookahead_slots( + num_lookahead_slots, + num_scheduler_steps=self.scheduler_config.num_scheduler_steps, + is_multi_step=self.scheduler_config.is_multi_step, + enable_chunking=enable_chunking, + ) + + seq_status: Optional[SequenceStatus] = SequenceStatus.RUNNING + if self.scheduler_config.is_multi_step and enable_chunking: + # In multi-step chunked-prefill any sequence type can have + # slots appended. + seq_status = None + + for seq in seq_group.get_seqs(status=seq_status): + cows = self.block_manager.append_slots(seq, num_lookahead_slots) + if len(cows) > 0: + blocks_to_copy.extend(cows) + + def _preempt(self, seq_group: SequenceGroup, + blocks_to_swap_out: List[Tuple[int, int]]) -> PreemptionMode: + # If preemption mode is not specified, we determine the mode as follows: + # We use recomputation by default since it incurs lower overhead than + # swapping. However, when the sequence group has multiple sequences + # (e.g., beam search), recomputation is not currently supported. In + # such a case, we use swapping instead. + # FIXME(woosuk): This makes our scheduling policy a bit bizarre. + # As swapped sequences are prioritized over waiting sequences, + # sequence groups with multiple sequences are implicitly prioritized + # over sequence groups with a single sequence. + # TODO(woosuk): Support recomputation for sequence groups with multiple + # sequences. This may require a more sophisticated CUDA kernel. + if self.user_specified_preemption_mode is None: + if seq_group.get_max_num_running_seqs() == 1: + preemption_mode = PreemptionMode.RECOMPUTE + else: + preemption_mode = PreemptionMode.SWAP + + elif self.user_specified_preemption_mode == "swap": + preemption_mode = PreemptionMode.SWAP + else: + preemption_mode = PreemptionMode.RECOMPUTE + + if self.num_cumulative_preemption % 50 == 0: + logger.warning( + "Sequence group %s is preempted by %s mode because there is " + "not enough KV cache space. This can affect the end-to-end " + "performance. Increase gpu_memory_utilization or " + "tensor_parallel_size to provide more KV cache memory. " + "total_num_cumulative_preemption=%d", + seq_group.request_id, + preemption_mode, + self.num_cumulative_preemption + 1, + ) + self.num_cumulative_preemption += 1 + + if preemption_mode == PreemptionMode.RECOMPUTE: + self._preempt_by_recompute(seq_group) + elif preemption_mode == PreemptionMode.SWAP: + self._preempt_by_swap(seq_group, blocks_to_swap_out) + else: + raise AssertionError("Invalid preemption mode.") + return preemption_mode + + def _preempt_by_recompute( + self, + seq_group: SequenceGroup, + ) -> None: + seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + assert len(seqs) == 1 + for seq in seqs: + seq.status = SequenceStatus.WAITING + self.free_seq(seq) + seq.reset_state_for_recompute() + self._free_seq_group_cross_attn_blocks(seq_group) + + def _preempt_by_swap( + self, + seq_group: SequenceGroup, + blocks_to_swap_out: List[Tuple[int, int]], + ) -> None: + self._swap_out(seq_group, blocks_to_swap_out) + + def _swap_in( + self, + seq_group: SequenceGroup, + blocks_to_swap_in: List[Tuple[int, int]], + ) -> None: + mapping = self.block_manager.swap_in(seq_group) + blocks_to_swap_in.extend(mapping) + for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): + seq.status = SequenceStatus.RUNNING + + def _swap_out( + self, + seq_group: SequenceGroup, + blocks_to_swap_out: List[Tuple[int, int]], + ) -> None: + if not self.block_manager.can_swap_out(seq_group): + # FIXME(woosuk): Abort the sequence group instead of aborting the + # entire engine. + raise RuntimeError( + "Aborted due to the lack of CPU swap space. Please increase " + "the swap space to avoid this error.") + mapping = self.block_manager.swap_out(seq_group) + blocks_to_swap_out.extend(mapping) + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + seq.status = SequenceStatus.SWAPPED + + def _passed_delay(self, now: float) -> bool: + if self.prev_prompt: + self.last_prompt_latency = now - self.prev_time + self.prev_time, self.prev_prompt = now, False + # Delay scheduling prompts to let waiting queue fill up + if self.scheduler_config.delay_factor > 0 and self.waiting: + earliest_arrival_time = min( + [e.metrics.arrival_time for e in self.waiting]) + passed_delay = ((now - earliest_arrival_time) + > (self.scheduler_config.delay_factor * + self.last_prompt_latency) or not self.running) + else: + passed_delay = True + return passed_delay + + def _get_num_lookahead_slots(self, is_prefill: bool, + enable_chunking: bool) -> int: + """The number of slots to allocate per sequence per step, beyond known + token ids. Speculative decoding uses these slots to store KV activations + of tokens which may or may not be accepted. + + Speculative decoding does not yet support prefill, so we do not perform + lookahead allocation for prefill. + + When chunking is enabled with multi-step, we allocate lookahead slots + for the prefills for when the prefills turn into decodes in the first + step. + """ + if is_prefill: + if self.scheduler_config.is_multi_step and enable_chunking: + # num_lookahead_slots was introduced in the context of decodes, + # in Speculative Decoding. + # When the num_scheduler_steps is 8, say, then the + # num_lookahead_slots is 7. Meaning, we are doing a 1-step of + # decode anyways and we wish to do 7 more. + # + # "lookaheads" for prefills, is introduced in support for + # Chunked-Prefill in Multi-Step. + return self.scheduler_config.num_lookahead_slots + 1 + else: + return 0 + + return self.scheduler_config.num_lookahead_slots + + def _get_num_new_uncached_and_cached_tokens( + self, + seq_group: SequenceGroup, + status: SequenceStatus, + enable_chunking: bool, + budget: SchedulingBudget, + partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, + ) -> Tuple[int, int]: + """ + Returns the number of new uncached and cached tokens to schedule for a + given sequence group that's in a given `status`. + + The API could chunk the number of tokens to compute based on `budget` + if `enable_chunking` is True. If a sequence group has multiple + sequences (e.g., running beam search), it means it is in decoding + phase, so chunking doesn't happen. + + Returns (0, 0) if the new token cannot be computed due to token budget. + + The cached tokens's blocks are already computed, and the attention + backend will reuse the cached blocks rather than recomputing them. So + the scheduler could schedule these cached tokens "for free". + + Args: + seq_group: The sequence group to get the number of new tokens to + schedule. + status: The status of the sequences to get the number of new tokens + to schedule. + enable_chunking: Whether to chunk the number of tokens to compute. + budget: The budget to chunk the number of tokens to compute. + partial_prefill_metadata: information about the partial prefills + that are currently running + + + Returns: + A tuple of two ints. The first int is the number of new uncached + tokens to schedule. The second int is the number of cached tokens. + If no more new tokens can be scheduled, returns (0, 0). + """ + num_cached_new_tokens = 0 + num_uncached_new_tokens = 0 + + seqs = seq_group.get_seqs(status=status) + # Compute the number of new uncached and cached tokens for + # each sequence. + for seq in seqs: + if not seq.is_prefill(): + # Decode sequences should always just have 1 uncached token + # TODO(rickyx): Actually is this still correct for multi-step? + num_uncached_new_tokens += 1 + continue + + num_computed_tokens_seq = seq.get_num_computed_tokens() + all_num_new_tokens_seq = seq.get_len() - num_computed_tokens_seq + if not self.cache_config.enable_prefix_caching: + # If prefix caching is not enabled, all new tokens are uncached. + num_uncached_new_tokens += all_num_new_tokens_seq + continue + + # NOTE: the cache token might be currently in a block that's in an + # evictor meaning that it's not yet allocated. However, we don't + # exclude such tokens in the cache count because it will be + # guaranteed to be allocated later if the sequence can be allocated. + num_cached_tokens_seq = self.block_manager.get_num_cached_tokens( + seq) + + # Sanity check. + if num_cached_tokens_seq < num_computed_tokens_seq: + # This should only happen with chunked prefill, and + # the seq is still in prefill. The `num_cached_tokens_seq` + # is the value we calculated on scheduling the first prefill. + # For subsequent continuous prefill steps, we cached the + # number of cache tokens for the sequence so the cached token + # count could be less than the number of computed tokens. + # See comments on `ComputedBlocksTracker` for more details. + assert ( + seq.is_prefill() and seq.status == SequenceStatus.RUNNING + and self.scheduler_config.chunked_prefill_enabled + ), ("Number of cached tokens should not be less than the " + "number of computed tokens for a sequence that's still " + f"in prefill. But there are {num_cached_tokens_seq} cached " + f"tokens and {num_computed_tokens_seq} computed tokens " + f"for sequence {seq.seq_id}.") + + num_cached_new_tokens_seq = max( + 0, num_cached_tokens_seq - num_computed_tokens_seq) + num_uncached_new_tokens_seq = (all_num_new_tokens_seq - + num_cached_new_tokens_seq) + + num_uncached_new_tokens += num_uncached_new_tokens_seq + num_cached_new_tokens += num_cached_new_tokens_seq + + if num_uncached_new_tokens == 0 and num_cached_new_tokens > 0: + # For a fully cached hit sequence, we actually need to recompute the + # last token. So we need at least 1 uncached token to schedule. + # See ModelRunner._compute_for_prefix_cache_hit for more details. + num_uncached_new_tokens = 1 + num_cached_new_tokens -= 1 + + if enable_chunking and len(seqs) == 1: + # Chunk if a running request cannot fit in the given budget. + # If number of seq > 1, it means it is doing beam search + # in a decode phase. Do not chunk. + num_uncached_new_tokens = self._chunk_new_tokens_to_schedule( + self.scheduler_config, + self.cache_config, + budget, + self._get_prompt_limit(seq_group), + num_uncached_new_tokens, + self.partial_prefill_budget_lookup_list, + partial_prefill_metadata, + ) + + return num_uncached_new_tokens, num_cached_new_tokens + + @staticmethod + def _chunk_new_tokens_to_schedule( + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + budget: SchedulingBudget, + prompt_limit: int, + num_new_tokens: int, + partial_prefill_budget_lookup_list: List[int], + partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, + ) -> int: + """ + Chunks the number of new tokens to schedule based on the budget when + chunked prefill is enabled. + + Args: + scheduler_config: The scheduler config. + cache_config: The cache config. + budget: The budget to chunk the number of tokens to compute. + prompt_limit: The maximum number of tokens allowed in a prompt. + num_new_tokens: The number of new tokens to schedule. + + Returns: + The number of new tokens to schedule after chunking. + """ + remaining_token_budget = budget.remaining_token_budget() + if scheduler_config.is_multi_step: + # The current multi-step + chunked prefill capability does + # not actually support chunking prompts. + # + # Therefore, `num_new_tokens` is computed in the same fashion + # for both multi-step+chunked-prefill & + # multi-step+chunked-prefill+APC + # + # Prompts with more tokens than the current remaining budget + # are postponed to future scheduler steps + if num_new_tokens > prompt_limit: + # If the seq_group is in prompt-stage, pass the + # num_new_tokens as-is so the caller can ignore + # the sequence. + return num_new_tokens + + return 0 if num_new_tokens > \ + remaining_token_budget else num_new_tokens + + # Get the number of tokens to allocate to this prefill slot + prefill_slot_budget = ( + remaining_token_budget if partial_prefill_metadata is None else + partial_prefill_budget_lookup_list[ + partial_prefill_metadata.schedulable_prefills]) + + if cache_config.enable_prefix_caching: + # When prefix caching is enabled and we're partially prefilling + # a sequence, we always allocate a number of new tokens that is + # divisible by the block size to avoid partial block matching. + block_size = cache_config.block_size + # Don't exceed either the total budget or slot budget. + # Take min of those and get the next lowest multiple of the + # block size: + remaining_token_budget = ( + min(remaining_token_budget, prefill_slot_budget) // + block_size) * block_size + # NB: In the case where num_new_tokens < budget, we are + # finishing prefill for this sequence, so we do not need to + # allocate a full block. + + num_new_tokens = min(num_new_tokens, remaining_token_budget, + prefill_slot_budget) + + return num_new_tokens diff --git a/vllm/device_allocator/__init__.py b/vllm/device_allocator/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py new file mode 100644 index 0000000..9ff77f1 --- /dev/null +++ b/vllm/device_allocator/cumem.py @@ -0,0 +1,280 @@ +# SPDX-License-Identifier: Apache-2.0 + +# cumem-based pytorch pluggable allocator to implement sleep mode. +# other approaches tried but failed: +# - cuda-python package binding +# - custom libcuda driver ctypes wrapper +# both of them failed because of cuda context mismatch. +# not sure why, they are created from a different context. +# the only successful approach is to call cuda driver API in C. +import dataclasses +import gc +import os +from contextlib import contextmanager +from typing import Any, Callable, Dict, Optional, Tuple, Union + +import torch + +from vllm.utils import is_pin_memory_available + + +def find_loaded_library(lib_name) -> Optional[str]: + """ + According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html, + the file `/proc/self/maps` contains the memory maps of the process, which includes the + shared libraries loaded by the process. We can use this file to find the path of the + a loaded library. + """ # noqa + found_line = None + with open("/proc/self/maps") as f: + for line in f: + if lib_name in line: + found_line = line + break + if found_line is None: + # the library is not loaded in the current process + return None + # if lib_name is libcudart, we need to match a line with: + # address /path/to/libcudart-hash.so.11.0 + start = found_line.index("/") + path = found_line[start:].strip() + filename = path.split("/")[-1] + assert filename.rpartition(".so")[0].startswith(lib_name), \ + f"Unexpected filename: {filename} for library {lib_name}" + return path + + +cumem_available = False +try: + from vllm.cumem_allocator import (init_module, python_create_and_map, + python_unmap_and_release) + from vllm.distributed.device_communicators.cuda_wrapper import ( + CudaRTLibrary) + lib_name = find_loaded_library("cumem_allocator") + libcudart = CudaRTLibrary() + cumem_available = True +except ModuleNotFoundError: + # rocm platform does not support cumem allocator + init_module = None + python_create_and_map = None + python_unmap_and_release = None + CudaRTLibrary = None + lib_name = None + libcudart = None + +# py_device, py_alignedSize, py_d_mem, py_p_memHandle +HandleType = Tuple[int, int, int, int] + + +@dataclasses.dataclass +class AllocationData: + handle: HandleType + tag: str + cpu_backup_tensor: Optional[torch.Tensor] = None + + +def create_and_map(allocation_handle: HandleType) -> None: + python_create_and_map(*allocation_handle) + + +def unmap_and_release(allocation_handle: HandleType) -> None: + python_unmap_and_release(*allocation_handle) + + +def get_pluggable_allocator( + python_malloc_fn: Callable[[int], + int], python_free_func: Callable[[int, int], + None] +) -> torch.cuda.memory.CUDAPluggableAllocator: + init_module(python_malloc_fn, python_free_func) + new_alloc = torch.cuda.memory.CUDAPluggableAllocator( + lib_name, 'my_malloc', 'my_free') + return new_alloc + + +@contextmanager +def use_memory_pool_with_allocator( + python_malloc_fn: Callable[[int], int], + python_free_func: Callable[[int, int], None]) -> None: + new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func) + mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator) + with torch.cuda.memory.use_mem_pool(mem_pool): + yield mem_pool, new_alloc + + +class CuMemAllocator: + """ + A singleton class that manages a memory pool for CUDA tensors. + The memory in this pool can be offloaded or discarded when the + allocator sleeps. + + Inside the `use_memory_pool(tag)` context, all tensors created will + be allocated in the memory pool, and has the same tag as the + tag passed to the context. + + When we call `sleep`, all tensors with the specified tag will be + offloaded to CPU memory, and the rest of the tensors will be discarded. + When we call `wake_up`, all tensors that are previously offloaded + will be loaded back to GPU memory, and the rest of the tensors will + have empty memory. + + Why it needs to be a singleton? + When allocated tensors are garbage collected, PyTorch will call + the free callback, which will call the `python_free_callback` method. + The C-extension uses a global variable to store the function of an + instance of this class. If we create multiple instances of this class, + the global variable will be overwritten and the free callback will + not work as expected. + """ + instance: "CuMemAllocator" = None + default_tag: str = "default" + + @staticmethod + def get_instance() -> "CuMemAllocator": + """ + CuMemAllocator is a singleton class. + We cannot call the constructor directly. + Call this method to get the instance. + """ + assert cumem_available, "cumem allocator is not available" + if CuMemAllocator.instance is None: + CuMemAllocator.instance = CuMemAllocator() + return CuMemAllocator.instance + + def __init__(self): + conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "") + assert "expandable_segments:True" not in conf, \ + ("Expandable segments are not compatible with memory pool. " + "Please track https://github.com/pytorch/pytorch/issues/147851 " + "for the latest updates.") + + self.pointer_to_data: Dict[int, AllocationData] = {} + self.current_tag: str = CuMemAllocator.default_tag + self.allocator_and_pools: Dict[str, Any] = {} + + def python_malloc_callback(self, allocation_handle: HandleType) -> None: + """ + Internal method to store the allocation data + when memory is allocated in the memory pool.""" + py_d_mem = allocation_handle[2] + self.pointer_to_data[py_d_mem] = AllocationData( + allocation_handle, self.current_tag) + return + + def python_free_callback(self, ptr: int) -> HandleType: + """ + Internal method to look up the allocation data + when memory is freed in the memory pool.""" + data = self.pointer_to_data.pop(ptr) + if data.cpu_backup_tensor is not None: + data.cpu_backup_tensor = None + return data.handle + + def sleep( + self, + offload_tags: Optional[Union[Tuple[str, ...], + str]] = None) -> None: + """ + Put the allocator in sleep mode. + All data in the memory allocation with the specified tag will be + offloaded to CPU memory, and others will be discarded. + + :param offload_tags: The tags of the memory allocation that will be + offloaded. The rest of the memory allocation will be discarded. + """ + if offload_tags is None: + # by default, allocated tensors are offloaded + # when the allocator sleeps + offload_tags = (CuMemAllocator.default_tag, ) + elif isinstance(offload_tags, str): + offload_tags = (offload_tags, ) + + assert isinstance(offload_tags, tuple) + + for ptr, data in self.pointer_to_data.items(): + handle = data.handle + if data.tag in offload_tags: + size_in_bytes = handle[1] + cpu_backup_tensor = torch.empty( + size_in_bytes, + dtype=torch.uint8, + device='cpu', + pin_memory=is_pin_memory_available()) + cpu_ptr = cpu_backup_tensor.data_ptr() + libcudart.cudaMemcpy(cpu_ptr, ptr, size_in_bytes) + data.cpu_backup_tensor = cpu_backup_tensor + unmap_and_release(handle) + + gc.collect() + torch.cuda.empty_cache() + + def wake_up(self, tags: Optional[list[str]] = None) -> None: + """ + Wake up the allocator from sleep mode. + All data that is previously offloaded will be loaded back to GPU + memory, and the rest of the data will have empty memory. + + :param tags: The tags of the memory allocation that will be loaded + back to GPU memory. If None, all memory allocation will be loaded + back to GPU memory. + """ + for ptr, data in self.pointer_to_data.items(): + if tags is None or data.tag in tags: + handle = data.handle + create_and_map(handle) + if data.cpu_backup_tensor is not None: + cpu_backup_tensor = data.cpu_backup_tensor + if cpu_backup_tensor is not None: + size_in_bytes = cpu_backup_tensor.numel( + ) * cpu_backup_tensor.element_size() + cpu_ptr = cpu_backup_tensor.data_ptr() + libcudart.cudaMemcpy(ptr, cpu_ptr, size_in_bytes) + data.cpu_backup_tensor = None + + @contextmanager + def use_memory_pool(self, tag: Optional[str] = None): + """ + A context manager to use the memory pool. + All memory allocation created inside the context will be allocated + in the memory pool, and has the specified tag. + + :param tag: The tag of the memory allocation. If None, the default tag + will be used. + """ + if tag is None: + tag = CuMemAllocator.default_tag + + assert isinstance(tag, str) + + old_tag = self.current_tag + self.current_tag = tag + with use_memory_pool_with_allocator(self.python_malloc_callback, + self.python_free_callback) as data: + # start to hit another PyTorch bug in PyTorch 2.6, + # possibly because of gc-related issue w.r.t. the allocator and + # the memory pool. + # to avoid the issue, we keep a reference of the data. + # see https://github.com/pytorch/pytorch/issues/146431 . + self.allocator_and_pools[tag] = data + yield + # PyTorch's bug, calling torch.cuda.empty_cache() will error + # when using pluggable allocator, see + # https://github.com/pytorch/pytorch/issues/145168 . + # if we have some memory allocated and then freed, + # the memory will not be released. + # right now it is fine, because we only use this allocator + # during weight loading and kv cache creation, where we only + # allocate memory. + # TODO: we need to find a way to release the memory, + # i.e. calling torch.cuda.empty_cache() + self.current_tag = old_tag + + def get_current_usage(self) -> int: + """ + Get the total number of bytes allocated in the memory pool. + """ + sum_bytes: int = 0 + for ptr, data in self.pointer_to_data.items(): + handle = data.handle + sum_bytes += handle[1] + return sum_bytes diff --git a/vllm/distributed/__init__.py b/vllm/distributed/__init__.py new file mode 100644 index 0000000..39955dd --- /dev/null +++ b/vllm/distributed/__init__.py @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 + +from .communication_op import * +from .parallel_state import * +from .utils import * diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py new file mode 100644 index 0000000..0228264 --- /dev/null +++ b/vllm/distributed/communication_op.py @@ -0,0 +1,34 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, Optional, Union + +import torch +import torch.distributed + +from .parallel_state import get_tp_group + + +def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: + """All-reduce the input tensor across model parallel group.""" + return get_tp_group().all_reduce(input_) + + +def tensor_model_parallel_all_gather(input_: torch.Tensor, + dim: int = -1) -> torch.Tensor: + """All-gather the input tensor across model parallel group.""" + return get_tp_group().all_gather(input_, dim) + + +def tensor_model_parallel_gather(input_: torch.Tensor, + dst: int = 0, + dim: int = -1) -> Optional[torch.Tensor]: + """Gather the input tensor across model parallel group.""" + return get_tp_group().gather(input_, dst, dim) + + +def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor, + Any]]] = None, + src: int = 0): + if not torch.distributed.is_initialized(): + return tensor_dict + return get_tp_group().broadcast_tensor_dict(tensor_dict, src) diff --git a/vllm/distributed/device_communicators/__init__.py b/vllm/distributed/device_communicators/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py new file mode 100644 index 0000000..fe5182c --- /dev/null +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -0,0 +1,164 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +import ixformer.distributed as ixfd +import os + + +class DeviceCommunicatorBase: + """ + Base class for device-specific communicator. + It can use the `cpu_group` to initialize the communicator. + If the device has PyTorch integration (PyTorch can recognize its + communication backend), the `device_group` will also be given. + """ + + def __init__(self, + cpu_group: ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[ProcessGroup] = None, + unique_name: str = ""): + self.device = device or torch.device("cpu") + self.cpu_group = cpu_group + self.device_group = device_group + self.unique_name = unique_name + self.rank = dist.get_rank(cpu_group) + self.world_size = dist.get_world_size(cpu_group) + self.ranks = dist.get_process_group_ranks(cpu_group) + self.global_rank = dist.get_rank() + self.global_world_size = dist.get_world_size() + self.rank_in_group = dist.get_group_rank(self.cpu_group, + self.global_rank) + self.use_vllm_comm = os.environ.get("VLLM_FORCE_NCCL_COMM",None) not in ["1", "Y", "y"] + if "pp" in unique_name: + # pipeline parallel does not need custom allreduce + use_custom_allreduce = False + else: + from vllm.distributed.parallel_state import ( + _ENABLE_CUSTOM_ALL_REDUCE) + use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE + self.use_custom_allreduce = use_custom_allreduce + # lazy import to avoid documentation build error + from vllm.distributed.device_communicators.custom_all_reduce import ( + CustomAllreduce) + self.ca_comm: Optional[CustomAllreduce] = None + if use_custom_allreduce and self.world_size > 1: + # Initialize a custom fast all-reduce implementation. + self.ca_comm = CustomAllreduce( + group=self.cpu_group, + device=self.device, + ) + + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + + if self.world_size == 1: + return input_ + + if self.use_vllm_comm: + ixfd.all_reduce(input_, group=self.device_group, async_op=True) + else: + torch.distributed.all_reduce(input_, group=self.device_group) + return input_ + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + input_size = input_.size() + # NOTE: we have to use concat-style all-gather here, + # stack-style all-gather has compatibility issues with + # torch.compile . see https://github.com/pytorch/pytorch/issues/138795 + output_size = (input_size[0] * self.world_size, ) + input_size[1:] + # Allocate output tensor. + output_tensor = torch.empty(output_size, + dtype=input_.dtype, + device=input_.device) + # All-gather. + if self.use_vllm_comm: + ixfd.all_gather_into_tensor(output_tensor, + input_, + group=self.device_group, + async_op=True) + else: + torch.distributed.all_gather_into_tensor(output_tensor, + input_, + group=self.device_group) + # Reshape + output_tensor = output_tensor.reshape((self.world_size, ) + input_size) + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape(input_size[:dim] + + (self.world_size * + input_size[dim], ) + + input_size[dim + 1:]) + return output_tensor + + def gather(self, + input_: torch.Tensor, + dst: int = 0, + dim: int = -1) -> Optional[torch.Tensor]: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + # Gather. + if self.use_vllm_comm: + ixfd.gather(input_, + gather_list, + dst=self.ranks[dst], + group=self.device_group, + async_op=True) + else: + torch.distributed.gather(input_, + gather_list, + dst=self.ranks[dst], + group=self.device_group) + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the local rank of the destination rank.""" + if dst is None: + dst = (self.rank_in_group + 1) % self.world_size + if self.use_vllm_comm: + ixfd.send(tensor, self.ranks[dst], self.device_group) + else: + torch.distributed.send(tensor, self.ranks[dst], self.device_group) + + def recv(self, + size: torch.Size, + dtype: torch.dtype, + src: Optional[int] = None) -> torch.Tensor: + """Receives a tensor from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + if src is None: + src = (self.rank_in_group - 1) % self.world_size + + tensor = torch.empty(size, dtype=dtype, device=self.device) + if self.use_vllm_comm: + ixfd.recv(tensor, self.ranks[src], self.device_group) + else: + torch.distributed.recv(tensor, self.ranks[src], self.device_group) + return tensor + + def destroy(self): + pass diff --git a/vllm/distributed/device_communicators/cpu_communicator.py b/vllm/distributed/device_communicators/cpu_communicator.py new file mode 100644 index 0000000..1f4b4fa --- /dev/null +++ b/vllm/distributed/device_communicators/cpu_communicator.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +from typing import List, Optional + +import torch +from torch.distributed import ProcessGroup + +from vllm.platforms import current_platform +from vllm.platforms.interface import CpuArchEnum + +from .base_device_communicator import DeviceCommunicatorBase + + +class CpuCommunicator(DeviceCommunicatorBase): + + def __init__(self, + cpu_group: ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[ProcessGroup] = None, + unique_name: str = ""): + super().__init__(cpu_group, device, device_group, unique_name) + self.dist_module = torch.distributed + + if current_platform.get_cpu_architecture() == CpuArchEnum.X86: + self.dist_module = _CPUSHMDistributed(self) + + def all_reduce(self, input_): + self.dist_module.all_reduce(input_, group=self.device_group) + return input_ + + def gather(self, + input_: torch.Tensor, + dst: int = 0, + dim: int = -1) -> Optional[torch.Tensor]: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + + # Gather. + self.dist_module.gather(input_, + gather_list, + dst=self.ranks[dst], + group=self.device_group) + + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + input_size = input_.size() + # NOTE: we have to use concat-style all-gather here, + # stack-style all-gather has compatibility issues with + # torch.compile . see https://github.com/pytorch/pytorch/issues/138795 + output_size = (input_size[0] * self.world_size, ) + input_size[1:] + # Allocate output tensor. + output_tensor = torch.empty(output_size, + dtype=input_.dtype, + device=input_.device) + # All-gather. + self.dist_module.all_gather_into_tensor(output_tensor, + input_, + group=self.device_group) + + # Reshape + output_tensor = output_tensor.reshape((self.world_size, ) + input_size) + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape(input_size[:dim] + + (self.world_size * + input_size[dim], ) + + input_size[dim + 1:]) + return output_tensor + + +class _CPUSHMDistributed: + + def __init__(self, communicator: CpuCommunicator): + instance_identifier = os.environ["VLLM_DIST_IDENT"] + self.communicator = communicator + + group_ranks = [str(rank) for rank in self.communicator.ranks] + shm_group_identifier = f"[{'-'.join(group_ranks)}]" + self.group_name = f"{instance_identifier}-{shm_group_identifier}-cpushm" + + self.handle = self._init_cpu_shm() + + def _init_cpu_shm(self) -> int: + handle = torch.ops._C.init_shm_manager( + self.group_name, + self.communicator.world_size, + self.communicator.rank, + ) + torch.distributed.barrier(self.communicator.device_group) + torch.ops._C.join_shm_manager( + handle, + self.group_name, + ) + torch.distributed.barrier(self.communicator.device_group) + + return handle + + def all_reduce(self, + input: torch.Tensor, + group: Optional[ProcessGroup] = None) -> None: + torch.ops._C.shm_allreduce(self.handle, input) + + def gather(self, + input: torch.Tensor, + gather_list: Optional[List[torch.Tensor]], + dst: int = -1, + group: Optional[ProcessGroup] = None) -> None: + # Note: different from the torch gather, here we use local dst rank. + torch.ops._C.shm_gather(self.handle, input, gather_list, + torch.distributed.get_group_rank(group, dst)) + + def all_gather_into_tensor(self, + output: torch.Tensor, + input: torch.Tensor, + group: Optional[ProcessGroup] = None) -> None: + torch.ops._C.shm_all_gather(self.handle, input, output) diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py new file mode 100644 index 0000000..07c9ff5 --- /dev/null +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -0,0 +1,106 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import torch +from torch.distributed import ProcessGroup + +from .base_device_communicator import DeviceCommunicatorBase + + +class CudaCommunicator(DeviceCommunicatorBase): + + def __init__(self, + cpu_group: ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[ProcessGroup] = None, + unique_name: str = ""): + super().__init__(cpu_group, device, device_group, unique_name) + if "tp" not in unique_name: + # only tp uses custom allreduce + use_custom_allreduce = False + else: + from vllm.distributed.parallel_state import ( + _ENABLE_CUSTOM_ALL_REDUCE) + use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE + use_pynccl = True + + self.use_pynccl = use_pynccl + self.use_custom_allreduce = use_custom_allreduce + + # lazy import to avoid documentation build error + from vllm.distributed.device_communicators.custom_all_reduce import ( + CustomAllreduce) + from vllm.distributed.device_communicators.pynccl import ( + PyNcclCommunicator) + + self.pynccl_comm: Optional[PyNcclCommunicator] = None + if use_pynccl and self.world_size > 1: + self.pynccl_comm = PyNcclCommunicator( + group=self.cpu_group, + device=self.device, + ) + + self.ca_comm: Optional[CustomAllreduce] = None + if use_custom_allreduce and self.world_size > 1: + # Initialize a custom fast all-reduce implementation. + self.ca_comm = CustomAllreduce( + group=self.cpu_group, + device=self.device, + ) + + def all_reduce(self, input_): + # always try custom allreduce first, + # and then pynccl. + ca_comm = self.ca_comm + if ca_comm is not None and not ca_comm.disabled and \ + ca_comm.should_custom_ar(input_): + out = ca_comm.custom_all_reduce(input_) + assert out is not None + return out + pynccl_comm = self.pynccl_comm + assert pynccl_comm is not None + out = pynccl_comm.all_reduce(input_) + if out is None: + # fall back to the default all-reduce using PyTorch. + # this usually happens during testing. + # when we run the model, allreduce only happens for the TP + # group, where we always have either custom allreduce or pynccl. + out = input_.clone() + torch.distributed.all_reduce(out, group=self.device_group) + return out + + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the local rank of the destination rank.""" + if dst is None: + dst = (self.rank_in_group + 1) % self.world_size + + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.send(tensor, dst) + else: + torch.distributed.send(tensor, self.ranks[dst], self.device_group) + + def recv(self, + size: torch.Size, + dtype: torch.dtype, + src: Optional[int] = None) -> torch.Tensor: + """Receives a tensor from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + if src is None: + src = (self.rank_in_group - 1) % self.world_size + + tensor = torch.empty(size, dtype=dtype, device=self.device) + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.recv(tensor, src) + else: + torch.distributed.recv(tensor, self.ranks[src], self.device_group) + return tensor + + def destroy(self): + if self.pynccl_comm is not None: + self.pynccl_comm = None + if self.ca_comm is not None: + self.ca_comm = None diff --git a/vllm/distributed/device_communicators/cuda_wrapper.py b/vllm/distributed/device_communicators/cuda_wrapper.py new file mode 100644 index 0000000..1d53b1c --- /dev/null +++ b/vllm/distributed/device_communicators/cuda_wrapper.py @@ -0,0 +1,179 @@ +# SPDX-License-Identifier: Apache-2.0 +"""This file is a pure Python wrapper for the cudart library. +It avoids the need to compile a separate shared library, and is +convenient for use when we just need to call a few functions. +""" + +import ctypes +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +# this line makes it possible to directly load `libcudart.so` using `ctypes` +import torch # noqa + +import vllm.envs as envs +from vllm.logger import init_logger + +logger = init_logger(__name__) + +# === export types and functions from cudart to Python === +# for the original cudart definition, please check +# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html + +cudaError_t = ctypes.c_int +cudaMemcpyKind = ctypes.c_int + + +class cudaIpcMemHandle_t(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 128)] + + +@dataclass +class Function: + name: str + restype: Any + argtypes: List[Any] + + +def find_loaded_library(lib_name) -> Optional[str]: + """ + According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html, + the file `/proc/self/maps` contains the memory maps of the process, which includes the + shared libraries loaded by the process. We can use this file to find the path of the + a loaded library. + """ # noqa + found = False + with open("/proc/self/maps") as f: + for line in f: + if lib_name in line: + found = True + break + if not found: + # the library is not loaded in the current process + return None + # if lib_name is libcudart, we need to match a line with: + # address /path/to/libcudart-hash.so.11.0 + start = line.index("/") + path = line[start:].strip() + filename = path.split("/")[-1] + assert filename.rpartition(".so")[0].startswith(lib_name), \ + f"Unexpected filename: {filename} for library {lib_name}" + return path + + +class CudaRTLibrary: + exported_functions = [ + # ​cudaError_t cudaSetDevice ( int device ) + Function("cudaSetDevice", cudaError_t, [ctypes.c_int]), + # cudaError_t cudaDeviceSynchronize ( void ) + Function("cudaDeviceSynchronize", cudaError_t, []), + # ​cudaError_t cudaDeviceReset ( void ) + Function("cudaDeviceReset", cudaError_t, []), + + # const char* cudaGetErrorString ( cudaError_t error ) + Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]), + + # ​cudaError_t cudaMalloc ( void** devPtr, size_t size ) + Function("cudaMalloc", cudaError_t, + [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]), + # ​cudaError_t cudaFree ( void* devPtr ) + Function("cudaFree", cudaError_t, [ctypes.c_void_p]), + # ​cudaError_t cudaMemset ( void* devPtr, int value, size_t count ) + Function("cudaMemset", cudaError_t, + [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]), + # ​cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa + Function("cudaMemcpy", cudaError_t, [ + ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind + ]), + + # cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa + Function("cudaIpcGetMemHandle", cudaError_t, + [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]), + # ​cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa + Function("cudaIpcOpenMemHandle", cudaError_t, [ + ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint + ]), + ] + + # class attribute to store the mapping from the path to the library + # to avoid loading the same library multiple times + path_to_library_cache: Dict[str, Any] = {} + + # class attribute to store the mapping from library path + # to the corresponding dictionary + path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} + + def __init__(self, so_file: Optional[str] = None): + if so_file is None: + so_file = find_loaded_library("libcudart") + if so_file is None: + so_file = envs.VLLM_CUDART_SO_PATH # fallback to env var + assert so_file is not None, \ + ( + "libcudart is not loaded in the current process, " + "try setting VLLM_CUDART_SO_PATH" + ) + if so_file not in CudaRTLibrary.path_to_library_cache: + lib = ctypes.CDLL(so_file) + CudaRTLibrary.path_to_library_cache[so_file] = lib + self.lib = CudaRTLibrary.path_to_library_cache[so_file] + + if so_file not in CudaRTLibrary.path_to_dict_mapping: + _funcs = {} + for func in CudaRTLibrary.exported_functions: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + _funcs[func.name] = f + CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs + self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file] + + def CUDART_CHECK(self, result: cudaError_t) -> None: + if result != 0: + error_str = self.cudaGetErrorString(result) + raise RuntimeError(f"CUDART error: {error_str}") + + def cudaGetErrorString(self, error: cudaError_t) -> str: + return self.funcs["cudaGetErrorString"](error).decode("utf-8") + + def cudaSetDevice(self, device: int) -> None: + self.CUDART_CHECK(self.funcs["cudaSetDevice"](device)) + + def cudaDeviceSynchronize(self) -> None: + self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]()) + + def cudaDeviceReset(self) -> None: + self.CUDART_CHECK(self.funcs["cudaDeviceReset"]()) + + def cudaMalloc(self, size: int) -> ctypes.c_void_p: + devPtr = ctypes.c_void_p() + self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size)) + return devPtr + + def cudaFree(self, devPtr: ctypes.c_void_p) -> None: + self.CUDART_CHECK(self.funcs["cudaFree"](devPtr)) + + def cudaMemset(self, devPtr: ctypes.c_void_p, value: int, + count: int) -> None: + self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count)) + + def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p, + count: int) -> None: + cudaMemcpyDefault = 4 + kind = cudaMemcpyDefault + self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind)) + + def cudaIpcGetMemHandle(self, + devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t: + handle = cudaIpcMemHandle_t() + self.CUDART_CHECK(self.funcs["cudaIpcGetMemHandle"]( + ctypes.byref(handle), devPtr)) + return handle + + def cudaIpcOpenMemHandle(self, + handle: cudaIpcMemHandle_t) -> ctypes.c_void_p: + cudaIpcMemLazyEnablePeerAccess = 1 + devPtr = ctypes.c_void_p() + self.CUDART_CHECK(self.funcs["cudaIpcOpenMemHandle"]( + ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess)) + return devPtr diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py new file mode 100644 index 0000000..45fc2a7 --- /dev/null +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -0,0 +1,301 @@ +# SPDX-License-Identifier: Apache-2.0 + +from contextlib import contextmanager +from typing import List, Optional, Union + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.distributed.device_communicators.custom_all_reduce_utils import ( + gpu_p2p_access_check) +from vllm.distributed.parallel_state import in_the_same_node_as +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import cuda_device_count_stateless + +try: + ops.meta_size() + custom_ar = True +except Exception: + # For CPUs + custom_ar = False + +logger = init_logger(__name__) + + +def _can_p2p(rank: int, world_size: int) -> bool: + for i in range(world_size): + if i == rank: + continue + if envs.VLLM_SKIP_P2P_CHECK: + logger.info( + "Skipping P2P check and trusting the driver's P2P report.") + return torch.cuda.can_device_access_peer(rank, i) + if not gpu_p2p_access_check(rank, i): + return False + return True + + +def is_weak_contiguous(inp: torch.Tensor): + return inp.is_contiguous() or (inp.storage().nbytes() - + inp.storage_offset() * inp.element_size() + == inp.numel() * inp.element_size()) + + +class CustomAllreduce: + + _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] + + # max_size: max supported allreduce size + def __init__(self, + group: ProcessGroup, + device: Union[int, str, torch.device], + max_size=8192 * 1024) -> None: + """ + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the CustomAllreduce to. If None, + it will be bind to f"cuda:{local_rank}". + It is the caller's responsibility to make sure each communicator + is bind to a unique device, and all communicators in this group + are in the same node. + """ + self._IS_CAPTURING = False + self.disabled = True + + if not custom_ar: + # disable because of missing custom allreduce library + # e.g. in a non-GPU environment + logger.info("Custom allreduce is disabled because " + "of missing custom allreduce library") + return + + self.group = group + + assert dist.get_backend(group) != dist.Backend.NCCL, ( + "CustomAllreduce should be attached to a non-NCCL group.") + + if not all(in_the_same_node_as(group, source_rank=0)): + # No need to initialize custom allreduce for multi-node case. + logger.warning( + "Custom allreduce is disabled because this process group" + " spans across nodes.") + return + + rank = dist.get_rank(group=self.group) + self.rank = rank + world_size = dist.get_world_size(group=self.group) + if world_size == 1: + # No need to initialize custom allreduce for single GPU case. + return + + if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES: + logger.warning( + "Custom allreduce is disabled due to an unsupported world" + " size: %d. Supported world sizes: %s. To silence this " + "warning, specify disable_custom_all_reduce=True explicitly.", + world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES)) + return + + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + # now `device` is a `torch.device` object + assert isinstance(device, torch.device) + self.device = device + + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + if cuda_visible_devices: + device_ids = list(map(int, cuda_visible_devices.split(","))) + else: + device_ids = list(range(cuda_device_count_stateless())) + + physical_device_id = device_ids[device.index] + tensor = torch.tensor([physical_device_id], + dtype=torch.int, + device="cpu") + gather_list = [ + torch.tensor([0], dtype=torch.int, device="cpu") + for _ in range(world_size) + ] + dist.all_gather(gather_list, tensor, group=self.group) + physical_device_ids = [t.item() for t in gather_list] + + # test nvlink first, this will filter out most of the cases + # where custom allreduce is not supported + # this checks hardware and driver support for NVLink + assert current_platform.is_cuda_alike() + fully_connected = current_platform.is_fully_connected( + physical_device_ids) + if world_size > 2 and not fully_connected: + logger.warning( + "Custom allreduce is disabled because it's not supported on" + " more than two PCIe-only GPUs. To silence this warning, " + "specify disable_custom_all_reduce=True explicitly.") + return + # test P2P capability, this checks software/cudaruntime support + # this is expensive to compute at the first time + # then we cache the result + # On AMD GPU, p2p is always enabled between XGMI connected GPUs + if not current_platform.is_rocm() and not _can_p2p(rank, world_size): + logger.warning( + "Custom allreduce is disabled because your platform lacks " + "GPU P2P capability or P2P test failed. To silence this " + "warning, specify disable_custom_all_reduce=True explicitly.") + return + + self.disabled = False + # Buffers memory are owned by this Python class and passed to C++. + # Meta data composes of two parts: meta data for synchronization and a + # temporary buffer for storing intermediate allreduce results. + self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size, + group=group, + uncached=True) + # This is a pre-registered IPC buffer. In eager mode, input tensors + # are first copied into this buffer before allreduce is performed + self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) + # This is a buffer for storing the tuples of pointers pointing to + # IPC buffers from all ranks. Each registered tuple has size of + # 8*world_size bytes where world_size is at most 8. Allocating 8MB + # is enough for 131072 such tuples. The largest model I've seen only + # needs less than 10000 of registered tuples. + self.rank_data = torch.empty(8 * 1024 * 1024, + dtype=torch.uint8, + device=self.device) + self.max_size = max_size + self.rank = rank + self.world_size = world_size + self.fully_connected = fully_connected + self._ptr = ops.init_custom_ar(self.meta_ptrs, self.rank_data, rank, + self.fully_connected) + ops.register_buffer(self._ptr, self.buffer_ptrs) + + @contextmanager + def capture(self): + """ + The main responsibility of this context manager is the + `register_graph_buffers` call at the end of the context. + It records all the buffer addresses used in the CUDA graph. + """ + try: + self._IS_CAPTURING = True + yield + finally: + self._IS_CAPTURING = False + if not self.disabled: + self.register_graph_buffers() + + def register_graph_buffers(self): + handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr) + logger.info("Registering %d cuda graph addresses", len(offset)) + # We cannot directly use `dist.all_gather_object` here + # because it is incompatible with `gloo` backend under inference mode. + # see https://github.com/pytorch/pytorch/issues/126032 for details. + all_data = [[None, None] + for _ in range(dist.get_world_size(group=self.group))] + all_data[self.rank] = [handle, offset] + ranks = sorted(dist.get_process_group_ranks(group=self.group)) + for i, rank in enumerate(ranks): + dist.broadcast_object_list(all_data[i], + src=rank, + group=self.group, + device="cpu") + # Unpack list of tuples to tuple of lists. + handles = [d[0] for d in all_data] # type: ignore + offsets = [d[1] for d in all_data] # type: ignore + ops.register_graph_buffers(self._ptr, handles, offsets) + + def should_custom_ar(self, inp: torch.Tensor): + if self.disabled: + return False + inp_size = inp.numel() * inp.element_size() + # custom allreduce requires input byte size to be multiples of 16 + if inp_size % 16 != 0: + return False + if not is_weak_contiguous(inp): + return False + # for 4 or more non NVLink-capable GPUs, custom allreduce provides + # little performance improvement over NCCL. + if self.world_size == 2 or self.fully_connected: + return inp_size < self.max_size + return False + + def all_reduce(self, + inp: torch.Tensor, + *, + out: torch.Tensor = None, + registered: bool = False): + """Performs an out-of-place all reduce. + + If registered is True, this assumes inp's pointer is already + IPC-registered. Otherwise, inp is first copied into a pre-registered + buffer. + """ + if out is None: + out = torch.empty_like(inp) + if registered: + ops.all_reduce(self._ptr, inp, out, 0, 0) + else: + ops.all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank], + self.max_size) + return out + + def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: + """The main allreduce API that provides support for cuda graph.""" + # When custom allreduce is disabled, this will be None. + if self.disabled or not self.should_custom_ar(input): + return None + if self._IS_CAPTURING: + if torch.cuda.is_current_stream_capturing(): + return self.all_reduce(input, registered=True) + else: + # If warm up, mimic the allocation pattern since custom + # allreduce is out-of-place. + return torch.empty_like(input) + else: + # Note: outside of cuda graph context, custom allreduce incurs a + # cost of cudaMemcpy, which should be small (<=1% of overall + # latency) compared to the performance gain of using custom kernels + return self.all_reduce(input, registered=False) + + def close(self): + if not self.disabled and self._ptr: + ops.dispose(self._ptr) + self._ptr = 0 + self.free_shared_buffer(self.meta_ptrs, rank=self.rank) + self.free_shared_buffer(self.buffer_ptrs, rank=self.rank) + + def __del__(self): + self.close() + + @staticmethod + def create_shared_buffer(size_in_bytes: int, + group: Optional[ProcessGroup] = None, + uncached: Optional[bool] = False) -> List[int]: + pointer, handle = ops.allocate_shared_buffer_and_handle(size_in_bytes) + + world_size = dist.get_world_size(group=group) + rank = dist.get_rank(group=group) + handles = [None] * world_size + dist.all_gather_object(handles, handle, group=group) + + pointers: List[int] = [] + for i, h in enumerate(handles): + if i == rank: + pointers.append(pointer) # type: ignore + else: + pointers.append(ops.open_mem_handle(h)) + return pointers + + @staticmethod + def free_shared_buffer(pointers: List[int], + group: Optional[ProcessGroup] = None, + rank: Optional[int] = 0) -> None: + if rank is None: + rank = dist.get_rank(group=group) + ops.free_shared_buffer(pointers[rank]) diff --git a/vllm/distributed/device_communicators/custom_all_reduce_utils.py b/vllm/distributed/device_communicators/custom_all_reduce_utils.py new file mode 100644 index 0000000..d8d6eed --- /dev/null +++ b/vllm/distributed/device_communicators/custom_all_reduce_utils.py @@ -0,0 +1,257 @@ +# SPDX-License-Identifier: Apache-2.0 + +import ctypes +import json +import os +import pickle +import subprocess +import sys +import tempfile +from itertools import product +from typing import Dict, List, Optional, Sequence + +import torch.distributed as dist +import torch.multiprocessing as mp + +import vllm.envs as envs +from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary +from vllm.logger import init_logger +from vllm.utils import (cuda_device_count_stateless, + update_environment_variables) + +logger = init_logger(__name__) + + +def producer(batch_src: Sequence[int], + producer_queue, + consumer_queue, + result_queue, + cuda_visible_devices: Optional[str] = None): + if cuda_visible_devices is not None: + update_environment_variables( + {"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) + + lib = CudaRTLibrary() + for i in batch_src: + lib.cudaSetDevice(i) + pointer = lib.cudaMalloc(1024) + lib.cudaMemset(pointer, 1, 1024) + lib.cudaDeviceSynchronize() + handle = lib.cudaIpcGetMemHandle(pointer) + producer_queue.put(handle) + open_success = consumer_queue.get() + if open_success: + # use two queues to simulate barrier + producer_queue.put(0) + consumer_queue.get() + # check if the memory is modified + host_data = (ctypes.c_char * 1024)() + lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore + for i in range(1024): + if ord(host_data[i]) != 2: + open_success = False + break + result_queue.put(open_success) + lib.cudaDeviceReset() + + +def consumer(batch_tgt: Sequence[int], + producer_queue, + consumer_queue, + result_queue, + cuda_visible_devices: Optional[str] = None): + if cuda_visible_devices is not None: + update_environment_variables( + {"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) + + lib = CudaRTLibrary() + for j in batch_tgt: + lib.cudaSetDevice(j) + handle = producer_queue.get() + open_success = False + try: + pointer = lib.cudaIpcOpenMemHandle(handle) # type: ignore + open_success = True + except RuntimeError: + # cannot error out here, because the producer process + # is still waiting for the response. + pass + consumer_queue.put(open_success) + if open_success: + # modify the memory + lib.cudaMemset(pointer, 2, 1024) + lib.cudaDeviceSynchronize() + # use two queues to simulate barrier + producer_queue.get() + consumer_queue.put(0) + # check if the memory is modified + host_data = (ctypes.c_char * 1024)() + lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore + for i in range(1024): + if ord(host_data[i]) != 2: + open_success = False + break + result_queue.put(open_success) + lib.cudaDeviceReset() + + +def can_actually_p2p( + batch_src: Sequence[int], + batch_tgt: Sequence[int], +) -> Sequence[bool]: + """ + Usually, checking if P2P access is enabled can be done by + `torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes + the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)` + returns `True` even if P2P access is not actually possible. + See https://github.com/vllm-project/vllm/issues/2728 and + https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10 + Therefore, we have to perform a real P2P access to check if it is actually + possible. + + Note on p2p and cuda IPC: + Usually, one process uses one GPU: + GPU src --> cuda context src --> tensor src --> process src + + We need to combine p2p and cuda IPC, so that: + GPU src --> cuda context src --> tensor src --> process src + |shared| + GPU tgt --> cuda context tgt --> tensor tgt --> process tgt + That is to say, process src creates a tensor in GPU src, passes IPC handle to + process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the + tensor in process tgt will be reflected in the tensor in process src, because + they are the same memory segment. + It is important to note that process tgt accesses the tensor in GPU tgt, not + GPU src. That's why we need p2p access. + + The most time-consuming part is the process creation. To avoid creating + processes for every pair of GPUs, we use batched testing. We create two + processes for testing all pairs of GPUs in batch. The trick is to reset + the device after each test (which is not available in PyTorch). + """ # noqa + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + # pass the CUDA_VISIBLE_DEVICES to the child process + # to make sure they see the same set of GPUs + + # make sure the processes are spawned + smp = mp.get_context("spawn") + producer_queue = smp.Queue() + consumer_queue = smp.Queue() + result_queue = smp.Queue() + p_src = smp.Process(target=producer, + args=(batch_src, producer_queue, consumer_queue, + result_queue, cuda_visible_devices)) + p_tgt = smp.Process(target=consumer, + args=(batch_tgt, producer_queue, consumer_queue, + result_queue, cuda_visible_devices)) + p_src.start() + p_tgt.start() + p_src.join() + p_tgt.join() + assert p_src.exitcode == 0 and p_tgt.exitcode == 0 + result: List[bool] = [] + for src, tgt in zip(batch_src, batch_tgt): + a = result_queue.get() + b = result_queue.get() + if a != b: + logger.warning( + "Two processes do not agree on the P2P access" + " status on %d -> %d, treat as disabled.", src, tgt) + result.append(False) + else: + result.append(a) + return result + + +# why do we need this cache? +# we are testing peer-to-peer (p2p) access between GPUs,across processes. +# if we test it every time, it will be very slow, because we need to create +# N * N * 2 processes, where N is the world size. This is very slow. +# to reduce the time, we use a cache file to store the p2p access status. +# the cache file is generated by the master process if it does not exist. +# then all the processes can read the cache file to check the p2p access status. +# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we +# can have different cache files for different CUDA_VISIBLE_DEVICES settings, +# e.g. used by different vllm engines. The device id in the cache file is a +# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number +# of visible devices in the vllm engine. +_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None + + +def gpu_p2p_access_check(src: int, tgt: int) -> bool: + """Check if GPU src can access GPU tgt.""" + + # if the cache variable is already calculated, + # read from the cache instead of checking it again + global _gpu_p2p_access_cache + if _gpu_p2p_access_cache is not None: + return _gpu_p2p_access_cache[f"{src}->{tgt}"] + + is_distributed = dist.is_initialized() + + num_dev = cuda_device_count_stateless() + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + if cuda_visible_devices is None: + cuda_visible_devices = ",".join(str(i) for i in range(num_dev)) + + path = os.path.join( + envs.VLLM_CACHE_ROOT, + f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json") + os.makedirs(os.path.dirname(path), exist_ok=True) + from vllm.distributed.parallel_state import get_world_group + if ((not is_distributed or get_world_group().local_rank == 0) + and (not os.path.exists(path))): + # only the local master process (with local_rank == 0) can + # enter this block to calculate the cache + logger.info("generating GPU P2P access cache in %s", path) + cache: Dict[str, bool] = {} + ids = list(range(num_dev)) + # batch of all pairs of GPUs + batch_src, batch_tgt = zip(*list(product(ids, ids))) + # NOTE: we use `subprocess` rather than `multiprocessing` here + # because the caller might not have `if __name__ == "__main__":`, + # in that case we cannot use spawn method in multiprocessing. + # However, `can_actually_p2p` requires spawn method. + # The fix is, we use `subprocess` to call the function, + # where we have `if __name__ == "__main__":` in this file. + + # use a temporary file to store the result + # we don't use the output of the subprocess directly, + # because the subprocess might produce logging output + with tempfile.NamedTemporaryFile() as output_file: + input_bytes = pickle.dumps( + (batch_src, batch_tgt, output_file.name)) + returned = subprocess.run([sys.executable, __file__], + input=input_bytes, + capture_output=True) + # check if the subprocess is successful + try: + returned.check_returncode() + except Exception as e: + # wrap raised exception to provide more information + raise RuntimeError( + f"Error happened when batch testing " + f"peer-to-peer access from {batch_src} to {batch_tgt}:\n" + f"{returned.stderr.decode()}") from e + with open(output_file.name, "rb") as f: + result = pickle.load(f) + for _i, _j, r in zip(batch_src, batch_tgt, result): + cache[f"{_i}->{_j}"] = r + with open(path, "w") as f: + json.dump(cache, f, indent=4) + if is_distributed: + get_world_group().barrier() + logger.info("reading GPU P2P access cache from %s", path) + with open(path) as f: + cache = json.load(f) + _gpu_p2p_access_cache = cache + return _gpu_p2p_access_cache[f"{src}->{tgt}"] + + +__all__ = ["gpu_p2p_access_check"] + +if __name__ == "__main__": + batch_src, batch_tgt, output_file = pickle.loads(sys.stdin.buffer.read()) + result = can_actually_p2p(batch_src, batch_tgt) + with open(output_file, "wb") as f: + f.write(pickle.dumps(result)) diff --git a/vllm/distributed/device_communicators/hpu_communicator.py b/vllm/distributed/device_communicators/hpu_communicator.py new file mode 100644 index 0000000..9536a7f --- /dev/null +++ b/vllm/distributed/device_communicators/hpu_communicator.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.distributed as dist + +from vllm.platforms import current_platform + +from .base_device_communicator import DeviceCommunicatorBase + +if current_platform.is_hpu(): + import habana_frameworks.torch as htorch # noqa: F401 + + +class HpuCommunicator(DeviceCommunicatorBase): + + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + # FIXME(kzawora): this is a workaround for a bug in Habana PT bridge + # occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used + # (which is required for tensor parallel HPUGraph inference) + htorch.core.mark_step() + dist.all_reduce(input_, group=self.device_group) + return input_ + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + world_size = self.world_size + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + input_size = input_.size() + # Allocate output tensor. + output_tensor = torch.empty((world_size, ) + input_size, + dtype=input_.dtype, + device=input_.device) + # All-gather. + htorch.core.mark_step() + dist.all_gather_into_tensor(output_tensor, + input_, + group=self.device_group) + # Reshape + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape(input_size[:dim] + + (world_size * + input_size[dim], ) + + input_size[dim + 1:]) + return output_tensor diff --git a/vllm/distributed/device_communicators/neuron_communicator.py b/vllm/distributed/device_communicators/neuron_communicator.py new file mode 100644 index 0000000..dfa4b51 --- /dev/null +++ b/vllm/distributed/device_communicators/neuron_communicator.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 +import torch + +from vllm.distributed.device_communicators.base_device_communicator import ( + DeviceCommunicatorBase) +from vllm.platforms import current_platform + +if current_platform.is_neuron(): + import torch_xla.core.xla_model as xm + + +class NeuronCommunicator(DeviceCommunicatorBase): + + def all_reduce(self, x: torch.Tensor) -> torch.Tensor: + return xm.all_reduce(xm.REDUCE_SUM, x) + + def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + assert dim == -1, "Neuron only supports dim=-1 for all-gather." + return xm.all_gather(x, dim=dim) diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py new file mode 100644 index 0000000..0ccd423 --- /dev/null +++ b/vllm/distributed/device_communicators/pynccl.py @@ -0,0 +1,217 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Union + +# ===================== import region ===================== +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup, ReduceOp + +from vllm.distributed.device_communicators.pynccl_wrapper import ( + NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum, + ncclRedOpTypeEnum, ncclUniqueId) +from vllm.distributed.utils import StatelessProcessGroup +from vllm.logger import init_logger +from vllm.utils import current_stream + +logger = init_logger(__name__) + + +class PyNcclCommunicator: + + def __init__( + self, + group: Union[ProcessGroup, StatelessProcessGroup], + device: Union[int, str, torch.device], + library_path: Optional[str] = None, + ): + """ + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the PyNcclCommunicator to. If None, + it will be bind to f"cuda:{local_rank}". + library_path: the path to the NCCL library. If None, it will + use the default library path. + It is the caller's responsibility to make sure each communicator + is bind to a unique device. + """ + if not isinstance(group, StatelessProcessGroup): + assert dist.is_initialized() + assert dist.get_backend(group) != dist.Backend.NCCL, ( + "PyNcclCommunicator should be attached to a non-NCCL group.") + # note: this rank is the rank in the group + self.rank = dist.get_rank(group) + self.world_size = dist.get_world_size(group) + else: + self.rank = group.rank + self.world_size = group.world_size + + self.group = group + + # if world_size == 1, no need to create communicator + if self.world_size == 1: + self.available = False + self.disabled = True + return + try: + self.nccl = NCCLLibrary(library_path) + except Exception: + # disable because of missing NCCL library + # e.g. in a non-GPU environment + self.available = False + self.disabled = True + return + + self.available = True + self.disabled = False + + logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion()) + + if self.rank == 0: + # get the unique id from NCCL + self.unique_id = self.nccl.ncclGetUniqueId() + else: + # construct an empty unique id + self.unique_id = ncclUniqueId() + + if not isinstance(group, StatelessProcessGroup): + tensor = torch.ByteTensor(list(self.unique_id.internal)) + ranks = dist.get_process_group_ranks(group) + # arg `src` in `broadcast` is the global rank + dist.broadcast(tensor, src=ranks[0], group=group) + byte_list = tensor.tolist() + for i, byte in enumerate(byte_list): + self.unique_id.internal[i] = byte + else: + self.unique_id = group.broadcast_obj(self.unique_id, src=0) + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + # now `device` is a `torch.device` object + assert isinstance(device, torch.device) + self.device = device + # nccl communicator and stream will use this device + # `torch.cuda.device` is a context manager that changes the + # current cuda device to the specified one + with torch.cuda.device(device): + self.comm: ncclComm_t = self.nccl.ncclCommInitRank( + self.world_size, self.unique_id, self.rank) + + stream = current_stream() + # A small all_reduce for warmup. + data = torch.zeros(1, device=device) + self.all_reduce(data) + stream.synchronize() + del data + + def all_reduce(self, + in_tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + stream=None) -> torch.Tensor: + if self.disabled: + return None + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert in_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {in_tensor.device}") + + out_tensor = torch.empty_like(in_tensor) + + if stream is None: + stream = current_stream() + self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()), + buffer_type(out_tensor.data_ptr()), + in_tensor.numel(), + ncclDataTypeEnum.from_torch(in_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), self.comm, + cudaStream_t(stream.cuda_stream)) + return out_tensor + + def all_gather(self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + stream=None): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}") + if stream is None: + stream = current_stream() + self.nccl.ncclAllGather( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), input_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm, + cudaStream_t(stream.cuda_stream)) + + def reduce_scatter(self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + stream=None): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}") + if stream is None: + stream = current_stream() + self.nccl.ncclReduceScatter( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), output_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), self.comm, + cudaStream_t(stream.cuda_stream)) + + def send(self, tensor: torch.Tensor, dst: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), dst, + self.comm, cudaStream_t(stream.cuda_stream)) + + def recv(self, tensor: torch.Tensor, src: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), src, + self.comm, cudaStream_t(stream.cuda_stream)) + + def broadcast(self, tensor: torch.Tensor, src: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + if src == self.rank: + sendbuff = buffer_type(tensor.data_ptr()) + # NCCL requires the sender also to have a receive buffer + recvbuff = buffer_type(tensor.data_ptr()) + else: + sendbuff = buffer_type() + recvbuff = buffer_type(tensor.data_ptr()) + self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), src, + self.comm, cudaStream_t(stream.cuda_stream)) diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py new file mode 100644 index 0000000..82933fa --- /dev/null +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -0,0 +1,361 @@ +# SPDX-License-Identifier: Apache-2.0 + +# This file is a pure Python wrapper for the NCCL library. +# The main purpose is to use NCCL combined with CUDA graph. +# Before writing this script, we tried the following approach: +# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself +# often gets stuck when initializing the NCCL communicator. +# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce` +# contains many other potential cuda APIs, that are not allowed during +# capturing the CUDA graph. For further details, please check +# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ . +# +# Another rejected idea is to write a C/C++ binding for NCCL. It is usually +# doable, but we often encounter issues related with nccl versions, and need +# to switch between different versions of NCCL. See +# https://github.com/NVIDIA/nccl/issues/1234 for more details. +# A C/C++ binding is not flexible enough to handle this. It requires +# recompilation of the code every time we want to switch between different +# versions. This current implementation, with a **pure** Python wrapper, is +# more flexible. We can easily switch between different versions of NCCL by +# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file` +# variable in the code. + +import ctypes +import platform +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import torch +from torch.distributed import ReduceOp + +from vllm.logger import init_logger +from vllm.utils import find_nccl_library + +logger = init_logger(__name__) + +# === export types and functions from nccl to Python === +# for the original nccl definition, please check +# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in + +ncclResult_t = ctypes.c_int +ncclComm_t = ctypes.c_void_p + + +class ncclUniqueId(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 128)] + + +cudaStream_t = ctypes.c_void_p +buffer_type = ctypes.c_void_p + +ncclDataType_t = ctypes.c_int + + +class ncclDataTypeEnum: + ncclInt8 = 0 + ncclChar = 0 + ncclUint8 = 1 + ncclInt32 = 2 + ncclInt = 2 + ncclUint32 = 3 + ncclInt64 = 4 + ncclUint64 = 5 + ncclFloat16 = 6 + ncclHalf = 6 + ncclFloat32 = 7 + ncclFloat = 7 + ncclFloat64 = 8 + ncclDouble = 8 + ncclBfloat16 = 9 + ncclNumTypes = 10 + + @classmethod + def from_torch(cls, dtype: torch.dtype) -> int: + if dtype == torch.int8: + return cls.ncclInt8 + if dtype == torch.uint8: + return cls.ncclUint8 + if dtype == torch.int32: + return cls.ncclInt32 + if dtype == torch.int64: + return cls.ncclInt64 + if dtype == torch.float16: + return cls.ncclFloat16 + if dtype == torch.float32: + return cls.ncclFloat32 + if dtype == torch.float64: + return cls.ncclFloat64 + if dtype == torch.bfloat16: + return cls.ncclBfloat16 + raise ValueError(f"Unsupported dtype: {dtype}") + + +ncclRedOp_t = ctypes.c_int + + +class ncclRedOpTypeEnum: + ncclSum = 0 + ncclProd = 1 + ncclMax = 2 + ncclMin = 3 + ncclAvg = 4 + ncclNumOps = 5 + + @classmethod + def from_torch(cls, op: ReduceOp) -> int: + if op == ReduceOp.SUM: + return cls.ncclSum + if op == ReduceOp.PRODUCT: + return cls.ncclProd + if op == ReduceOp.MAX: + return cls.ncclMax + if op == ReduceOp.MIN: + return cls.ncclMin + if op == ReduceOp.AVG: + return cls.ncclAvg + raise ValueError(f"Unsupported op: {op}") + + +@dataclass +class Function: + name: str + restype: Any + argtypes: List[Any] + + +class NCCLLibrary: + exported_functions = [ + # const char* ncclGetErrorString(ncclResult_t result) + Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]), + # ncclResult_t ncclGetVersion(int *version); + Function("ncclGetVersion", ncclResult_t, + [ctypes.POINTER(ctypes.c_int)]), + # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); + Function("ncclGetUniqueId", ncclResult_t, + [ctypes.POINTER(ncclUniqueId)]), + # ncclResult_t ncclCommInitRank( + # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); + # note that ncclComm_t is a pointer type, so the first argument + # is a pointer to a pointer + Function("ncclCommInitRank", ncclResult_t, [ + ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, + ctypes.c_int + ]), + # ncclResult_t ncclAllReduce( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclAllReduce", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclRedOp_t, ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclAllGather( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclAllGather", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclReduceScatter( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclReduceScatter", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclRedOp_t, ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclSend( + # const void* sendbuff, size_t count, ncclDataType_t datatype, + # int dest, ncclComm_t comm, cudaStream_t stream); + Function("ncclSend", ncclResult_t, [ + buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, + ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclRecv( + # void* recvbuff, size_t count, ncclDataType_t datatype, + # int src, ncclComm_t comm, cudaStream_t stream); + Function("ncclRecv", ncclResult_t, [ + buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, + ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclBroadcast( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, int root, ncclComm_t comm, + # cudaStream_t stream); + Function("ncclBroadcast", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ctypes.c_int, ncclComm_t, cudaStream_t + ]), + + # be cautious! this is a collective call, it will block until all + # processes in the communicator have called this function. + # because Python object destruction can happen in random order, + # it is better not to call it at all. + # ncclResult_t ncclCommDestroy(ncclComm_t comm); + Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]), + ] + + # class attribute to store the mapping from the path to the library + # to avoid loading the same library multiple times + path_to_library_cache: Dict[str, Any] = {} + + # class attribute to store the mapping from library path + # to the corresponding dictionary + path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} + + def __init__(self, so_file: Optional[str] = None): + + so_file = so_file or find_nccl_library() + + try: + if so_file not in NCCLLibrary.path_to_dict_mapping: + lib = ctypes.CDLL(so_file) + NCCLLibrary.path_to_library_cache[so_file] = lib + self.lib = NCCLLibrary.path_to_library_cache[so_file] + except Exception as e: + logger.error( + "Failed to load NCCL library from %s. " + "It is expected if you are not running on NVIDIA/AMD GPUs." + "Otherwise, the nccl library might not exist, be corrupted " + "or it does not support the current platform %s. " + "If you already have the library, please set the " + "environment variable VLLM_NCCL_SO_PATH" + " to point to the correct nccl library path.", so_file, + platform.platform()) + raise e + + if so_file not in NCCLLibrary.path_to_dict_mapping: + _funcs: Dict[str, Any] = {} + for func in NCCLLibrary.exported_functions: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + _funcs[func.name] = f + NCCLLibrary.path_to_dict_mapping[so_file] = _funcs + self._funcs = NCCLLibrary.path_to_dict_mapping[so_file] + + def ncclGetErrorString(self, result: ncclResult_t) -> str: + return self._funcs["ncclGetErrorString"](result).decode("utf-8") + + def NCCL_CHECK(self, result: ncclResult_t) -> None: + if result != 0: + error_str = self.ncclGetErrorString(result) + raise RuntimeError(f"NCCL error: {error_str}") + + def ncclGetVersion(self) -> str: + version = ctypes.c_int() + self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version))) + version_str = str(version.value) + # something like 21903 --> "2.19.3" + major = version_str[0].lstrip("0") + minor = version_str[1:3].lstrip("0") + patch = version_str[3:].lstrip("0") + return f"{major}.{minor}.{patch}" + + def ncclGetUniqueId(self) -> ncclUniqueId: + unique_id = ncclUniqueId() + self.NCCL_CHECK(self._funcs["ncclGetUniqueId"]( + ctypes.byref(unique_id))) + return unique_id + + def unique_id_from_bytes(self, data: bytes) -> ncclUniqueId: + """ + Reconstructs an `ncclUniqueId` object from bytes data. + + Args: + data: Must be a 128-byte data block (matching NCCL's unique_id). + + Returns: + ncclUniqueId: The reconstructed NCCL Unique ID object. + + Raises: + ValueError: If the input data length is not 128 bytes. + """ + if len(data) != 128: + raise ValueError( + f"Expected 128 bytes for ncclUniqueId, got {len(data)} bytes") + + unique_id = ncclUniqueId() + ctypes.memmove(ctypes.addressof(unique_id.internal), data, 128) + return unique_id + + def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId, + rank: int) -> ncclComm_t: + comm = ncclComm_t() + self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm), + world_size, unique_id, + rank)) + return comm + + def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, op: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count, + datatype, op, comm, + stream)) + + def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, op: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclReduceScatter"](sendbuff, recvbuff, + count, datatype, op, + comm, stream)) + + def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # which is an aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclAllGather"](sendbuff, recvbuff, count, + datatype, comm, stream)) + + def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int, + dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype, + dest, comm, stream)) + + def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int, + src: int, comm: ncclComm_t, stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src, + comm, stream)) + + def ncclBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, root: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclBroadcast"](sendbuff, recvbuff, count, + datatype, root, comm, + stream)) + + def ncclCommDestroy(self, comm: ncclComm_t) -> None: + self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) + + +__all__ = [ + "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", + "ncclComm_t", "cudaStream_t", "buffer_type" +] \ No newline at end of file diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py new file mode 100644 index 0000000..11ed7c0 --- /dev/null +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -0,0 +1,542 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +import pickle +import sys +import time +from contextlib import contextmanager +from dataclasses import dataclass, field +from multiprocessing import shared_memory +from typing import List, Optional, Tuple, Union +from unittest.mock import patch + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from zmq import IPV6 # type: ignore +from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore + +import vllm.envs as envs +from vllm.distributed.utils import StatelessProcessGroup +from vllm.logger import init_logger +from vllm.utils import (get_ip, get_open_port, get_open_zmq_ipc_path, + is_valid_ipv6_address) + +VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL + +logger = init_logger(__name__) + +# We prefer to use os.sched_yield as it results in tighter polling loops, +# measured to be around 3e-7 seconds. However on earlier versions of Python +# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0) +USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1)) + or (sys.version_info[:2] == (3, 10) + and sys.version_info[2] >= 8)) + + +def sched_yield(): + if USE_SCHED_YIELD: + os.sched_yield() + else: + time.sleep(0) + + +class ShmRingBuffer: + + def __init__(self, + n_reader: int, + max_chunk_bytes: int, + max_chunks: int, + name: Optional[str] = None): + """ + A shared memory ring buffer implementation for broadcast communication. + Essentially, it is a queue where only one will `enqueue` and multiple + will `dequeue`. The max size of each item, together with the max number + of items that can be stored in the buffer are known in advance. + In this case, we don't need to synchronize the access to + the buffer. + + Buffer memory layout: + data metadata + | | + | (current_idx) | (current_idx) + v v + +-------------------------------+----------------------------------------+ + | chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata | + +-------------------------------+----------------------------------------+ + | max_chunks x max_chunk_bytes | max_chunks x (1 + n_reader) bytes | + + metadata memory layout: each byte is a flag, the first byte is the written + flag, and the rest are reader flags. The flags are set to 0 by default. + +--------------+--------------+--------------+-----+--------------+ + | written_flag | reader0_flag | reader1_flag | ... | readerN_flag | + +--------------+--------------+--------------+-----+--------------+ + + The state of metadata is as follows: + + (case 1) 0???...???: the block is not written yet, cannot read, can write + (case 2) 1000...000: the block is just written, can read, cannot write + (case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write + (case 4) 1111...111: the block is written and read by all readers, cannot read, can write + + State transition for readers: + + When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read. + Only after the caller finishes reading the block, the reader can mark the block as read. + Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0). + + State transition for writer: + + When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case + to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer + can reset the reader flags to 0, and mark the block as written (from 0 to 1). + NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct. + + During creation, `name` is None and the buffer is created. We can pass the + created object to other processes by pickling it. The other processes will + get the name of the shared memory and open it, so that they can access the + same shared memory buffer. + """# noqa + self.n_reader = n_reader + self.metadata_size = 1 + n_reader + self.max_chunk_bytes = max_chunk_bytes + self.max_chunks = max_chunks + self.total_bytes_of_buffer = (self.max_chunk_bytes + + self.metadata_size) * self.max_chunks + self.data_offset = 0 + self.metadata_offset = self.max_chunk_bytes * self.max_chunks + + if name is None: + # we are creating a buffer + self.is_creator = True + self.shared_memory = shared_memory.SharedMemory( + create=True, size=self.total_bytes_of_buffer) + # initialize the metadata section to 0 + with memoryview(self.shared_memory.buf[self.metadata_offset:] + ) as metadata_buffer: + torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0) + else: + # we are opening an existing buffer + self.is_creator = False + # fix to https://stackoverflow.com/q/62748654/9191338 + # Python incorrectly tracks shared memory even if it is not + # created by the process. The following patch is a workaround. + with patch("multiprocessing.resource_tracker.register", + lambda *args, **kwargs: None): + try: + self.shared_memory = shared_memory.SharedMemory(name=name) + # See https://docs.python.org/3/library/multiprocessing.shared_memory.html # noqa + # Some platforms allocate memory based on page size, + # so the shared memory block size may be larger or equal + # to the requested size. The size parameter is ignored + # when attaching to an existing block. + assert (self.shared_memory.size + >= self.total_bytes_of_buffer) + except FileNotFoundError: + # we might deserialize the object in a different node + # in this case, this object is not used, + # and we should suppress the error + pass + + def handle(self): + return (self.n_reader, self.max_chunk_bytes, self.max_chunks, + self.shared_memory.name) + + def __reduce__(self): + return ( + self.__class__, + self.handle(), + ) + + def __del__(self): + if hasattr(self, "shared_memory"): + self.shared_memory.close() + if self.is_creator: + self.shared_memory.unlink() + + @contextmanager + def get_data(self, current_idx: int): + start = self.data_offset + current_idx * self.max_chunk_bytes + end = start + self.max_chunk_bytes + with memoryview(self.shared_memory.buf[start:end]) as buf: + yield buf + + @contextmanager + def get_metadata(self, current_idx: int): + start = self.metadata_offset + current_idx * self.metadata_size + end = start + self.metadata_size + with memoryview(self.shared_memory.buf[start:end]) as buf: + yield buf + + +@dataclass +class Handle: + local_reader_ranks: List[int] = field(default_factory=list) + + buffer_handle: Optional[Tuple[int, int, int, str]] = None + local_subscribe_addr: Optional[str] = None + remote_subscribe_addr: Optional[str] = None + remote_addr_ipv6: bool = False + + +class MessageQueue: + + def __init__( + self, + n_reader, # number of all readers + n_local_reader, # number of local readers through shared memory + local_reader_ranks: Optional[List[int]] = None, + max_chunk_bytes: int = 1024 * 1024 * 10, + max_chunks: int = 10, + connect_ip: Optional[str] = None, + ): + if local_reader_ranks is None: + local_reader_ranks = list(range(n_local_reader)) + else: + assert len(local_reader_ranks) == n_local_reader + self.n_local_reader = n_local_reader + n_remote_reader = n_reader - n_local_reader + self.n_remote_reader = n_remote_reader + + context = Context() + + if n_local_reader > 0: + # for local readers, we will: + # 1. create a shared memory ring buffer to communicate small data + # 2. create a publish-subscribe socket to communicate large data + self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes, + max_chunks) + + # XPUB is very similar to PUB, + # except that it can receive subscription messages + # to confirm the number of subscribers + self.local_socket = context.socket(XPUB) + # set the verbose option so that we can receive every subscription + # message. otherwise, we will only receive the first subscription + # see http://api.zeromq.org/3-3:zmq-setsockopt for more details + self.local_socket.setsockopt(XPUB_VERBOSE, True) + local_subscribe_addr = get_open_zmq_ipc_path() + logger.debug("Binding to %s", local_subscribe_addr) + self.local_socket.bind(local_subscribe_addr) + + self.current_idx = 0 + else: + self.buffer = None # type: ignore + local_subscribe_addr = None + self.local_socket = None + self.current_idx = -1 + + remote_addr_ipv6 = False + if n_remote_reader > 0: + # for remote readers, we will: + # create a publish-subscribe socket to communicate large data + if not connect_ip: + connect_ip = get_ip() + self.remote_socket = context.socket(XPUB) + self.remote_socket.setsockopt(XPUB_VERBOSE, True) + remote_subscribe_port = get_open_port() + if is_valid_ipv6_address(connect_ip): + self.remote_socket.setsockopt(IPV6, 1) + remote_addr_ipv6 = True + connect_ip = f"[{connect_ip}]" + socket_addr = f"tcp://*:{remote_subscribe_port}" + self.remote_socket.bind(socket_addr) + remote_subscribe_addr = f"tcp://{connect_ip}:{remote_subscribe_port}" + else: + remote_subscribe_addr = None + self.remote_socket = None + + self._is_writer = True + self._is_local_reader = False + self.local_reader_rank = -1 + # rank does not matter for remote readers + self._is_remote_reader = False + + self.handle = Handle( + local_reader_ranks=local_reader_ranks, + buffer_handle=self.buffer.handle() + if self.buffer is not None else None, + local_subscribe_addr=local_subscribe_addr, + remote_subscribe_addr=remote_subscribe_addr, + remote_addr_ipv6=remote_addr_ipv6, + ) + + logger.info("vLLM message queue communication handle: %s", self.handle) + + def export_handle(self) -> Handle: + return self.handle + + @staticmethod + def create_from_handle(handle: Handle, rank) -> "MessageQueue": + self = MessageQueue.__new__(MessageQueue) + self.handle = handle + self._is_writer = False + + context = Context() + + if rank in handle.local_reader_ranks: + assert handle.buffer_handle is not None + self.buffer = ShmRingBuffer(*handle.buffer_handle) + self.current_idx = 0 + self.local_reader_rank = handle.local_reader_ranks.index(rank) + self._is_local_reader = True + self._is_remote_reader = False + + self.local_socket = context.socket(SUB) + self.local_socket.setsockopt_string(SUBSCRIBE, "") + socket_addr = handle.local_subscribe_addr + logger.debug("Connecting to %s", socket_addr) + self.local_socket.connect(socket_addr) + + self.remote_socket = None + else: + self.buffer = None # type: ignore + self.current_idx = -1 + self.local_reader_rank = -1 + self._is_local_reader = False + self._is_remote_reader = True + + self.local_socket = None + + self.remote_socket = context.socket(SUB) + self.remote_socket.setsockopt_string(SUBSCRIBE, "") + if handle.remote_addr_ipv6: + self.remote_socket.setsockopt(IPV6, 1) + socket_addr = handle.remote_subscribe_addr + logger.debug("Connecting to %s", socket_addr) + self.remote_socket.connect(socket_addr) + + return self + + def wait_until_ready(self): + """This is a collective operation. All processes (including the + readers and the writer) should call this function. + """ + if self._is_writer: + # wait for all readers to connect + + # local readers + for i in range(self.n_local_reader): + # wait for subscription messages from all local readers + self.local_socket.recv() + if self.n_local_reader > 0: + # send a message to all local readers + # to make sure the publish channel is working + self.local_socket.send(b"READY") + + # remote readers + for i in range(self.n_remote_reader): + # wait for subscription messages from all remote readers + self.remote_socket.recv() + if self.n_remote_reader > 0: + # send a message to all remote readers + # to make sure the publish channel is working + self.remote_socket.send(b"READY") + elif self._is_local_reader: + # wait for the writer to send a message + recv = self.local_socket.recv() + assert recv == b"READY" + elif self._is_remote_reader: + # wait for the writer to send a message + recv = self.remote_socket.recv() + assert recv == b"READY" + + @contextmanager + def acquire_write(self, timeout: Optional[float] = None): + assert self._is_writer, "Only writers can acquire write" + start_time = time.monotonic() + n_warning = 1 + while True: + with self.buffer.get_metadata(self.current_idx) as metadata_buffer: + read_count = sum(metadata_buffer[1:]) + written_flag = metadata_buffer[0] + if written_flag and read_count != self.buffer.n_reader: + # this block is written and not read by all readers + # for writers, `self.current_idx` is the next block to write + # if this block is not ready to write, + # we need to wait until it is read by all readers + + # Release the processor to other threads + sched_yield() + + # if we wait for a long time, log a message + if (time.monotonic() - start_time + > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning): + logger.debug( + ("No available shared memory broadcast block found" + " in %s second."), + VLLM_RINGBUFFER_WARNING_INTERVAL, + ) + n_warning += 1 + + # if we time out, raise an exception + if (timeout is not None + and time.monotonic() - start_time > timeout): + raise TimeoutError + + continue + # found a block that is either + # (1) not written + # (2) read by all readers + + # mark the block as not written + metadata_buffer[0] = 0 + # let caller write to the buffer + with self.buffer.get_data(self.current_idx) as buf: + yield buf + + # caller has written to the buffer + # NOTE: order is important here + # first set the read flags to 0 + # then set the written flag to 1 + # otherwise, the readers may think they already read the block + for i in range(1, self.buffer.n_reader + 1): + # set read flag to 0, meaning it is not read yet + metadata_buffer[i] = 0 + # mark the block as written + metadata_buffer[0] = 1 + self.current_idx = (self.current_idx + + 1) % self.buffer.max_chunks + break + + @contextmanager + def acquire_read(self, timeout: Optional[float] = None): + assert self._is_local_reader, "Only readers can acquire read" + start_time = time.monotonic() + n_warning = 1 + while True: + with self.buffer.get_metadata(self.current_idx) as metadata_buffer: + read_flag = metadata_buffer[self.local_reader_rank + 1] + written_flag = metadata_buffer[0] + if not written_flag or read_flag: + # this block is either + # (1) not written + # (2) already read by this reader + + # for readers, `self.current_idx` is the next block to read + # if this block is not ready, + # we need to wait until it is written + + # Release the processor to other threads + sched_yield() + + # if we wait for a long time, log a message + if (time.monotonic() - start_time + > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning): + logger.debug( + ("No available shared memory broadcast block found" + "in %s second."), + VLLM_RINGBUFFER_WARNING_INTERVAL, + ) + n_warning += 1 + + # if we time out, raise an exception + if (timeout is not None + and time.monotonic() - start_time > timeout): + raise TimeoutError + + continue + # found a block that is not read by this reader + # let caller read from the buffer + with self.buffer.get_data(self.current_idx) as buf: + yield buf + + # caller has read from the buffer + # set the read flag + metadata_buffer[self.local_reader_rank + 1] = 1 + self.current_idx = (self.current_idx + + 1) % self.buffer.max_chunks + break + + def enqueue(self, obj, timeout: Optional[float] = None): + """ Write to message queue with optional timeout (in seconds) """ + assert self._is_writer, "Only writers can enqueue" + serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) + if self.n_local_reader > 0: + if len(serialized_obj) >= self.buffer.max_chunk_bytes: + with self.acquire_write(timeout) as buf: + buf[0] = 1 # overflow + self.local_socket.send(serialized_obj) + else: + with self.acquire_write(timeout) as buf: + buf[0] = 0 # not overflow + buf[1:len(serialized_obj) + 1] = serialized_obj + if self.n_remote_reader > 0: + self.remote_socket.send(serialized_obj) + + def dequeue(self, timeout: Optional[float] = None): + """ Read from message queue with optional timeout (in seconds) """ + if self._is_local_reader: + with self.acquire_read(timeout) as buf: + overflow = buf[0] == 1 + if not overflow: + # no need to know the size of serialized object + # pickle format contains the size information internally + # see https://docs.python.org/3/library/pickle.html + obj = pickle.loads(buf[1:]) + if overflow: + recv = self.local_socket.recv() + obj = pickle.loads(recv) + elif self._is_remote_reader: + recv = self.remote_socket.recv() + obj = pickle.loads(recv) + else: + raise RuntimeError("Only readers can dequeue") + return obj + + def broadcast_object(self, obj=None): + if self._is_writer: + self.enqueue(obj) + return obj + else: + return self.dequeue() + + @staticmethod + def create_from_process_group(pg: Union[ProcessGroup, + StatelessProcessGroup], + max_chunk_bytes, + max_chunks, + writer_rank=0) -> "MessageQueue": + if isinstance(pg, ProcessGroup): + group_rank = dist.get_rank(pg) + group_world_size = dist.get_world_size(pg) + global_ranks = dist.get_process_group_ranks(pg) + else: + group_rank = pg.rank + group_world_size = pg.world_size + global_ranks = list(range(pg.world_size)) + + from vllm.distributed.parallel_state import in_the_same_node_as + status = in_the_same_node_as(pg, source_rank=writer_rank) + same_node_ranks = [i for i, s in enumerate(status) if s] + n_reader = group_world_size - 1 + n_local_reader = len(same_node_ranks) - 1 + local_reader_ranks = [i for i in same_node_ranks if i != writer_rank] + buffer_io: MessageQueue + if group_rank == writer_rank: + buffer_io = MessageQueue( + n_reader=n_reader, + n_local_reader=n_local_reader, + local_reader_ranks=local_reader_ranks, + max_chunk_bytes=max_chunk_bytes, + max_chunks=max_chunks, + ) + handle = buffer_io.export_handle() + if isinstance(pg, ProcessGroup): + dist.broadcast_object_list([handle], + src=global_ranks[writer_rank], + group=pg) + else: + pg.broadcast_obj(handle, writer_rank) + else: + if isinstance(pg, ProcessGroup): + recv = [None] + dist.broadcast_object_list(recv, + src=global_ranks[writer_rank], + group=pg) + handle = recv[0] # type: ignore + else: + handle = pg.broadcast_obj(None, writer_rank) + buffer_io = MessageQueue.create_from_handle(handle, group_rank) + buffer_io.wait_until_ready() + return buffer_io diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py new file mode 100644 index 0000000..de66cea --- /dev/null +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -0,0 +1,93 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +from typing import Optional + +import torch +from torch.distributed import ProcessGroup + +from vllm.config import get_current_vllm_config +from vllm.logger import init_logger +from vllm.platforms import current_platform + +from .base_device_communicator import DeviceCommunicatorBase + +USE_RAY = parallel_config = get_current_vllm_config( +).parallel_config.distributed_executor_backend == "ray" + +logger = init_logger(__name__) + +if current_platform.is_tpu(): + import torch_xla + import torch_xla.core.xla_model as xm + import torch_xla.runtime as xr + from torch_xla._internal import pjrt + from torch_xla.distributed.xla_multiprocessing import ( + create_optimized_replica_groups) + + if USE_RAY: + from vllm.executor import ray_utils + + +class TpuCommunicator(DeviceCommunicatorBase): + + def __init__(self, + cpu_group: ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[ProcessGroup] = None, + unique_name: str = ""): + super().__init__(cpu_group, device, device_group, unique_name) + + # NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node + # must be used together. Therefore, the local rank and world size can + # be simply calculated as follows. + global_rank = self.global_rank + global_world_size = self.global_world_size + + if USE_RAY: + logger.info("TpuCommunicator initialized with RAY") + # Calculate how many TPU nodes are in the current deployment. This + # is the Ray placement group if it is deployed with Ray. Default + # to the number of TPU nodes in the Ray cluster. The number of TPU + # nodes is computed by the total number of TPUs divided by the + # number of TPU accelerators per node, to account for clusters + # with both CPUs and TPUs. + num_nodes = ray_utils.get_num_tpu_nodes() + num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group() + if num_nodes_in_pg > 0: + num_nodes = num_nodes_in_pg + + local_world_size = global_world_size // num_nodes + local_rank = global_rank % local_world_size + else: + logger.info("TpuCommunicator initialized with MP") + # Sanity: Verify we run on a single host + num_hosts = torch_xla.tpu.num_tpu_workers() + assert num_hosts == 1 + + # Get the current number of TPUs (we have locally) + local_world_size = torch_xla.tpu.num_available_chips() + + # Get current rank + local_rank = global_rank % local_world_size + + # Ensure environment variables are set for multihost deployments. + # On GKE, this is needed for libtpu and TPU driver to know which TPU + # chip is actually visible. Otherwise the TPU driver will fail to + # initialize because the number of devices would be different from + # the number of visible worker addresses. + os.environ["CLOUD_TPU_TASK_ID"] = str(global_rank) + os.environ["TPU_VISIBLE_CHIPS"] = str(local_rank) + + pjrt.initialize_multiprocess(local_rank, local_world_size) + xr._init_world_size_ordinal() + self.groups = create_optimized_replica_groups() + + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + # TODO: Remove the groups specification after XLA compiler can support + # auto-reordering the ring order for all-reduce. + return xm.all_reduce(xm.REDUCE_SUM, input_, groups=self.groups) + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + assert dim == -1, "TPUs only support dim=-1 for all-gather." + return xm.all_gather(input_, dim=dim) diff --git a/vllm/distributed/device_communicators/xpu_communicator.py b/vllm/distributed/device_communicators/xpu_communicator.py new file mode 100644 index 0000000..256e796 --- /dev/null +++ b/vllm/distributed/device_communicators/xpu_communicator.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from .base_device_communicator import DeviceCommunicatorBase + + +class XpuCommunicator(DeviceCommunicatorBase): + + def __init__(self, + cpu_group: ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[ProcessGroup] = None, + unique_name: str = ""): + super().__init__(cpu_group, device, device_group, unique_name) + + def all_reduce(self, input_) -> torch.Tensor: + dist.all_reduce(input_, group=self.device_group) + return input_ + + def gather(self, + input_: torch.Tensor, + dst: int = 0, + dim: int = -1) -> Optional[torch.Tensor]: + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # For xpu path, gather doesn't work properly together with ray + # cluster so we use all_gather instead for now. + input_size = input_.size() + # Allocate output tensor. + output_tensor = torch.empty((self.world_size, ) + input_size, + dtype=input_.dtype, + device=input_.device) + # All-gather. + dist.all_gather_into_tensor(output_tensor, + input_, + group=self.device_group) + if self.rank_in_group == dst: + # Reshape + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape(input_size[:dim] + + (self.world_size * + input_size[dim], ) + + input_size[dim + 1:]) + else: + output_tensor = None + return output_tensor diff --git a/vllm/distributed/kv_transfer/README.md b/vllm/distributed/kv_transfer/README.md new file mode 100644 index 0000000..349d3df --- /dev/null +++ b/vllm/distributed/kv_transfer/README.md @@ -0,0 +1,29 @@ + +# Distributed KV cache transfer + +This folder implements distributed KV cache transfer across vLLM instances. +Currently the main usecase is for disaggregated prefilling. + +## Abstractions + +The KV cache transfer contains three layer of abstractions: + +- KV pipe: a FIFO pipe for torch.tensor transmission. Key APIs: `send_tensor` and `recv_tensor`. +- KV lookup buffer: a lookup buffer for KV caches. Key: the tokens, value: the KV caches (and/or hidden states). Key APIs: `insert` and `drop_select` (similar to SQL semantics). +- KV connector: a connector that connects the KV pipe and KV lookup buffer to vLLM. Key APIs: `send_kv_caches_and_hidden_states` and `recv_kv_caches_and_hidden_states`. + +Why we need KV lookup buffer: FIFO pipe itself is not enough as prefill vLLM worker may process requests in a different order compared to decode vLLM worker. Say the QPS is really high, prefill worker may handle requests in order A -> B -> C, but the decode worker may process request C first. This is not the case that can be naturally handled by FIFO pipe, so we provide KV lookup buffer to help translate a FIFO pipe to a lookup buffer. + +NOTE: KV pipe layer is bypassible: you can skip this layer if your distributed +communication service already supports key-value-based lookup (like redis or +RDMA database). + +NOTE: If you want to not only transfer KV caches, but adjust the model execution flow of vLLM as well (for example, allow vLLM to receive KV caches on some tokens and do prefill on the remaining tokens), you can bypass both KV pipe layer and KV lookup buffer layer, and directly implement on KV connector layer. Bear in mind that as vLLM's model input is constantly changing, this implementation will likely be broken when vLLM has new updates. + +## Disaggregated prefilling + +The example usage is in [this file](../../../examples/online_serving/disaggregated_prefill.sh). + +Here is the diagram of how we run disaggregated prefilling. + +![Disaggregated prefill workflow](./disagg_prefill_workflow.jpg) diff --git a/vllm/distributed/kv_transfer/__init__.py b/vllm/distributed/kv_transfer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/distributed/kv_transfer/disagg_prefill_workflow.jpg b/vllm/distributed/kv_transfer/disagg_prefill_workflow.jpg new file mode 100644 index 0000000..a25ec5e Binary files /dev/null and b/vllm/distributed/kv_transfer/disagg_prefill_workflow.jpg differ diff --git a/vllm/distributed/kv_transfer/kv_connector/__init__.py b/vllm/distributed/kv_transfer/kv_connector/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/distributed/kv_transfer/kv_connector/base.py b/vllm/distributed/kv_transfer/kv_connector/base.py new file mode 100644 index 0000000..57c764b --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/base.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +KVConnectorBase Class for Distributed KV Cache & Hidden State communication + +The class provides two primary abstract methods: +1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states +2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, List, Tuple, Union + +import torch + +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from vllm.config import VllmConfig + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + + +class KVConnectorBase(ABC): + """ + Abstract base class for a KV connector. + + The class provides two primary abstract methods: + 1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states + 2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states + """ + + @abstractmethod + def __init__( + self, + rank: int, + local_rank: int, + config: "VllmConfig", + ): + raise NotImplementedError + + @abstractmethod + def close(self) -> None: + """Close the buffer and release resources. + + This method is responsible for cleaning up resources related to the + connector when it is no longer needed. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + """ + Send KV caches and hidden states to the connector. + + This method processes the input tokens, KV caches, and + hidden/intermediate states for a given model and sends the data to the + decode instance. + + Args: + model_executable (torch.nn.Module): The model executable containing + start and end layer information. + model_input (ModelInputForGPUWithSamplingMetadata): The input + metadata from vLLM. + kv_caches (List[torch.Tensor]): List of KV caches (keys and values) + for each layer. + hidden_or_intermediate_states (Union[torch.Tensor, + IntermediateTensors]): + The hidden or intermediate states associated with the tokens. + + Returns: + None + + """ + + raise NotImplementedError + + @abstractmethod + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor] + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + """ + Receive KV caches and hidden states from the connector. + + This method attempts to retrieve KV caches and hidden states for input + tokens. If all required KV caches and hidden states are received, it + will bypass model input, else it will fall back to normal vLLM model + forwarding. + + Args: + model_executable (torch.nn.Module): + The model executable from vLLM modelrunner. + model_input (ModelInputForGPUWithSamplingMetadata): + The model input from vLLM modelrunner. + kv_caches (List[torch.Tensor]): + List of KV caches for each layer. + + Returns: + - hidden_or_intermediate_states (torch.Tensor or + IntermediateTensors): + Concatenated hidden states if all required data is retrieved, + otherwise `None`. + - bypass_model_exec (bool): + Indicates whether the model execution can be skipped (True) or + needs to be redone (False). + - model_input (ModelInputForGPUWithSamplingMetadata): + Optionally adjusted input metadata for re-execution when + `bypass_model_exec=False`. + + """ + + raise NotImplementedError diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py new file mode 100644 index 0000000..39c9809 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -0,0 +1,64 @@ +# SPDX-License-Identifier: Apache-2.0 + +import importlib +from typing import TYPE_CHECKING, Callable, Dict, Type + +from .base import KVConnectorBase + +if TYPE_CHECKING: + from vllm.config import VllmConfig + + +class KVConnectorFactory: + _registry: Dict[str, Callable[[], Type[KVConnectorBase]]] = {} + + @classmethod + def register_connector(cls, name: str, module_path: str, + class_name: str) -> None: + """Register a connector with a lazy-loading module and class name.""" + if name in cls._registry: + raise ValueError(f"Connector '{name}' is already registered.") + + def loader() -> Type[KVConnectorBase]: + module = importlib.import_module(module_path) + return getattr(module, class_name) + + cls._registry[name] = loader + + @classmethod + def create_connector(cls, rank: int, local_rank: int, + config: "VllmConfig") -> KVConnectorBase: + connector_name = config.kv_transfer_config.kv_connector + if connector_name not in cls._registry: + raise ValueError(f"Unsupported connector type: {connector_name}") + + connector_cls = cls._registry[connector_name]() + return connector_cls(rank, local_rank, config) + + +# Register various connectors here. +# The registration should not be done in each individual file, as we want to +# only load the files corresponding to the current connector. +KVConnectorFactory.register_connector( + "P2pConnector", "vllm.distributed.kv_transfer.kv_connector.p2p_connector", + "P2pConnector") + +KVConnectorFactory.register_connector( + "PyNcclConnector", + "vllm.distributed.kv_transfer.kv_connector.simple_connector", + "SimpleConnector") + +KVConnectorFactory.register_connector( + "MooncakeConnector", + "vllm.distributed.kv_transfer.kv_connector.simple_connector", + "SimpleConnector") + +KVConnectorFactory.register_connector( + "LMCacheConnector", + "vllm.distributed.kv_transfer.kv_connector.lmcache_connector", + "LMCacheConnector") + +KVConnectorFactory.register_connector( + "MooncakeStoreConnector", + "vllm.distributed.kv_transfer.kv_connector.mooncake_store_connector", + "MooncakeStoreConnector") \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py new file mode 100644 index 0000000..42de227 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +LMCache KV Cache Connector for Distributed Machine Learning Inference + +The LMCacheConnector can (1) transfer KV caches between prefill vLLM worker +(KV cache producer) and decode vLLM worker (KV cache consumer) using LMCache; +(2) offload and share KV caches. +""" + +from typing import TYPE_CHECKING, List, Tuple, Union + +import torch + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + +logger = init_logger(__name__) + + +class LMCacheConnector(KVConnectorBase): + + def __init__( + self, + rank: int, + local_rank: int, + config: VllmConfig, + ): + + self.transfer_config = config.kv_transfer_config + self.vllm_config = config + + from lmcache.experimental.cache_engine import LMCacheEngineBuilder + from lmcache.integration.vllm.utils import ENGINE_NAME + from lmcache.integration.vllm.vllm_adapter import ( + RetrieveStatus, StoreStatus, init_lmcache_engine, + lmcache_retrieve_kv, lmcache_should_retrieve, lmcache_should_store, + lmcache_store_kv) + logger.info("Initializing LMCacheConfig under kv_transfer_config %s", + self.transfer_config) + + # TODO (Jiayi): Find model_config, parallel_config, and cache_config + self.engine = init_lmcache_engine(config.model_config, + config.parallel_config, + config.cache_config) + self.lmcache_engine_name = ENGINE_NAME + self.lmcache_engine_builder = LMCacheEngineBuilder + + self.model_config = config.model_config + self.parallel_config = config.parallel_config + self.cache_config = config.cache_config + self.lmcache_retrieve_kv = lmcache_retrieve_kv + self.lmcache_store_kv = lmcache_store_kv + self.lmcache_should_retrieve = lmcache_should_retrieve + self.lmcache_should_store = lmcache_should_store + self.store_status = StoreStatus + self.retrieve_status = RetrieveStatus + + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor] + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + + retrieve_status = self.lmcache_should_retrieve(model_input) + model_input, bypass_model_exec, hidden_or_intermediate_states =\ + self.lmcache_retrieve_kv( + model_executable, model_input, self.cache_config, kv_caches, + retrieve_status) + return hidden_or_intermediate_states, bypass_model_exec, model_input + + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + + store_status = self.lmcache_should_store(model_input) + self.lmcache_store_kv( + self.model_config, + self.parallel_config, + self.cache_config, + model_executable, + model_input, + kv_caches, + store_status, + ) + + def close(self): + self.lmcache_engine_builder.destroy(self.lmcache_engine_name) diff --git a/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py b/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py new file mode 100644 index 0000000..c5135da --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py @@ -0,0 +1,216 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +MooncakeStore Connector for Distributed Machine Learning Inference + +The MooncakeStoreConnector transfers KV caches between prefill vLLM workers +(KV cache producer) and decode vLLM workers (KV cache consumer) using a +database-style KVStore. +""" +import hashlib +from typing import TYPE_CHECKING, List, Tuple, Union + +import torch + +from vllm import _custom_ops as ops +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + +logger = init_logger(__name__) + + +class MooncakeStoreConnector(KVConnectorBase): + + def __init__( + self, + rank: int, + local_rank: int, + config: VllmConfig, + ): + self.config = config.kv_transfer_config + self.tp_size = config.parallel_config.tensor_parallel_size + + self.local_tp_rank = local_rank + + # Init kv_store + if self.config.kv_connector == "MooncakeStoreConnector": + # Check if MOONCAKE_CONFIG_PATH is set + import os + use_mooncake_store = os.getenv('MOONCAKE_CONFIG_PATH') is not None + + if not use_mooncake_store: + raise ValueError( + "To use MooncakeStoreConnector, you need to pass the ENV: " + "'MOONCAKE_CONFIG_PATH=/path/to/mooncake_config.json'.") + else: + from vllm.distributed.kv_transfer.kv_lookup_buffer.mooncake_store import ( # noqa: E501 + MooncakeStore) + logger.info( + "Initializing KVStoreConnector under kv_transfer_config %s", + self.config) + self.kv_store = MooncakeStore(config) + else: + logger.error("Can not find %s", self.config.kv_connector) + + assert self.kv_store is not None + + def close(self) -> None: + """Close the buffer and release resources. + This method is responsible for cleaning up resources related to the + connector when it is no longer needed. + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + self.kv_store.close() + + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() + start_layer = model_executable.model.start_layer + end_layer = model_executable.model.end_layer + + model_config = model_executable.model.config + num_heads = int(model_config.num_key_value_heads / self.tp_size) + hidden_size = model_config.hidden_size + num_attention_heads = model_config.num_attention_heads + head_size = int(hidden_size / num_attention_heads) + + for idx, slen in enumerate(seq_lens): + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + + current_tokens = input_tokens_tensor[start_pos:end_pos] + store_key_prefix = self.tensor_hash(current_tokens) + keys, values = [], [] + + for layer_id in range(start_layer, end_layer): + kv_cache = kv_caches[layer_id - start_layer] + + key_cache = kv_cache[0].reshape(-1, num_heads, head_size) + value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + + current_slot_mapping = slot_mapping_flat[start_pos:end_pos] + + keys.append(key_cache[current_slot_mapping].unsqueeze(0)) + values.append(value_cache[current_slot_mapping].unsqueeze(0)) + + keys = torch.cat(keys, dim=0) + values = torch.cat(values, dim=0) + kvcache_to_sent = torch.stack((keys, values), dim=0) + store_kvcache_key = f"{store_key_prefix}_{self.local_tp_rank}" + self.kv_store.put(store_kvcache_key, kvcache_to_sent) + + hidden_key = f"{store_key_prefix}_hidden_{self.local_tp_rank}" + self.kv_store.put(hidden_key, + hidden_or_intermediate_states[start_pos:end_pos]) + + logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) + + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor] + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + bypass_model_exec = True + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens + slot_mapping = model_input.attn_metadata.slot_mapping.flatten() + start_layer = model_executable.model.start_layer + end_layer = model_executable.model.end_layer + hidden_or_intermediate_states_for_one_req = [] + + for idx, slen in enumerate(seq_lens): + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + + if start_pos >= num_prefill_tokens: + # This can happen during inflight batching. See: + # vllm/worker/model_runner.py::_prepare_model_input_tensors: + # - input_tokens[:num_prefill_tokens] contains prefill tokens. + # - input_tokens[num_prefill_tokens:] contains decode tokens. + logger.warning("You should set --enable_chunked_prefill=False " + "and --max_num_batched_tokens " + "should be equal to max_seq_len_to_capture") + bypass_model_exec = False + assert start_pos == num_prefill_tokens + break + + current_tokens = input_tokens_tensor[start_pos:end_pos] + + # get roi for current seq + load_key_prefix = self.tensor_hash(current_tokens) + load_kvcache_key = f"{load_key_prefix}_{self.local_tp_rank}" + remote_kv = self.kv_store.get(load_kvcache_key) + hidden_key = f"{load_key_prefix}_hidden_{self.local_tp_rank}" + hidden = self.kv_store.get(hidden_key) + + if remote_kv is None or hidden is None: + # didn't find any match. + bypass_model_exec = False + continue + + num_computed_tokens = current_tokens.shape[0] + + # update the end position based on how many tokens are cached. + end_pos = start_pos + num_computed_tokens + + # call self.kv_store to get kv layer by layer + for layer_id in range(start_layer, end_layer): + layer = model_executable.model.layers[layer_id] + # get kvcache object + kv_cache = kv_caches[layer_id - start_layer] + key_cache, value_cache = kv_cache[0], kv_cache[1] + # get remote kvcache + + remote_k, remote_v = remote_kv[0][layer_id], remote_kv[1][ + layer_id] + # use ops.reshape_and_cache_flash to put kv into kvcache + ops.reshape_and_cache_flash( + remote_k.to(key_cache.device), + remote_v.to(value_cache.device), + key_cache, + value_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + layer.self_attn.attn._v_scale, + ) + + hidden_or_intermediate_states_for_one_req.append(hidden) + + if not bypass_model_exec: + logger.warning( + "[rank%d]: Failed to receive all KVs and hidden " + "states, redo model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = None + + else: + logger.debug( + "[rank%d]: Successfully received all KVs and hidden " + "states, skip model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = torch.cat( + hidden_or_intermediate_states_for_one_req, dim=0) + + return hidden_or_intermediate_states, bypass_model_exec, model_input + + @staticmethod + def tensor_hash(tensor: torch.Tensor) -> int: + """Calculate the hash value of the tensor.""" + tensor_bytes = tensor.clone().detach().cpu().numpy().tobytes() + hash_object = hashlib.blake2b(tensor_bytes) + hash_hex = hash_object.hexdigest() + return int(hash_hex[:16], 16) diff --git a/vllm/distributed/kv_transfer/kv_connector/p2p_connector.py b/vllm/distributed/kv_transfer/kv_connector/p2p_connector.py new file mode 100644 index 0000000..f493fd9 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/p2p_connector.py @@ -0,0 +1,306 @@ +# Mainly adopted from https://github.com/FlagOpen/FlagScale/blob/44ceca57dd6f86b10163968e617497c613e47d6e/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_connector/p2p_connector.py. +# Below is the original copyright: +# SPDX-License-Identifier: Apache-2.0 +import os +import re +from typing import TYPE_CHECKING, List, Tuple, Union + +import torch + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +if os.getenv("USE_FLAGCX", "false").lower() in ("1", "true"): + from vllm.distributed.kv_transfer.kv_pipe.flagcx_p2p_nccl_pipe import P2pNcclPipe +else: + from vllm.distributed.kv_transfer.kv_pipe.p2p_nccl_pipe import P2pNcclPipe +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + +logger = init_logger(__name__) + + +class P2pConnector(KVConnectorBase): + + def __init__( + self, + rank: int, + local_rank: int, + config: VllmConfig, + ): + self.rank = rank + self.config = config.kv_transfer_config + self.tp_size = config.parallel_config.tensor_parallel_size + self.is_deepseek_mla = config.model_config.is_deepseek_mla + self.use_mla_opt = not envs.VLLM_MLA_DISABLE + + assert self.config.kv_connector == "P2pConnector" + + self.lookup_buffer_size = self.config.kv_buffer_size + + self.p2p_nccl_pipe = P2pNcclPipe( + local_rank=local_rank, + config=self.config, + hostname="", + port_offset=rank, + ) + + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + + # input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() + num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens + request_ids = list(model_input.request_ids_to_seq_ids.keys()) + start_layer = model_executable.model.start_layer + end_layer = model_executable.model.end_layer + + model_config = model_executable.model.config + num_heads = int(model_config.num_key_value_heads / self.tp_size) + hidden_size = model_config.hidden_size + num_attention_heads = model_config.num_attention_heads + + # Deepseek's MLA (Multi-head Latent Attention) uses two different + # kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0. + # When VLLM_MLA_DISABLE=0 (default), forward absorb is applied, + # resulting in a kv_cache shape of [num_blks, blk_size, 1, + # kv_lora_rank + qk_rope_head_dim]. + # When VLLM_MLA_DISABLE=1, standard FA is used instead, leading + # to a kv_cache shape of [2, num_blks, blk_size, + # num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim]. + # For more details, see vllm/attention/backends/mla/common.py. + if self.is_deepseek_mla and self.use_mla_opt: + head_size = model_config.kv_lora_rank + \ + model_config.qk_rope_head_dim + num_heads = 1 + elif self.is_deepseek_mla and not self.use_mla_opt: + head_size = model_config.qk_nope_head_dim + \ + model_config.qk_rope_head_dim + else: + head_size = getattr(model_config, "head_dim", + int(hidden_size // num_attention_heads)) + + # query_lens contains new KV caches that are added to vLLM. + # so we will send them to decode instance + # FIXME(Kuntai): This assume that all requests are prefill. + for idx, slen in enumerate(seq_lens): + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + + if start_pos >= num_prefill_tokens: + # vllm/worker/model_runner.py::_prepare_model_input_tensors: + # - input_tokens[:num_prefill_tokens] contains prefill tokens. + # - input_tokens[num_prefill_tokens:] contains decode tokens. + logger.warning("You have some decode requests while using " + "SimpleConnector. Their KVCache won't be sent.") + break + + # current_tokens = input_tokens_tensor[start_pos:end_pos] + + keys, values = [], [] + + for layer_id in range(start_layer, end_layer): + kv_cache = kv_caches[layer_id - start_layer] + ## 更改kv_cache 排布 + kv_cache = kv_cache.permute(0, 1, 3, 2, 4) + + if self.is_deepseek_mla and self.use_mla_opt: + key_cache = kv_cache.reshape(-1, num_heads, head_size) + value_cache = kv_cache.reshape(-1, num_heads, head_size) + else: + key_cache = kv_cache[0].reshape(-1, num_heads, head_size) + value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + + current_slot_mapping = slot_mapping_flat[start_pos:end_pos] + + keys.append(key_cache[current_slot_mapping].unsqueeze(0)) + values.append(value_cache[current_slot_mapping].unsqueeze(0)) + + keys = torch.cat(keys, dim=0) + values = torch.cat(values, dim=0) + + request_id = request_ids[idx] + ip, port = self.parse_request_id(request_id, True) + remote_address = ip + ":" + str(port + self.rank) + + self.p2p_nccl_pipe.send_tensor(request_id + "keys", keys, + remote_address) + self.p2p_nccl_pipe.send_tensor(request_id + "values", values, + remote_address) + self.p2p_nccl_pipe.send_tensor( + request_id + "hidden", + hidden_or_intermediate_states[start_pos:end_pos], + remote_address) + + logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) + + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor] + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + + # When bypass_model_exec is set to False, it means that at least for one + # request its corresponding KV cache or hidden state is missing. + # In this case we need to do prefilling to recompute missing KV cache + # and hidden states. + bypass_model_exec = True + + model_config = model_executable.model.config + + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens + slot_mapping = model_input.attn_metadata.slot_mapping.flatten() + request_ids = list(model_input.request_ids_to_seq_ids.keys()) + + hidden_or_intermediate_states_for_one_req = [] + + input_tokens_list = [] + num_computed_tokens_list = [] + start_pos_list = [] + + # enumerate different requests + # FIXME(Kuntai): This impl assumes that all requests are prefill. + for idx, slen in enumerate(seq_lens): + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + + if start_pos >= num_prefill_tokens: + # This can happen during inflight batching. See: + # vllm/worker/model_runner.py::_prepare_model_input_tensors: + # - input_tokens[:num_prefill_tokens] contains prefill tokens. + # - input_tokens[num_prefill_tokens:] contains decode tokens. + logger.warning("You should set --enable_chunked_prefill=False " + "and --max_num_batched_tokens " + "should be equal to --max_seq_len_to_capture") + bypass_model_exec = False + assert start_pos == num_prefill_tokens + break + + current_tokens = input_tokens_tensor[start_pos:end_pos] + num_tokens = slen + + # collecting data for rebuilding the input + input_tokens_list.append(current_tokens) + start_pos_list.append(start_pos) + + request_id = request_ids[idx] + ip, port = self.parse_request_id(request_id, False) + remote_address = ip + ":" + str(port + self.rank) + + keys = self.p2p_nccl_pipe.recv_tensor(request_id + "keys", + remote_address) + values = self.p2p_nccl_pipe.recv_tensor(request_id + "values", + remote_address) + hidden = self.p2p_nccl_pipe.recv_tensor(request_id + "hidden", + remote_address) + + num_computed_tokens = current_tokens.shape[0] + num_computed_tokens_list.append(num_computed_tokens) + + # check if both KV cache and the hidden states are received + # If not, need to redo the forwarding to compute missing states + if not all([(num_computed_tokens == num_tokens), keys is not None, + values is not None, hidden is not None]): + bypass_model_exec = False + break + + # update the end position based on how many tokens are cached. + end_pos = start_pos + num_computed_tokens + + # put received KV caches into paged memory + for i in range(model_executable.model.start_layer, + model_executable.model.end_layer): + + kv_cache = kv_caches[i - model_executable.model.start_layer] + layer = model_executable.model.layers[i] + + if self.is_deepseek_mla and self.use_mla_opt: + layer.self_attn.attn = layer.self_attn.mla_attn + k_c_normed_k_pe = keys[ + i - model_executable.model.start_layer].to( + kv_cache.device).squeeze(1) + k_c_normed = k_c_normed_k_pe[:, :model_config.kv_lora_rank] + k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank:] + ops.concat_and_cache_mla( + k_c_normed, + k_pe, + kv_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + ) + else: + key_cache, value_cache = kv_cache[0], kv_cache[1] + ops.reshape_and_cache_flash( + keys[i - model_executable.model.start_layer].to( + key_cache.device), + values[i - model_executable.model.start_layer].to( + value_cache.device), + key_cache, + value_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + layer.self_attn.attn._v_scale, + ) + + hidden_or_intermediate_states_for_one_req.append(hidden) + + if not bypass_model_exec: + # Some of the KV cache is not retrieved + # Here we will fall back to normal model forwarding + # But optionally you can adjust model_input so that you only do + # prefilling on those tokens that are missing KV caches. + logger.warning( + "[rank%d]: Failed to receive all KVs and hidden " + "states, redo model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = None + + else: + logger.debug( + "[rank%d]: Successfully received all KVs and hidden " + "states, skip model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = torch.cat( + hidden_or_intermediate_states_for_one_req, dim=0) + + return hidden_or_intermediate_states, bypass_model_exec, model_input + + @staticmethod + def parse_request_id(request_id: str, is_prefill=True) -> Tuple[str, int]: + logger.debug("parse_request_id, request_id: %s, is_prefill: %s", + request_id, is_prefill) + # Regular expression to match the string hostname and integer port + if is_prefill: + pattern = r"___decode_addr_(.*):(\d+)" + else: + pattern = r"___prefill_addr_(.*):(\d+)___" + + # Use re.search to find the pattern in the request_id + match = re.search(pattern, request_id) + if match: + # Extract the ranks + ip = match.group(1) + port = int(match.group(2)) + + logger.debug("parse_request_id, request_id: %s, ip: %s, port: %s", + request_id, ip, str(port)) + return ip, port + raise ValueError( + f"Request id {request_id} does not contain hostname and port") + + def close(self): + self.p2p_nccl_pipe.close() \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py new file mode 100644 index 0000000..19e557e --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py @@ -0,0 +1,382 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Simple KV Cache Connector for Distributed Machine Learning Inference + +The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache +producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or +MooncakePipe. + +But the logic can be extended to support other pipe and lookup buffer. +""" +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import torch + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import ( + SimpleBuffer) +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + +logger = init_logger(__name__) + + +class SimpleConnector(KVConnectorBase): + + def __init__( + self, + rank: int, + local_rank: int, + config: VllmConfig, + ): + + self.config = config.kv_transfer_config + self.tp_size = config.parallel_config.tensor_parallel_size + self.is_deepseek_mla = config.model_config.is_deepseek_mla + self.use_mla_opt = not envs.VLLM_MLA_DISABLE + + if self.config.kv_connector == "PyNcclConnector": + from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import ( + PyNcclPipe) + logger.info( + "Initializing PyNcclConfig under kv_transfer_config %s", + self.config) + elif self.config.kv_connector == "MooncakeConnector": + # Check if MOONCAKE_CONFIG_PATH is set + import os + use_mooncake_distributed_pipe = os.getenv( + 'MOONCAKE_CONFIG_PATH') is not None + + if not use_mooncake_distributed_pipe: + raise ValueError( + "To use MooncakeConnector, you need to pass the ENV: " + "'MOONCAKE_CONFIG_PATH=/path/to/mooncake_config.json'.") + else: + from vllm.distributed.kv_transfer.kv_pipe.mooncake_pipe import ( # noqa: E501 + MooncakePipe) + logger.info( + "Initializing MooncakeConfig under kv_transfer_config %s", + self.config) + + self.lookup_buffer_size = self.config.kv_buffer_size + + self.producer_buffer: Optional[SimpleBuffer] = None + self.consumer_buffer: Optional[SimpleBuffer] = None + + self.producer_data_pipe: Union[PyNcclPipe, MooncakePipe] + self.consumer_data_pipe: Union[PyNcclPipe, MooncakePipe] + self.producer_signal_pipe: Union[PyNcclPipe, MooncakePipe] + self.consumer_signal_pipe: Union[PyNcclPipe, MooncakePipe] + + # 2 pipes for every rank in the world + port_offset_base = 2 * rank + + # In disaggregated prefill, the prefill vLLM only uses send pipe + # and the decode vLLM only uses recv pipe + if self.config.is_kv_producer: + + if self.config.kv_connector == "PyNcclConnector": + self.producer_data_pipe = PyNcclPipe( + local_rank=local_rank, + config=self.config, + port_offset=port_offset_base, + ) + self.producer_signal_pipe = PyNcclPipe( + local_rank=local_rank, + config=self.config, + port_offset=port_offset_base + 1, + device="cpu", + ) + elif self.config.kv_connector == "MooncakeConnector": + self.producer_data_pipe = MooncakePipe( + local_rank=local_rank, + config=self.config, + ) + # We only need to initialize MooncakePipe once + self.producer_signal_pipe = self.producer_data_pipe + + self.producer_buffer = SimpleBuffer(self.producer_signal_pipe, + self.producer_data_pipe, + self.config.kv_buffer_size) + + else: + + # the current vLLM instance is KV consumer, so it needs to connect + # its recv pipe to the send pipe of KV producder + if self.config.kv_connector == "PyNcclConnector": + self.consumer_data_pipe = PyNcclPipe( + local_rank=local_rank, + config=self.config, + port_offset=port_offset_base, + ) + self.consumer_signal_pipe = PyNcclPipe( + local_rank=local_rank, + config=self.config, + port_offset=port_offset_base + 1, + device="cpu", + ) + elif self.config.kv_connector == "MooncakeConnector": + self.consumer_data_pipe = MooncakePipe( + local_rank=local_rank, + config=self.config, + ) + self.consumer_signal_pipe = self.consumer_data_pipe + + self.consumer_buffer = SimpleBuffer( + self.consumer_signal_pipe, + self.consumer_data_pipe, + self.config.kv_buffer_size, + ) + + def select(self, input_tokens: Optional[torch.Tensor], + roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: + + assert self.consumer_buffer is not None, "Please initialize the "\ + "consumer buffer before calling select." + return self.consumer_buffer.drop_select(input_tokens, roi) + + def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor) -> None: + + assert self.producer_buffer is not None, "Please initialize the "\ + "producer buffer before calling insert." + + self.producer_buffer.insert(input_tokens, roi, key, value, hidden) + + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() + num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens + start_layer = model_executable.model.start_layer + end_layer = model_executable.model.end_layer + + model_config = model_executable.model.config + num_heads = int(model_config.num_key_value_heads / self.tp_size) + hidden_size = model_config.hidden_size + num_attention_heads = model_config.num_attention_heads + + # Deepseek's MLA (Multi-head Latent Attention) uses two different + # kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0. + # When VLLM_MLA_DISABLE=0 (default), forward absorb is applied, + # resulting in a kv_cache shape of [num_blks, blk_size, 1, + # kv_lora_rank + qk_rope_head_dim]. + # When VLLM_MLA_DISABLE=1, standard FA is used instead, leading + # to a kv_cache shape of [2, num_blks, blk_size, + # num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim]. + # For more details, see vllm/attention/backends/mla/common.py. + if self.is_deepseek_mla and self.use_mla_opt: + head_size = model_config.kv_lora_rank + \ + model_config.qk_rope_head_dim + num_heads = 1 + elif self.is_deepseek_mla and not self.use_mla_opt: + head_size = model_config.qk_nope_head_dim + \ + model_config.qk_rope_head_dim + else: + head_size = getattr(model_config, "head_dim", + int(hidden_size // num_attention_heads)) + + # query_lens contains new KV caches that are added to vLLM. + # so we will send them to decode instance + # FIXME(Kuntai): This assume that all requests are prefill. + for idx, slen in enumerate(seq_lens): + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + + if start_pos >= num_prefill_tokens: + # vllm/worker/model_runner.py::_prepare_model_input_tensors: + # - input_tokens[:num_prefill_tokens] contains prefill tokens. + # - input_tokens[num_prefill_tokens:] contains decode tokens. + logger.warning("You have some decode requests while using " + "SimpleConnector. Their KVCache won't be sent.") + break + + current_tokens = input_tokens_tensor[start_pos:end_pos] + + keys, values = [], [] + + for layer_id in range(start_layer, end_layer): + kv_cache = kv_caches[layer_id - start_layer] + # TODO, fix this + kv_cache = kv_cache.permute(0, 1, 3, 2, 4) + if self.is_deepseek_mla and self.use_mla_opt: + key_cache = kv_cache.reshape(-1, num_heads, head_size) + value_cache = kv_cache.reshape(-1, num_heads, head_size) + else: + key_cache = kv_cache[0].reshape(-1, num_heads, head_size) + value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + + current_slot_mapping = slot_mapping_flat[start_pos:end_pos] + + keys.append(key_cache[current_slot_mapping].unsqueeze(0)) + values.append(value_cache[current_slot_mapping].unsqueeze(0)) + + keys = torch.cat(keys, dim=0) + values = torch.cat(values, dim=0) + + self.insert(current_tokens, + torch.ones_like(current_tokens, + dtype=bool), keys, values, + hidden_or_intermediate_states[start_pos:end_pos]) + + logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) + + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor] + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + + # When bypass_model_exec is set to False, it means that at least for one + # request its corresponding KV cache or hidden state is missing. + # In this case we need to do prefilling to recompute missing KV cache + # and hidden states. + bypass_model_exec = True + + model_config = model_executable.model.config + + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens + slot_mapping = model_input.attn_metadata.slot_mapping.flatten() + + hidden_or_intermediate_states_for_one_req = [] + + input_tokens_list = [] + num_computed_tokens_list = [] + start_pos_list = [] + + # enumerate different requests + # FIXME(Kuntai): This impl assumes that all requests are prefill. + for idx, slen in enumerate(seq_lens): + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + + if start_pos >= num_prefill_tokens: + # This can happen during inflight batching. See: + # vllm/worker/model_runner.py::_prepare_model_input_tensors: + # - input_tokens[:num_prefill_tokens] contains prefill tokens. + # - input_tokens[num_prefill_tokens:] contains decode tokens. + logger.warning("You should set --enable_chunked_prefill=False " + "and --max_num_batched_tokens " + "should be equal to --max_seq_len_to_capture") + bypass_model_exec = False + assert start_pos == num_prefill_tokens + break + + current_tokens = input_tokens_tensor[start_pos:end_pos] + num_tokens = slen + + # collecting data for rebuilding the input + input_tokens_list.append(current_tokens) + start_pos_list.append(start_pos) + + ret = self.select(current_tokens, + torch.ones_like(current_tokens, dtype=bool)) + if ret[0] is None: + # didn't find any match. + bypass_model_exec = False + num_computed_tokens_list.append(0) + continue + + roi: torch.Tensor = ret[1] + keys: torch.Tensor = ret[2] + values: torch.Tensor = ret[3] + hidden: torch.Tensor = ret[4] + + num_computed_tokens = roi.shape[0] + num_computed_tokens_list.append(num_computed_tokens) + + # check if both KV cache and the hidden states are received + # If not, need to redo the forwarding to compute missing states + if not all([(num_computed_tokens == num_tokens), hidden is not None + ]): + bypass_model_exec = False + + # update the end position based on how many tokens are cached. + end_pos = start_pos + num_computed_tokens + + # put received KV caches into paged memory + for i in range(model_executable.model.start_layer, + model_executable.model.end_layer): + + kv_cache = kv_caches[i - model_executable.model.start_layer] + layer = model_executable.model.layers[i] + + if self.is_deepseek_mla and self.use_mla_opt: + layer.self_attn.attn = layer.self_attn.mla_attn + k_c_normed_k_pe = keys[ + i - model_executable.model.start_layer].to( + kv_cache.device).squeeze(1) + k_c_normed = k_c_normed_k_pe[:, :model_config.kv_lora_rank] + k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank:] + ops.concat_and_cache_mla( + k_c_normed, + k_pe, + kv_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + ) + else: + key_cache, value_cache = kv_cache[0], kv_cache[1] + ops.reshape_and_cache_flash( + keys[i - model_executable.model.start_layer].to( + key_cache.device), + values[i - model_executable.model.start_layer].to( + value_cache.device), + key_cache, + value_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + layer.self_attn.attn._v_scale, + ) + + hidden_or_intermediate_states_for_one_req.append(hidden) + + if not bypass_model_exec: + # Some of the KV cache is not retrieved + # Here we will fall back to normal model forwarding + # But optionally you can adjust model_input so that you only do + # prefilling on those tokens that are missing KV caches. + logger.warning( + "[rank%d]: Failed to receive all KVs and hidden " + "states, redo model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = None + + else: + logger.debug( + "[rank%d]: Successfully received all KVs and hidden " + "states, skip model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = torch.cat( + hidden_or_intermediate_states_for_one_req, dim=0) + + return hidden_or_intermediate_states, bypass_model_exec, model_input + + def close(self): + self.producer_data_pipe.close() + self.consumer_data_pipe.close() + if self.config.kv_connector == "PyNcclConnector": + self.producer_signal_pipe.close() + self.consumer_signal_pipe.close() + elif self.config.kv_connector == "MooncakeConnector": + # MooncakePipe reuses data_pipe for signal_pipe, so we only have to + # close the data_pipe. + pass diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/__init__.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py new file mode 100644 index 0000000..bea4284 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py @@ -0,0 +1,174 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This file contains a new class `KVLookupBufferBase` that allows developers to +think of KV cache operations as inserting new KV cache entries (`insert`) +into the lookup buffer and querying existing KV caches (`drop_select`) +from the lookup buffer. + +This file also contains a new class `KVStoreBufferBase` that allows developers +to manage the KVCache buffer as a simple key-value storage buffer with basic +put/get operations. + +These classes above are abstracted behind class `KVCacheBufferBase`. +""" + +from abc import ABC, abstractmethod +from typing import List, Optional + +import torch + + +class KVCacheBufferBase(ABC): + """ + Abstract base class for a KVCache buffer. + """ + + @abstractmethod + def close(self) -> None: + """Close the buffer and release resources. + + This method is responsible for cleaning up resources related to the + KVCache buffer when it is no longer needed. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + +class KVLookupBufferBase(KVCacheBufferBase): + """ + Abstract base class for a KVCache lookup buffer. + + This class provides an abstraction for a key-value (KV) cache lookup buffer. + + The key of the lookup buffer: + - input_tokens: token IDs of the request + - roi: a binary mask on top of input_tokens. + - Purpose of roi: Since KV cache may only be available for a subset of + tokens in the input (for example, when vLLM is connected to an external + KV cache service), roi specifies the subset of tokens that the KV cache + is associated with. + - NOTE: roi can be further extended to describe which part of KV the + current process is holding (each process may only hold a part of KV + due to TP and PP). This is not implemented for now. + + The value of the lookup buffer: + - key: the key tensor in the KV cache + - value: the value tensor in the KV cache + - hidden: the final hidden state generated by model forwarding. This allows + vLLM to bypass further model forwarding by transmitting the hidden state. + """ + + @abstractmethod + def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor) -> None: + """Insert into the lookup buffer. + + The functionality is similar to the following python statement + ``` + buffer[input_tokens, roi] = [key, value, hidden] + ``` + + FIXME: in the future, we should only have two arguments, key and value, + where key is a tensor dict and value is a tensor dict. + + FIXME: we should transmit both sampler outputs and the hidden states. + + Args: + input_tokens (torch.Tensor): token IDs. + roi (torch.Tensor): A binary mask on top of the input tokens + key (torch.Tensor): The key tensor in the KV cache. + value (torch.Tensor): The value tensor in the KV cache. + hidden (torch.Tensor): The final hidden state tensor generated + during model forwarding to bypass model + forwarding. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def drop_select( + self, input_tokens: Optional[torch.Tensor], + roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: + """Select and *drop* KV cache entries from the lookup buffer. + + The functionality is similar to the following python statements + ``` + ret = buffer.pop(input_tokens, roi) + return ret + ``` + + If `input_tokens` and `roi` is `None`, it means selecting any of the + KV caches in the buffer, return, and remove it from the buffer, useful + when offloading KV cache to KV cache storage service. + + Args: + input_tokens (torch.Tensor): token IDs. + roi (torch.Tensor): A binary mask on top of the input tokens + + Returns: + List[Optional[torch.Tensor]]: A list of tensors. Can be None. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + +class KVStoreBufferBase(KVCacheBufferBase): + """ + Abstract base class for a KVCache storage buffer with key-value semantics. + This class provides a simple key-value storage buffer abstract with basic + put/get operations, which enables flexible KVCache transfer granular + control. + + The functionality is similar to a distributed key-value store, where: + - Key: A unique string identifier for the cached entry + - Value: + - Tensor to be stored and retrieved + - None (indicating deletion or empty value) + """ + + @abstractmethod + def put( + self, + key: str, + value: Optional[torch.Tensor], + ) -> None: + """Store a key-value pair in the buffer. + + Args: + key (str): Unique identifier for a tensor, this tensor could be the + key cache tensor, value cache tensor, or hidden state tensor + generated during model forwarding. + + value (Optional[torch.Tensor]): Tensor to be stored. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def get( + self, + key: str, + ) -> Optional[torch.Tensor]: + """Retrieve a value from the buffer by key. + + Args: + key (str): Unique identifier for a tensor, this tensor could be the + key cache tensor, value cache tensor, or hidden state tensor + generated during model forwarding. + + Returns: + Optional[torch.Tensor]: Stored tensor if exists, None otherwise. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py new file mode 100644 index 0000000..7fd5967 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py @@ -0,0 +1,160 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This file contains a new class `MooncakeStore` that allows developers to +think of KV cache transfer operations as putting new KV cache entries +into a remote KVStore-based lookup buffer and getting existing KV caches +from this remote lookup buffer. +""" +import json +import os +from dataclasses import dataclass +from typing import Optional + +import torch +from safetensors.torch import load as safetensors_load +from safetensors.torch import save as safetensors_save + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_lookup_buffer.base import ( + KVStoreBufferBase) +from vllm.logger import init_logger + +DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB +DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB + +logger = init_logger(__name__) + + +@dataclass +class MooncakeStoreConfig: + local_hostname: str + metadata_server: str + global_segment_size: int + local_buffer_size: int + protocol: str + device_name: str + master_server_address: str + + @staticmethod + def from_file(file_path: str) -> 'MooncakeStoreConfig': + """Load the config from a JSON file.""" + with open(file_path) as fin: + config = json.load(fin) + return MooncakeStoreConfig( + local_hostname=config.get("local_hostname"), + metadata_server=config.get("metadata_server"), + global_segment_size=config.get("global_segment_size", + DEFAULT_GLOBAL_SEGMENT_SIZE), + local_buffer_size=config.get("local_buffer_size", + DEFAULT_LOCAL_BUFFER_SIZE), + protocol=config.get("protocol", "tcp"), + device_name=config.get("device_name", ""), + master_server_address=config.get("master_server_address"), + ) + + @staticmethod + def load_from_env() -> 'MooncakeStoreConfig': + """Load config from a file specified in the environment variable.""" + config_file_path = os.getenv('MOONCAKE_CONFIG_PATH') + if config_file_path is None: + raise ValueError( + "The environment variable 'MOONCAKE_CONFIG_PATH' is not set.") + return MooncakeStoreConfig.from_file(config_file_path) + + +class MooncakeStore(KVStoreBufferBase): + + def __init__( + self, + config: VllmConfig, + ): + + try: + from mooncake_vllm_adaptor import MooncakeDistributedStore + except ImportError as e: + raise ImportError( + "Please install mooncake by following the instructions at " + "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 + "to run vLLM with MooncakeConnector.") from e + + try: + self.store = MooncakeDistributedStore() + self.config = MooncakeStoreConfig.load_from_env() + logger.info("Mooncake Configuration loaded successfully.") + + self.store.setup(self.config.local_hostname, + self.config.metadata_server, + self.config.global_segment_size, + self.config.local_buffer_size, + self.config.protocol, self.config.device_name, + self.config.master_server_address) + + except ValueError as e: + logger.error("Configuration loading failed: %s", e) + raise + except Exception as exc: + logger.error( + "An error occurred while loading the configuration: %s", exc) + raise + + def close(self): + # MooncakeDistributedStore will automatically call the destructor, so + # it is unnecessary to close it manually. + pass + + def put( + self, + key: str, + value: Optional[torch.Tensor], + ) -> None: + # A message queue needs to be introduced before making it asynchronous. + if value is not None: + self._put_impl(key, value) + + def get( + self, + key: str, + ) -> Optional[torch.Tensor]: + # A message queue needs to be introduced before making it asynchronous. + value = self._get_impl(key) + return value + + def _put_impl( + self, + key: str, + value: torch.Tensor, + ) -> None: + """Put KVCache to Mooncake Store""" + device_id = value.device.index if value.device.type == 'cuda' else -1 + device_tensor = torch.tensor(device_id, dtype=torch.int32) + value_bytes = safetensors_save({ + "tensor": value, + "device_id": device_tensor + }) + try: + self.store.put(key, value_bytes) + except TypeError as err: + logger.error("Failed to put value into Mooncake Store: %s", err) + raise TypeError("Mooncake Store Put Type Error.") from err + + def _get_impl( + self, + key: str, + ) -> Optional[torch.Tensor]: + """Get KVCache from Mooncake Store""" + try: + data = self.store.get(key) + except TypeError as err: + logger.error("Failed to get value from Mooncake Store: %s", err) + raise TypeError("Mooncake Store Get Type Error.") from err + + if data: + loaded_tensors = safetensors_load(data) + tensor = loaded_tensors["tensor"] + device_id_tensor = loaded_tensors["device_id"] + device_id = int(device_id_tensor.item()) + device = torch.device( + 'cuda', device_id) if device_id >= 0 else torch.device('cpu') + return tensor.to(device) + + return None diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py new file mode 100644 index 0000000..10bbfe1 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py @@ -0,0 +1,236 @@ +# SPDX-License-Identifier: Apache-2.0 +""" + Implements a distributed key-value (KV) cache transfer mechanism. + + Key Features: + - Distributed KV cache transmission using PyNccl pipes. + - Non-blocking `insert`, blocking `drop_select`. + - Use CPU signal pipe to avoid racing condition + - Handles buffer size constraints and provide backpressure mechanism to + stop the prefill instance when the decode instance is slow. +""" +import threading +from collections import deque +from typing import Deque, List, Optional, Union + +import torch + +from vllm.distributed.kv_transfer.kv_lookup_buffer.base import ( + KVLookupBufferBase) +from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class SimpleBuffer(KVLookupBufferBase): + + def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase, + buffer_size_thresh: float): + """ + signal_pipe: on CPU + + NOTE: on-device recv will block all threads in the process, making the + KV cache producer unable to listen to new request while transmitting + KV cache. Luckily CPU recv only blocks the current thread so we use + CPU recv to listen to new request. + + data_pipe: on device (e.g. GPU) + """ + + self.buffer: Deque[List[torch.Tensor]] = deque() + + self.buffer_size = 0 + self.buffer_size_threshold = buffer_size_thresh + self.buffer_cv = threading.Condition() + self.signal_pipe = signal_pipe + self.data_pipe = data_pipe + self.request_handling_thread: Optional[threading.Thread] = None + + self.normal_signal = torch.tensor([0], device="cpu") + self.end_signal = None + + def _matches(self, tokens_roi_sender: List[torch.Tensor], + tokens_roi_recver: List[torch.Tensor]): + + # tokens_roi_sender: tokens and roi of the producer (in the buffer) + # tokens_roi_recver: tokens and roi of the consumer (query) + + tokens_sender = tokens_roi_sender[0] + tokens_recver = tokens_roi_recver[0] + roi_sender = tokens_roi_sender[1] + roi_recver = tokens_roi_recver[1] + + if tokens_recver is None: + # consumer sends an empty request + # semantics: DROP SELECT * LIMIT 1 + # so any of the data in the buffer can be drop-selected + return True + + # Assuming that roi is a binary mask on tokens + tokens_sender = tokens_sender[roi_sender] + tokens_recver = tokens_recver[roi_recver] + + # simple common prefix matching + min_length = min(len(tokens_sender), len(tokens_recver)) + if torch.allclose(tokens_sender[:min_length], + tokens_recver[:min_length]): + return min_length + + return 0 + + def _send_tensor_and_dec_size(self, + tensor: Optional[torch.Tensor]) -> None: + + assert tensor is not None, "Use self.data_pipe.send(None) instead" + self.buffer_size -= tensor.element_size() * tensor.numel() + if tensor.dtype == torch.bool: + tensor = tensor.float() + self.data_pipe.send_tensor(tensor) + + def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]): + + if isinstance(data, torch.Tensor): + return data.element_size() * data.numel() + if not data: + # cannot perform `not data` on a tensor + # so this check needs to go after the check above + return 0 + + raise AssertionError(f"Unknown data type {type(data)}") + + def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor): + + if isinstance(input_tokens, torch.Tensor): + input_tokens = input_tokens.clone() + if isinstance(roi, torch.Tensor): + roi = roi.clone() + if isinstance(key, torch.Tensor): + key = key.clone() + if isinstance(value, torch.Tensor): + value = value.clone() + if isinstance(hidden, torch.Tensor): + hidden = hidden.clone() + + buffer_item = [input_tokens, roi, key, value, hidden] + data_size = sum([self._get_element_size(data) for data in buffer_item]) + + with self.buffer_cv: + if self.buffer_size + data_size > self.buffer_size_threshold: + # log outside the while loop to avoid this message being logged + # repeatedly. + logger.debug("KV transfer buffer is full. Handling...") + while self.buffer_size + data_size > self.buffer_size_threshold: + self.buffer_cv.wait() + + self.buffer_size += data_size + self.buffer.append(buffer_item) + self.buffer_cv.notify() + + def _is_end_signal(self, signal): + return signal is None + + def drop_select_handler(self): + + try: + + while True: + signal = self.signal_pipe.recv_tensor() + if self._is_end_signal(signal): + logger.info("Received end signal!") + break + + input_tokens = self.data_pipe.recv_tensor() + + roi = self.data_pipe.recv_tensor() + assert roi is not None, "Please provide the roi when sending "\ + "drop-select request" + roi = (roi > 0.5) + tokens_roi_recver = [input_tokens, roi] + + def is_buffer_available( + tokens_roi_recver: List[torch.Tensor], ) -> bool: + # perform input tokens and roi matching + # FIXME: this matching is O(n), ideally it should be O(1) + # but this buffer size won't (and shouldn't) be too large so + # the fix is not urgent. + for _ in range(len(self.buffer)): + if self._matches(self.buffer[0], + tokens_roi_recver) > 0: + return True + # rotate the element we just accessed to the end + self.buffer.rotate(-1) + return False + + with self.buffer_cv: + while not is_buffer_available(tokens_roi_recver): + logger.debug( + "KV transfer buffer is not available. Waiting...") + self.buffer_cv.wait() + # need to clone the tensor + # in case the tensor is freed before sending finishes + matched_item = self.buffer.popleft() + for tensor in matched_item: + self._send_tensor_and_dec_size(tensor) + self.buffer_cv.notify() + + except RuntimeError as e: + if 'Connection closed by peer' not in str(e): + raise e + + logger.debug("Closing drop_select_handler") + + def drop_select( + self, input_tokens: Optional[torch.Tensor], + roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: + + assert self.request_handling_thread is None, \ + "drop_select should be called by the KV cache consumer "\ + "(e.g. the decode vLLM instance)" + + if isinstance(input_tokens, torch.Tensor): + input_tokens = input_tokens.clone() + if isinstance(roi, torch.Tensor): + roi = roi.clone().float() + + self.signal_pipe.send_tensor(self.normal_signal) + self.data_pipe.send_tensor(input_tokens) + self.data_pipe.send_tensor(roi) + + input_tokens = self.data_pipe.recv_tensor() + roi = self.data_pipe.recv_tensor() + if roi is not None: + # convert from float tensor to bool tensor + # as PyNccl does not support sending bool tensor + roi = (roi > 0.5) + key = self.data_pipe.recv_tensor() + value = self.data_pipe.recv_tensor() + hidden = self.data_pipe.recv_tensor() + + return [input_tokens, roi, key, value, hidden] + + def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor) -> None: + + self._add_to_buffer(input_tokens, roi, key, value, hidden) + + # when calling the insert, the current process is a sender + # need to launch the request handler and start listening to request. + if self.request_handling_thread is None: + self.request_handling_thread = threading.Thread( + target=self.drop_select_handler) + self.request_handling_thread.start() + + def close(self): + + if hasattr(self, "request_handling_thread" + ) and self.request_handling_thread is not None: + self.request_handling_thread.join() + + else: + # TODO: have a explicit close signal and have a explicit way to + # check if it's requester + self.signal_pipe.send_tensor(self.end_signal) diff --git a/vllm/distributed/kv_transfer/kv_pipe/__init__.py b/vllm/distributed/kv_transfer/kv_pipe/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/distributed/kv_transfer/kv_pipe/base.py b/vllm/distributed/kv_transfer/kv_pipe/base.py new file mode 100644 index 0000000..40589fb --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_pipe/base.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This file defines an interface `KVPipeBase` +that provides an abstraction for sending and receiving tensors, or None, via +distributed communications. + +All classes instantiated from this interface are assumed to be a FIFO pipe. + +If your distributed communication platform already supports key-value lookup, +you can bypass this interface and directly start from `kv_lookup_buffer`. +""" + +from abc import ABC, abstractmethod +from typing import Optional + +import torch + + +class KVPipeBase(ABC): + """ + This class provides an interface for sending and receiving tensors, or + None, by distributed communications. + """ + + @abstractmethod + def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: + """Send a tensor, or None, via the pipe. + + Need to support sending None -- important for error handling. + + TODO: add a `key` argument so that we can use traditional + key-value database as the distributed communication mechanism behind + the pipe. + + Args: + tensor (Optional[torch.Tensor]): The tensor to be sent. Can be None. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def recv_tensor(self) -> Optional[torch.Tensor]: + """Receive a tensor (can be None) from the pipeline. + + Returns: + Optional[torch.Tensor]: The tensor received from the pipeline. Can + be None. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def close(self) -> None: + """Close the pipeline and release resources. + + This method is responsible for closing the communication pipeline + and releasing any resources associated with it. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError diff --git a/vllm/distributed/kv_transfer/kv_pipe/flagcx_p2p_nccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/flagcx_p2p_nccl_pipe.py new file mode 100644 index 0000000..3ceec79 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_pipe/flagcx_p2p_nccl_pipe.py @@ -0,0 +1,458 @@ +# Mainly adopted from https://github.com/vllm-project/vllm/blob/1ad957950ffc1552af5abda78c03d88ddb67945b/vllm/distributed/kv_transfer/kv_pipe/p2p_nccl_pipe.py. +# Below is the original copyright: +# SPDX-License-Identifier: Apache-2.0 +import os +import logging +import threading +import time +import typing +from collections import deque +from typing import Any, Deque, Dict, List, Optional + +import msgpack +import torch +import zmq +import ctypes +import sys +sys.path.append(os.getenv('FLAGCX_PATH')) +from plugin.interservice.flagcx_wrapper import ( + FLAGCXLibrary, + buffer_type, + cudaStream_t, + flagcxComm_t, + flagcxDataTypeEnum, +) +from vllm.config import KVTransferConfig +from vllm.utils import current_stream, get_ip + +logger = logging.getLogger(__name__) + + +class P2pNcclPipe: + + def __init__(self, + local_rank: int, + config: KVTransferConfig, + hostname: str = "", + port_offset: int = 0, + library_path: Optional[str] = None) -> None: + self.config = config + self.rank = port_offset + self.local_rank = local_rank + self.device = torch.device(f"cuda:{self.local_rank}") + flagcx_path = os.getenv('FLAGCX_PATH') + library_path=os.path.join(flagcx_path, "build/lib/libflagcx.so") + self.flagcx = FLAGCXLibrary(library_path) + + if not hostname: + hostname = get_ip() + port = self.config.kv_port + port_offset + if port == 0: + raise ValueError("Port cannot be 0") + self._hostname = hostname + self._port = port + + # Each card corresponds to a ZMQ address. + self.zmq_address = f"{self._hostname}:{self._port}" + + # The `http_port` must be consistent with the port of OpenAI. + self.http_address = ( + f"{self._hostname}:" + f"{self.config.kv_connector_extra_config['http_port']}") + + # If `proxy_ip` or `proxy_port` is `""`, + # then the ping thread will not be enabled. + proxy_ip = self.config.get_from_extra_config("proxy_ip", "") + proxy_port = self.config.get_from_extra_config("proxy_port", "") + if proxy_ip == "" or proxy_port == "": + self.proxy_address = "" + else: + self.proxy_address = proxy_ip + ":" + proxy_port + + self.context = zmq.Context() + self.router_socket = self.context.socket(zmq.ROUTER) + self.router_socket.bind(f"tcp://{self.zmq_address}") + + self.poller = zmq.Poller() + self.poller.register(self.router_socket, zmq.POLLIN) + + self.send_store_cv = threading.Condition() + self.send_queue_cv = threading.Condition() + self.recv_store_cv = threading.Condition() + self.comm_cv = threading.Condition() + + # The sending type includes tree mutually exclusive options: + # PUT, GET, PUT_ASYNC. + self.send_type = self.config.get_from_extra_config("send_type", "PUT") + if self.send_type == "GET": + self.send_store: Dict[str, + torch.Tensor] = {} # tensor_id: torch.Tensor + else: + # PUT or PUT_ASYNC + self.send_queue: Deque[ + List[Any]] = deque() # tensor_id: torch.Tensor + if self.send_type == "PUT_ASYNC": + self._send_thread = threading.Thread(target=self._send_async, + daemon=True) + self._send_thread.start() + + self.recv_store: Dict[str, + torch.Tensor] = {} # tensor_id: torch.Tensor + self.socks: Dict[str, Any] = {} # remote_address: client socket + self.comms: Dict[str, Any] = {} # remote_address: (ncclComm_t, rank) + + self.buffer_size = 0 + self.buffer_size_threshold = self.config.kv_buffer_size + + self._listener_thread = threading.Thread( + target=self._listen_for_requests, daemon=True) + self._listener_thread.start() + + self._ping_thread = None + if port_offset == 0 and self.proxy_address != "": + self._ping_thread = threading.Thread(target=self._ping, + daemon=True) + self._ping_thread.start() + + def _create_connect(self, remote_address: typing.Optional[str] = None): + assert remote_address is not None + if remote_address not in self.socks: + sock = self.context.socket(zmq.DEALER) + sock.setsockopt_string(zmq.IDENTITY, self.zmq_address) + sock.connect(f"tcp://{remote_address}") + self.socks[remote_address] = sock + if remote_address in self.comms: + logger.info("👋comm exists, remote_address:%s, comms:%s", + remote_address, self.comms) + return sock, self.comms[remote_address] + + unique_id = self.flagcx.flagcxGetUniqueId().contents + data = {"cmd": "NEW", "unique_id": bytes(unique_id.internal)} + sock.send(msgpack.dumps(data)) + + with torch.cuda.device(self.device): + rank = 0 + comm = self.flagcx.flagcxCommInitRank( + 2, ctypes.byref(unique_id), rank) + self.comms[remote_address] = (comm, rank) + logger.info("🤝ncclCommInitRank Success, %s👉%s, MyRank: %s", + self.zmq_address, remote_address, rank) + + return self.socks[remote_address], self.comms[remote_address] + + def send_tensor( + self, + tensor_id: str, + tensor: torch.Tensor, + remote_address: typing.Optional[str] = None, + ) -> bool: + if remote_address is None: + with self.recv_store_cv: + self.recv_store[tensor_id] = tensor + self.recv_store_cv.notify() + return True + else: + if self.send_type == "PUT": + return self._send_sync(tensor_id, tensor, remote_address) + elif self.send_type == "PUT_ASYNC": + with self.send_queue_cv: + self.send_queue.append([tensor_id, remote_address, tensor]) + self.send_queue_cv.notify() + else: # GET + with self.send_store_cv: + tensor_size = tensor.element_size() * tensor.numel() + while (self.buffer_size + tensor_size + > self.buffer_size_threshold): + oldest_tenser_id = next(iter(self.send_store)) + oldest_tenser = self.send_store.pop(oldest_tenser_id) + oldest_tenser_size = oldest_tenser.element_size( + ) * oldest_tenser.numel() + self.buffer_size -= oldest_tenser_size + logger.info( + "⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d," + " buffer_size:%d, oldest_tenser_size:%d, rank:%d", + remote_address, tensor_id, tensor_size, + self.buffer_size, oldest_tenser_size, self.rank) + + self.send_store[tensor_id] = tensor + self.buffer_size += tensor_size + logger.info( + "🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, " + "shape:%s, rank:%d, buffer_size:%d(%.2f%%)", + remote_address, tensor_id, tensor_size, tensor.shape, + self.rank, self.buffer_size, + self.buffer_size / self.buffer_size_threshold * 100) + + return True + + def recv_tensor( + self, + tensor_id: str, + remote_address: typing.Optional[str] = None, + ) -> torch.Tensor: + if self.send_type == "PUT" or self.send_type == "PUT_ASYNC": + start_time = time.time() + with self.recv_store_cv: + while tensor_id not in self.recv_store: + self.recv_store_cv.wait() + tensor = self.recv_store[tensor_id] + self.recv_store[tensor_id] = None + while len(self.recv_store) > 10000: + self.recv_store.pop(next(iter(self.recv_store))) + + duration = time.time() - start_time + if tensor is not None: + self.buffer_size -= (tensor.element_size() * tensor.numel()) + logger.info( + "🔵[PUT]Recv From %s, tensor_id:%s, shape:%s, " + "duration:%.3fms, size:%.3fGB, rank:%d", remote_address, + tensor_id, tensor.shape, duration * 1000, + tensor.element_size() * tensor.numel() / 1024**3, + self.rank) + else: + logger.warning( + "🔴[PUT]Recv From %s, tensor_id:%s, duration:%.3fms, " + "rank:%d", remote_address, tensor_id, duration * 1000, + self.rank) + return tensor + + # GET + if remote_address is None: + return None + + if remote_address not in self.socks: + self._create_connect(remote_address) + + sock = self.socks[remote_address] + comm, rank = self.comms[remote_address] + + data = {"cmd": "GET", "tensor_id": tensor_id} + sock.send(msgpack.dumps(data)) + + message = sock.recv() + data = msgpack.loads(message) + if data["ret"] != 0: + logger.warning("🔴[GET]Recv From %s, tensor_id: %s, ret: %d", + remote_address, tensor_id, data["ret"]) + return None + + tensor = torch.empty(data["shape"], + dtype=getattr(torch, data["dtype"]), + device=self.device) + + start_time = time.time() + self._recv(comm, tensor, rank ^ 1) + duration = time.time() - start_time + logger.info( + "🔵[GET]Recv From %s, tensor_id:%s, shape:%s, duration:%.3fms, " + "size:%.3fGB, rank:%d", remote_address, tensor_id, tensor.shape, + duration * 1000, + tensor.element_size() * tensor.numel() / 1024**3, self.rank) + + return tensor + + def _listen_for_requests(self): + while True: + socks = dict(self.poller.poll()) + if self.router_socket in socks: + remote_address, message = self.router_socket.recv_multipart() + data = msgpack.loads(message) + logger.debug("Received message from %s, data:%s", + remote_address.decode(), data) + if data["cmd"] == "NEW": + unique_id = self.flagcx.unique_id_from_bytes( + bytes(data["unique_id"])) + with torch.cuda.device(self.device): + rank = 1 + # comm: ncclComm_t = self.nccl.ncclCommInitRank( + # 2, unique_id, rank) + comm = self.flagcx.flagcxCommInitRank( + 2, ctypes.byref(unique_id), rank) + self.comms[remote_address.decode()] = (comm, rank) + logger.info( + "🤝ncclCommInitRank Success, %s👈%s, MyRank:%s", + self.zmq_address, remote_address.decode(), rank) + elif data["cmd"] == "PUT": + tensor_id = data["tensor_id"] + try: + tensor = torch.empty(data["shape"], + dtype=getattr( + torch, data["dtype"]), + device=self.device) + + tensor_size = tensor.element_size() * tensor.numel() + if (self.buffer_size + tensor_size + > self.buffer_size_threshold): + self.router_socket.send_multipart( + [remote_address, b"2"]) + logger.warning( + "🔴[PUT]Recv Tensor, Out Of Threshold, " + "%s👈%s, data:%s", self.zmq_address, + remote_address.decode(), data) + tensor = None + else: + self.buffer_size += tensor_size + self.router_socket.send_multipart( + [remote_address, b"0"]) + comm, rank = self.comms[remote_address.decode()] + self._recv(comm, tensor, rank ^ 1) + logger.info( + "🔵[PUT]Recv Tensor, %s👈%s, MyRank:%s, " + "data:%s, shape:%s", self.zmq_address, + remote_address.decode(), rank, data, + tensor.shape) + + except torch.cuda.OutOfMemoryError: + self.router_socket.send_multipart( + [remote_address, b"1"]) + tensor = None + logger.warning( + "🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, " + "data:%s", self.zmq_address, + remote_address.decode(), data) + + with self.recv_store_cv: + self.recv_store[tensor_id] = tensor + self.recv_store_cv.notify() + + elif data["cmd"] == "GET": + tensor_id = data["tensor_id"] + with self.send_store_cv: + tensor = self.send_store.pop(tensor_id, None) + if tensor is not None: + data = { + "ret": 0, + "shape": tensor.shape, + "dtype": + str(tensor.dtype).replace("torch.", "") + } + # LRU + self.send_store[tensor_id] = tensor + else: + data = {"ret": 1} + + self.router_socket.send_multipart( + [remote_address, msgpack.dumps(data)]) + + if data["ret"] == 0: + self._send(comm, tensor.to(self.device), rank ^ 1) + + logger.info( + "🔵[GET]Send Tensor, %s👉%s, " + "MyRank:%s, data:%s", self.zmq_address, + remote_address.decode(), rank, data) + else: + logger.warning( + "🚧Unexpected, Received message from %s, data:%s", + remote_address, data) + + # Asynchronous sending may cause conflicts between P2P NCCL and + # NCCL used in TP/PP, which can lead to deadlock issues. + def _send_async(self): + while True: + with self.send_queue_cv: + while not self.send_queue: + self.send_queue_cv.wait() + tensor_id, remote_address, tensor = self.send_queue.popleft() + if not self.send_queue: + self.send_queue_cv.notify() + self._send_sync(tensor_id, tensor, remote_address) + + def wait_for_sent(self): + if self.send_type == "PUT_ASYNC": + start_time = time.time() + with self.send_queue_cv: + while self.send_queue: + self.send_queue_cv.wait() + duration = time.time() - start_time + logger.info( + "🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue" + " to be empty, rank:%d", duration * 1000, self.rank) + + def _send_sync( + self, + tensor_id: str, + tensor: torch.Tensor, + remote_address: typing.Optional[str] = None, + ) -> bool: + if remote_address is None: + return False + if remote_address not in self.socks: + self._create_connect(remote_address) + + sock = self.socks[remote_address] + comm, rank = self.comms[remote_address] + data = { + "cmd": "PUT", + "tensor_id": tensor_id, + "shape": tensor.shape, + "dtype": str(tensor.dtype).replace("torch.", "") + } + sock.send(msgpack.dumps(data)) + + response = sock.recv() + if response != b"0": + # with self.send_queue_cv: + # self.send_queue.append([tensor_id, remote_address, tensor]) + # self.send_queue_cv.notify() + logger.warning( + "🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, " + "MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s", + self.zmq_address, remote_address, rank, data, tensor.shape, + tensor.element_size() * tensor.numel() / 1024**3, + response.decode()) + return False + + self._send(comm, tensor.to(self.device), rank ^ 1) + logger.info("🔵Send Tensor, %s👉%s, MyRank:%s, data:%s, tensor:%s", + self.zmq_address, remote_address, rank, data, tensor.shape) + return True + + def _ping(self): + sock = self.context.socket(zmq.DEALER) + sock.setsockopt_string(zmq.IDENTITY, self.zmq_address) + logger.debug("ping start, zmq_address:%s", self.zmq_address) + sock.connect(f"tcp://{self.proxy_address}") + data = { + "type": "P" if self.config.is_kv_producer else "D", + "http_address": self.http_address, + "zmq_address": self.zmq_address + } + while True: + sock.send(msgpack.dumps(data)) + time.sleep(3) + + def _send(self, comm, tensor: torch.Tensor, dst: int, stream=None): + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + + with self.comm_cv: + flagcx_stream = self.flagcx.adaptor_stream_copy(stream) + self.flagcx.flagcxSend(buffer_type(tensor.data_ptr()), tensor.numel(), + flagcxDataTypeEnum.from_torch(tensor.dtype), dst, + comm, flagcx_stream) + self.flagcx.adaptor_stream_free(flagcx_stream) + + def _recv(self, comm, tensor: torch.Tensor, src: int, stream=None): + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + + with self.comm_cv: + flagcx_stream = self.flagcx.adaptor_stream_copy(stream) + self.flagcx.flagcxRecv(buffer_type(tensor.data_ptr()), tensor.numel(), + flagcxDataTypeEnum.from_torch(tensor.dtype), src, + comm, flagcx_stream) + self.flagcx.adaptor_stream_free(flagcx_stream) + + def close(self) -> None: + self._listener_thread.join() + if self.send_type == "PUT_ASYNC": + self._send_thread.join() + if self._ping_thread is not None: + self._ping_thread.join() \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py new file mode 100644 index 0000000..ec46d40 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py @@ -0,0 +1,274 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json +import os +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from typing import Optional, Union + +import torch +import zmq +from safetensors.torch import load as safetensors_load +from safetensors.torch import save as safetensors_save + +from vllm.config import KVTransferConfig +from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase +from vllm.logger import init_logger + +logger = init_logger(__name__) +NONE_INT = -150886311 + + +@dataclass +class MooncakeTransferEngineConfig: + prefill_url: str + decode_url: str + metadata_backend: Union[str, None] + metadata_server: str + protocol: str + device_name: str + + @staticmethod + def from_file(file_path: str) -> 'MooncakeTransferEngineConfig': + """Load the config from a JSON file.""" + with open(file_path) as fin: + config = json.load(fin) + return MooncakeTransferEngineConfig( + prefill_url=config.get("prefill_url"), + decode_url=config.get("decode_url"), + metadata_backend=config.get("metadata_backend", None), + metadata_server=config.get("metadata_server"), + protocol=config.get("protocol", "tcp"), + device_name=config.get("device_name", ""), + ) + + @staticmethod + def load_from_env() -> 'MooncakeTransferEngineConfig': + """Load config from a file specified in the environment variable.""" + config_file_path = os.getenv('MOONCAKE_CONFIG_PATH') + if config_file_path is None: + raise ValueError( + "The environment variable 'MOONCAKE_CONFIG_PATH' is not set.") + return MooncakeTransferEngineConfig.from_file(config_file_path) + + +class MooncakeTransferEngine: + """Handles the transfer of data using mooncake_vllm_adaptor and ZeroMQ.""" + + def __init__(self, kv_rank: int, local_rank: int): + try: + import mooncake_vllm_adaptor as mva + except ImportError as e: + raise ImportError( + "Please install mooncake by following the instructions at " + "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 + "to run vLLM with MooncakeConnector.") from e + + self.engine = mva.mooncake_vllm_adaptor() + self.local_rank = local_rank + + try: + self.config = MooncakeTransferEngineConfig.load_from_env() + logger.info("Mooncake Configuration loaded successfully.") + except ValueError as e: + logger.error(e) + raise + except Exception as exc: + logger.error( + "An error occurred while loading the configuration: %s", exc) + raise + prefill_host, base_prefill_port = self.config.prefill_url.split(':') + decode_host, base_decode_port = self.config.decode_url.split(':') + + # Avoid ports conflict when running prefill and decode on the same node + if prefill_host == decode_host and \ + base_prefill_port == base_decode_port: + base_decode_port = str(int(base_decode_port) + 100) + + prefill_port = int(base_prefill_port) + self.local_rank + decode_port = int(base_decode_port) + self.local_rank + self.prefill_url = ':'.join([prefill_host, str(prefill_port)]) + self.decode_url = ':'.join([decode_host, str(decode_port)]) + + self.initialize(self.prefill_url if kv_rank == 0 else self.decode_url, + self.config.metadata_server, self.config.protocol, + self.config.device_name, self.config.metadata_backend) + + self.remote_url = (self.decode_url + if kv_rank == 0 else self.prefill_url) + + # Initialize ZeroMQ context and sockets + self.context = zmq.Context() # type: ignore[attr-defined] + self.sender_socket = self.context.socket(zmq.constants.PUSH) + self.receiver_socket = self.context.socket(zmq.constants.PULL) + self.sender_ack = self.context.socket(zmq.constants.PULL) + self.receiver_ack = self.context.socket(zmq.constants.PUSH) + + self.buffer_cleaner = ThreadPoolExecutor(max_workers=1) + self._setup_metadata_sockets(kv_rank, prefill_host, base_prefill_port, + decode_host, base_decode_port) + + def _setup_metadata_sockets(self, kv_rank: int, p_host: str, p_port: str, + d_host: str, d_port: str) -> None: + """Set up ZeroMQ sockets for sending and receiving data.""" + # Offsets < 8 are left for initialization in case tp and pp are enabled + p_rank_offset = int(p_port) + 8 + self.local_rank * 2 + d_rank_offset = int(d_port) + 8 + self.local_rank * 2 + if kv_rank == 0: + self.sender_socket.bind(f"tcp://*:{p_rank_offset + 1}") + self.receiver_socket.connect(f"tcp://{d_host}:{d_rank_offset + 1}") + self.sender_ack.connect(f"tcp://{d_host}:{d_rank_offset + 2}") + self.receiver_ack.bind(f"tcp://*:{p_rank_offset + 2}") + else: + self.receiver_socket.connect(f"tcp://{p_host}:{p_rank_offset + 1}") + self.sender_socket.bind(f"tcp://*:{d_rank_offset + 1}") + self.receiver_ack.bind(f"tcp://*:{d_rank_offset + 2}") + self.sender_ack.connect(f"tcp://{p_host}:{p_rank_offset + 2}") + + def initialize(self, local_hostname: str, metadata_server: str, + protocol: str, device_name: str, + metadata_backend: Union[str, None]) -> None: + """Initialize the mooncake instance.""" + if metadata_backend is None: + self.engine.initialize(local_hostname, metadata_server, protocol, + device_name) + else: + supported_backend = ["etcd", "redis"] + metadata_backend = metadata_backend.lower() + if metadata_backend not in supported_backend: + raise ValueError( + "Mooncake Configuration error. `metadata_backend`" + f" should be one of {supported_backend}.") + + self.engine.initializeExt(local_hostname, metadata_server, + protocol, device_name, metadata_backend) + + def allocate_managed_buffer(self, length: int) -> int: + """Allocate a managed buffer of the specified length.""" + ret = self.engine.allocateManagedBuffer(length) + if ret <= 0: + logger.error("Allocation Return Error") + raise Exception("Allocation Return Error") + return ret + + def free_managed_buffer(self, buffer: int, length: int) -> int: + """Free a previously allocated managed buffer.""" + return self.engine.freeManagedBuffer(buffer, length) + + def transfer_sync(self, buffer: int, peer_buffer_address: int, + length: int) -> int: + """Synchronously transfer data to the specified address.""" + ret = self.engine.transferSync(self.remote_url, buffer, + peer_buffer_address, length) + if ret < 0: + logger.error("Transfer Return Error") + raise Exception("Transfer Return Error") + return ret + + def write_bytes_to_buffer(self, buffer: int, user_data: bytes, + length: int) -> int: + """Write bytes to the allocated buffer.""" + return self.engine.writeBytesToBuffer(buffer, user_data, length) + + def read_bytes_from_buffer(self, buffer: int, length: int) -> bytes: + """Read bytes from the allocated buffer.""" + return self.engine.readBytesFromBuffer(buffer, length) + + def wait_for_ack(self, src_ptr: int, length: int) -> None: + """Asynchronously wait for ACK from the receiver.""" + ack = self.sender_ack.recv_pyobj() + if ack != b'ACK': + logger.error("Failed to receive ACK from the receiver") + + self.free_managed_buffer(src_ptr, length) + + def send_bytes(self, user_data: bytes) -> None: + """Send bytes to the remote process.""" + length = len(user_data) + src_ptr = self.allocate_managed_buffer(length) + self.write_bytes_to_buffer(src_ptr, user_data, length) + self.sender_socket.send_pyobj((src_ptr, length)) + self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length) + + def recv_bytes(self) -> bytes: + """Receive bytes from the remote process.""" + src_ptr, length = self.receiver_socket.recv_pyobj() + dst_ptr = self.allocate_managed_buffer(length) + self.transfer_sync(dst_ptr, src_ptr, length) + ret = self.read_bytes_from_buffer(dst_ptr, length) + + # Buffer cleanup + self.receiver_ack.send_pyobj(b'ACK') + self.free_managed_buffer(dst_ptr, length) + + return ret + + +class MooncakePipe(KVPipeBase): + """MooncakeTransferEngine based Pipe implementation.""" + + def __init__(self, + local_rank: int, + config: KVTransferConfig, + device: Optional[str] = None): + """Initialize the mooncake pipe and set related parameters.""" + self.config = config + self.local_rank = local_rank + self.kv_rank = self.config.kv_rank + if device is None: + self.device = self._select_device(self.config.kv_buffer_device) + else: + self.device = self._select_device(device) + + self.transfer_engine = MooncakeTransferEngine(self.kv_rank, + self.local_rank) + self.transport_thread: Optional[ThreadPoolExecutor] = None + self.none_tensor = torch.tensor([NONE_INT], device=self.device) + + def _select_device(self, device: str) -> torch.device: + """Select available device (CUDA or CPU).""" + logger.info("Selecting device: %s", device) + if device == "cuda": + return torch.device(f"cuda:{self.local_rank}") + else: + return torch.device("cpu") + + def tensor_hash(self, tensor: torch.Tensor) -> int: + """Calculate the hash value of the tensor.""" + return hash(tensor.data_ptr()) + + def _send_impl(self, tensor: torch.Tensor) -> None: + """Implement the tensor sending logic using safetensors.""" + self.transfer_engine.send_bytes(safetensors_save({"tensor": tensor})) + + def _recv_impl(self) -> torch.Tensor: + """Implement the tensor receiving logic using safetensors.""" + data = self.transfer_engine.recv_bytes() + return safetensors_load(data)["tensor"].to(self.device) + + def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: + """Send tensor to the target process.""" + if self.transport_thread is None: + self.transport_thread = ThreadPoolExecutor(max_workers=1) + tensor = tensor if tensor is not None else self.none_tensor + assert (len(tensor.shape) > 0) + self.transport_thread.submit(self._send_impl, tensor) + + def recv_tensor(self) -> Optional[torch.Tensor]: + """Receive tensor from other processes.""" + if self.transport_thread is None: + self.transport_thread = ThreadPoolExecutor(max_workers=1) + tensor = self.transport_thread.submit(self._recv_impl).result() + if tensor.numel() == 1 and tensor.item() == NONE_INT: + return None + else: + return tensor + + def close(self) -> None: + """Cleanup logic when closing the pipe.""" + self.transfer_engine.sender_socket.close() + self.transfer_engine.receiver_socket.close() + self.transfer_engine.sender_ack.close() + self.transfer_engine.receiver_ack.close() + self.transfer_engine.context.term() # Terminate the ZMQ context + logger.info("Closed the transfer engine and cleaned up resources.") diff --git a/vllm/distributed/kv_transfer/kv_pipe/p2p_nccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/p2p_nccl_pipe.py new file mode 100644 index 0000000..2451cd3 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_pipe/p2p_nccl_pipe.py @@ -0,0 +1,445 @@ +# Copied adopted from https://github.com/vllm-project/vllm/blob/1ad957950ffc1552af5abda78c03d88ddb67945b/vllm/distributed/kv_transfer/kv_pipe/p2p_nccl_pipe.py. +# Below is the original copyright: +# SPDX-License-Identifier: Apache-2.0 +import logging +import threading +import time +import typing +from collections import deque +from typing import Any, Deque, Dict, List, Optional + +import msgpack +import torch +import zmq + +from vllm.config import KVTransferConfig +from vllm.distributed.device_communicators.pynccl_wrapper import ( + NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum) +from vllm.utils import current_stream, get_ip + +logger = logging.getLogger(__name__) + + +class P2pNcclPipe: + + def __init__(self, + local_rank: int, + config: KVTransferConfig, + hostname: str = "", + port_offset: int = 0, + library_path: Optional[str] = None) -> None: + self.config = config + self.rank = port_offset + self.local_rank = local_rank + self.device = torch.device(f"cuda:{self.local_rank}") + self.nccl = NCCLLibrary(library_path) + + if not hostname: + hostname = get_ip() + port = self.config.kv_port + port_offset + if port == 0: + raise ValueError("Port cannot be 0") + self._hostname = hostname + self._port = port + + # Each card corresponds to a ZMQ address. + self.zmq_address = f"{self._hostname}:{self._port}" + + # The `http_port` must be consistent with the port of OpenAI. + self.http_address = ( + f"{self._hostname}:" + f"{self.config.kv_connector_extra_config['http_port']}") + + # If `proxy_ip` or `proxy_port` is `""`, + # then the ping thread will not be enabled. + proxy_ip = self.config.get_from_extra_config("proxy_ip", "") + proxy_port = self.config.get_from_extra_config("proxy_port", "") + if proxy_ip == "" or proxy_port == "": + self.proxy_address = "" + else: + self.proxy_address = proxy_ip + ":" + proxy_port + + self.context = zmq.Context() + self.router_socket = self.context.socket(zmq.ROUTER) + self.router_socket.bind(f"tcp://{self.zmq_address}") + + self.poller = zmq.Poller() + self.poller.register(self.router_socket, zmq.POLLIN) + + self.send_store_cv = threading.Condition() + self.send_queue_cv = threading.Condition() + self.recv_store_cv = threading.Condition() + + self.send_stream = torch.cuda.Stream() + self.recv_stream = torch.cuda.Stream() + + # The sending type includes tree mutually exclusive options: + # PUT, GET, PUT_ASYNC. + self.send_type = self.config.get_from_extra_config("send_type", "PUT") + if self.send_type == "GET": + self.send_store: Dict[str, + torch.Tensor] = {} # tensor_id: torch.Tensor + else: + # PUT or PUT_ASYNC + self.send_queue: Deque[ + List[Any]] = deque() # tensor_id: torch.Tensor + if self.send_type == "PUT_ASYNC": + self._send_thread = threading.Thread(target=self._send_async, + daemon=True) + self._send_thread.start() + + self.recv_store: Dict[str, + torch.Tensor] = {} # tensor_id: torch.Tensor + self.socks: Dict[str, Any] = {} # remote_address: client socket + self.comms: Dict[str, Any] = {} # remote_address: (ncclComm_t, rank) + + self.buffer_size = 0 + self.buffer_size_threshold = self.config.kv_buffer_size + + self._listener_thread = threading.Thread( + target=self._listen_for_requests, daemon=True) + self._listener_thread.start() + + self._ping_thread = None + if port_offset == 0 and self.proxy_address != "": + self._ping_thread = threading.Thread(target=self._ping, + daemon=True) + self._ping_thread.start() + + def _create_connect(self, remote_address: typing.Optional[str] = None): + assert remote_address is not None + if remote_address not in self.socks: + sock = self.context.socket(zmq.DEALER) + sock.setsockopt_string(zmq.IDENTITY, self.zmq_address) + sock.connect(f"tcp://{remote_address}") + self.socks[remote_address] = sock + if remote_address in self.comms: + logger.info("👋comm exists, remote_address:%s, comms:%s", + remote_address, self.comms) + return sock, self.comms[remote_address] + + unique_id = self.nccl.ncclGetUniqueId() + data = {"cmd": "NEW", "unique_id": bytes(unique_id.internal)} + sock.send(msgpack.dumps(data)) + + with torch.cuda.device(self.device): + rank = 0 + comm: ncclComm_t = self.nccl.ncclCommInitRank( + 2, unique_id, rank) + self.comms[remote_address] = (comm, rank) + logger.info("🤝ncclCommInitRank Success, %s👉%s, MyRank: %s", + self.zmq_address, remote_address, rank) + + return self.socks[remote_address], self.comms[remote_address] + + def send_tensor( + self, + tensor_id: str, + tensor: torch.Tensor, + remote_address: typing.Optional[str] = None, + ) -> bool: + if remote_address is None: + with self.recv_store_cv: + self.recv_store[tensor_id] = tensor + self.recv_store_cv.notify() + return True + else: + if self.send_type == "PUT": + return self._send_sync(tensor_id, tensor, remote_address) + elif self.send_type == "PUT_ASYNC": + with self.send_queue_cv: + self.send_queue.append([tensor_id, remote_address, tensor]) + self.send_queue_cv.notify() + else: # GET + with self.send_store_cv: + tensor_size = tensor.element_size() * tensor.numel() + while (self.buffer_size + tensor_size + > self.buffer_size_threshold): + oldest_tenser_id = next(iter(self.send_store)) + oldest_tenser = self.send_store.pop(oldest_tenser_id) + oldest_tenser_size = oldest_tenser.element_size( + ) * oldest_tenser.numel() + self.buffer_size -= oldest_tenser_size + logger.info( + "⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d," + " buffer_size:%d, oldest_tenser_size:%d, rank:%d", + remote_address, tensor_id, tensor_size, + self.buffer_size, oldest_tenser_size, self.rank) + + self.send_store[tensor_id] = tensor + self.buffer_size += tensor_size + logger.info( + "🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, " + "shape:%s, rank:%d, buffer_size:%d(%.2f%%)", + remote_address, tensor_id, tensor_size, tensor.shape, + self.rank, self.buffer_size, + self.buffer_size / self.buffer_size_threshold * 100) + + return True + + def recv_tensor( + self, + tensor_id: str, + remote_address: typing.Optional[str] = None, + ) -> torch.Tensor: + if self.send_type == "PUT" or self.send_type == "PUT_ASYNC": + start_time = time.time() + with self.recv_store_cv: + while tensor_id not in self.recv_store: + self.recv_store_cv.wait() + tensor = self.recv_store[tensor_id] + self.recv_store[tensor_id] = None + while len(self.recv_store) > 10000: + self.recv_store.pop(next(iter(self.recv_store))) + + duration = time.time() - start_time + if tensor is not None: + self.buffer_size -= (tensor.element_size() * tensor.numel()) + logger.info( + "🔵[PUT]Recv From %s, tensor_id:%s, shape:%s, " + "duration:%.3fms, size:%.3fGB, rank:%d", remote_address, + tensor_id, tensor.shape, duration * 1000, + tensor.element_size() * tensor.numel() / 1024**3, + self.rank) + else: + logger.warning( + "🔴[PUT]Recv From %s, tensor_id:%s, duration:%.3fms, " + "rank:%d", remote_address, tensor_id, duration * 1000, + self.rank) + return tensor + + # GET + if remote_address is None: + return None + + if remote_address not in self.socks: + self._create_connect(remote_address) + + sock = self.socks[remote_address] + comm, rank = self.comms[remote_address] + + data = {"cmd": "GET", "tensor_id": tensor_id} + sock.send(msgpack.dumps(data)) + + message = sock.recv() + data = msgpack.loads(message) + if data["ret"] != 0: + logger.warning("🔴[GET]Recv From %s, tensor_id: %s, ret: %d", + remote_address, tensor_id, data["ret"]) + return None + + tensor = torch.empty(data["shape"], + dtype=getattr(torch, data["dtype"]), + device=self.device) + + start_time = time.time() + self._recv(comm, tensor, rank ^ 1, self.recv_stream) + duration = time.time() - start_time + logger.info( + "🔵[GET]Recv From %s, tensor_id:%s, shape:%s, duration:%.3fms, " + "size:%.3fGB, rank:%d", remote_address, tensor_id, tensor.shape, + duration * 1000, + tensor.element_size() * tensor.numel() / 1024**3, self.rank) + + return tensor + + def _listen_for_requests(self): + while True: + socks = dict(self.poller.poll()) + if self.router_socket in socks: + remote_address, message = self.router_socket.recv_multipart() + data = msgpack.loads(message) + logger.debug("Received message from %s, data:%s", + remote_address.decode(), data) + if data["cmd"] == "NEW": + unique_id = self.nccl.unique_id_from_bytes( + bytes(data["unique_id"])) + with torch.cuda.device(self.device): + rank = 1 + comm: ncclComm_t = self.nccl.ncclCommInitRank( + 2, unique_id, rank) + self.comms[remote_address.decode()] = (comm, rank) + logger.info( + "🤝ncclCommInitRank Success, %s👈%s, MyRank:%s", + self.zmq_address, remote_address.decode(), rank) + elif data["cmd"] == "PUT": + tensor_id = data["tensor_id"] + try: + tensor = torch.empty(data["shape"], + dtype=getattr( + torch, data["dtype"]), + device=self.device) + + tensor_size = tensor.element_size() * tensor.numel() + if (self.buffer_size + tensor_size + > self.buffer_size_threshold): + self.router_socket.send_multipart( + [remote_address, b"2"]) + logger.warning( + "🔴[PUT]Recv Tensor, Out Of Threshold, " + "%s👈%s, data:%s", self.zmq_address, + remote_address.decode(), data) + tensor = None + else: + self.buffer_size += tensor_size + self.router_socket.send_multipart( + [remote_address, b"0"]) + comm, rank = self.comms[remote_address.decode()] + self._recv(comm, tensor, rank ^ 1, + self.recv_stream) + logger.info( + "🔵[PUT]Recv Tensor, %s👈%s, MyRank:%s, " + "data:%s, shape:%s", self.zmq_address, + remote_address.decode(), rank, data, + tensor.shape) + + except torch.cuda.OutOfMemoryError: + self.router_socket.send_multipart( + [remote_address, b"1"]) + tensor = None + logger.warning( + "🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, " + "data:%s", self.zmq_address, + remote_address.decode(), data) + + with self.recv_store_cv: + self.recv_store[tensor_id] = tensor + self.recv_store_cv.notify() + + elif data["cmd"] == "GET": + tensor_id = data["tensor_id"] + with self.send_store_cv: + tensor = self.send_store.pop(tensor_id, None) + if tensor is not None: + data = { + "ret": 0, + "shape": tensor.shape, + "dtype": + str(tensor.dtype).replace("torch.", "") + } + # LRU + self.send_store[tensor_id] = tensor + else: + data = {"ret": 1} + + self.router_socket.send_multipart( + [remote_address, msgpack.dumps(data)]) + + if data["ret"] == 0: + comm, rank = self.comms[remote_address.decode()] + self._send(comm, tensor.to(self.device), rank ^ 1, + self.send_stream) + + logger.info( + "🔵[GET]Send Tensor, %s👉%s, " + "MyRank:%s, data:%s", self.zmq_address, + remote_address.decode(), rank, data) + else: + logger.warning( + "🚧Unexpected, Received message from %s, data:%s", + remote_address, data) + + def _send_async(self): + while True: + with self.send_queue_cv: + while not self.send_queue: + self.send_queue_cv.wait() + tensor_id, remote_address, tensor = self.send_queue.popleft() + if not self.send_queue: + self.send_queue_cv.notify() + self._send_sync(tensor_id, tensor, remote_address) + + def wait_for_sent(self): + if self.send_type == "PUT_ASYNC": + start_time = time.time() + with self.send_queue_cv: + while self.send_queue: + self.send_queue_cv.wait() + duration = time.time() - start_time + logger.info( + "🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue" + " to be empty, rank:%d", duration * 1000, self.rank) + + def _send_sync( + self, + tensor_id: str, + tensor: torch.Tensor, + remote_address: typing.Optional[str] = None, + ) -> bool: + if remote_address is None: + return False + if remote_address not in self.socks: + self._create_connect(remote_address) + + sock = self.socks[remote_address] + comm, rank = self.comms[remote_address] + data = { + "cmd": "PUT", + "tensor_id": tensor_id, + "shape": tensor.shape, + "dtype": str(tensor.dtype).replace("torch.", "") + } + sock.send(msgpack.dumps(data)) + + response = sock.recv() + if response != b"0": + # with self.send_queue_cv: + # self.send_queue.append([tensor_id, remote_address, tensor]) + # self.send_queue_cv.notify() + logger.warning( + "🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, " + "MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s", + self.zmq_address, remote_address, rank, data, tensor.shape, + tensor.element_size() * tensor.numel() / 1024**3, + response.decode()) + return False + + self._send(comm, tensor.to(self.device), rank ^ 1, self.send_stream) + logger.info("🔵Send Tensor, %s👉%s, MyRank:%s, data:%s, tensor:%s", + self.zmq_address, remote_address, rank, data, tensor.shape) + return True + + def _ping(self): + sock = self.context.socket(zmq.DEALER) + sock.setsockopt_string(zmq.IDENTITY, self.zmq_address) + logger.debug("ping start, zmq_address:%s", self.zmq_address) + sock.connect(f"tcp://{self.proxy_address}") + data = { + "type": "P" if self.config.is_kv_producer else "D", + "http_address": self.http_address, + "zmq_address": self.zmq_address + } + while True: + sock.send(msgpack.dumps(data)) + time.sleep(3) + + def _send(self, comm, tensor: torch.Tensor, dst: int, stream=None): + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + + with torch.cuda.stream(stream): + self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), dst, + comm, cudaStream_t(stream.cuda_stream)) + + def _recv(self, comm, tensor: torch.Tensor, src: int, stream=None): + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + + with torch.cuda.stream(stream): + self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), src, + comm, cudaStream_t(stream.cuda_stream)) + + def close(self) -> None: + self._listener_thread.join() + if self.send_type == "PUT_ASYNC": + self._send_thread.join() + if self._ping_thread is not None: + self._ping_thread.join() \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py new file mode 100644 index 0000000..e8bf607 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py @@ -0,0 +1,279 @@ +# SPDX-License-Identifier: Apache-2.0 +""" + This module implements a PyNccl pipe for sending and receiving + Optional[torch.Tensor] between distributed ranks with advanced + communication features. + + Key Features: + - Supports sending and receiving tensors with metadata + - Handles both CUDA and CPU device communications + - Implements a non-blocking tensor transfer mechanism + - Manages buffer size and provides backpressure control + - Supports distributed process groups with configurable parameters +""" + +import threading +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Callable, Dict, Optional, Tuple + +import torch + +from vllm.config import KVTransferConfig +from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator +from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase +from vllm.distributed.utils import StatelessProcessGroup +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class BrokenPipeException(Exception): + + def __init__(self, message): + self.message = message + super().__init__(self.message) + + +Metadata = Dict[str, Optional[torch.Tensor]] + + +class PyNcclPipe(KVPipeBase): + + METADATA_LENGTH = 16 + MAX_TENSOR_DIMENSIONS = 14 + METADATA_DTYPE = torch.int64 + + def __init__(self, + local_rank: int, + config: KVTransferConfig, + device: Optional[str] = None, + port_offset: int = 0): + self.config = config + self.local_rank = local_rank + self.kv_rank = self.config.kv_rank + self.kv_parallel_size = self.config.kv_parallel_size + if device is None: + self.device = self._select_device(self.config.kv_buffer_device) + else: + self.device = self._select_device(device) + + # build distributed connection and send/recv implementation + store_timeout = self.config.get_from_extra_config("store_timeout", 300) + self.group = StatelessProcessGroup.create( + host=self.config.kv_ip, + port=self.config.kv_port + port_offset, + rank=self.kv_rank, + world_size=self.kv_parallel_size, + store_timeout=store_timeout, + ) + # add a barrier to make sure the connection is initiated properly + self.group.barrier() + impl = self._get_device_send_recv_impl(self.group) + self.device_send_func, self.device_recv_func = impl + # set target rank + self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size + self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size + + # transportation-related variables + self.transport_thread: Optional[ThreadPoolExecutor] = None + self.buffer_size = 0 + self.buffer_size_lock = threading.Lock() + self.buffer_size_thresh = self.config.kv_buffer_size + + def _get_device_send_recv_impl( + self, group: StatelessProcessGroup + ) -> Tuple[Callable[[torch.Tensor, int], None], Callable[ + [torch.Tensor, int], None]]: + + send: Callable[[torch.Tensor, int], None] + recv: Callable[[torch.Tensor, int], None] + if self.device.type == "cuda": + # use PyNCCL for send / recv + comm = PyNcclCommunicator(group, device=self.local_rank) + comm.disabled = False + send, recv = comm.send, comm.recv # type: ignore + else: + # This send / recv implementation here is NOT intended to transfer + # KV caches (and should NOT be repurposed to transfer KV caches). + # Currently it is only used to transmit control-plane messages + # for PyNcclBuffer. + send = group.send_obj + + def my_recv(x, src): + x[...] = group.recv_obj(src) + + recv = my_recv + + return send, recv + + def _select_device(self, device: str): + logger.info("Selecting device: %s", device) + if device == "cuda": + return torch.device(f"cuda:{self.local_rank}") + else: + return torch.device("cpu") + + def _make_metadata(self, tensor: Optional[torch.Tensor]) -> Metadata: + """ + Create the metadata as a dictionary based on the input tensor. + + Parameters: + - tensor: The input tensor or None if no tensor is provided. + + Returns: + - metadata: A dictionary with the following keys: + - "dtype": The data type of the tensor or None. + - "shape": The shape of the tensor or None. + """ + if tensor is None: + return {"dtype": None, "shape": None} + else: + return {"dtype": tensor.dtype, "shape": tensor.shape} + + def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor: + """ + Create a buffer to receive the tensor based on the provided metadata. + + Parameters: + - metadata: A dictionary with keys "dtype" and "shape", describing + the tensor's data type and shape. + + Returns: + - buffer: A tensor of the specified type and shape, allocated on + self.device. + """ + return torch.empty(metadata["shape"], + dtype=metadata["dtype"], + device=self.device) + + def _send_metadata(self, metadata: Metadata): + """ + Send the metadata dictionary to the target rank. + + Parameters: + - metadata: A dictionary with keys "dtype" and "shape". + """ + self.group.send_obj(metadata, self.target_rank_for_send) + + def _recv_metadata(self) -> Metadata: + """ + Receive the metadata dictionary from the target rank. + + Returns: + - metadata: A dictionary with keys "dtype" and "shape" describing + the tensor. + """ + return self.group.recv_obj(self.target_rank_for_recv) + + def _send_impl(self, tensor: Optional[torch.Tensor]) -> None: + """ + The actual implementation of sending the tensor and its metadata to the + target rank. + + Parameters: + - tensor: The input tensor to be sent, or None if no tensor is + being sent. + """ + metadata = self._make_metadata(tensor) + self._send_metadata(metadata) + if tensor is not None: + self.device_send_func(tensor.to(self.device), + self.target_rank_for_send) + + def _recv_impl(self) -> Optional[torch.Tensor]: + """ + The actual implementation of receiving a tensor and its metadata from + the target rank. + + Returns: + - buffer: The received tensor, or None if no tensor is received. + """ + metadata = self._recv_metadata() + if metadata["dtype"] is None: + return None + buffer = self._prepare_recv_buffer(metadata) + self.device_recv_func(buffer, self.target_rank_for_recv) + + return buffer + + def send_tensor_wrapper(self, tensor: Optional[torch.Tensor], + tensor_size: int) -> None: + """ + Wrapper for _send_impl to handle exceptions and update buffer size. + """ + try: + self._send_impl(tensor) + + with self.buffer_size_lock: + self.buffer_size -= tensor_size + except Exception as e: + logger.error("[rank%d]: Exception when trying to send %s, msg: %s", + torch.distributed.get_rank(), str(tensor), str(e)) + import traceback + traceback.print_exc() + + def block_if_full(self): + """ + Block the current thread if the buffer size is larger than the + threshold. + """ + while self.buffer_size > self.buffer_size_thresh: + logger.debug("KV cache transfer pipe is full. Waiting...") + time.sleep(0.05) + + def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: + """ + Sends a tensor and its metadata to the destination rank in a + non-blocking way. + + Parameters: + - tensor: The tensor to send, or None if no tensor is being sent. + """ + if self.transport_thread is None: + self.transport_thread = ThreadPoolExecutor(max_workers=1) + + if tensor is not None: + tensor_size = tensor.element_size() * tensor.numel() + else: + tensor_size = 0 + + self.block_if_full() + + with self.buffer_size_lock: + self.buffer_size += tensor_size + + self.transport_thread.submit(self.send_tensor_wrapper, tensor, + tensor_size) + + def recv_tensor(self) -> Optional[torch.Tensor]: + """ + Receives a tensor and its metadata from the source rank. Blocking call. + + Returns: + - tensor: The received tensor, or None if no tensor is received. + """ + if self.transport_thread is None: + self.transport_thread = ThreadPoolExecutor(max_workers=1) + + future = self.transport_thread.submit(self._recv_impl) + + try: + tensor = future.result() + except Exception as e: + logger.error("Encountering exception in KV receiving thread") + logger.error("%s", e) + logger.error("My device: %s", self.device) + import traceback + traceback.print_exc() + raise e + + return tensor + + def close(self): + """ + Close the pipe and release associated resources. + """ + if hasattr(self, + "transport_thread") and self.transport_thread is not None: + self.transport_thread.shutdown() diff --git a/vllm/distributed/kv_transfer/kv_transfer_agent.py b/vllm/distributed/kv_transfer/kv_transfer_agent.py new file mode 100644 index 0000000..1e80e0b --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_transfer_agent.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 +"""A centralized entrypoint to perform distributed KV cache transfer. + +This implementation is a shim wrapper on two APIs exposed by `kv_connector`: +1. `send_kv_caches_and_hidden_states` +2. `recv_kv_caches_and_hidden_states +""" +from typing import TYPE_CHECKING, List, Tuple, Union + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + from vllm.config import VllmConfig + +import torch + +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors + +logger = init_logger(__name__) + + +class KVTransferAgent: + """ + A class designated for distributed KV transfer + + Target use cases: + 1. Disaggregated prefill + 2. Remote KV cache storage + """ + + def __init__( + self, + rank: int, + local_rank: int, + config: "VllmConfig", + ): + + self.config = config + + if config.kv_transfer_config is None: + raise ValueError("KVTransferConfig is not set in the VllmConfig," + " cannot initialize KVConnector.") + + assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\ + "TransferAgent should only be used when kv_connector is set." + + self.connector = KVConnectorFactory.create_connector( + rank, local_rank, config) + + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + + self.connector.send_kv_caches_and_hidden_states( + model_executable, model_input, kv_caches, + hidden_or_intermediate_states) + + def close(self) -> None: + self.connector.close() + + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor] + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + + return self.connector.recv_kv_caches_and_hidden_states( + model_executable, model_input, kv_caches) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py new file mode 100644 index 0000000..e972b79 --- /dev/null +++ b/vllm/distributed/parallel_state.py @@ -0,0 +1,1228 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2023 The vLLM team. +# Adapted from +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +"""vLLM distributed state. +It takes over the control of the distributed environment from PyTorch. +The typical workflow is: + +- call `init_distributed_environment` to initialize the distributed environment. +- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to + initialize the model parallel groups. + +- any code dealing with the distributed stuff + +- call `destroy_model_parallel` to destroy the model parallel groups. +- call `destroy_distributed_environment` to destroy the distributed environment. + +If you only need to use the distributed environment without model/pipeline + parallelism, you can skip the model parallel initialization and destruction + steps. +""" +import contextlib +import gc +import pickle +import weakref +from collections import namedtuple +from contextlib import contextmanager, nullcontext +from dataclasses import dataclass +from multiprocessing import shared_memory +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, + Union) +from unittest.mock import patch + +import torch +import torch.distributed +from torch.distributed import Backend, ProcessGroup + +import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer +import vllm.envs as envs +from vllm.distributed.device_communicators.base_device_communicator import ( + DeviceCommunicatorBase) +from vllm.distributed.utils import StatelessProcessGroup +from vllm.logger import init_logger +from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname, + supports_custom_op) +import ixformer.distributed as ixfd +import vllm._custom_ops as ops + + +if TYPE_CHECKING: + from vllm.config import VllmConfig + + +@dataclass +class GraphCaptureContext: + stream: torch.cuda.Stream + + +TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) + + +def _split_tensor_dict( + tensor_dict: Dict[str, Union[torch.Tensor, Any]] +) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: + """Split the tensor dictionary into two parts: + 1. A list of (key, value) pairs. If the value is a tensor, it is replaced + by its metadata. + 2. A list of tensors. + """ + metadata_list: List[Tuple[str, Any]] = [] + tensor_list: List[torch.Tensor] = [] + for key, value in tensor_dict.items(): + if isinstance(value, torch.Tensor): + # Note: we cannot use `value.device` here, + # because it contains not only the device type but also the device + # index (e.g. "cuda:0"). We only need the device type. + # receiving side will set the device index. + device = value.device.type + metadata_list.append( + (key, TensorMetadata(device, value.dtype, value.size()))) + tensor_list.append(value) + else: + metadata_list.append((key, value)) + return metadata_list, tensor_list + + +_group_name_counter: Dict[str, int] = {} + + +def _get_unique_name(name: str) -> str: + """Get a unique name for the group. + Example: + _get_unique_name("tp") -> "tp:0" + _get_unique_name("tp") -> "tp:1" + """ + if name not in _group_name_counter: + _group_name_counter[name] = 0 + newname = f"{name}:{_group_name_counter[name]}" + _group_name_counter[name] += 1 + return newname + + +_groups: Dict[str, Callable[[], Optional["GroupCoordinator"]]] = {} + + +def _register_group(group: "GroupCoordinator") -> None: + _groups[group.unique_name] = weakref.ref(group) + + +def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + return group._all_reduce_out_place(tensor) + + +def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + return torch.empty_like(tensor) + + +if supports_custom_op(): + from vllm.platforms import current_platform + direct_register_custom_op( + op_name="all_reduce", + op_func=all_reduce, + mutates_args=[], + fake_impl=all_reduce_fake, + dispatch_key=current_platform.dispatch_key, + ) + + +class GroupCoordinator: + """ + PyTorch ProcessGroup wrapper for a group of processes. + PyTorch ProcessGroup is bound to one specific communication backend, + e.g. NCCL, Gloo, MPI, etc. + GroupCoordinator takes charge of all the communication operations among + the processes in the group. It manages both CPU and device + communication. + """ + + # available attributes: + rank: int # global rank + ranks: List[int] # global ranks in the group + world_size: int # size of the group + # difference between `local_rank` and `rank_in_group`: + # if we have a group of size 4 across two nodes: + # Process | Node | Rank | Local Rank | Rank in Group + # 0 | 0 | 0 | 0 | 0 + # 1 | 0 | 1 | 1 | 1 + # 2 | 1 | 2 | 0 | 2 + # 3 | 1 | 3 | 1 | 3 + local_rank: int # local rank used to assign devices + rank_in_group: int # rank inside the group + cpu_group: ProcessGroup # group for CPU communication + device_group: ProcessGroup # group for device communication + use_device_communicator: bool # whether to use device communicator + device_communicator: DeviceCommunicatorBase # device communicator + mq_broadcaster: Optional[Any] # shared memory broadcaster + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + use_device_communicator: bool, + use_message_queue_broadcaster: bool = False, + group_name: Optional[str] = None, + ): + group_name = group_name or "anonymous" + self.unique_name = _get_unique_name(group_name) + _register_group(self) + + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + self.cpu_group = None + + for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + cpu_group = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + ixfd.init_comm_with_store(device_group) + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + self.cpu_group = cpu_group + + assert self.cpu_group is not None + assert self.device_group is not None + + from vllm.platforms import current_platform + + # TODO: fix it for other platforms + if current_platform.is_cuda_alike(): + self.device = torch.device(f"cuda:{local_rank}") + else: + self.device = torch.device("cpu") + + self.use_device_communicator = use_device_communicator + + self.device_communicator: DeviceCommunicatorBase = None # type: ignore + if use_device_communicator and self.world_size > 1: + device_comm_cls = resolve_obj_by_qualname( + current_platform.get_device_communicator_cls()) + self.device_communicator = device_comm_cls( + cpu_group=self.cpu_group, + device=self.device, + device_group=self.device_group, + unique_name=self.unique_name, + ) + + from vllm.distributed.device_communicators.shm_broadcast import ( + MessageQueue) + self.mq_broadcaster: Optional[MessageQueue] = None + if use_message_queue_broadcaster and self.world_size > 1: + self.mq_broadcaster = MessageQueue.create_from_process_group( + self.cpu_group, 1 << 22, 6) + + from vllm.platforms import current_platform + self.use_custom_op_call = False + + @property + def first_rank(self): + """Return the global rank of the first process in the group""" + return self.ranks[0] + + @property + def last_rank(self): + """Return the global rank of the last process in the group""" + return self.ranks[-1] + + @property + def is_first_rank(self): + """Return whether the caller is the first process in the group""" + return self.rank == self.first_rank + + @property + def is_last_rank(self): + """Return whether the caller is the last process in the group""" + return self.rank == self.last_rank + + @property + def next_rank(self): + """Return the global rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group + 1) % world_size] + + @property + def prev_rank(self): + """Return the global rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group - 1) % world_size] + + @contextmanager + def graph_capture( + self, graph_capture_context: Optional[GraphCaptureContext] = None): + if graph_capture_context is None: + stream = torch.cuda.Stream() + graph_capture_context = GraphCaptureContext(stream) + else: + stream = graph_capture_context.stream + + # only cuda uses this function, + # so we don't abstract it into the base class + maybe_ca_context = nullcontext() + from vllm.distributed.device_communicators.base_device_communicator import DeviceCommunicatorBase + if self.device_communicator is not None: + assert isinstance(self.device_communicator, DeviceCommunicatorBase) + ca_comm = self.device_communicator.ca_comm + if ca_comm is not None: + maybe_ca_context = ca_comm.capture() # type: ignore + + # ensure all initialization operations complete before attempting to + # capture the graph on another stream + curr_stream = torch.cuda.current_stream() + if curr_stream != stream: + stream.wait_stream(curr_stream) + + with torch.cuda.stream(stream), maybe_ca_context: + yield graph_capture_context + + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + """ + User-facing all-reduce function before we actually call the + all-reduce operation. + + We need this because Dynamo does not support passing an arbitrary + object (`self` in this case) to a custom op. We need to pass the + group name as a string, and then look up the group coordinator from + the group name, dispatch the all-reduce operation to the group + coordinator. + + In addition, PyTorch custom ops do not support mutation or returning + a new tensor in the same op. So we always make the all-reduce operation + out-of-place. + """ + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + + if self.use_custom_op_call: + return torch.ops.vllm.all_reduce(input_, + group_name=self.unique_name) + else: + return self._all_reduce_out_place(input_) + + def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: + return self.device_communicator.all_reduce(input_) + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + + return self.device_communicator.all_gather(input_, dim) + + def gather(self, + input_: torch.Tensor, + dst: int = 0, + dim: int = -1) -> Optional[torch.Tensor]: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + return self.device_communicator.gather(input_, dst, dim) + + def broadcast(self, input_: torch.Tensor, src: int = 0): + """Broadcast the input tensor. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + # Broadcast. + if self.device_communicator.use_vllm_comm: + ops.broadcast(input_, + src=self.ranks[src], + group=self.device_group) + else: + torch.distributed.broadcast(input_, + src=self.ranks[src], + group=self.device_group) + return input_ + + def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): + """Broadcast the input object. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj + if self.mq_broadcaster is not None: + assert src == 0, "Message queue broadcaster only supports src=0" + return self.mq_broadcaster.broadcast_object(obj) + if self.rank_in_group == src: + torch.distributed.broadcast_object_list([obj], + src=self.ranks[src], + group=self.cpu_group) + return obj + else: + recv = [None] + torch.distributed.broadcast_object_list(recv, + src=self.ranks[src], + group=self.cpu_group) + return recv[0] + + def broadcast_object_list(self, + obj_list: List[Any], + src: int = 0, + group: Optional[ProcessGroup] = None): + """Broadcast the input object list. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj_list + # Broadcast. + torch.distributed.broadcast_object_list(obj_list, + src=self.ranks[src], + group=self.device_group) + return obj_list + + def send_object(self, obj: Any, dst: int) -> None: + """Send the input object list to the destination rank.""" + """NOTE: `dst` is the local rank of the destination rank.""" + + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + assert dst != self.rank_in_group, ( + "Invalid destination rank. Destination rank is the same " + "as the current rank.") + + # Serialize object to tensor and get the size as well + object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) + + size_tensor = torch.tensor([object_tensor.numel()], + dtype=torch.long, + device="cpu") + + # Send object size + + torch.distributed.send(size_tensor, + dst=self.ranks[dst], + group=self.cpu_group) + + # Send object + torch.distributed.send(object_tensor, + dst=self.ranks[dst], + group=self.cpu_group) + + return None + + def recv_object(self, src: int) -> Any: + """Receive the input object list from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + + assert src < self.world_size, f"Invalid src rank ({src})" + + assert src != self.rank_in_group, ( + "Invalid source rank. Source rank is the same as the current rank." + ) + + size_tensor = torch.empty(1, dtype=torch.long, device="cpu") + + # Receive object size + rank_size = torch.distributed.recv(size_tensor, + src=self.ranks[src], + group=self.cpu_group) + + # Tensor to receive serialized objects into. + object_tensor = torch.empty( # type: ignore[call-overload] + size_tensor.item(), # type: ignore[arg-type] + dtype=torch.uint8, + device="cpu") + + rank_object = torch.distributed.recv(object_tensor, + src=self.ranks[src], + group=self.cpu_group) + + assert rank_object == rank_size, ( + "Received object sender rank does not match the size sender rank.") + + obj = pickle.loads(object_tensor.numpy().tobytes()) + + return obj + + def broadcast_tensor_dict( + self, + tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None, + src: int = 0, + group: Optional[ProcessGroup] = None, + metadata_group: Optional[ProcessGroup] = None + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Broadcast the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if (not torch.distributed.is_initialized() or self.world_size == 1): + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + assert src < self.world_size, f"Invalid src rank ({src})" + + rank_in_group = self.rank_in_group + if rank_in_group == src: + metadata_list: List[Tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, + dict), (f"Expecting a dictionary, got {type(tensor_dict)}") + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.broadcast_object(metadata_list, src=src) + async_handles = [] + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast(tensor, + src=self.ranks[src], + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + if self.device_communicator.use_vllm_comm: + handle = ops.broadcast(tensor, + src=self.ranks[src], + group=group, + async_op=True) + else: + handle = torch.distributed.broadcast(tensor, + src=self.ranks[src], + group=group, + async_op=True) + async_handles.append(handle) + for async_handle in async_handles: + async_handle.wait() + + else: + metadata_list = self.broadcast_object(None, src=src) + tensor_dict = {} + async_handles = [] + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, + dtype=value.dtype, + device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + tensor_dict[key] = tensor + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, + src=self.ranks[src], + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + if self.device_communicator.use_vllm_comm: + handle = ops.broadcast(tensor, + src=self.ranks[src], + group=group, + async_op=True) + else: + handle = torch.distributed.broadcast( + tensor, + src=self.ranks[src], + group=group, + async_op=True) + async_handles.append(handle) + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + for async_handle in async_handles: + async_handle.wait() + return tensor_dict + + def send_tensor_dict( + self, + tensor_dict: Dict[str, Union[torch.Tensor, Any]], + dst: Optional[int] = None, + all_gather_group: Optional["GroupCoordinator"] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Send the input tensor dictionary. + NOTE: `dst` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + all_gather_size = (1 if all_gather_group is None else + all_gather_group.world_size) + all_gather_rank = (0 if all_gather_group is None else + all_gather_group.rank_in_group) + + group = self.device_group + metadata_group = self.cpu_group + + if dst is None: + dst = (self.rank_in_group + 1) % self.world_size + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + metadata_list: List[Tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, + dict), f"Expecting a dictionary, got {type(tensor_dict)}" + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `send_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.send_object(metadata_list, dst=dst) + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip sending empty tensors. + continue + + # send-allgather: send only a slice, then do allgather. + if (all_gather_group is not None + and tensor.numel() % all_gather_size == 0): + tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] + + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.send(tensor, + dst=self.ranks[dst], + group=metadata_group) + else: + # use group for GPU tensors + if self.device_communicator.use_vllm_comm: + ixfd.send(tensor, + dst=self.ranks[dst], + group=group) + else: + # use group for GPU tensors + torch.distributed.send(tensor, + dst=self.ranks[dst], + group=group) + return None + + def recv_tensor_dict( + self, + src: Optional[int] = None, + all_gather_group: Optional["GroupCoordinator"] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Recv the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return None + + all_gather_size = (1 if all_gather_group is None else + all_gather_group.world_size) + all_gather_rank = (0 if all_gather_group is None else + all_gather_group.rank_in_group) + + group = self.device_group + metadata_group = self.cpu_group + + if src is None: + src = (self.rank_in_group - 1) % self.world_size + assert src < self.world_size, f"Invalid src rank ({src})" + + recv_metadata_list = self.recv_object(src=src) + tensor_dict: Dict[str, Any] = {} + for key, value in recv_metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, + dtype=value.dtype, + device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + tensor_dict[key] = tensor + continue + + # send-allgather: send only a slice, then do allgather. + use_all_gather = (all_gather_group is not None + and tensor.numel() % all_gather_size == 0) + + if use_all_gather: + orig_shape = tensor.shape + tensor = tensor.reshape(all_gather_size, + -1)[all_gather_rank] + + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.recv(tensor, + src=self.ranks[src], + group=metadata_group) + else: + # use group for GPU tensors + if self.device_communicator.use_vllm_comm: + ixfd.recv(tensor, + src=self.ranks[src], + group=group) + else: + # use group for GPU tensors + torch.distributed.recv(tensor, + src=self.ranks[src], + group=group) + if use_all_gather: + # do the allgather + tensor = all_gather_group.all_gather( # type: ignore + tensor, dim=0) + tensor = tensor.reshape(orig_shape) + + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + return tensor_dict + + def barrier(self): + """Barrier synchronization among the group. + NOTE: don't use `device_group` here! `barrier` in NCCL is + terrible because it is internally a broadcast operation with + secretly created GPU tensors. It is easy to mess up the current + device. Use the CPU group instead. + """ + torch.distributed.barrier(group=self.cpu_group) + + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the local rank of the destination rank.""" + self.device_communicator.send(tensor, dst) + + def recv(self, + size: torch.Size, + dtype: torch.dtype, + src: Optional[int] = None) -> torch.Tensor: + """Receives a tensor from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + return self.device_communicator.recv(size, dtype, src) + + def destroy(self): + if self.device_group is not None: + if self.device_communicator.use_vllm_comm: + ixfd.destroy_process_group(self.device_group) + else: + torch.distributed.destroy_process_group(self.device_group) + self.device_group = None + if self.cpu_group is not None: + torch.distributed.destroy_process_group(self.cpu_group) + self.cpu_group = None + if self.device_communicator is not None: + self.device_communicator.destroy() + if self.mq_broadcaster is not None: + self.mq_broadcaster = None + + +_WORLD: Optional[GroupCoordinator] = None + + +def get_world_group() -> GroupCoordinator: + assert _WORLD is not None, ("world group is not initialized") + return _WORLD + + +def init_world_group(ranks: List[int], local_rank: int, + backend: str) -> GroupCoordinator: + return GroupCoordinator( + group_ranks=[ranks], + local_rank=local_rank, + torch_distributed_backend=backend, + use_device_communicator=False, + group_name="world", + ) + + +def init_model_parallel_group( + group_ranks: List[List[int]], + local_rank: int, + backend: str, + use_message_queue_broadcaster: bool = False, + group_name: Optional[str] = None, +) -> GroupCoordinator: + + return GroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + use_device_communicator=True, + use_message_queue_broadcaster=use_message_queue_broadcaster, + group_name=group_name, + ) + + +_TP: Optional[GroupCoordinator] = None + + +def get_tp_group() -> GroupCoordinator: + assert _TP is not None, ("tensor model parallel group is not initialized") + return _TP + + +# kept for backward compatibility +get_tensor_model_parallel_group = get_tp_group + +_PP: Optional[GroupCoordinator] = None + +_DP: Optional[GroupCoordinator] = None + + +def get_dp_group() -> GroupCoordinator: + assert _DP is not None, ("data parallel group is not initialized") + return _DP + + +def get_pp_group() -> GroupCoordinator: + assert _PP is not None, ( + "pipeline model parallel group is not initialized") + return _PP + + +# kept for backward compatibility +get_pipeline_model_parallel_group = get_pp_group + +_KV_TRANSFER: Optional[kv_transfer.KVTransferAgent] = None + + +def get_kv_transfer_group() -> kv_transfer.KVTransferAgent: + assert _KV_TRANSFER is not None, ( + "disaggregated KV cache transfer parallel group is not initialized") + return _KV_TRANSFER + + +@contextmanager +def graph_capture(device: torch.device): + """ + `graph_capture` is a context manager which should surround the code that + is capturing the CUDA graph. Its main purpose is to ensure that the + some operations will be run after the graph is captured, before the graph + is replayed. It returns a `GraphCaptureContext` object which contains the + necessary data for the graph capture. Currently, it only contains the + stream that the graph capture is running on. This stream is set to the + current CUDA stream when the context manager is entered and reset to the + default stream when the context manager is exited. This is to ensure that + the graph capture is running on a separate stream from the default stream, + in order to explicitly distinguish the kernels to capture + from other kernels possibly launched on background in the default stream. + """ + context = GraphCaptureContext(torch.cuda.Stream(device=device)) + with get_tp_group().graph_capture(context), get_pp_group().graph_capture( + context): + yield context + + +logger = init_logger(__name__) + +_ENABLE_CUSTOM_ALL_REDUCE = True + + +def set_custom_all_reduce(enable: bool): + global _ENABLE_CUSTOM_ALL_REDUCE + _ENABLE_CUSTOM_ALL_REDUCE = enable + + +def init_distributed_environment( + world_size: int = -1, + rank: int = -1, + distributed_init_method: str = "env://", + local_rank: int = -1, + backend: str = "nccl", +): + logger.debug( + "world_size=%d rank=%d local_rank=%d " + "distributed_init_method=%s backend=%s", world_size, rank, local_rank, + distributed_init_method, backend) + from vllm.config import get_current_vllm_config + config = get_current_vllm_config() + if config is not None and config.parallel_config.data_parallel_size > 1: + parallel_config = config.parallel_config + # adjust to take into account data parallelism + # offset the rank by the data parallel rank + rank = parallel_config.data_parallel_rank * world_size + rank + # adjust the world size to take into account data parallelism + world_size = parallel_config.world_size_across_dp + ip = parallel_config.data_parallel_master_ip + port = parallel_config.get_next_dp_init_port() + distributed_init_method = f"tcp://{ip}:{port}" # noqa + logger.info( + "Adjusting world_size=%d rank=%d distributed_init_method=%s for DP", + world_size, rank, distributed_init_method) + if not torch.distributed.is_initialized(): + assert distributed_init_method is not None, ( + "distributed_init_method must be provided when initializing " + "distributed environment") + # this backend is used for WORLD + torch.distributed.init_process_group( + backend=backend, + init_method=distributed_init_method, + world_size=world_size, + rank=rank) + # set the local rank + # local_rank is not available in torch ProcessGroup, + # see https://github.com/pytorch/pytorch/issues/122816 + if local_rank == -1: + # local rank not set, this usually happens in single-node + # setting, where we can use rank as local rank + if distributed_init_method == "env://": + local_rank = envs.LOCAL_RANK + else: + local_rank = rank + global _WORLD + if _WORLD is None: + ranks = list(range(torch.distributed.get_world_size())) + _WORLD = init_world_group(ranks, local_rank, backend) + else: + assert _WORLD.world_size == torch.distributed.get_world_size(), ( + "world group already initialized with a different world size") + + +def initialize_model_parallel( + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + backend: Optional[str] = None, +) -> None: + """ + Initialize model parallel groups. + + Arguments: + tensor_model_parallel_size: number of GPUs used for tensor model + parallelism. + pipeline_model_parallel_size: number of GPUs used for pipeline model + parallelism. + + Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we + use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize + the model pipeline. The present function will + create 4 tensor model-parallel groups and 2 pipeline model-parallel groups: + 4 tensor model-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7] + 2 pipeline model-parallel groups: + [g0, g2, g4, g6], [g1, g3, g5, g7] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + """ + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + backend = backend or torch.distributed.get_backend( + get_world_group().device_group) + + data_parallel_size = 1 + from vllm.config import get_current_vllm_config + config = get_current_vllm_config() + if config is not None: + data_parallel_size = config.parallel_config.data_parallel_size + + # the layout order is: ExternalDP x DP x PP x TP + # ExternalDP is the data parallel group that is not part of the model, + # every dp rank can generate independently (in verl integration). + # DP is the data parallel group that is part of the model, + # all the ranks in the same DP group should generate simultaneously, + # i.e. the `generate` call in the same DP group should be called together, + # otherwise it will cause deadlock. + # to get group_ranks for each dimension, transpose that dimension to the + # last dimension, then reshape to 2D, then unbind the last dimension + all_ranks = torch.arange(world_size).reshape( + -1, data_parallel_size, pipeline_model_parallel_size, + tensor_model_parallel_size) # noqa + + # Build the tensor model-parallel groups. + global _TP + assert _TP is None, ("tensor model parallel group is already initialized") + group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + + # message queue broadcaster is only used in tensor model parallel group + _TP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + use_message_queue_broadcaster=True, + group_name="tp") + + # Build the pipeline model-parallel groups. + global _PP + assert _PP is None, ( + "pipeline model parallel group is already initialized") + group_ranks = all_ranks.transpose(2, 3).reshape( + -1, pipeline_model_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + _PP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name="pp") + + global _DP + assert _DP is None, ("data parallel group is already initialized") + group_ranks = all_ranks.transpose(1, + 3).reshape(-1, + data_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + _DP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name="dp") + + logger.info( + "rank %s in world size %s is assigned as " + "DP rank %s, PP rank %s, TP rank %s", rank, world_size, + _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group) + + +def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: + """ + Initialize KV cache transfer parallel group. + """ + + global _KV_TRANSFER + + if vllm_config.kv_transfer_config is None: + return + + if all([ + vllm_config.kv_transfer_config.is_kv_transfer_instance, + _KV_TRANSFER is None + ]): + _KV_TRANSFER = kv_transfer.KVTransferAgent( + rank=get_world_group().rank, + local_rank=get_world_group().local_rank, + config=vllm_config) + + +def ensure_model_parallel_initialized( + tensor_model_parallel_size: int, + pipeline_model_parallel_size: int, + backend: Optional[str] = None, +) -> None: + """Helper to initialize model parallel groups if they are not initialized, + or ensure tensor-parallel and pipeline-parallel sizes are equal to expected + values if the model parallel groups are initialized. + """ + backend = backend or torch.distributed.get_backend( + get_world_group().device_group) + if not model_parallel_is_initialized(): + initialize_model_parallel(tensor_model_parallel_size, + pipeline_model_parallel_size, backend) + return + + assert ( + get_tensor_model_parallel_world_size() == tensor_model_parallel_size + ), ("tensor parallel group already initialized, but of unexpected size: " + f"{get_tensor_model_parallel_world_size()=} vs. " + f"{tensor_model_parallel_size=}") + pp_world_size = get_pp_group().world_size + assert (pp_world_size == pipeline_model_parallel_size), ( + "pipeline parallel group already initialized, but of unexpected size: " + f"{pp_world_size=} vs. " + f"{pipeline_model_parallel_size=}") + + +def model_parallel_is_initialized(): + """Check if tensor and pipeline parallel groups are initialized.""" + return (_TP is not None and _PP is not None) + + +_TP_STATE_PATCHED = False +_PP_STATE_PATCHED = False + +@contextmanager +def patch_model_parallel_group(tp_group: GroupCoordinator, + pp_group: GroupCoordinator): + """Patch the tp and pp group temporarily until this function ends. + + This method is for draft workers of speculative decoding to run draft model + with different tp degree from that of target model workers. + + Args: + tp_group (GroupCoordinator): the tp group coordinator + pp_group (GroupCoordinator): the pp group coordinator + """ + global _TP_STATE_PATCHED + assert not _TP_STATE_PATCHED, "Should not call when it's already patched" + + global _PP_STATE_PATCHED + assert not _PP_STATE_PATCHED, "Should not call when it's already patched" + + _TP_STATE_PATCHED = True + old_tp_group = get_tp_group() + global _TP + _TP = tp_group + + _PP_STATE_PATCHED = True + old_pp_group = get_pp_group() + global _PP + _PP = pp_group + try: + yield + finally: + # restore the original state + _TP_STATE_PATCHED = False + _TP = old_tp_group + _PP_STATE_PATCHED = False + _PP = old_pp_group + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + return get_tp_group().world_size + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + return get_tp_group().rank_in_group + + +def destroy_model_parallel(): + """Set the groups to none and destroy them.""" + global _TP + if _TP: + _TP.destroy() + _TP = None + + global _PP + if _PP: + _PP.destroy() + _PP = None + + global _DP + if _DP: + _DP.destroy() + _DP = None + + +def destroy_distributed_environment(): + global _WORLD + if _WORLD: + _WORLD.destroy() + _WORLD = None + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + +def cleanup_dist_env_and_memory(shutdown_ray: bool = False): + destroy_model_parallel() + destroy_distributed_environment() + with contextlib.suppress(AssertionError): + torch.distributed.destroy_process_group() + if shutdown_ray: + import ray # Lazy import Ray + ray.shutdown() + gc.collect() + from vllm.platforms import current_platform + if not current_platform.is_cpu(): + torch.cuda.empty_cache() + try: + torch._C._host_emptyCache() + except AttributeError: + logger.warning( + "torch._C._host_emptyCache() only available in Pytorch >=2.5") + + +def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup], + source_rank: int = 0) -> List[bool]: + """ + This is a collective operation that returns if each rank is in the same node + as the source rank. It tests if processes are attached to the same + memory system (shared access to shared memory). + """ + if isinstance(pg, ProcessGroup): + assert torch.distributed.get_backend( + pg) != torch.distributed.Backend.NCCL, ( + "in_the_same_node_as should be tested with a non-NCCL group.") + # local rank inside the group + rank = torch.distributed.get_rank(group=pg) + world_size = torch.distributed.get_world_size(group=pg) + + # global ranks of the processes in the group + ranks = torch.distributed.get_process_group_ranks(pg) + else: + rank = pg.rank + world_size = pg.world_size + ranks = list(range(world_size)) + + # local tensor in each process to store the result + is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32) + + magic_message = b"magic_message" + shm = None + + try: + with contextlib.suppress(OSError): + if rank == source_rank: + # create a shared memory segment + shm = shared_memory.SharedMemory(create=True, size=128) + shm.buf[:len(magic_message)] = magic_message + if isinstance(pg, ProcessGroup): + torch.distributed.broadcast_object_list( + [shm.name], src=ranks[source_rank], group=pg) + else: + pg.broadcast_obj(shm.name, src=source_rank) + is_in_the_same_node[rank] = 1 + else: + # try to open the shared memory segment + if isinstance(pg, ProcessGroup): + recv = [None] + torch.distributed.broadcast_object_list( + recv, src=ranks[source_rank], group=pg) + name = recv[0] + else: + name = pg.broadcast_obj(None, src=source_rank) + # fix to https://stackoverflow.com/q/62748654/9191338 + # Python incorrectly tracks shared memory even if it is not + # created by the process. The following patch is a workaround. + with patch("multiprocessing.resource_tracker.register", + lambda *args, **kwargs: None): + shm = shared_memory.SharedMemory(name=name) + if shm.buf[:len(magic_message)] == magic_message: + is_in_the_same_node[rank] = 1 + except Exception as e: + logger.error("Error ignored in is_in_the_same_node: %s", e) + finally: + if shm: + shm.close() + + if isinstance(pg, ProcessGroup): + torch.distributed.barrier(group=pg) + else: + pg.barrier() + + # clean up the shared memory segment + with contextlib.suppress(OSError): + if rank == source_rank and shm: + shm.unlink() + + if isinstance(pg, ProcessGroup): + torch.distributed.all_reduce(is_in_the_same_node, group=pg) + aggregated_data = is_in_the_same_node + else: + aggregated_data = torch.zeros_like(is_in_the_same_node) + for i in range(world_size): + rank_data = pg.broadcast_obj(is_in_the_same_node, src=i) + aggregated_data += rank_data + + return [x == 1 for x in aggregated_data.tolist()] \ No newline at end of file diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py new file mode 100644 index 0000000..32c8467 --- /dev/null +++ b/vllm/distributed/utils.py @@ -0,0 +1,347 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2023 The vLLM team. +# Adapted from +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +import dataclasses +import datetime +import pickle +import time +from collections import deque +from typing import Any, Deque, Dict, Optional, Sequence, Tuple + +import torch +from torch.distributed import ProcessGroup, TCPStore +from torch.distributed.distributed_c10d import (Backend, PrefixStore, + _get_default_timeout, + _unregister_process_group, + is_nccl_available) +from torch.distributed.rendezvous import rendezvous + +import vllm.envs as envs +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def ensure_divisibility(numerator, denominator): + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, "{} is not divisible by {}".format( + numerator, denominator) + + +def divide(numerator, denominator): + """Ensure that numerator is divisible by the denominator and return + the division value.""" + ensure_divisibility(numerator, denominator) + return numerator // denominator + + +def split_tensor_along_last_dim( + tensor: torch.Tensor, + num_partitions: int, + contiguous_split_chunks: bool = False, +) -> Sequence[torch.Tensor]: + """ Split a tensor along its last dimension. + + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + + Returns: + A list of Tensors + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = divide(tensor.size()[last_dim], num_partitions) + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # NOTE: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +def get_pp_indices(num_hidden_layers: int, pp_rank: int, + pp_size: int) -> Tuple[int, int]: + """Try to evenly distribute layers across partitions. + + If the number of layers is not divisible by the number of partitions, + the remaining layers are evenly distributed across all but the last + partition. The last partition is excluded because it often contains an + additional norm layer and we are attempting to balance compute. + + If `pp_size > 2` and the number of remaining layers is + `0 < x <= pp_size - 2` then the remaining layers are evenly distributed + across the middle partitions. The first and last partitions are excluded + because they contain the input and output embeddings respectively and we + are attempting to reduce maximum memory consumption across partitions. + """ + partition_list_str = envs.VLLM_PP_LAYER_PARTITION + if partition_list_str is not None: + try: + partitions = [ + int(layer) for layer in partition_list_str.split(",") + ] + except ValueError as err: + raise ValueError("Invalid partition string: {}".format( + partition_list_str)) from err + if len(partitions) != pp_size: + raise ValueError(f"{len(partitions)=} does not match {pp_size=}.") + if sum(partitions) != num_hidden_layers: + raise ValueError( + f"{sum(partitions)=} does not match {num_hidden_layers=}.") + else: + layers_per_partition = num_hidden_layers // pp_size + partitions = [layers_per_partition for _ in range(pp_size)] + + if remaining_layers := num_hidden_layers % pp_size: + for i in range(2, remaining_layers + 2): + partitions[-i] += 1 + logger.info("Hidden layers were unevenly partitioned: %s", + ",".join(str(p) for p in partitions)) + logger.info("This can be manually overridden using the " + "VLLM_PP_LAYER_PARTITION environment variable") + + start_layer = sum(partitions[:pp_rank]) + end_layer = start_layer + partitions[pp_rank] + + return (start_layer, end_layer) + + +@dataclasses.dataclass +class StatelessProcessGroup: + """A dataclass to hold a metadata store, and the rank, world_size of the + group. Only use it to communicate metadata between processes. + For data-plane communication, create NCCL-related objects. + """ + rank: int + world_size: int + store: torch._C._distributed_c10d.Store + data_expiration_seconds: int = 3600 # 1 hour + + # dst rank -> counter + send_dst_counter: Dict[int, int] = dataclasses.field(default_factory=dict) + # src rank -> counter + recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict) + broadcast_send_counter: int = 0 + broadcast_recv_src_counter: Dict[int, int] = dataclasses.field( + default_factory=dict) + + # A deque to store the data entries, with key and timestamp. + entries: Deque[Tuple[str, + float]] = dataclasses.field(default_factory=deque) + + def __post_init__(self): + assert self.rank < self.world_size + self.send_dst_counter = {i: 0 for i in range(self.world_size)} + self.recv_src_counter = {i: 0 for i in range(self.world_size)} + self.broadcast_recv_src_counter = { + i: 0 + for i in range(self.world_size) + } + + def send_obj(self, obj: Any, dst: int): + """Send an object to a destination rank.""" + self.expire_data() + key = f"send_to/{dst}/{self.send_dst_counter[dst]}" + self.store.set(key, pickle.dumps(obj)) + self.send_dst_counter[dst] += 1 + self.entries.append((key, time.time())) + + def expire_data(self): + """Expire data that is older than `data_expiration_seconds` seconds.""" + while self.entries: + # check the oldest entry + key, timestamp = self.entries[0] + if time.time() - timestamp > self.data_expiration_seconds: + self.store.delete_key(key) + self.entries.popleft() + else: + break + + def recv_obj(self, src: int) -> Any: + """Receive an object from a source rank.""" + obj = pickle.loads( + self.store.get( + f"send_to/{self.rank}/{self.recv_src_counter[src]}")) + self.recv_src_counter[src] += 1 + return obj + + def broadcast_obj(self, obj: Optional[Any], src: int) -> Any: + """Broadcast an object from a source rank to all other ranks. + It does not clean up after all ranks have received the object. + Use it for limited times, e.g., for initialization. + """ + if self.rank == src: + self.expire_data() + key = (f"broadcast_from/{src}/" + f"{self.broadcast_send_counter}") + self.store.set(key, pickle.dumps(obj)) + self.broadcast_send_counter += 1 + self.entries.append((key, time.time())) + return obj + else: + key = (f"broadcast_from/{src}/" + f"{self.broadcast_recv_src_counter[src]}") + recv_obj = pickle.loads(self.store.get(key)) + self.broadcast_recv_src_counter[src] += 1 + return recv_obj + + def all_gather_obj(self, obj: Any) -> list[Any]: + """All gather an object from all ranks.""" + gathered_objs = [] + for i in range(self.world_size): + if i == self.rank: + gathered_objs.append(obj) + self.broadcast_obj(obj, src=self.rank) + else: + recv_obj = self.broadcast_obj(None, src=i) + gathered_objs.append(recv_obj) + return gathered_objs + + def barrier(self): + """A barrier to synchronize all ranks.""" + for i in range(self.world_size): + self.broadcast_obj(None, src=i) + + @staticmethod + def create( + host: str, + port: int, + rank: int, + world_size: int, + data_expiration_seconds: int = 3600, + store_timeout: int = 300, + ) -> "StatelessProcessGroup": + """A replacement for `torch.distributed.init_process_group` that does not + pollute the global state. + + If we have process A and process B called `torch.distributed.init_process_group` + to form a group, and then we want to form another group with process A, B, C, + D, it is not possible in PyTorch, because process A and process B have already + formed a group, and process C and process D cannot join that group. This + function is a workaround for this issue. + + `torch.distributed.init_process_group` is a global call, while this function + is a stateless call. It will return a `StatelessProcessGroup` object that can be + used for exchanging metadata. With this function, process A and process B + can call `StatelessProcessGroup.create` to form a group, and then process A, B, + C, and D can call `StatelessProcessGroup.create` to form another group. + """ # noqa + store = TCPStore( + host_name=host, + port=port, + world_size=world_size, + is_master=(rank == 0), + timeout=datetime.timedelta(seconds=store_timeout), + ) + + return StatelessProcessGroup( + rank=rank, + world_size=world_size, + store=store, + data_expiration_seconds=data_expiration_seconds) + + +def stateless_init_torch_distributed_process_group( + host: str, port: int, rank: int, world_size: int, + backend: str) -> ProcessGroup: + """ + A replacement for `torch.distributed.init_process_group` that does not + pollute the global state. The created ProcessGroup object can be used for + some operations such as `allreduce`, because it does not depend on the + global rank. However, some operations such as `broadcast` cannot be used + because it depends on the global rank. + + # TODO: ask for help from PyTorch team if we need the `broadcast` operation. + + This function is useful when we are not sure about the total number of + processes in the process group. For example, we may have process + 1, 2, ..., 8 who want to communicate, and process 9 might be the same + process as process 1, or it might be a different process; process 10 + might be the same process as process 5, or it might be a different process. + In this case, how can we reliably form a communication channel within + process 9 and 10, without affecting the communication channel within + process 1, 2, ..., 8? + + One possible solution is to figure out if process 9 and 10 are the same + as process 1 and 5 beforehand, and then form a communication channel + based on the information, adjusting the ranks and world_size etc. However, + figuring out the information is not always easy, and it will interfere + with the main communication channel. + + Our solution is to always form a communication channel with process 1, 2, + ..., 8, and then use this function to form another communication channel + with process 9 and 10. This way, regardless of whether process 9 and 10 + are the same as process 1 and 5, the main communication channel is + always formed with process 1, 2, ..., 8, and the additional communication + channel is formed with process 9 and 10. + """ + init_method = f"tcp://{host}:{port}" + backend = Backend(backend) # it is basically string + timeout = _get_default_timeout(backend) + + store, rank, world_size = next( + rendezvous(init_method, rank, world_size, timeout=timeout)) + store.set_timeout(timeout) + + group_rank = rank + group_size = world_size + + # Use a PrefixStore to avoid accidental overrides of keys used by + # different systems (e.g. RPC) in case the store is multi-tenant. + prefix_store = PrefixStore(init_method, store) + + pg: ProcessGroup = ProcessGroup( + prefix_store, + group_rank, + group_size, + ProcessGroup.Options(backend=str("nccl")), + ) + + if backend == "gloo": + from torch.distributed.distributed_c10d import ProcessGroupGloo + backend_class = ProcessGroupGloo(prefix_store, + group_rank, + group_size, + timeout=timeout) + backend_type = ProcessGroup.BackendType.GLOO + device = torch.device("cpu") + elif backend == "nccl": + assert is_nccl_available() + from torch.distributed.distributed_c10d import ProcessGroupNCCL + + backend_options = ProcessGroupNCCL.Options() + backend_options._timeout = timeout + + backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size, + backend_options) + backend_type = ProcessGroup.BackendType.NCCL + device = torch.device("cuda") + else: + raise RuntimeError(f"Unsupported torch distributed backend: {backend}") + + # unsupported attribute + # pg._set_default_backend(backend_type) + backend_class._set_sequence_number_for_group() + + pg._register_backend(device, backend_type, backend_class) + + return pg + + +def stateless_destroy_torch_distributed_process_group( + pg: ProcessGroup) -> None: + """ + Destroy ProcessGroup returned by + stateless_init_torch_distributed_process_group(). + """ + # Lazy import for non-CUDA backends. + from torch.distributed.distributed_c10d import _shutdown_backend + _shutdown_backend(pg) + _unregister_process_group(pg.group_name) diff --git a/vllm/engine/__init__.py b/vllm/engine/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py new file mode 100644 index 0000000..c2ffa90 --- /dev/null +++ b/vllm/engine/arg_utils.py @@ -0,0 +1,1734 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import dataclasses +import json +import threading +from dataclasses import dataclass +from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, + Tuple, Type, Union, cast, get_args) + +import torch + +import vllm.envs as envs +from vllm import version +from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat, + DecodingConfig, DeviceConfig, HfOverrides, + KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig, + ModelConfig, ModelImpl, ObservabilityConfig, + ParallelConfig, PoolerConfig, PromptAdapterConfig, + SchedulerConfig, SpeculativeConfig, TaskOption, + TokenizerPoolConfig, VllmConfig) +from vllm.executor.executor_base import ExecutorBase +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.plugins import load_general_plugins +from vllm.reasoning import ReasoningParserManager +from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 +from vllm.transformers_utils.utils import check_gguf_file +from vllm.usage.usage_lib import UsageContext +from vllm.utils import FlexibleArgumentParser, StoreBoolean, is_in_ray_actor + +if TYPE_CHECKING: + from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup + +logger = init_logger(__name__) + +ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"] + +DEVICE_OPTIONS = [ + "auto", + "cuda", + "neuron", + "cpu", + "tpu", + "xpu", + "hpu", +] + + +def nullable_str(val: str): + if not val or val == "None": + return None + return val + + +def nullable_kvs(val: str) -> Optional[Mapping[str, int]]: + """Parses a string containing comma separate key [str] to value [int] + pairs into a dictionary. + + Args: + val: String value to be parsed. + + Returns: + Dictionary with parsed values. + """ + if len(val) == 0: + return None + + out_dict: Dict[str, int] = {} + for item in val.split(","): + kv_parts = [part.lower().strip() for part in item.split("=")] + if len(kv_parts) != 2: + raise argparse.ArgumentTypeError( + "Each item should be in the form KEY=VALUE") + key, value = kv_parts + + try: + parsed_value = int(value) + except ValueError as exc: + msg = f"Failed to parse value of item {key}={value}" + raise argparse.ArgumentTypeError(msg) from exc + + if key in out_dict and out_dict[key] != parsed_value: + raise argparse.ArgumentTypeError( + f"Conflicting values specified for key: {key}") + out_dict[key] = parsed_value + + return out_dict + + +@dataclass +class EngineArgs: + """Arguments for vLLM engine.""" + model: str = 'facebook/opt-125m' + served_model_name: Optional[Union[str, List[str]]] = None + tokenizer: Optional[str] = None + hf_config_path: Optional[str] = None + task: TaskOption = "auto" + skip_tokenizer_init: bool = False + tokenizer_mode: str = 'auto' + trust_remote_code: bool = False + allowed_local_media_path: str = "" + download_dir: Optional[str] = None + load_format: str = 'auto' + config_format: ConfigFormat = ConfigFormat.AUTO + dtype: str = 'auto' + kv_cache_dtype: str = 'auto' + seed: Optional[int] = None + max_model_len: Optional[int] = None + # Note: Specifying a custom executor backend by passing a class + # is intended for expert use only. The API may change without + # notice. + distributed_executor_backend: Optional[Union[str, + Type[ExecutorBase]]] = None + # number of P/D disaggregation (or other disaggregation) workers + pipeline_parallel_size: int = 1 + tensor_parallel_size: int = 1 + data_parallel_size: int = 1 + enable_expert_parallel: bool = False + max_parallel_loading_workers: Optional[int] = None + block_size: Optional[int] = None + enable_prefix_caching: Optional[bool] = None + prefix_caching_hash_algo: str = "builtin" + disable_sliding_window: bool = False + disable_cascade_attn: bool = False + use_v2_block_manager: bool = True + swap_space: float = 4 # GiB + cpu_offload_gb: float = 0 # GiB + gpu_memory_utilization: float = 0.90 + max_num_batched_tokens: Optional[int] = None + max_num_partial_prefills: Optional[int] = 1 + max_long_partial_prefills: Optional[int] = 1 + long_prefill_token_threshold: Optional[int] = 0 + max_num_seqs: Optional[int] = None + max_logprobs: int = 20 # Default value for OpenAI Chat Completions API + disable_log_stats: bool = False + revision: Optional[str] = None + code_revision: Optional[str] = None + rope_scaling: Optional[Dict[str, Any]] = None + rope_theta: Optional[float] = None + hf_overrides: Optional[HfOverrides] = None + tokenizer_revision: Optional[str] = None + quantization: Optional[str] = None + enforce_eager: Optional[bool] = None + max_seq_len_to_capture: int = 8192 + disable_custom_all_reduce: bool = False + tokenizer_pool_size: int = 0 + # Note: Specifying a tokenizer pool by passing a class + # is intended for expert use only. The API may change without + # notice. + tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray" + tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None + limit_mm_per_prompt: Optional[Mapping[str, int]] = None + mm_processor_kwargs: Optional[Dict[str, Any]] = None + disable_mm_preprocessor_cache: bool = False + enable_lora: bool = False + enable_lora_bias: bool = False + max_loras: int = 1 + max_lora_rank: int = 16 + enable_prompt_adapter: bool = False + max_prompt_adapters: int = 1 + max_prompt_adapter_token: int = 0 + fully_sharded_loras: bool = False + lora_extra_vocab_size: int = 256 + long_lora_scaling_factors: Optional[Tuple[float]] = None + lora_dtype: Optional[Union[str, torch.dtype]] = 'auto' + max_cpu_loras: Optional[int] = None + device: str = 'auto' + num_scheduler_steps: int = 1 + multi_step_stream_outputs: bool = True + ray_workers_use_nsight: bool = False + num_gpu_blocks_override: Optional[int] = None + num_lookahead_slots: int = 0 + model_loader_extra_config: Optional[dict] = None + ignore_patterns: Optional[Union[str, List[str]]] = None + preemption_mode: Optional[str] = None + + scheduler_delay_factor: float = 0.0 + enable_chunked_prefill: Optional[bool] = None + + guided_decoding_backend: str = 'xgrammar' + logits_processor_pattern: Optional[str] = None + + speculative_config: Optional[Dict[str, Any]] = None + + qlora_adapter_name_or_path: Optional[str] = None + show_hidden_metrics_for_version: Optional[str] = None + otlp_traces_endpoint: Optional[str] = None + collect_detailed_traces: Optional[str] = None + disable_async_output_proc: bool = False + scheduling_policy: Literal["fcfs", "priority"] = "fcfs" + scheduler_cls: Union[str, Type[object]] = "vllm.core.scheduler.Scheduler" + + override_neuron_config: Optional[Dict[str, Any]] = None + override_pooler_config: Optional[PoolerConfig] = None + compilation_config: Optional[CompilationConfig] = None + worker_cls: str = "auto" + worker_extension_cls: str = "" + + kv_transfer_config: Optional[KVTransferConfig] = None + + generation_config: Optional[str] = "auto" + override_generation_config: Optional[Dict[str, Any]] = None + enable_sleep_mode: bool = False + model_impl: str = "auto" + + calculate_kv_scales: Optional[bool] = None + + additional_config: Optional[Dict[str, Any]] = None + enable_reasoning: Optional[bool] = None + reasoning_parser: Optional[str] = None + use_tqdm_on_load: bool = True + + def __post_init__(self): + if not self.tokenizer: + self.tokenizer = self.model + + # support `EngineArgs(compilation_config={...})` + # without having to manually construct a + # CompilationConfig object + if isinstance(self.compilation_config, (int, dict)): + self.compilation_config = CompilationConfig.from_cli( + str(self.compilation_config)) + + # Setup plugins + from vllm.plugins import load_general_plugins + load_general_plugins() + + @staticmethod + def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + """Shared CLI arguments for vLLM engine.""" + # Model arguments + parser.add_argument( + '--model', + type=str, + default=EngineArgs.model, + help='Name or path of the huggingface model to use.') + parser.add_argument( + '--task', + default=EngineArgs.task, + choices=get_args(TaskOption), + help='The task to use the model for. Each vLLM instance only ' + 'supports one task, even if the same model can be used for ' + 'multiple tasks. When the model only supports one task, ``"auto"`` ' + 'can be used to select it; otherwise, you must specify explicitly ' + 'which task to use.') + parser.add_argument( + '--tokenizer', + type=nullable_str, + default=EngineArgs.tokenizer, + help='Name or path of the huggingface tokenizer to use. ' + 'If unspecified, model name or path will be used.') + parser.add_argument( + "--hf-config-path", + type=nullable_str, + default=EngineArgs.hf_config_path, + help='Name or path of the huggingface config to use. ' + 'If unspecified, model name or path will be used.') + parser.add_argument( + '--skip-tokenizer-init', + action='store_true', + help='Skip initialization of tokenizer and detokenizer. ' + 'Expects valid prompt_token_ids and None for prompt from ' + 'the input. The generated output will contain token ids.') + parser.add_argument( + '--revision', + type=nullable_str, + default=None, + help='The specific model version to use. It can be a branch ' + 'name, a tag name, or a commit id. If unspecified, will use ' + 'the default version.') + parser.add_argument( + '--code-revision', + type=nullable_str, + default=None, + help='The specific revision to use for the model code on ' + 'Hugging Face Hub. It can be a branch name, a tag name, or a ' + 'commit id. If unspecified, will use the default version.') + parser.add_argument( + '--tokenizer-revision', + type=nullable_str, + default=None, + help='Revision of the huggingface tokenizer to use. ' + 'It can be a branch name, a tag name, or a commit id. ' + 'If unspecified, will use the default version.') + parser.add_argument( + '--tokenizer-mode', + type=str, + default=EngineArgs.tokenizer_mode, + choices=['auto', 'slow', 'mistral', 'custom'], + help='The tokenizer mode.\n\n* "auto" will use the ' + 'fast tokenizer if available.\n* "slow" will ' + 'always use the slow tokenizer. \n* ' + '"mistral" will always use the `mistral_common` tokenizer. \n* ' + '"custom" will use --tokenizer to select the ' + 'preregistered tokenizer.') + parser.add_argument('--trust-remote-code', + action='store_true', + help='Trust remote code from huggingface.') + parser.add_argument( + '--allowed-local-media-path', + type=str, + help="Allowing API requests to read local images or videos " + "from directories specified by the server file system. " + "This is a security risk. " + "Should only be enabled in trusted environments.") + parser.add_argument('--download-dir', + type=nullable_str, + default=EngineArgs.download_dir, + help='Directory to download and load the weights.') + parser.add_argument( + '--load-format', + type=str, + default=EngineArgs.load_format, + choices=[f.value for f in LoadFormat], + help='The format of the model weights to load.\n\n' + '* "auto" will try to load the weights in the safetensors format ' + 'and fall back to the pytorch bin format if safetensors format ' + 'is not available.\n' + '* "pt" will load the weights in the pytorch bin format.\n' + '* "safetensors" will load the weights in the safetensors format.\n' + '* "npcache" will load the weights in pytorch format and store ' + 'a numpy cache to speed up the loading.\n' + '* "dummy" will initialize the weights with random values, ' + 'which is mainly for profiling.\n' + '* "tensorizer" will load the weights using tensorizer from ' + 'CoreWeave. See the Tensorize vLLM Model script in the Examples ' + 'section for more information.\n' + '* "runai_streamer" will load the Safetensors weights using Run:ai' + 'Model Streamer.\n' + '* "bitsandbytes" will load the weights using bitsandbytes ' + 'quantization.\n' + '* "sharded_state" will load weights from pre-sharded checkpoint ' + 'files, supporting efficient loading of tensor-parallel models\n' + '* "gguf" will load weights from GGUF format files (details ' + 'specified in https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n' + '* "mistral" will load weights from consolidated safetensors files ' + 'used by Mistral models.\n') + parser.add_argument( + '--config-format', + default=EngineArgs.config_format, + choices=[f.value for f in ConfigFormat], + help='The format of the model config to load.\n\n' + '* "auto" will try to load the config in hf format ' + 'if available else it will try to load in mistral format ') + parser.add_argument( + '--dtype', + type=str, + default=EngineArgs.dtype, + choices=[ + 'auto', 'half', 'float16', 'bfloat16', 'float', 'float32' + ], + help='Data type for model weights and activations.\n\n' + '* "auto" will use FP16 precision for FP32 and FP16 models, and ' + 'BF16 precision for BF16 models.\n' + '* "half" for FP16. Recommended for AWQ quantization.\n' + '* "float16" is the same as "half".\n' + '* "bfloat16" for a balance between precision and range.\n' + '* "float" is shorthand for FP32 precision.\n' + '* "float32" for FP32 precision.') + parser.add_argument( + '--kv-cache-dtype', + type=str, + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], + default=EngineArgs.kv_cache_dtype, + help='Data type for kv cache storage. If "auto", will use model ' + 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' + 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') + parser.add_argument('--max-model-len', + type=int, + default=EngineArgs.max_model_len, + help='Model context length. If unspecified, will ' + 'be automatically derived from the model config.') + parser.add_argument( + '--guided-decoding-backend', + type=str, + default='xgrammar', + help='Which engine will be used for guided decoding' + ' (JSON schema / regex etc) by default. Currently support ' + 'https://github.com/mlc-ai/xgrammar and ' + 'https://github.com/guidance-ai/llguidance.' + 'Valid backend values are "xgrammar", "guidance", and "auto". ' + 'With "auto", we will make opinionated choices based on request' + 'contents and what the backend libraries currently support, so ' + 'the behavior is subject to change in each release.') + parser.add_argument( + '--logits-processor-pattern', + type=nullable_str, + default=None, + help='Optional regex pattern specifying valid logits processor ' + 'qualified names that can be passed with the `logits_processors` ' + 'extra completion argument. Defaults to None, which allows no ' + 'processors.') + parser.add_argument( + '--model-impl', + type=str, + default=EngineArgs.model_impl, + choices=[f.value for f in ModelImpl], + help='Which implementation of the model to use.\n\n' + '* "auto" will try to use the vLLM implementation if it exists ' + 'and fall back to the Transformers implementation if no vLLM ' + 'implementation is available.\n' + '* "vllm" will use the vLLM model implementation.\n' + '* "transformers" will use the Transformers model ' + 'implementation.\n') + # Parallel arguments + parser.add_argument( + '--distributed-executor-backend', + choices=['ray', 'mp', 'uni', 'external_launcher'], + default=EngineArgs.distributed_executor_backend, + help='Backend to use for distributed model ' + 'workers, either "ray" or "mp" (multiprocessing). If the product ' + 'of pipeline_parallel_size and tensor_parallel_size is less than ' + 'or equal to the number of GPUs available, "mp" will be used to ' + 'keep processing on a single host. Otherwise, this will default ' + 'to "ray" if Ray is installed and fail otherwise. Note that tpu ' + 'only supports Ray for distributed inference.') + + parser.add_argument('--pipeline-parallel-size', + '-pp', + type=int, + default=EngineArgs.pipeline_parallel_size, + help='Number of pipeline stages.') + parser.add_argument('--tensor-parallel-size', + '-tp', + type=int, + default=EngineArgs.tensor_parallel_size, + help='Number of tensor parallel replicas.') + parser.add_argument('--data-parallel-size', + '-dp', + type=int, + default=EngineArgs.data_parallel_size, + help='Number of data parallel replicas. ' + 'MoE layers will be sharded according to the ' + 'product of the tensor-parallel-size and ' + 'data-parallel-size.') + parser.add_argument( + '--enable-expert-parallel', + action='store_true', + help='Use expert parallelism instead of tensor parallelism ' + 'for MoE layers.') + parser.add_argument( + '--max-parallel-loading-workers', + type=int, + default=EngineArgs.max_parallel_loading_workers, + help='Load model sequentially in multiple batches, ' + 'to avoid RAM OOM when using tensor ' + 'parallel and large models.') + parser.add_argument( + '--ray-workers-use-nsight', + action='store_true', + help='If specified, use nsight to profile Ray workers.') + # KV cache arguments + parser.add_argument('--block-size', + type=int, + default=EngineArgs.block_size, + choices=[8, 16, 32, 64, 128], + help='Token block size for contiguous chunks of ' + 'tokens. This is ignored on neuron devices and ' + 'set to ``--max-model-len``. On CUDA devices, ' + 'only block sizes up to 32 are supported. ' + 'On HPU devices, block size defaults to 128.') + + parser.add_argument( + "--enable-prefix-caching", + action=argparse.BooleanOptionalAction, + default=EngineArgs.enable_prefix_caching, + help="Enables automatic prefix caching. " + "Use ``--no-enable-prefix-caching`` to disable explicitly.", + ) + parser.add_argument( + "--prefix-caching-hash-algo", + type=str, + choices=["builtin", "sha256"], + default=EngineArgs.prefix_caching_hash_algo, + help="Set the hash algorithm for prefix caching. " + "Options are 'builtin' (Python's built-in hash) or 'sha256' " + "(collision resistant but with certain overheads).", + ) + parser.add_argument('--disable-sliding-window', + action='store_true', + help='Disables sliding window, ' + 'capping to sliding window size.') + parser.add_argument('--use-v2-block-manager', + action='store_true', + default=True, + help='[DEPRECATED] block manager v1 has been ' + 'removed and SelfAttnBlockSpaceManager (i.e. ' + 'block manager v2) is now the default. ' + 'Setting this flag to True or False' + ' has no effect on vLLM behavior.') + parser.add_argument( + '--num-lookahead-slots', + type=int, + default=EngineArgs.num_lookahead_slots, + help='Experimental scheduling config necessary for ' + 'speculative decoding. This will be replaced by ' + 'speculative config in the future; it is present ' + 'to enable correctness tests until then.') + + parser.add_argument('--seed', + type=int, + default=EngineArgs.seed, + help='Random seed for operations.') + parser.add_argument('--swap-space', + type=float, + default=EngineArgs.swap_space, + help='CPU swap space size (GiB) per GPU.') + parser.add_argument( + '--cpu-offload-gb', + type=float, + default=0, + help='The space in GiB to offload to CPU, per GPU. ' + 'Default is 0, which means no offloading. Intuitively, ' + 'this argument can be seen as a virtual way to increase ' + 'the GPU memory size. For example, if you have one 24 GB ' + 'GPU and set this to 10, virtually you can think of it as ' + 'a 34 GB GPU. Then you can load a 13B model with BF16 weight, ' + 'which requires at least 26GB GPU memory. Note that this ' + 'requires fast CPU-GPU interconnect, as part of the model is ' + 'loaded from CPU memory to GPU memory on the fly in each ' + 'model forward pass.') + parser.add_argument( + '--gpu-memory-utilization', + type=float, + default=EngineArgs.gpu_memory_utilization, + help='The fraction of GPU memory to be used for the model ' + 'executor, which can range from 0 to 1. For example, a value of ' + '0.5 would imply 50%% GPU memory utilization. If unspecified, ' + 'will use the default value of 0.9. This is a per-instance ' + 'limit, and only applies to the current vLLM instance.' + 'It does not matter if you have another vLLM instance running ' + 'on the same GPU. For example, if you have two vLLM instances ' + 'running on the same GPU, you can set the GPU memory utilization ' + 'to 0.5 for each instance.') + parser.add_argument( + '--num-gpu-blocks-override', + type=int, + default=None, + help='If specified, ignore GPU profiling result and use this number' + ' of GPU blocks. Used for testing preemption.') + parser.add_argument('--max-num-batched-tokens', + type=int, + default=EngineArgs.max_num_batched_tokens, + help='Maximum number of batched tokens per ' + 'iteration.') + parser.add_argument( + "--max-num-partial-prefills", + type=int, + default=EngineArgs.max_num_partial_prefills, + help="For chunked prefill, the max number of concurrent \ + partial prefills.") + parser.add_argument( + "--max-long-partial-prefills", + type=int, + default=EngineArgs.max_long_partial_prefills, + help="For chunked prefill, the maximum number of prompts longer " + "than --long-prefill-token-threshold that will be prefilled " + "concurrently. Setting this less than --max-num-partial-prefills " + "will allow shorter prompts to jump the queue in front of longer " + "prompts in some cases, improving latency.") + parser.add_argument( + "--long-prefill-token-threshold", + type=float, + default=EngineArgs.long_prefill_token_threshold, + help="For chunked prefill, a request is considered long if the " + "prompt is longer than this number of tokens.") + parser.add_argument('--max-num-seqs', + type=int, + default=EngineArgs.max_num_seqs, + help='Maximum number of sequences per iteration.') + parser.add_argument( + '--max-logprobs', + type=int, + default=EngineArgs.max_logprobs, + help=('Max number of log probs to return logprobs is specified in' + ' SamplingParams.')) + parser.add_argument('--disable-log-stats', + action='store_true', + help='Disable logging statistics.') + # Quantization settings. + parser.add_argument('--quantization', + '-q', + type=nullable_str, + choices=[*QUANTIZATION_METHODS, None], + default=EngineArgs.quantization, + help='Method used to quantize the weights. If ' + 'None, we first check the `quantization_config` ' + 'attribute in the model config file. If that is ' + 'None, we assume the model weights are not ' + 'quantized and use `dtype` to determine the data ' + 'type of the weights.') + parser.add_argument( + '--rope-scaling', + default=None, + type=json.loads, + help='RoPE scaling configuration in JSON format. ' + 'For example, ``{"rope_type":"dynamic","factor":2.0}``') + parser.add_argument('--rope-theta', + default=None, + type=float, + help='RoPE theta. Use with `rope_scaling`. In ' + 'some cases, changing the RoPE theta improves the ' + 'performance of the scaled model.') + parser.add_argument('--hf-overrides', + type=json.loads, + default=EngineArgs.hf_overrides, + help='Extra arguments for the HuggingFace config. ' + 'This should be a JSON string that will be ' + 'parsed into a dictionary.') + parser.add_argument('--enforce-eager', + action='store_true', + help='Always use eager-mode PyTorch. If False, ' + 'will use eager mode and CUDA graph in hybrid ' + 'for maximal performance and flexibility.') + parser.add_argument('--max-seq-len-to-capture', + type=int, + default=EngineArgs.max_seq_len_to_capture, + help='Maximum sequence length covered by CUDA ' + 'graphs. When a sequence has context length ' + 'larger than this, we fall back to eager mode. ' + 'Additionally for encoder-decoder models, if the ' + 'sequence length of the encoder input is larger ' + 'than this, we fall back to the eager mode.') + parser.add_argument('--disable-custom-all-reduce', + action='store_true', + default=EngineArgs.disable_custom_all_reduce, + help='See ParallelConfig.') + parser.add_argument('--tokenizer-pool-size', + type=int, + default=EngineArgs.tokenizer_pool_size, + help='Size of tokenizer pool to use for ' + 'asynchronous tokenization. If 0, will ' + 'use synchronous tokenization.') + parser.add_argument('--tokenizer-pool-type', + type=str, + default=EngineArgs.tokenizer_pool_type, + help='Type of tokenizer pool to use for ' + 'asynchronous tokenization. Ignored ' + 'if tokenizer_pool_size is 0.') + parser.add_argument('--tokenizer-pool-extra-config', + type=nullable_str, + default=EngineArgs.tokenizer_pool_extra_config, + help='Extra config for tokenizer pool. ' + 'This should be a JSON string that will be ' + 'parsed into a dictionary. Ignored if ' + 'tokenizer_pool_size is 0.') + + # Multimodal related configs + parser.add_argument( + '--limit-mm-per-prompt', + type=nullable_kvs, + default=EngineArgs.limit_mm_per_prompt, + # The default value is given in + # MultiModalConfig.get_limit_per_prompt + help=('For each multimodal plugin, limit how many ' + 'input instances to allow for each prompt. ' + 'Expects a comma-separated list of items, ' + 'e.g.: `image=16,video=2` allows a maximum of 16 ' + 'images and 2 videos per prompt. Defaults to 1 for ' + 'each modality.')) + parser.add_argument( + '--mm-processor-kwargs', + default=None, + type=json.loads, + help=('Overrides for the multimodal input mapping/processing, ' + 'e.g., image processor. For example: ``{"num_crops": 4}``.')) + parser.add_argument( + '--disable-mm-preprocessor-cache', + action='store_true', + help='If true, then disables caching of the multi-modal ' + 'preprocessor/mapper. (not recommended)') + + # LoRA related configs + parser.add_argument('--enable-lora', + action='store_true', + help='If True, enable handling of LoRA adapters.') + parser.add_argument('--enable-lora-bias', + action='store_true', + help='If True, enable bias for LoRA adapters.') + parser.add_argument('--max-loras', + type=int, + default=EngineArgs.max_loras, + help='Max number of LoRAs in a single batch.') + parser.add_argument('--max-lora-rank', + type=int, + default=EngineArgs.max_lora_rank, + help='Max LoRA rank.') + parser.add_argument( + '--lora-extra-vocab-size', + type=int, + default=EngineArgs.lora_extra_vocab_size, + help=('Maximum size of extra vocabulary that can be ' + 'present in a LoRA adapter (added to the base ' + 'model vocabulary).')) + parser.add_argument( + '--lora-dtype', + type=str, + default=EngineArgs.lora_dtype, + choices=['auto', 'float16', 'bfloat16'], + help=('Data type for LoRA. If auto, will default to ' + 'base model dtype.')) + parser.add_argument( + '--long-lora-scaling-factors', + type=nullable_str, + default=EngineArgs.long_lora_scaling_factors, + help=('Specify multiple scaling factors (which can ' + 'be different from base model scaling factor ' + '- see eg. Long LoRA) to allow for multiple ' + 'LoRA adapters trained with those scaling ' + 'factors to be used at the same time. If not ' + 'specified, only adapters trained with the ' + 'base model scaling factor are allowed.')) + parser.add_argument( + '--max-cpu-loras', + type=int, + default=EngineArgs.max_cpu_loras, + help=('Maximum number of LoRAs to store in CPU memory. ' + 'Must be >= than max_loras.')) + parser.add_argument( + '--fully-sharded-loras', + action='store_true', + help=('By default, only half of the LoRA computation is ' + 'sharded with tensor parallelism. ' + 'Enabling this will use the fully sharded layers. ' + 'At high sequence length, max rank or ' + 'tensor parallel size, this is likely faster.')) + parser.add_argument('--enable-prompt-adapter', + action='store_true', + help='If True, enable handling of PromptAdapters.') + parser.add_argument('--max-prompt-adapters', + type=int, + default=EngineArgs.max_prompt_adapters, + help='Max number of PromptAdapters in a batch.') + parser.add_argument('--max-prompt-adapter-token', + type=int, + default=EngineArgs.max_prompt_adapter_token, + help='Max number of PromptAdapters tokens') + parser.add_argument("--device", + type=str, + default=EngineArgs.device, + choices=DEVICE_OPTIONS, + help='Device type for vLLM execution.') + parser.add_argument('--num-scheduler-steps', + type=int, + default=1, + help=('Maximum number of forward steps per ' + 'scheduler call.')) + parser.add_argument( + '--use-tqdm-on-load', + dest='use_tqdm_on_load', + action=argparse.BooleanOptionalAction, + default=EngineArgs.use_tqdm_on_load, + help='Whether to enable/disable progress bar ' + 'when loading model weights.', + ) + + parser.add_argument( + '--multi-step-stream-outputs', + action=StoreBoolean, + default=EngineArgs.multi_step_stream_outputs, + nargs="?", + const="True", + help='If False, then multi-step will stream outputs at the end ' + 'of all steps') + parser.add_argument( + '--scheduler-delay-factor', + type=float, + default=EngineArgs.scheduler_delay_factor, + help='Apply a delay (of delay factor multiplied by previous ' + 'prompt latency) before scheduling next prompt.') + parser.add_argument( + '--enable-chunked-prefill', + action=StoreBoolean, + default=EngineArgs.enable_chunked_prefill, + nargs="?", + const="True", + help='If set, the prefill requests can be chunked based on the ' + 'max_num_batched_tokens.') + parser.add_argument('--speculative-config', + type=json.loads, + default=None, + help='The configurations for speculative decoding.' + ' Should be a JSON string.') + + parser.add_argument('--model-loader-extra-config', + type=nullable_str, + default=EngineArgs.model_loader_extra_config, + help='Extra config for model loader. ' + 'This will be passed to the model loader ' + 'corresponding to the chosen load_format. ' + 'This should be a JSON string that will be ' + 'parsed into a dictionary.') + parser.add_argument( + '--ignore-patterns', + action="append", + type=str, + default=[], + help="The pattern(s) to ignore when loading the model." + "Default to `original/**/*` to avoid repeated loading of llama's " + "checkpoints.") + parser.add_argument( + '--preemption-mode', + type=str, + default=None, + help='If \'recompute\', the engine performs preemption by ' + 'recomputing; If \'swap\', the engine performs preemption by ' + 'block swapping.') + + parser.add_argument( + "--served-model-name", + nargs="+", + type=str, + default=None, + help="The model name(s) used in the API. If multiple " + "names are provided, the server will respond to any " + "of the provided names. The model name in the model " + "field of a response will be the first name in this " + "list. If not specified, the model name will be the " + "same as the ``--model`` argument. Noted that this name(s) " + "will also be used in `model_name` tag content of " + "prometheus metrics, if multiple names provided, metrics " + "tag will take the first one.") + parser.add_argument('--qlora-adapter-name-or-path', + type=str, + default=None, + help='Name or path of the QLoRA adapter.') + + parser.add_argument('--show-hidden-metrics-for-version', + type=str, + default=None, + help='Enable deprecated Prometheus metrics that ' + 'have been hidden since the specified version. ' + 'For example, if a previously deprecated metric ' + 'has been hidden since the v0.7.0 release, you ' + 'use --show-hidden-metrics-for-version=0.7 as a ' + 'temporary escape hatch while you migrate to new ' + 'metrics. The metric is likely to be removed ' + 'completely in an upcoming release.') + + parser.add_argument( + '--otlp-traces-endpoint', + type=str, + default=None, + help='Target URL to which OpenTelemetry traces will be sent.') + parser.add_argument( + '--collect-detailed-traces', + type=str, + default=None, + help="Valid choices are " + + ",".join(ALLOWED_DETAILED_TRACE_MODULES) + + ". It makes sense to set this only if ``--otlp-traces-endpoint`` is" + " set. If set, it will collect detailed traces for the specified " + "modules. This involves use of possibly costly and or blocking " + "operations and hence might have a performance impact.") + + parser.add_argument( + '--disable-async-output-proc', + action='store_true', + default=EngineArgs.disable_async_output_proc, + help="Disable async output processing. This may result in " + "lower performance.") + + parser.add_argument( + '--scheduling-policy', + choices=['fcfs', 'priority'], + default="fcfs", + help='The scheduling policy to use. "fcfs" (first come first served' + ', i.e. requests are handled in order of arrival; default) ' + 'or "priority" (requests are handled based on given ' + 'priority (lower value means earlier handling) and time of ' + 'arrival deciding any ties).') + + parser.add_argument( + '--scheduler-cls', + default=EngineArgs.scheduler_cls, + help='The scheduler class to use. "vllm.core.scheduler.Scheduler" ' + 'is the default scheduler. Can be a class directly or the path to ' + 'a class of form "mod.custom_class".') + + parser.add_argument( + '--override-neuron-config', + type=json.loads, + default=None, + help="Override or set neuron device configuration. " + "e.g. ``{\"cast_logits_dtype\": \"bloat16\"}``.") + parser.add_argument( + '--override-pooler-config', + type=PoolerConfig.from_json, + default=None, + help="Override or set the pooling method for pooling models. " + "e.g. ``{\"pooling_type\": \"mean\", \"normalize\": false}``.") + + parser.add_argument('--compilation-config', + '-O', + type=CompilationConfig.from_cli, + default=None, + help='torch.compile configuration for the model.' + 'When it is a number (0, 1, 2, 3), it will be ' + 'interpreted as the optimization level.\n' + 'NOTE: level 0 is the default level without ' + 'any optimization. level 1 and 2 are for internal ' + 'testing only. level 3 is the recommended level ' + 'for production.\n' + 'To specify the full compilation config, ' + 'use a JSON string.\n' + 'Following the convention of traditional ' + 'compilers, using -O without space is also ' + 'supported. -O3 is equivalent to -O 3.') + + parser.add_argument('--kv-transfer-config', + type=KVTransferConfig.from_cli, + default=None, + help='The configurations for distributed KV cache ' + 'transfer. Should be a JSON string.') + + parser.add_argument( + '--worker-cls', + type=str, + default="auto", + help='The worker class to use for distributed execution.') + parser.add_argument( + '--worker-extension-cls', + type=str, + default="", + help='The worker extension class on top of the worker cls, ' + 'it is useful if you just want to add new functions to the worker ' + 'class without changing the existing functions.') + parser.add_argument( + "--generation-config", + type=nullable_str, + default="auto", + help="The folder path to the generation config. " + "Defaults to 'auto', the generation config will be loaded from " + "model path. If set to 'vllm', no generation config is loaded, " + "vLLM defaults will be used. If set to a folder path, the " + "generation config will be loaded from the specified folder path. " + "If `max_new_tokens` is specified in generation config, then " + "it sets a server-wide limit on the number of output tokens " + "for all requests.") + + parser.add_argument( + "--override-generation-config", + type=json.loads, + default=None, + help="Overrides or sets generation config in JSON format. " + "e.g. ``{\"temperature\": 0.5}``. If used with " + "--generation-config=auto, the override parameters will be merged " + "with the default config from the model. If generation-config is " + "None, only the override parameters are used.") + + parser.add_argument("--enable-sleep-mode", + action="store_true", + default=False, + help="Enable sleep mode for the engine. " + "(only cuda platform is supported)") + + parser.add_argument( + '--calculate-kv-scales', + action='store_true', + help='This enables dynamic calculation of ' + 'k_scale and v_scale when kv-cache-dtype is fp8. ' + 'If calculate-kv-scales is false, the scales will ' + 'be loaded from the model checkpoint if available. ' + 'Otherwise, the scales will default to 1.0.') + + parser.add_argument( + "--additional-config", + type=json.loads, + default=None, + help="Additional config for specified platform in JSON format. " + "Different platforms may support different configs. Make sure the " + "configs are valid for the platform you are using. The input format" + " is like '{\"config_key\":\"config_value\"}'") + + parser.add_argument( + "--enable-reasoning", + action="store_true", + default=False, + help="Whether to enable reasoning_content for the model. " + "If enabled, the model will be able to generate reasoning content." + ) + + parser.add_argument( + "--reasoning-parser", + type=str, + choices=list(ReasoningParserManager.reasoning_parsers), + default=None, + help= + "Select the reasoning parser depending on the model that you're " + "using. This is used to parse the reasoning content into OpenAI " + "API format. Required for ``--enable-reasoning``.") + + parser.add_argument( + "--disable-cascade-attn", + action="store_true", + default=False, + help="Disable cascade attention for V1. While cascade attention " + "does not change the mathematical correctness, disabling it " + "could be useful for preventing potential numerical issues. " + "Note that even if this is set to False, cascade attention will be " + "only used when the heuristic tells that it's beneficial.") + + return parser + + @classmethod + def from_cli_args(cls, args: argparse.Namespace): + # Get the list of attributes of this dataclass. + attrs = [attr.name for attr in dataclasses.fields(cls)] + # Set the attributes from the parsed arguments. + engine_args = cls(**{attr: getattr(args, attr) for attr in attrs}) + return engine_args + + def create_model_config(self) -> ModelConfig: + # gguf file needs a specific model loader and doesn't use hf_repo + if check_gguf_file(self.model): + self.quantization = self.load_format = "gguf" + + # NOTE: This is to allow model loading from S3 in CI + if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3 + and self.model in MODELS_ON_S3 + and self.load_format == LoadFormat.AUTO): # noqa: E501 + self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}" + self.load_format = LoadFormat.RUNAI_STREAMER + + return ModelConfig( + model=self.model, + hf_config_path=self.hf_config_path, + task=self.task, + # We know this is not None because we set it in __post_init__ + tokenizer=cast(str, self.tokenizer), + tokenizer_mode=self.tokenizer_mode, + trust_remote_code=self.trust_remote_code, + allowed_local_media_path=self.allowed_local_media_path, + dtype=self.dtype, + seed=self.seed, + revision=self.revision, + code_revision=self.code_revision, + rope_scaling=self.rope_scaling, + rope_theta=self.rope_theta, + hf_overrides=self.hf_overrides, + tokenizer_revision=self.tokenizer_revision, + max_model_len=self.max_model_len, + quantization=self.quantization, + enforce_eager=self.enforce_eager, + max_seq_len_to_capture=self.max_seq_len_to_capture, + max_logprobs=self.max_logprobs, + disable_sliding_window=self.disable_sliding_window, + disable_cascade_attn=self.disable_cascade_attn, + skip_tokenizer_init=self.skip_tokenizer_init, + served_model_name=self.served_model_name, + limit_mm_per_prompt=self.limit_mm_per_prompt, + use_async_output_proc=not self.disable_async_output_proc, + config_format=self.config_format, + mm_processor_kwargs=self.mm_processor_kwargs, + disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache, + override_neuron_config=self.override_neuron_config, + override_pooler_config=self.override_pooler_config, + logits_processor_pattern=self.logits_processor_pattern, + generation_config=self.generation_config, + override_generation_config=self.override_generation_config, + enable_sleep_mode=self.enable_sleep_mode, + model_impl=self.model_impl, + ) + + def create_load_config(self) -> LoadConfig: + + if(self.qlora_adapter_name_or_path is not None) and \ + self.quantization != "bitsandbytes": + raise ValueError( + "QLoRA adapter only support " + f"'bitsandbytes' quantization, but got {self.quantization}") + + if self.quantization == "bitsandbytes": + self.load_format = "bitsandbytes" + return LoadConfig( + load_format=self.load_format, + download_dir=self.download_dir, + model_loader_extra_config=self.model_loader_extra_config, + ignore_patterns=self.ignore_patterns, + use_tqdm_on_load=self.use_tqdm_on_load, + ) + + def create_speculative_config( + self, + target_model_config: ModelConfig, + target_parallel_config: ParallelConfig, + enable_chunked_prefill: bool, + disable_log_stats: bool, + ) -> Optional["SpeculativeConfig"]: + """Initializes and returns a SpeculativeConfig object based on + `speculative_config`. + + This function utilizes `speculative_config` to create a + SpeculativeConfig object. The `speculative_config` can either be + provided as a JSON string input via CLI arguments or directly as a + dictionary from the engine. + """ + if self.speculative_config is None: + return None + + # Note(Shangming): These parameters are not obtained from the cli arg + # '--speculative-config' and must be passed in when creating the engine + # config. + self.speculative_config.update({ + "target_model_config": target_model_config, + "target_parallel_config": target_parallel_config, + "enable_chunked_prefill": enable_chunked_prefill, + "disable_log_stats": disable_log_stats, + }) + speculative_config = SpeculativeConfig.from_dict( + self.speculative_config) + + return speculative_config + + def create_engine_config( + self, + usage_context: Optional[UsageContext] = None, + ) -> VllmConfig: + """ + Create the VllmConfig. + + NOTE: for autoselection of V0 vs V1 engine, we need to + create the ModelConfig first, since ModelConfig's attrs + (e.g. the model arch) are needed to make the decision. + + This function set VLLM_USE_V1=X if VLLM_USE_V1 is + unspecified by the user. + + If VLLM_USE_V1 is specified by the user but the VllmConfig + is incompatible, we raise an error. + """ + from vllm.platforms import current_platform + current_platform.pre_register_and_update() + + device_config = DeviceConfig(device=self.device) + model_config = self.create_model_config() + + # * If VLLM_USE_V1 is unset, we enable V1 for "supported features" + # and fall back to V0 for experimental or unsupported features. + # * If VLLM_USE_V1=1, we enable V1 for supported + experimental + # features and raise error for unsupported features. + # * If VLLM_USE_V1=0, we disable V1. + use_v1 = False + try_v1 = envs.VLLM_USE_V1 + if try_v1 and self._is_v1_supported_oracle(model_config): + use_v1 = True + + # If user explicitly set VLLM_USE_V1, sanity check we respect it. + if envs.is_set("VLLM_USE_V1"): + assert use_v1 == envs.VLLM_USE_V1 + # Otherwise, set the VLLM_USE_V1 variable globally. + else: + envs.set_vllm_use_v1(use_v1) + + # Set default arguments for V0 or V1 Engine. + if use_v1: + self._set_default_args_v1(usage_context) + else: + self._set_default_args_v0(model_config) + + assert self.enable_chunked_prefill is not None + + cache_config = CacheConfig( + block_size=self.block_size, + gpu_memory_utilization=self.gpu_memory_utilization, + swap_space=self.swap_space, + cache_dtype=self.kv_cache_dtype, + is_attention_free=model_config.is_attention_free, + num_gpu_blocks_override=self.num_gpu_blocks_override, + sliding_window=model_config.get_sliding_window(), + enable_prefix_caching=self.enable_prefix_caching, + prefix_caching_hash_algo=self.prefix_caching_hash_algo, + cpu_offload_gb=self.cpu_offload_gb, + calculate_kv_scales=self.calculate_kv_scales, + ) + + # Get the current placement group if Ray is initialized and + # we are in a Ray actor. If so, then the placement group will be + # passed to spawned processes. + placement_group = None + if is_in_ray_actor(): + import ray + + # This call initializes Ray automatically if it is not initialized, + # but we should not do this here. + placement_group = ray.util.get_current_placement_group() + + parallel_config = ParallelConfig( + pipeline_parallel_size=self.pipeline_parallel_size, + tensor_parallel_size=self.tensor_parallel_size, + num_virtual_engine=self.pipeline_parallel_size, + data_parallel_size=self.data_parallel_size, + enable_expert_parallel=self.enable_expert_parallel, + max_parallel_loading_workers=self.max_parallel_loading_workers, + disable_custom_all_reduce=self.disable_custom_all_reduce, + tokenizer_pool_config=TokenizerPoolConfig.create_config( + self.tokenizer_pool_size, + self.tokenizer_pool_type, + self.tokenizer_pool_extra_config, + ), + ray_workers_use_nsight=self.ray_workers_use_nsight, + placement_group=placement_group, + distributed_executor_backend=self.distributed_executor_backend, + worker_cls=self.worker_cls, + worker_extension_cls=self.worker_extension_cls, + ) + + speculative_config = self.create_speculative_config( + target_model_config=model_config, + target_parallel_config=parallel_config, + enable_chunked_prefill=self.enable_chunked_prefill, + disable_log_stats=self.disable_log_stats, + ) + + # Reminder: Please update docs/source/features/compatibility_matrix.md + # If the feature combo become valid + if self.num_scheduler_steps > 1: + if speculative_config is not None: + raise ValueError("Speculative decoding is not supported with " + "multi-step (--num-scheduler-steps > 1)") + if self.enable_chunked_prefill and self.pipeline_parallel_size > 1: + raise ValueError("Multi-Step Chunked-Prefill is not supported " + "for pipeline-parallel-size > 1") + from vllm.platforms import current_platform + if current_platform.is_cpu(): + logger.warning("Multi-Step (--num-scheduler-steps > 1) is " + "currently not supported for CPUs and has been " + "disabled.") + self.num_scheduler_steps = 1 + + # make sure num_lookahead_slots is set the higher value depending on + # if we are using speculative decoding or multi-step + num_lookahead_slots = max(self.num_lookahead_slots, + self.num_scheduler_steps - 1) + num_lookahead_slots = num_lookahead_slots \ + if speculative_config is None \ + else speculative_config.num_lookahead_slots + + scheduler_config = SchedulerConfig( + runner_type=model_config.runner_type, + max_num_batched_tokens=self.max_num_batched_tokens, + max_num_seqs=self.max_num_seqs, + max_model_len=model_config.max_model_len, + num_lookahead_slots=num_lookahead_slots, + delay_factor=self.scheduler_delay_factor, + enable_chunked_prefill=self.enable_chunked_prefill, + is_multimodal_model=model_config.is_multimodal_model, + preemption_mode=self.preemption_mode, + num_scheduler_steps=self.num_scheduler_steps, + multi_step_stream_outputs=self.multi_step_stream_outputs, + send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER + and parallel_config.use_ray), + policy=self.scheduling_policy, + scheduler_cls=self.scheduler_cls, + max_num_partial_prefills=self.max_num_partial_prefills, + max_long_partial_prefills=self.max_long_partial_prefills, + long_prefill_token_threshold=self.long_prefill_token_threshold, + ) + + lora_config = LoRAConfig( + bias_enabled=self.enable_lora_bias, + max_lora_rank=self.max_lora_rank, + max_loras=self.max_loras, + fully_sharded_loras=self.fully_sharded_loras, + lora_extra_vocab_size=self.lora_extra_vocab_size, + long_lora_scaling_factors=self.long_lora_scaling_factors, + lora_dtype=self.lora_dtype, + max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras + and self.max_cpu_loras > 0 else None) if self.enable_lora else None + + if self.qlora_adapter_name_or_path is not None and \ + self.qlora_adapter_name_or_path != "": + if self.model_loader_extra_config is None: + self.model_loader_extra_config = {} + self.model_loader_extra_config[ + "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path + + load_config = self.create_load_config() + + prompt_adapter_config = PromptAdapterConfig( + max_prompt_adapters=self.max_prompt_adapters, + max_prompt_adapter_token=self.max_prompt_adapter_token) \ + if self.enable_prompt_adapter else None + + decoding_config = DecodingConfig( + guided_decoding_backend=self.guided_decoding_backend, + reasoning_backend=self.reasoning_parser + if self.enable_reasoning else None, + ) + + show_hidden_metrics = False + if self.show_hidden_metrics_for_version is not None: + show_hidden_metrics = version._prev_minor_version_was( + self.show_hidden_metrics_for_version) + + detailed_trace_modules = [] + if self.collect_detailed_traces is not None: + detailed_trace_modules = self.collect_detailed_traces.split(",") + for m in detailed_trace_modules: + if m not in ALLOWED_DETAILED_TRACE_MODULES: + raise ValueError( + f"Invalid module {m} in collect_detailed_traces. " + f"Valid modules are {ALLOWED_DETAILED_TRACE_MODULES}") + observability_config = ObservabilityConfig( + show_hidden_metrics=show_hidden_metrics, + otlp_traces_endpoint=self.otlp_traces_endpoint, + collect_model_forward_time="model" in detailed_trace_modules + or "all" in detailed_trace_modules, + collect_model_execute_time="worker" in detailed_trace_modules + or "all" in detailed_trace_modules, + ) + + config = VllmConfig( + model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + lora_config=lora_config, + speculative_config=speculative_config, + load_config=load_config, + decoding_config=decoding_config, + observability_config=observability_config, + prompt_adapter_config=prompt_adapter_config, + compilation_config=self.compilation_config, + kv_transfer_config=self.kv_transfer_config, + additional_config=self.additional_config, + ) + + return config + + def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: + """Oracle for whether to use V0 or V1 Engine by default.""" + + ############################################################# + # Unsupported Feature Flags on V1. + + if (self.load_format == LoadFormat.TENSORIZER.value + or self.load_format == LoadFormat.SHARDED_STATE.value): + _raise_or_fallback( + feature_name=f"--load_format {self.load_format}", + recommend_to_remove=False) + return False + + if (self.logits_processor_pattern + != EngineArgs.logits_processor_pattern): + _raise_or_fallback(feature_name="--logits-processor-pattern", + recommend_to_remove=False) + return False + + if self.preemption_mode != EngineArgs.preemption_mode: + _raise_or_fallback(feature_name="--preemption-mode", + recommend_to_remove=True) + return False + + if (self.disable_async_output_proc + != EngineArgs.disable_async_output_proc): + _raise_or_fallback(feature_name="--disable-async-output-proc", + recommend_to_remove=True) + return False + + if self.scheduling_policy != EngineArgs.scheduling_policy: + _raise_or_fallback(feature_name="--scheduling-policy", + recommend_to_remove=False) + return False + + if self.num_scheduler_steps != EngineArgs.num_scheduler_steps: + _raise_or_fallback(feature_name="--num-scheduler-steps", + recommend_to_remove=True) + return False + + if self.scheduler_delay_factor != EngineArgs.scheduler_delay_factor: + _raise_or_fallback(feature_name="--scheduler-delay-factor", + recommend_to_remove=True) + return False + + if self.additional_config != EngineArgs.additional_config: + _raise_or_fallback(feature_name="--additional-config", + recommend_to_remove=False) + return False + + # Xgrammar and Guidance are supported. + SUPPORTED_GUIDED_DECODING = [ + "xgrammar", "xgrammar:disable-any-whitespace", "guidance", + "guidance:disable-any-whitespace", "auto" + ] + if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING: + _raise_or_fallback(feature_name="--guided-decoding-backend", + recommend_to_remove=False) + return False + + # Need at least Ampere for now (FA support required). + # Skip this check if we are running on a non-GPU platform, + # or if the device capability is not available + # (e.g. in a Ray actor without GPUs). + from vllm.platforms import current_platform + if (current_platform.is_cuda() + and current_platform.get_device_capability() + and current_platform.get_device_capability().major < 8): + _raise_or_fallback(feature_name="Compute Capability < 8.0", + recommend_to_remove=False) + return False + + # No Fp8 KV cache so far. + if self.kv_cache_dtype != "auto": + fp8_attention = self.kv_cache_dtype.startswith("fp8") + will_use_fa = ( + current_platform.is_cuda() + and not envs.is_set("VLLM_ATTENTION_BACKEND") + ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1" + supported = False + if fp8_attention and will_use_fa: + from vllm.vllm_flash_attn.fa_utils import ( + flash_attn_supports_fp8) + supported = flash_attn_supports_fp8() + if not supported: + _raise_or_fallback(feature_name="--kv-cache-dtype", + recommend_to_remove=False) + return False + + # No Prompt Adapter so far. + if self.enable_prompt_adapter: + _raise_or_fallback(feature_name="--enable-prompt-adapter", + recommend_to_remove=False) + return False + + # Only Fp16 and Bf16 dtypes since we only support FA. + V1_SUPPORTED_DTYPES = [torch.bfloat16, torch.float16] + if model_config.dtype not in V1_SUPPORTED_DTYPES: + _raise_or_fallback(feature_name=f"--dtype {model_config.dtype}", + recommend_to_remove=False) + return False + + # Some quantization is not compatible with torch.compile. + V1_UNSUPPORTED_QUANT = ["gguf"] + if model_config.quantization in V1_UNSUPPORTED_QUANT: + _raise_or_fallback( + feature_name=f"--quantization {model_config.quantization}", + recommend_to_remove=False) + return False + + # No Embedding Models so far. + if model_config.task not in ["generate"]: + _raise_or_fallback(feature_name=f"--task {model_config.task}", + recommend_to_remove=False) + return False + + # No Mamba or Encoder-Decoder so far. + if not model_config.is_v1_compatible: + _raise_or_fallback(feature_name=model_config.architectures, + recommend_to_remove=False) + return False + + # No Concurrent Partial Prefills so far. + if (self.max_num_partial_prefills + != EngineArgs.max_num_partial_prefills + or self.max_long_partial_prefills + != EngineArgs.max_long_partial_prefills): + _raise_or_fallback(feature_name="Concurrent Partial Prefill", + recommend_to_remove=False) + return False + + # No OTLP observability so far. + if (self.otlp_traces_endpoint or self.collect_detailed_traces): + _raise_or_fallback(feature_name="--otlp-traces-endpoint", + recommend_to_remove=False) + return False + + # Only Ngram speculative decoding so far. + is_ngram_enabled = False + is_eagle_enabled = False + if self.speculative_config is not None: + # This is supported but experimental (handled below). + speculative_method = self.speculative_config.get("method") + if speculative_method: + if speculative_method in ("ngram", "[ngram]"): + is_ngram_enabled = True + elif speculative_method == "eagle": + is_eagle_enabled = True + else: + speculative_model = self.speculative_config.get("model") + if speculative_model in ("ngram", "[ngram]"): + is_ngram_enabled = True + if not (is_ngram_enabled or is_eagle_enabled): + # Other speculative decoding methods are not supported yet. + _raise_or_fallback(feature_name="Speculative Decoding", + recommend_to_remove=False) + return False + + # No Disaggregated Prefill so far. + if self.kv_transfer_config != EngineArgs.kv_transfer_config: + _raise_or_fallback(feature_name="--kv-transfer-config", + recommend_to_remove=False) + return False + + # No FlashInfer or XFormers so far. + V1_BACKENDS = [ + "FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1", + "TRITON_ATTN_VLLM_V1", "TRITON_MLA", "FLASHMLA" + ] + if (envs.is_set("VLLM_ATTENTION_BACKEND") + and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS): + name = f"VLLM_ATTENTION_BACKEND={envs.VLLM_ATTENTION_BACKEND}" + _raise_or_fallback(feature_name=name, recommend_to_remove=True) + return False + + # Platforms must decide if they can support v1 for this model + if not current_platform.supports_v1(model_config=model_config): + _raise_or_fallback( + feature_name=f"device type={current_platform.device_type}", + recommend_to_remove=False) + return False + ############################################################# + # Experimental Features - allow users to opt in. + + # Signal Handlers requires running in main thread. + if (threading.current_thread() != threading.main_thread() + and _warn_or_fallback("Engine in background thread")): + return False + + # PP is supported on V1 with Ray distributed executor, + # but off for MP distributed executor for now. + if (self.pipeline_parallel_size > 1 + and self.distributed_executor_backend != "ray"): + name = "Pipeline Parallelism without Ray distributed executor" + _raise_or_fallback(feature_name=name, recommend_to_remove=False) + return False + + # ngram is supported on V1, but off by default for now. + if is_ngram_enabled and _warn_or_fallback("ngram"): + return False + + # Eagle is under development, so we don't support it yet. + if is_eagle_enabled and _warn_or_fallback("Eagle"): + return False + + # Non-CUDA is supported on V1, but off by default for now. + not_cuda = not current_platform.is_cuda() + if not_cuda and _warn_or_fallback( # noqa: SIM103 + current_platform.device_name): + return False + ############################################################# + + return True + + def _set_default_args_v0(self, model_config: ModelConfig) -> None: + """Set Default Arguments for V0 Engine.""" + + max_model_len = model_config.max_model_len + use_long_context = max_model_len > 32768 + if self.enable_chunked_prefill is None: + # Chunked prefill not supported for Multimodal or MLA in V0. + if model_config.is_multimodal_model or model_config.use_mla: + self.enable_chunked_prefill = False + + # Enable chunked prefill by default for long context (> 32K) + # models to avoid OOM errors in initial memory profiling phase. + elif use_long_context: + from vllm.platforms import current_platform + is_gpu = current_platform.is_cuda() + use_sliding_window = (model_config.get_sliding_window() + is not None) + use_spec_decode = self.speculative_config is not None + + if (is_gpu and not use_sliding_window and not use_spec_decode + and not self.enable_lora + and not self.enable_prompt_adapter + and model_config.runner_type != "pooling"): + self.enable_chunked_prefill = True + logger.warning( + "Chunked prefill is enabled by default for models " + "with max_model_len > 32K. Chunked prefill might " + "not work with some features or models. If you " + "encounter any issues, please disable by launching " + "with --enable-chunked-prefill=False.") + + if self.enable_chunked_prefill is None: + self.enable_chunked_prefill = False + + if not self.enable_chunked_prefill and use_long_context: + logger.warning( + "The model has a long context length (%s). This may cause" + "OOM during the initial memory profiling phase, or result " + "in low performance due to small KV cache size. Consider " + "setting --max-model-len to a smaller value.", max_model_len) + elif (self.enable_chunked_prefill + and model_config.runner_type == "pooling"): + msg = "Chunked prefill is not supported for pooling models" + raise ValueError(msg) + + # if using prefix caching, we must set a hash algo + if self.enable_prefix_caching: + # Disable prefix caching for multimodal models for VLLM_V0. + if model_config.is_multimodal_model: + logger.warning( + "--enable-prefix-caching is not supported for multimodal " + "models in V0 and has been disabled.") + self.enable_prefix_caching = False + + # VLLM_V0 only supports builtin hash algo for prefix caching. + if self.prefix_caching_hash_algo is None: + self.prefix_caching_hash_algo = "builtin" + elif self.prefix_caching_hash_algo == "sha256": + raise ValueError( + "sha256 is not supported for prefix caching in V0 engine. " + "Please use 'builtin'.") + + # Set max_num_seqs to 256 for VLLM_V0. + if self.max_num_seqs is None: + self.max_num_seqs = 256 + + def _set_default_args_v1(self, usage_context: UsageContext) -> None: + """Set Default Arguments for V1 Engine.""" + + # V1 always uses chunked prefills. + self.enable_chunked_prefill = True + + # V1 enables prefix caching by default. + if self.enable_prefix_caching is None: + self.enable_prefix_caching = True + + # if using prefix caching, we must set a hash algo + if self.enable_prefix_caching and self.prefix_caching_hash_algo is None: + self.prefix_caching_hash_algo = "builtin" + + # V1 should use the new scheduler by default. + # Swap it only if this arg is set to the original V0 default + if self.scheduler_cls == EngineArgs.scheduler_cls: + self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler" + + # When no user override, set the default values based on the usage + # context. + # Use different default values for different hardware. + + # Try to query the device name on the current platform. If it fails, + # it may be because the platform that imports vLLM is not the same + # as the platform that vLLM is running on (e.g. the case of scaling + # vLLM with Ray) and has no GPUs. In this case we use the default + # values for non-H100/H200 GPUs. + try: + from vllm.platforms import current_platform + device_name = current_platform.get_device_name().lower() + except Exception: + # This is only used to set default_max_num_batched_tokens + device_name = "no-device" + + if "h100" in device_name or "h200" in device_name: + # For H100 and H200, we use larger default values. + default_max_num_batched_tokens = { + UsageContext.LLM_CLASS: 16384, + UsageContext.OPENAI_API_SERVER: 8192, + } + else: + # TODO(woosuk): Tune the default values for other hardware. + default_max_num_batched_tokens = { + UsageContext.LLM_CLASS: 8192, + UsageContext.OPENAI_API_SERVER: 2048, + } + + use_context_value = usage_context.value if usage_context else None + if (self.max_num_batched_tokens is None + and usage_context in default_max_num_batched_tokens): + self.max_num_batched_tokens = default_max_num_batched_tokens[ + usage_context] + logger.debug( + "Setting max_num_batched_tokens to %d for %s usage context.", + self.max_num_batched_tokens, use_context_value) + + default_max_num_seqs = 1024 + if self.max_num_seqs is None: + self.max_num_seqs = default_max_num_seqs + + logger.debug("Setting max_num_seqs to %d for %s usage context.", + self.max_num_seqs, use_context_value) + + +@dataclass +class AsyncEngineArgs(EngineArgs): + """Arguments for asynchronous vLLM engine.""" + disable_log_requests: bool = False + + @staticmethod + def add_cli_args(parser: FlexibleArgumentParser, + async_args_only: bool = False) -> FlexibleArgumentParser: + # Initialize plugin to update the parser, for example, The plugin may + # adding a new kind of quantization method to --quantization argument or + # a new device to --device argument. + load_general_plugins() + if not async_args_only: + parser = EngineArgs.add_cli_args(parser) + parser.add_argument('--disable-log-requests', + action='store_true', + help='Disable logging requests.') + from vllm.platforms import current_platform + current_platform.pre_register_and_update(parser) + return parser + + +def _raise_or_fallback(feature_name: str, recommend_to_remove: bool): + if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1: + raise NotImplementedError( + f"VLLM_USE_V1=1 is not supported with {feature_name}.") + msg = f"{feature_name} is not supported by the V1 Engine. " + msg += "Falling back to V0. " + if recommend_to_remove: + msg += f"We recommend to remove {feature_name} from your config " + msg += "in favor of the V1 Engine." + logger.warning(msg) + + +def _warn_or_fallback(feature_name: str) -> bool: + if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1: + logger.warning( + "Detected VLLM_USE_V1=1 with %s. Usage should " + "be considered experimental. Please report any " + "issues on Github.", feature_name) + should_exit = False + else: + logger.info( + "%s is experimental on VLLM_USE_V1=1. " + "Falling back to V0 Engine.", feature_name) + should_exit = True + return should_exit + + +# These functions are used by sphinx to build the documentation +def _engine_args_parser(): + return EngineArgs.add_cli_args(FlexibleArgumentParser()) + + +def _async_engine_args_parser(): + return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(), + async_args_only=True) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py new file mode 100644 index 0000000..05d6376 --- /dev/null +++ b/vllm/engine/async_llm_engine.py @@ -0,0 +1,1266 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import copy +import time +import weakref +from functools import partial +from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable, + List, Mapping, Optional, Set, Tuple, Type, Union, overload) +from weakref import ReferenceType + +from typing_extensions import deprecated + +import vllm.envs as envs +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig, VllmConfig) +from vllm.core.scheduler import SchedulerOutputs +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_timeout import asyncio_timeout +from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState +from vllm.engine.metrics_types import StatLoggerBase +from vllm.engine.protocol import EngineClient +from vllm.executor.executor_base import ExecutorBase +from vllm.inputs import PromptType +from vllm.inputs.preprocess import InputPreprocessor +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.guided_decoding import ( + get_guided_decoding_logits_processor) +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.outputs import PoolingRequestOutput, RequestOutput +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams +from vllm.sequence import ExecuteModelRequest +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.usage.usage_lib import UsageContext +from vllm.utils import Device, deprecate_kwargs, weak_bind + +logger = init_logger(__name__) +ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S + + +class AsyncEngineDeadError(RuntimeError): + pass + + +def _log_task_completion(task: asyncio.Task, + error_callback: Callable[[Exception], None]) -> None: + """This function is only intended for the `engine.run_engine_loop()` task. + + In particular, that task runs a `while True` loop that can only exit if + there is an exception. + """ + + exception = None + try: + return_value = task.result() + raise AssertionError( + f"The engine background task should never finish without an " + f"exception. {return_value}") + except asyncio.exceptions.CancelledError: + # We assume that if the task is cancelled, we are gracefully shutting + # down. This should only happen on program exit. + logger.info("Engine is gracefully shutting down.") + except Exception as e: + exception = e + logger.error("Engine background task failed", exc_info=e) + error_callback(exception) + raise AsyncEngineDeadError( + "Task finished unexpectedly. This should never happen! " + "Please open an issue on GitHub. See stack trace above for the " + "actual cause.") from e + + +STOP_ITERATION = Exception() # Sentinel + + +class AsyncStream: + """A stream of RequestOutputs or PoolingRequestOutputs for a request + that can be iterated over asynchronously via an async generator.""" + + def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None: + self.request_id = request_id + self._cancel = cancel + self._queue: asyncio.Queue = asyncio.Queue() + self._finished = False + + def put(self, item: Union[RequestOutput, PoolingRequestOutput, + Exception]) -> None: + if not self._finished: + self._queue.put_nowait(item) + + def finish( + self, + exception: Optional[Union[BaseException, Type[BaseException]]] = None, + ) -> None: + if not self._finished: + self._finished = True + self._queue.put_nowait( + exception if self._is_raisable(exception) else STOP_ITERATION) + + @property + def finished(self) -> bool: + return self._finished + + async def generator( + self + ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: + try: + while True: + result = await self._queue.get() + if self._is_raisable(result): + if result == STOP_ITERATION: + return + raise result + yield result + except GeneratorExit: + self._cancel(self.request_id) + raise asyncio.CancelledError from None + + @staticmethod + def _is_raisable(value: Any): + return isinstance(value, BaseException) or \ + (isinstance(value, type) and \ + issubclass(value, BaseException)) + + +class RequestTracker: + """Synchronous abstraction for tracking requests.""" + + def __init__(self) -> None: + self._request_streams: Dict[str, AsyncStream] = {} + self._aborted_requests: asyncio.Queue[str] = asyncio.Queue() + self._new_requests: asyncio.Queue[Tuple[AsyncStream, + dict]] = asyncio.Queue() + self.new_requests_event = asyncio.Event() + + def __contains__(self, item): + return item in self._request_streams + + def __len__(self) -> int: + return len(self._request_streams) + + def propagate_exception(self, + exc: Exception, + request_id: Optional[str] = None) -> None: + """Propagate an exception to request streams + (all if request_id is None).""" + if request_id is not None: + self.abort_request(request_id, exception=exc) + else: + # NB: tuple() used here because self.abort_request pops the stream + # out of self._request_streams, so we can't iterate on it directly + for rid in tuple(self._request_streams.keys()): + self.abort_request(rid, exception=exc) + + def process_request_output(self, + request_output: Union[RequestOutput, + PoolingRequestOutput], + *, + verbose: bool = False) -> None: + """Process a request output from the engine.""" + request_id = request_output.request_id + finished = request_output.finished + + if finished: + stream = self._request_streams.pop(request_id, None) + else: + stream = self._request_streams.get(request_id) + # Guard against a KeyError which can occur if the request was aborted + # while the output was generated + if stream is not None: + stream.put(request_output) + if finished: + stream.finish() + + if verbose and finished: + logger.info("Finished request %s.", request_id) + + def process_exception(self, + request_id: str, + exception: BaseException, + *, + verbose: bool = False) -> None: + """Propagate an exception from the engine.""" + if verbose: + logger.info("Finished request %s.", request_id) + self.abort_request(request_id, exception=exception) + + def add_request(self, + request_id: str, + *, + verbose: bool = False, + **engine_add_request_kwargs) -> AsyncStream: + """Add a request to be sent to the engine on the next background + loop iteration.""" + if request_id in self._request_streams: + raise KeyError(f"Request {request_id} already exists.") + + abort_request = partial(self.abort_request, verbose=verbose) + stream = AsyncStream(request_id, abort_request) + self._new_requests.put_nowait((stream, { + "request_id": request_id, + **engine_add_request_kwargs + })) + + self.new_requests_event.set() + + if verbose: + logger.info("Added request %s.", request_id) + + return stream + + def abort_request(self, + request_id: str, + *, + exception: Optional[Union[BaseException, + Type[BaseException]]] = None, + verbose: bool = False) -> None: + """Abort a request during next background loop iteration.""" + if verbose: + logger.info("Aborted request %s.", request_id) + + self._aborted_requests.put_nowait(request_id) + + stream = self._request_streams.pop(request_id, None) + if stream is not None: + stream.finish(exception=exception) + + def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]: + """Get the new requests and finished requests to be + sent to the engine.""" + new_requests: List[Dict] = [] + finished_requests: Set[str] = set() + + while not self._aborted_requests.empty(): + request_id = self._aborted_requests.get_nowait() + finished_requests.add(request_id) + + while not self._new_requests.empty(): + stream, new_request = self._new_requests.get_nowait() + request_id = stream.request_id + if request_id in finished_requests: + # The request has already been aborted. + stream.finish(asyncio.CancelledError) + finished_requests.discard(request_id) + else: + self._request_streams[request_id] = stream + new_requests.append(new_request) + + return new_requests, finished_requests + + async def wait_for_new_requests(self): + if not self.has_new_requests(): + await self.new_requests_event.wait() + self.new_requests_event.clear() + + def has_new_requests(self): + return not self._new_requests.empty() + + +class _AsyncLLMEngine(LLMEngine): + """Extension of LLMEngine to add async methods.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + async def step_async( + self, virtual_engine: int + ) -> List[Union[RequestOutput, PoolingRequestOutput]]: + """Performs one decoding iteration and returns newly generated results. + The workers are ran asynchronously if possible. + + This function performs one decoding iteration of the engine. It first + schedules the sequences to be executed in the next iteration and the + token blocks to be swapped in/out/copy. Then, it executes the model + and updates the scheduler with the model outputs. Finally, it decodes + the sequences and returns the newly generated results. + """ + # these are cached outputs from previous iterations. None if on first + # iteration + cached_outputs = self.cached_scheduler_outputs[virtual_engine] + seq_group_metadata_list = cached_outputs.seq_group_metadata_list + scheduler_outputs = cached_outputs.scheduler_outputs + allow_async_output_proc = cached_outputs.allow_async_output_proc + + ctx = self.scheduler_contexts[virtual_engine] + + # Clear outputs for each new scheduler iteration + ctx.request_outputs.clear() + + # skip the scheduler if there are any remaining steps in the seq groups. + # This ensures that the scheduler is only called again when the current + # batch has completed. + if not self._has_remaining_steps(seq_group_metadata_list): + + # Schedule iteration + (seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc + ) = self.scheduler[virtual_engine].schedule() + + ctx.seq_group_metadata_list = seq_group_metadata_list + ctx.scheduler_outputs = scheduler_outputs + + if not scheduler_outputs.is_empty(): + # this will cause mamba_cache/minimax_cache failed + # to release finished_requests_ids of the last steps + finished_requests_ids = self.scheduler[ + virtual_engine].get_and_reset_finished_requests_ids() + + # Maybe switch from async mode to sync mode + if not allow_async_output_proc and len(ctx.output_queue) > 0: + self._process_model_outputs(ctx=ctx) + + if (self.scheduler_config.is_multi_step + and scheduler_outputs.num_lookahead_slots > 0): + # cache the scheduler outputs for the next iteration if we have + # lookahead slots + self._cache_scheduler_outputs_for_multi_step( + virtual_engine, seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc) + else: + finished_requests_ids = list() + + assert seq_group_metadata_list is not None + assert scheduler_outputs is not None + + if not scheduler_outputs.is_empty(): + + # Check if we have a cached last_output from the previous iteration. + # For supporting PP this is probably the best way to pass the + # sampled_token_ids, as a separate broadcast over all the PP stages + # will cause one virtual engine's microbatch to block the pipeline. + last_sampled_token_ids = \ + self._get_last_sampled_token_ids(virtual_engine) + + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, + blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, + blocks_to_copy=scheduler_outputs.blocks_to_copy, + virtual_engine=virtual_engine, + num_lookahead_slots=scheduler_outputs.num_lookahead_slots, + running_queue_size=scheduler_outputs.running_queue_size, + finished_requests_ids=finished_requests_ids, + # We use ExecuteModelRequest to pass the last sampled_token_ids + # to each of the non-last PP stages for in-place prepare_input. + last_sampled_token_ids=last_sampled_token_ids) + + if allow_async_output_proc: + execute_model_req.async_callback = self.async_callbacks[ + virtual_engine] + + # Execute the model. + outputs = await self.model_executor.execute_model_async( + execute_model_req) + + # we need to do this here so that last step's sampled_token_ids can + # be passed to the next iteration for PP. + if self.scheduler_config.is_multi_step: + self._update_cached_scheduler_output(virtual_engine, outputs) + else: + if len(ctx.output_queue) > 0: + self._process_model_outputs(ctx=ctx) + outputs = [] + + # Finish the current step for all the sequence groups. + if self.scheduler_config.is_multi_step: + for seq_group in seq_group_metadata_list: + seq_group.finish_step() + + if not self._has_remaining_steps(seq_group_metadata_list): + # Clear the cache if we have finished all the steps + if self.scheduler_config.is_multi_step: + self.cached_scheduler_outputs[ + virtual_engine] = SchedulerOutputState() + + # is_first_step_output is True only when the num_steps of all + # the sequences are 1. When the num_steps > 1, + # multi_step_model_runner does the first-step output append. + is_first_step_output: bool = False if not seq_group_metadata_list \ + else seq_group_metadata_list[0].state.num_steps == 1 + + ctx.append_output(outputs=outputs, + seq_group_metadata_list=seq_group_metadata_list, + scheduler_outputs=scheduler_outputs, + is_async=allow_async_output_proc, + is_last_step=True, + is_first_step_output=is_first_step_output) + + if outputs and allow_async_output_proc: + assert len( + outputs + ) == 1, "Async postprocessor expects only a single output set" + self._advance_to_next_step( + outputs[0], seq_group_metadata_list, + scheduler_outputs.scheduled_seq_groups) + + if not allow_async_output_proc: + self._process_model_outputs(ctx=ctx) + + # Log stats. + self.do_log_stats(scheduler_outputs, outputs) + + # Tracing + self.do_tracing(scheduler_outputs) + + else: + # Multi-step case + return ctx.request_outputs + + if not self.has_unfinished_requests(): + # Drain async postprocessor (if exists) + if len(ctx.output_queue) > 0: + self._process_model_outputs(ctx=ctx) + assert len(ctx.output_queue) == 0 + + return ctx.request_outputs + + async def stop_remote_worker_execution_loop_async(self) -> None: + """Stop the remote worker execution loop.""" + await self.model_executor.stop_remote_worker_execution_loop_async() + + async def get_tokenizer_async(self, + lora_request: Optional[LoRARequest] = None + ) -> AnyTokenizer: + return await ( + self.get_tokenizer_group().get_lora_tokenizer_async(lora_request)) + + @overload + @deprecated("'inputs' will be renamed to 'prompt") + async def add_request_async( + self, + request_id: str, + *, + inputs: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> None: + ... + + @overload + async def add_request_async( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> None: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + async def add_request_async( + self, + request_id: str, + prompt: Optional[PromptType] = None, + params: Optional[Union[SamplingParams, PoolingParams]] = None, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + *, + inputs: Optional[PromptType] = None, # DEPRECATED + ) -> None: + """Async version of :meth:`add_request`.""" + if inputs is not None: + prompt = inputs + assert prompt is not None and params is not None + + if lora_request is not None and not self.lora_config: + raise ValueError(f"Got lora_request {lora_request} but LoRA is " + "not enabled!") + if priority != 0 and not self.scheduler_config.policy == "priority": + raise ValueError(f"Got priority {priority} but " + "Priority scheduling is not enabled.") + if arrival_time is None: + arrival_time = time.time() + + if self.tokenizer is not None: + tokenizer = await self.get_tokenizer_async(lora_request) + self._validate_token_prompt(prompt, tokenizer=tokenizer) + + preprocessed_inputs = await self.input_preprocessor.preprocess_async( + prompt, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + processed_inputs = self.input_processor(preprocessed_inputs) + + if isinstance(params, SamplingParams) and \ + params.guided_decoding is not None: + # Guided decoding has an async implementation for building logits + # processors in a separate threadpool. + # We want to invoke that here instead of using the blocking + # implementation in the LLMEngine + params = await build_guided_decoding_logits_processor_async( + sampling_params=params, + tokenizer=await self.get_tokenizer_async(lora_request), + default_guided_backend=self.decoding_config. + guided_decoding_backend, + reasoning_backend=self.decoding_config.reasoning_backend, + model_config=self.model_config) + + self._add_processed_request( + request_id=request_id, + processed_inputs=processed_inputs, + params=params, + arrival_time=arrival_time, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + trace_headers=trace_headers, + priority=priority, + ) + + async def check_health_async(self) -> None: + if self.tokenizer: + self.tokenizer.check_health() + self.model_executor.check_health() + + +async def build_guided_decoding_logits_processor_async( + sampling_params: SamplingParams, tokenizer: AnyTokenizer, + default_guided_backend: str, reasoning_backend: Optional[str], + model_config: ModelConfig) -> SamplingParams: + """Constructs logits processors based on the guided_decoding, + logits_bias, and allowed_token_ids fields in sampling_params. Deletes + those fields and adds the constructed logits processors to the + logits_processors field. Modifies sampling params in-place and returns + the modified sampling params.""" + if sampling_params.guided_decoding is None: + return sampling_params + + # Defensively copy sampling params since guided decoding logits + # processors can have different state for each request + sampling_params = copy.copy(sampling_params) + guided_decoding = sampling_params.guided_decoding + + logger.debug( + "Building guided decoding logits processor. " + "guided_decoding: %s%s", guided_decoding, + f", reasoning_backend: {reasoning_backend}" + if reasoning_backend is not None else "") + + guided_decoding.backend = guided_decoding.backend or default_guided_backend + + processor = await get_guided_decoding_logits_processor( + guided_params=guided_decoding, + tokenizer=tokenizer, + reasoning_backend=reasoning_backend, + model_config=model_config) + + if processor: + if sampling_params.logits_processors is None: + sampling_params.logits_processors = [] + sampling_params.logits_processors.append(processor) + + # Unset guided decoding params after constructing the lp from them + sampling_params.guided_decoding = None + + return sampling_params + + +class AsyncLLMEngine(EngineClient): + """An asynchronous wrapper for :class:`LLMEngine`. + + This class is used to wrap the :class:`LLMEngine` class to make it + asynchronous. It uses asyncio to create a background loop that keeps + processing incoming requests. The :class:`LLMEngine` is kicked by the + generate method when there are requests in the waiting queue. The generate + method yields the outputs from the :class:`LLMEngine` to the caller. + + Args: + log_requests: Whether to log the requests. + start_engine_loop: If True, the background task to run the engine + will be automatically started in the generate call. + *args: Arguments for :class:`LLMEngine`. + **kwargs: Arguments for :class:`LLMEngine`. + """ + + _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine + + def __init__(self, + *args, + log_requests: bool = True, + start_engine_loop: bool = True, + **kwargs) -> None: + if envs.VLLM_USE_V1: + raise ValueError( + "Using V0 AsyncLLMEngine, but envs.VLLM_USE_V1=True. " + "This should not happen. As a workaround, try using " + "AsyncLLMEngine.from_vllm_config(...) or explicitly set " + "VLLM_USE_V1=0 or 1 and report this issue on Github.") + + self.log_requests = log_requests + self.engine = self._engine_class(*args, **kwargs) + + # This ensures quick processing of request outputs + # so the append to asyncio queues is not delayed, + # especially for multi-step. + self.use_process_request_outputs_callback = ( + self.engine.model_config.use_async_output_proc) + + if self.use_process_request_outputs_callback: + self.engine.process_request_outputs_callback = \ + weak_bind(self.process_request_outputs) + + self.background_loop: Optional[asyncio.Future] = None + # We need to keep a reference to unshielded + # task as well to prevent it from being garbage + # collected + self._background_loop_unshielded: Optional[asyncio.Task] = None + self.start_engine_loop = start_engine_loop + self._errored_with: Optional[BaseException] = None + + # Lazy initialized fields + self._request_tracker: RequestTracker + + def __del__(self): + if rt := getattr(self, "request_tracker", None): + # Wake up engine loop so that it will exit cleanly + rt.new_requests_event.set() + + @classmethod + def _get_executor_cls(cls, + engine_config: VllmConfig) -> Type[ExecutorBase]: + return LLMEngine._get_executor_cls(engine_config) + + @classmethod + def from_vllm_config( + cls, + vllm_config: VllmConfig, + start_engine_loop: bool = True, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[dict[str, StatLoggerBase]] = None, + disable_log_requests: bool = False, + disable_log_stats: bool = False, + ) -> "AsyncLLMEngine": + """Create an AsyncLLMEngine from the EngineArgs.""" + + return cls( + vllm_config=vllm_config, + executor_class=cls._get_executor_cls(vllm_config), + start_engine_loop=start_engine_loop, + log_requests=not disable_log_requests, + log_stats=not disable_log_stats, + usage_context=usage_context, + stat_loggers=stat_loggers, + ) + + @classmethod + def from_engine_args( + cls, + engine_args: AsyncEngineArgs, + start_engine_loop: bool = True, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + ) -> "AsyncLLMEngine": + """Creates an async LLM engine from the engine arguments.""" + + vllm_config = engine_args.create_engine_config(usage_context) + + async_engine_cls = cls + if envs.VLLM_USE_V1: + from vllm.v1.engine.async_llm import AsyncLLM as V1AsyncLLMEngine + async_engine_cls = V1AsyncLLMEngine + + return async_engine_cls.from_vllm_config( + vllm_config=vllm_config, + start_engine_loop=start_engine_loop, + usage_context=usage_context, + stat_loggers=stat_loggers, + disable_log_stats=engine_args.disable_log_stats, + disable_log_requests=engine_args.disable_log_requests, + ) + + @property + def is_running(self) -> bool: + return (self.background_loop is not None + and self._background_loop_unshielded is not None + and not self._background_loop_unshielded.done()) + + @property + def is_stopped(self) -> bool: + return self.errored or (self.background_loop is not None and + self._background_loop_unshielded is not None + and self._background_loop_unshielded.done()) + + @property + def errored(self) -> bool: + return self._errored_with is not None + + @property + def dead_error(self) -> BaseException: + return AsyncEngineDeadError( + "Background loop is not running. If it was running, " + "inspect the output to find the stacktrace of the " + "error that caused the background loop to stop " + "(AsyncEngineDeadError).") + + def set_errored(self, exc: Exception) -> None: + self._errored_with = exc + + def _error_callback(self, exc: Exception) -> None: + self.set_errored(exc) + self._request_tracker.propagate_exception(exc) + + async def get_input_preprocessor(self) -> InputPreprocessor: + return self.engine.input_preprocessor + + async def get_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + return await self.engine.get_tokenizer_async(lora_request) + + def start_background_loop(self) -> None: + """Start the background loop.""" + if self.errored: + raise AsyncEngineDeadError( + "Background loop has errored already.") from self._errored_with + if self.is_running: + raise RuntimeError("Background loop is already running.") + # Initialize the RequestTracker here so it uses the right event loop. + self._request_tracker = RequestTracker() + + self._background_loop_unshielded = asyncio.get_event_loop( + ).create_task(self.run_engine_loop(weakref.ref(self))) + self._background_loop_unshielded.add_done_callback( + partial(_log_task_completion, error_callback=self._error_callback)) + self.background_loop = asyncio.shield(self._background_loop_unshielded) + + def shutdown_background_loop(self) -> None: + """ + Shut down the background loop. + + This method needs to be called during cleanup to remove + references to `self` and properly GC the resources held + by the async LLM engine (e.g., the executors as well as + their resources). + """ + if self._background_loop_unshielded is not None: + self._background_loop_unshielded.cancel() + self._background_loop_unshielded = None + self.background_loop = None + + async def engine_step(self, virtual_engine: int) -> bool: + """Kick the engine to process the waiting requests. + + Returns True if there are in-progress requests.""" + + new_requests, aborted_requests = ( + self._request_tracker.get_new_and_aborted_requests()) + + for new_request in new_requests: + # Add the request into the vLLM engine's waiting queue. + try: + await self.engine.add_request_async(**new_request) + except ValueError as e: + # TODO: use a vLLM specific error for failed validation + self._request_tracker.process_exception( + new_request["request_id"], + e, + verbose=self.log_requests, + ) + + if aborted_requests: + await self._engine_abort(aborted_requests) + + request_outputs = await self.engine.step_async(virtual_engine) + + # Put the outputs into the corresponding streams. + # If used as a callback, then already invoked inside + # LLMEngine's _process_model_outputs + if not self.use_process_request_outputs_callback: + all_finished = self.process_request_outputs(request_outputs) + else: + # For callback case, we only need to detect when all + # requests are finished + all_finished = all(request_output.finished + for request_output in request_outputs) + + return not all_finished + + def process_request_outputs(self, request_outputs) -> bool: + # Put the outputs into the corresponding streams. + all_finished = True + for request_output in request_outputs: + self._request_tracker.process_request_output( + request_output, verbose=self.log_requests) + all_finished = all_finished and request_output.finished + + return all_finished + + async def _engine_abort(self, request_ids: Iterable[str]): + self.engine.abort_request(request_ids) + + @staticmethod + async def run_engine_loop(engine_ref: ReferenceType): + """We use a weakref to the engine so that the running loop + doesn't prevent the engine being garbage collected.""" + engine: Optional[AsyncLLMEngine] = engine_ref() + if not engine: + return + + # pipeline_parallel_size = \ + # engine.engine.parallel_config.pipeline_parallel_size + # has_requests_in_progress = [False] * pipeline_parallel_size + num_virtual_engine = \ + engine.engine.parallel_config.num_virtual_engine + has_requests_in_progress = [False] * num_virtual_engine + while True: + if not any(has_requests_in_progress): + logger.debug("Waiting for new requests...") + # Stop the execute model loop in parallel workers until there + # are more requests to process. This avoids waiting + # indefinitely in torch.distributed ops which may otherwise + # timeout, and unblocks the RPC thread in the workers so that + # they can process any other queued control plane messages, + # such as add/remove lora adapters. + await engine.engine.stop_remote_worker_execution_loop_async() + request_tracker = engine._request_tracker + # Allow engine to be garbage collected while + # waiting for new requests + del engine + await asyncio.sleep(0) + if engine_ref() is None: + return + await request_tracker.wait_for_new_requests() + engine = engine_ref() + if not engine: + return + logger.debug("Got new requests!") + # may be we could do not use this when on server mode(all vllm server in new process) + # why i change the official code to the following code: + # when use pp, vllm will create "pipeline_parallel_size" tasks to get full parallel inference. + # assume we use pp 4, then requests_in_progress contains 4 taks, the question is: + # asyncio.wait(tasks, asyncio.FIRST_COMPLETED) will return immediately if there is no input to process on engine ve. + # so, a lots of asyncio.sleep(0) will be executed. + # By change the code to the following, we only execute the engine which has inputs, and the new inputs and it's task + # will be created and executed when first taks completed. + # requests_in_progress = [ + # asyncio.create_task(engine.engine_step(ve)) + # for ve in range(pipeline_parallel_size) + # ] + # has_requests_in_progress = [True] * pipeline_parallel_size + requests_in_progress = {0: asyncio.create_task(engine.engine_step(0))} + has_requests_in_progress[0] = True + + # Abort if iteration takes too long due to unrecoverable errors + # (eg. NCCL timeouts). + try: + async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S): + done, _ = await asyncio.wait( + requests_in_progress.values(), + return_when=asyncio.FIRST_COMPLETED) + for _ in range(len(requests_in_progress)): + await asyncio.sleep(0) + add_virtual_engine = set(list(range(num_virtual_engine))) + for task in done: + result = task.result() + # virtual_engine = requests_in_progress.index(task) + virtual_engine = [k for k,v in requests_in_progress.items() if v == task][0] + has_unfinished_requests = ( + engine.engine. + has_unfinished_requests_for_virtual_engine( + virtual_engine)) + if result or has_unfinished_requests: + requests_in_progress[virtual_engine] = ( + asyncio.create_task( + engine.engine_step(virtual_engine))) + has_requests_in_progress[virtual_engine] = True + else: + requests_in_progress.pop(virtual_engine) + has_requests_in_progress[virtual_engine] = False + # check other engine + for ve in list(add_virtual_engine - requests_in_progress.keys()): + has_unfinished_requests = ( + engine.engine.has_unfinished_requests_for_virtual_engine( + ve)) + if has_unfinished_requests: + requests_in_progress[ve] = asyncio.create_task(engine.engine_step(ve)) + has_requests_in_progress[ve] = True + except asyncio.TimeoutError as exc: + logger.error( + "Engine iteration timed out. This should never happen!") + engine.set_errored(exc) + raise + await asyncio.sleep(0) + + # This method does not need to be async, but kept that way + # for backwards compatibility. + @overload + @deprecated("'inputs' will be renamed to 'prompt") + def add_request( + self, + request_id: str, + *, + inputs: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> Coroutine[None, None, AsyncGenerator[Union[ + RequestOutput, PoolingRequestOutput], None]]: + ... + + @overload + def add_request( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> Coroutine[None, None, AsyncGenerator[Union[ + RequestOutput, PoolingRequestOutput], None]]: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + async def add_request( + self, + request_id: str, + prompt: Optional[PromptType] = None, + params: Optional[Union[SamplingParams, PoolingParams]] = None, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + *, + inputs: Optional[PromptType] = None, # DEPRECATED + ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: + if inputs is not None: + prompt = inputs + assert prompt is not None and params is not None + + if not self.is_running: + if self.start_engine_loop: + self.start_background_loop() + else: + raise AsyncEngineDeadError( + "Background loop is not running. If it was running, " + "inspect the output to find the stacktrace of the " + "error that caused the background loop to stop " + "(AsyncEngineDeadError).") + + if (priority != 0 + and not self.engine.scheduler_config.policy == "priority"): + raise ValueError(f"Got priority {priority} but " + "Priority scheduling is not enabled.") + + stream = self._request_tracker.add_request( + request_id, + verbose=self.log_requests, + prompt=prompt, + params=params, + arrival_time=arrival_time or time.time(), + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + ) + + return stream.generator() + + async def generate( + self, + prompt: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> AsyncGenerator[RequestOutput, None]: + """Generate outputs for a request. + + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. + + Args: + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` + for more details about the format of each input. + sampling_params: The sampling parameters of the request. + request_id: The unique id of the request. + lora_request: LoRA request to use for generation, if any. + trace_headers: OpenTelemetry trace headers. + prompt_adapter_request: Prompt Adapter request to use + for generation, if any. + priority: The priority of the request. + Only applicable with priority scheduling. + + Yields: + The output `RequestOutput` objects from the LLMEngine + for the request. + + Details: + - If the engine is not running, start the background loop, + which iteratively invokes + :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step` + to process the waiting requests. + - Add the request to the engine's `RequestTracker`. + On the next background loop, this request will be sent to + the underlying engine. + Also, a corresponding `AsyncStream` will be created. + - Wait for the request outputs from `AsyncStream` and yield them. + + Example: + >>> # Please refer to entrypoints/api_server.py for + >>> # the complete example. + >>> + >>> # initialize the engine and the example input + >>> # note that engine_args here is AsyncEngineArgs instance + >>> engine = AsyncLLMEngine.from_engine_args(engine_args) + >>> example_input = { + >>> "prompt": "What is LLM?", + >>> "stream": False, # assume the non-streaming case + >>> "temperature": 0.0, + >>> "request_id": 0, + >>> } + >>> + >>> # start the generation + >>> results_generator = engine.generate( + >>> example_input["prompt"], + >>> SamplingParams(temperature=example_input["temperature"]), + >>> example_input["request_id"]) + >>> + >>> # get the results + >>> final_output = None + >>> async for request_output in results_generator: + >>> if await request.is_disconnected(): + >>> # Abort the request if the client disconnects. + >>> await engine.abort(request_id) + >>> # Return or raise an error + >>> ... + >>> final_output = request_output + >>> + >>> # Process and return the final output + >>> ... + """ + try: + async for output in await self.add_request( + request_id, + prompt, + sampling_params, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + ): + yield LLMEngine.validate_output(output, RequestOutput) + except asyncio.CancelledError: + await self.abort(request_id) + raise + + async def encode( + self, + prompt: PromptType, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, + ) -> AsyncGenerator[PoolingRequestOutput, None]: + """Generate outputs for a request from a pooling model. + + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. + + Args: + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` + for more details about the format of each input. + pooling_params: The pooling parameters of the request. + request_id: The unique id of the request. + lora_request: LoRA request to use for generation, if any. + trace_headers: OpenTelemetry trace headers. + priority: The priority of the request. + Only applicable with priority scheduling. + + Yields: + The output `PoolingRequestOutput` objects from the LLMEngine + for the request. + + Details: + - If the engine is not running, start the background loop, + which iteratively invokes + :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step` + to process the waiting requests. + - Add the request to the engine's `RequestTracker`. + On the next background loop, this request will be sent to + the underlying engine. + Also, a corresponding `AsyncStream` will be created. + - Wait for the request outputs from `AsyncStream` and yield them. + + Example: + >>> # Please refer to entrypoints/api_server.py for + >>> # the complete example. + >>> + >>> # initialize the engine and the example input + >>> # note that engine_args here is AsyncEngineArgs instance + >>> engine = AsyncLLMEngine.from_engine_args(engine_args) + >>> example_input = { + >>> "input": "What is LLM?", + >>> "request_id": 0, + >>> } + >>> + >>> # start the generation + >>> results_generator = engine.encode( + >>> example_input["input"], + >>> PoolingParams(), + >>> example_input["request_id"]) + >>> + >>> # get the results + >>> final_output = None + >>> async for request_output in results_generator: + >>> if await request.is_disconnected(): + >>> # Abort the request if the client disconnects. + >>> await engine.abort(request_id) + >>> # Return or raise an error + >>> ... + >>> final_output = request_output + >>> + >>> # Process and return the final output + >>> ... + """ + try: + async for output in await self.add_request( + request_id, + prompt, + pooling_params, + lora_request=lora_request, + trace_headers=trace_headers, + priority=priority, + ): + yield LLMEngine.validate_output(output, PoolingRequestOutput) + except asyncio.CancelledError: + await self.abort(request_id) + raise + + async def abort(self, request_id: str) -> None: + """Abort a request. + + Abort a submitted request. If the request is finished or not found, + this method will be a no-op. + + Args: + request_id: The unique id of the request. + """ + if not self.is_running: + raise AsyncEngineDeadError( + "Background loop is not running. If it was running, " + "inspect the output to find the stacktrace of the " + "error that caused the background loop to stop " + "(AsyncEngineDeadError).") + + return self._abort(request_id) + + def _abort(self, request_id: str) -> None: + """Abort a request. + + Abort a submitted request. If the request is finished or not found, + this method will be a no-op. + + Args: + request_id: The unique id of the request. + """ + self._request_tracker.abort_request(request_id, + exception=asyncio.CancelledError, + verbose=self.log_requests) + + async def get_model_config(self) -> ModelConfig: + """Get the model configuration of the vLLM engine.""" + return self.engine.get_model_config() + + async def get_parallel_config(self) -> ParallelConfig: + """Get the parallel configuration of the vLLM engine.""" + return self.engine.get_parallel_config() + + async def get_decoding_config(self) -> DecodingConfig: + """Get the decoding configuration of the vLLM engine.""" + return self.engine.get_decoding_config() + + async def get_scheduler_config(self) -> SchedulerConfig: + """Get the scheduling configuration of the vLLM engine.""" + return self.engine.get_scheduler_config() + + async def get_lora_config(self) -> LoRAConfig: + """Get the lora configuration of the vLLM engine.""" + return self.engine.get_lora_config() + + async def do_log_stats( + self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[List[SamplerOutput]] = None) -> None: + self.engine.do_log_stats() + + async def check_health(self) -> None: + """Raises an error if engine is unhealthy.""" + t = time.perf_counter() + logger.debug("Starting health check...") + if self.is_stopped: + raise AsyncEngineDeadError("Background loop is stopped.") + + await self.engine.check_health_async() + logger.debug("Health check took %fs", time.perf_counter() - t) + + async def is_tracing_enabled(self) -> bool: + return self.engine.is_tracing_enabled() + + def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None: + self.engine.add_logger(logger_name=logger_name, logger=logger) + + def remove_logger(self, logger_name: str) -> None: + self.engine.remove_logger(logger_name=logger_name) + + async def start_profile(self) -> None: + self.engine.start_profile() + + async def stop_profile(self) -> None: + self.engine.stop_profile() + + async def reset_prefix_cache(self, + device: Optional[Device] = None) -> None: + self.engine.reset_prefix_cache(device) + + async def sleep(self, level: int = 1) -> None: + self.engine.sleep(level) + + async def wake_up(self, tags: Optional[list[str]] = None) -> None: + self.engine.wake_up(tags) + + async def is_sleeping(self) -> bool: + return self.engine.is_sleeping() + + async def add_lora(self, lora_request: LoRARequest) -> None: + self.engine.add_lora(lora_request) + + +# TODO(v1): Remove this class proxy when V1 goes default. +if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1: + from vllm.v1.engine.async_llm import AsyncLLM + + AsyncLLMEngine = AsyncLLM # type: ignore diff --git a/vllm/engine/async_timeout.py b/vllm/engine/async_timeout.py new file mode 100644 index 0000000..aa54c06 --- /dev/null +++ b/vllm/engine/async_timeout.py @@ -0,0 +1,191 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Workaround for https://github.com/python/cpython/issues/86296 +# +# From https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py +# Licensed under the Apache License (Apache-2.0) + +import asyncio +import enum +import sys +import warnings +from types import TracebackType +from typing import Any, Optional, Type + +if sys.version_info[:2] >= (3, 11): + from asyncio import timeout as asyncio_timeout +else: + + def asyncio_timeout(delay: Optional[float]) -> "Timeout": + """timeout context manager. + Useful in cases when you want to apply timeout logic around block + of code or in cases when asyncio.wait_for is not suitable. For example: + >>> async with timeout(0.001): + ... async with aiohttp.get('https://github.com') as r: + ... await r.text() + delay - value in seconds or None to disable timeout logic + """ + loop = asyncio.get_running_loop() + deadline = loop.time() + delay if delay is not None else None + return Timeout(deadline, loop) + + class _State(enum.Enum): + INIT = "INIT" + ENTER = "ENTER" + TIMEOUT = "TIMEOUT" + EXIT = "EXIT" + + class Timeout: + # Internal class, please don't instantiate it directly + # Use timeout() and timeout_at() public factories instead. + # + # Implementation note: `async with timeout()` is preferred + # over `with timeout()`. + # While technically the Timeout class implementation + # doesn't need to be async at all, + # the `async with` statement explicitly points that + # the context manager should be used from async function context. + # + # This design allows to avoid many silly misusages. + # + # TimeoutError is raised immediately when scheduled + # if the deadline is passed. + # The purpose is to time out as soon as possible + # without waiting for the next await expression. + + __slots__ = ("_deadline", "_loop", "_state", "_timeout_handler") + + def __init__(self, deadline: Optional[float], + loop: asyncio.AbstractEventLoop) -> None: + self._loop = loop + self._state = _State.INIT + + self._timeout_handler = None # type: Optional[asyncio.Handle] + if deadline is None: + self._deadline = None # type: Optional[float] + else: + self.update(deadline) + + def __enter__(self) -> "Timeout": + warnings.warn( + "with timeout() is deprecated, use async with timeout()", + DeprecationWarning, + stacklevel=2, + ) + self._do_enter() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + self._do_exit(exc_type) + return None + + async def __aenter__(self) -> "Timeout": + self._do_enter() + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + self._do_exit(exc_type) + return None + + @property + def expired(self) -> bool: + """Is timeout expired during execution?""" + return self._state == _State.TIMEOUT + + @property + def deadline(self) -> Optional[float]: + return self._deadline + + def reject(self) -> None: + """Reject scheduled timeout if any.""" + # cancel is maybe better name but + # task.cancel() raises CancelledError in asyncio world. + if self._state not in (_State.INIT, _State.ENTER): + raise RuntimeError(f"invalid state {self._state.value}") + self._reject() + + def _reject(self) -> None: + if self._timeout_handler is not None: + self._timeout_handler.cancel() + self._timeout_handler = None + + def shift(self, delay: float) -> None: + """Advance timeout on delay seconds. + The delay can be negative. + Raise RuntimeError if shift is called when deadline is not scheduled + """ + deadline = self._deadline + if deadline is None: + raise RuntimeError( + "cannot shift timeout if deadline is not scheduled") + self.update(deadline + delay) + + def update(self, deadline: float) -> None: + """Set deadline to absolute value. + deadline argument points on the time in the same clock system + as loop.time(). + If new deadline is in the past the timeout is raised immediately. + Please note: it is not POSIX time but a time with + undefined starting base, e.g. the time of the system power on. + """ + if self._state == _State.EXIT: + raise RuntimeError( + "cannot reschedule after exit from context manager") + if self._state == _State.TIMEOUT: + raise RuntimeError("cannot reschedule expired timeout") + if self._timeout_handler is not None: + self._timeout_handler.cancel() + self._deadline = deadline + if self._state != _State.INIT: + self._reschedule() + + def _reschedule(self) -> None: + assert self._state == _State.ENTER + deadline = self._deadline + if deadline is None: + return + + now = self._loop.time() + if self._timeout_handler is not None: + self._timeout_handler.cancel() + + task = asyncio.current_task() + if deadline <= now: + self._timeout_handler = self._loop.call_soon( + self._on_timeout, task) + else: + self._timeout_handler = self._loop.call_at( + deadline, self._on_timeout, task) + + def _do_enter(self) -> None: + if self._state != _State.INIT: + raise RuntimeError(f"invalid state {self._state.value}") + self._state = _State.ENTER + self._reschedule() + + def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None: + if exc_type is asyncio.CancelledError and \ + self._state == _State.TIMEOUT: + self._timeout_handler = None + raise asyncio.TimeoutError + # timeout has not expired + self._state = _State.EXIT + self._reject() + return None + + def _on_timeout(self, task: "Optional[asyncio.Task[Any]]") -> None: + if task: + task.cancel() + self._state = _State.TIMEOUT + # drop the reference early + self._timeout_handler = None diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py new file mode 100644 index 0000000..eecb56a --- /dev/null +++ b/vllm/engine/llm_engine.py @@ -0,0 +1,2145 @@ +# SPDX-License-Identifier: Apache-2.0 + +import copy +import time +from collections import Counter as collectionsCounter +from collections import deque +from contextlib import contextmanager +from dataclasses import dataclass +from functools import partial +from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, + Iterable, List, Mapping, NamedTuple, Optional) +from typing import Sequence as GenericSequence +from typing import Set, Type, Union, cast, overload + +import torch +from typing_extensions import TypeVar, deprecated + +import vllm.envs as envs +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ObservabilityConfig, ParallelConfig, SchedulerConfig, + VllmConfig) +from vllm.core.scheduler import ScheduledSequenceGroup, SchedulerOutputs +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.metrics_types import StatLoggerBase, Stats +from vllm.engine.output_processor.interfaces import ( + SequenceGroupOutputProcessor) +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.engine.output_processor.util import create_output_by_sequence_group +from vllm.entrypoints.openai.logits_processors import ( + get_logits_processors as get_openai_logits_processors) +from vllm.executor.executor_base import ExecutorBase +from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, + PromptType) +from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs +from vllm.inputs.preprocess import InputPreprocessor +from vllm.logger import init_logger +from vllm.logits_process import get_bad_words_logits_processors +from vllm.lora.request import LoRARequest +from vllm.model_executor.guided_decoding import ( + get_local_guided_decoding_logits_processor) +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.outputs import (PoolingRequestOutput, RequestOutput, + RequestOutputFactory) +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import RequestOutputKind, SamplingParams +from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup, + PoolingSequenceGroupOutput, Sequence, SequenceGroup, + SequenceGroupBase, SequenceGroupMetadata, + SequenceGroupOutput, SequenceStatus) +from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, + init_tracer) +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.transformers_utils.tokenizer_group import ( + BaseTokenizerGroup, init_tokenizer_from_configs) +from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, + usage_message) +from vllm.utils import (Counter, Device, deprecate_kwargs, + resolve_obj_by_qualname, weak_bind) +from vllm.version import __version__ as VLLM_VERSION +from vllm.worker.model_runner_base import InputProcessingError + +logger = init_logger(__name__) +_LOCAL_LOGGING_INTERVAL_SEC = 5 + +_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) +_O = TypeVar("_O", RequestOutput, PoolingRequestOutput) +_R = TypeVar("_R", default=Any) + + +@dataclass +class SchedulerOutputState: + """Caches the scheduler outputs for a virtual engine. Used for Multi-Step""" + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None + scheduler_outputs: Optional[SchedulerOutputs] = None + allow_async_output_proc: bool = False + last_output: Optional[SamplerOutput] = None + + +class OutputData(NamedTuple): + outputs: List[SamplerOutput] + seq_group_metadata_list: List[SequenceGroupMetadata] + scheduler_outputs: SchedulerOutputs + is_async: bool + is_last_step: bool + # Indicates if this output is from the first step of the + # multi-step. When multi-step is disabled, this is always + # set to True. + # is_first_step_output is invalid when `outputs` has + # outputs from multiple steps. + is_first_step_output: Optional[bool] + skip: List[int] + + +class SchedulerContext: + + def __init__(self, multi_step_stream_outputs: bool = False): + self.output_queue: Deque[OutputData] = deque() + self.request_outputs: List[Union[RequestOutput, + PoolingRequestOutput]] = [] + self.seq_group_metadata_list: Optional[ + List[SequenceGroupMetadata]] = None + self.scheduler_outputs: Optional[SchedulerOutputs] = None + + self.multi_step_stream_outputs: bool = multi_step_stream_outputs + + def append_output(self, outputs: List[SamplerOutput], + seq_group_metadata_list: List[SequenceGroupMetadata], + scheduler_outputs: SchedulerOutputs, is_async: bool, + is_last_step: bool, + is_first_step_output: Optional[bool]): + self.output_queue.append( + OutputData(outputs=outputs, + seq_group_metadata_list=seq_group_metadata_list, + scheduler_outputs=scheduler_outputs, + is_async=is_async, + is_last_step=is_last_step, + is_first_step_output=is_first_step_output, + skip=[])) + + +class LLMEngine: + """An LLM engine that receives requests and generates texts. + + This is the main class for the vLLM engine. It receives requests + from clients and generates texts from the LLM. It includes a tokenizer, a + language model (possibly distributed across multiple GPUs), and GPU memory + space allocated for intermediate states (aka KV cache). This class utilizes + iteration-level scheduling and efficient memory management to maximize the + serving throughput. + + The :class:`~vllm.LLM` class wraps this class for offline batched inference + and the :class:`AsyncLLMEngine` class wraps this class for online serving. + + The config arguments are derived from :class:`~vllm.EngineArgs`. (See + :ref:`engine-args`) + + Args: + model_config: The configuration related to the LLM model. + cache_config: The configuration related to the KV cache memory + management. + parallel_config: The configuration related to distributed execution. + scheduler_config: The configuration related to the request scheduler. + device_config: The configuration related to the device. + lora_config (Optional): The configuration related to serving multi-LoRA. + speculative_config (Optional): The configuration related to speculative + decoding. + executor_class: The model executor class for managing distributed + execution. + prompt_adapter_config (Optional): The configuration related to serving + prompt adapters. + log_stats: Whether to log statistics. + usage_context: Specified entry point, used for usage info collection. + """ + + DO_VALIDATE_OUTPUT: ClassVar[bool] = False + """A flag to toggle whether to validate the type of request output.""" + + @classmethod + @contextmanager + def enable_output_validation(cls): + cls.DO_VALIDATE_OUTPUT = True + + yield + + cls.DO_VALIDATE_OUTPUT = False + + @classmethod + def validate_output( + cls, + output: object, + output_type: Type[_O], + ) -> _O: + do_validate = cls.DO_VALIDATE_OUTPUT + + if ((TYPE_CHECKING or do_validate) + and not isinstance(output, output_type)): + raise TypeError(f"Expected output of type {output_type}, " + f"but found type {type(output)}") + + return cast(_O, output) + + @classmethod + def validate_outputs( + cls, + outputs: GenericSequence[object], + output_type: Type[_O], + ) -> List[_O]: + do_validate = cls.DO_VALIDATE_OUTPUT + + outputs_: List[_O] + if TYPE_CHECKING or do_validate: + outputs_ = [] + for output in outputs: + if not isinstance(output, output_type): + raise TypeError(f"Expected output of type {output_type}, " + f"but found type {type(output)}") + + outputs_.append(output) + else: + outputs_ = outputs + + return outputs_ + + tokenizer: Optional[BaseTokenizerGroup] + + def __init__( + self, + vllm_config: VllmConfig, + executor_class: Type[ExecutorBase], + log_stats: bool, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + input_registry: InputRegistry = INPUT_REGISTRY, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + use_cached_outputs: bool = False, + ) -> None: + if envs.VLLM_USE_V1: + raise ValueError( + "Using V0 LLMEngine, but envs.VLLM_USE_V1=True. " + "This should not happen. As a workaround, try using " + "LLMEngine.from_vllm_config(...) or explicitly set " + "VLLM_USE_V1=0 or 1 and report this issue on Github.") + + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config # noqa + self.load_config = vllm_config.load_config + self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa + ) + self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa + self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa + ) + + logger.info( + "Initializing a V0 LLM engine (v%s) with config: %s, " + "use_cached_outputs=%s, ", + VLLM_VERSION, + vllm_config, + use_cached_outputs, + ) + + self.log_stats = log_stats + self.use_cached_outputs = use_cached_outputs + + if not self.model_config.skip_tokenizer_init: + self.tokenizer = self._init_tokenizer() + self.detokenizer = Detokenizer(self.tokenizer) + tokenizer_group = self.get_tokenizer_group() + else: + self.tokenizer = None + self.detokenizer = None + tokenizer_group = None + + # Ensure that the function doesn't contain a reference to self, + # to avoid engine GC issues + def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: + assert tokenizer_group, ("tokenizer_group cannot be None, " + "make sure skip_tokenizer_init is False") + return tokenizer_group.get_lora_tokenizer(sequence.lora_request) + + self.seq_counter = Counter() + self.generation_config_fields = ( + self.model_config.try_get_generation_config()) + + self.input_preprocessor = InputPreprocessor(self.model_config, + self.tokenizer, + mm_registry) + + self.input_registry = input_registry + self.input_processor = input_registry.create_input_processor( + self.model_config) + + self.model_executor = executor_class(vllm_config=vllm_config, ) + + if self.model_config.runner_type != "pooling": + self._initialize_kv_caches() + + # If usage stat is enabled, collect relevant info. + if is_usage_stats_enabled(): + from vllm.model_executor.model_loader import ( + get_architecture_class_name) + usage_message.report_usage( + get_architecture_class_name(self.model_config), + usage_context, + extra_kvs={ + # Common configuration + "dtype": + str(self.model_config.dtype), + "tensor_parallel_size": + self.parallel_config.tensor_parallel_size, + "block_size": + self.cache_config.block_size, + "gpu_memory_utilization": + self.cache_config.gpu_memory_utilization, + + # Quantization + "quantization": + self.model_config.quantization, + "kv_cache_dtype": + str(self.cache_config.cache_dtype), + + # Feature flags + "enable_lora": + bool(self.lora_config), + "enable_prompt_adapter": + bool(self.prompt_adapter_config), + "enable_prefix_caching": + self.cache_config.enable_prefix_caching, + "enforce_eager": + self.model_config.enforce_eager, + "disable_custom_all_reduce": + self.parallel_config.disable_custom_all_reduce, + }) + + if self.tokenizer: + # Ping the tokenizer to ensure liveness if it runs in a + # different process. + self.tokenizer.ping() + + self.cached_scheduler_outputs = [ + SchedulerOutputState() + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + + self.scheduler_contexts = [ + SchedulerContext(multi_step_stream_outputs=self.scheduler_config. + multi_step_stream_outputs) + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + + if self.model_config.use_async_output_proc: + process_model_outputs = weak_bind(self._process_model_outputs) + + self.async_callbacks = [ + partial(process_model_outputs, + ctx=self.scheduler_contexts[v_id]) + for v_id in range(self.parallel_config.pipeline_parallel_size) + ] + else: + self.async_callbacks = [] + + # Currently used by AsyncLLMEngine to ensure quick append + # of request outputs to asyncio queues + self.process_request_outputs_callback: Optional[Callable] = None + + # Create the scheduler. + # NOTE: the cache_config here have been updated with the numbers of + # GPU and CPU blocks, which are profiled in the distributed executor. + if isinstance(self.vllm_config.scheduler_config.scheduler_cls, str): + Scheduler = resolve_obj_by_qualname( + self.vllm_config.scheduler_config.scheduler_cls) + else: + Scheduler = self.vllm_config.scheduler_config.scheduler_cls + self.scheduler = [ + Scheduler( + self.scheduler_config, self.cache_config, self.lora_config, + self.parallel_config.pipeline_parallel_size, + self.async_callbacks[v_id] + if self.model_config.use_async_output_proc else None) + for v_id in range(self.parallel_config.pipeline_parallel_size) + ] + + # Metric Logging. + if self.log_stats: + if stat_loggers is not None: + self.stat_loggers = stat_loggers + else: + # Lazy import for prometheus multiprocessing. + # We need to set PROMETHEUS_MULTIPROC_DIR environment variable + # before prometheus_client is imported. + # See https://prometheus.github.io/client_python/multiprocess/ + from vllm.engine.metrics import (LoggingStatLogger, + PrometheusStatLogger) + + self.stat_loggers = { + "logging": + LoggingStatLogger( + local_interval=_LOCAL_LOGGING_INTERVAL_SEC, + vllm_config=vllm_config), + "prometheus": + PrometheusStatLogger( + local_interval=_LOCAL_LOGGING_INTERVAL_SEC, + labels=dict( + model_name=self.model_config.served_model_name), + vllm_config=vllm_config), + } + self.stat_loggers["prometheus"].info("cache_config", + self.cache_config) + + self.tracer = None + if self.observability_config.otlp_traces_endpoint: + self.tracer = init_tracer( + "vllm.llm_engine", + self.observability_config.otlp_traces_endpoint) + + # Create sequence output processor, e.g. for beam search or + # speculative decoding. + self.output_processor = ( + SequenceGroupOutputProcessor.create_output_processor( + self.scheduler_config, + self.detokenizer, + self.scheduler, + self.seq_counter, + get_tokenizer_for_seq, + stop_checker=StopChecker( + self.scheduler_config.max_model_len, + get_tokenizer_for_seq, + ), + )) + + self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {} + + # Flag to set when an input fails to process and the engine should run + # the next step without re-scheduling. + self._skip_scheduling_next_step = False + + def _initialize_kv_caches(self) -> None: + """Initialize the KV cache in the worker(s). + + The workers will determine the number of blocks in both the GPU cache + and the swap CPU cache. + """ + start = time.time() + num_gpu_blocks, num_cpu_blocks = ( + self.model_executor.determine_num_available_blocks()) + + # TODO fix this... + if self.model_config.hf_config.model_type in ["deepseek_v3", "deepseek_v2", "deepseek2", "deepseek3"] and self.model_config.quantization == "gguf": + from vllm.distributed import get_tensor_model_parallel_world_size + assert get_tensor_model_parallel_world_size() == 1, print("Combining CPU inference and GGUF quantization for deepseek model with multi-gpu is not supported for now!") + num_gpu_blocks = int(num_gpu_blocks * 0.6) + num_cpu_blocks = int(num_cpu_blocks * 0.6) + + if self.cache_config.num_gpu_blocks_override is not None: + num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override + logger.info( + "Overriding num_gpu_blocks=%d with " + "num_gpu_blocks_override=%d", num_gpu_blocks, + num_gpu_blocks_override) + num_gpu_blocks = num_gpu_blocks_override + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks) + elapsed = time.time() - start + logger.info(("init engine (profile, create kv cache, " + "warmup model) took %.2f seconds"), elapsed) + + @classmethod + def _get_executor_cls(cls, + engine_config: VllmConfig) -> Type[ExecutorBase]: + # distributed_executor_backend must be set in VllmConfig.__post_init__ + distributed_executor_backend = ( + engine_config.parallel_config.distributed_executor_backend) + # Initialize the cluster and specify the executor class. + if isinstance(distributed_executor_backend, type): + if not issubclass(distributed_executor_backend, ExecutorBase): + raise TypeError( + "distributed_executor_backend must be a subclass of " + f"ExecutorBase. Got {distributed_executor_backend}.") + executor_class = distributed_executor_backend + elif distributed_executor_backend == "ray": + from vllm.executor.ray_distributed_executor import ( + RayDistributedExecutor) + executor_class = RayDistributedExecutor + elif distributed_executor_backend == "mp": + from vllm.executor.mp_distributed_executor import ( + MultiprocessingDistributedExecutor) + assert not envs.VLLM_USE_RAY_SPMD_WORKER, ( + "multiprocessing distributed executor backend does not " + "support VLLM_USE_RAY_SPMD_WORKER=1") + executor_class = MultiprocessingDistributedExecutor + elif distributed_executor_backend == "uni": + # JAX-style, single-process, multi-device executor. + from vllm.executor.uniproc_executor import UniProcExecutor + executor_class = UniProcExecutor + elif distributed_executor_backend == "external_launcher": + # executor with external launcher + from vllm.executor.uniproc_executor import ( # noqa + ExecutorWithExternalLauncher) + executor_class = ExecutorWithExternalLauncher + else: + raise ValueError("unrecognized distributed_executor_backend: " + f"{distributed_executor_backend}") + return executor_class + + @classmethod + def from_vllm_config( + cls, + vllm_config: VllmConfig, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + disable_log_stats: bool = False, + ) -> "LLMEngine": + return cls( + vllm_config=vllm_config, + executor_class=cls._get_executor_cls(vllm_config), + log_stats=(not disable_log_stats), + usage_context=usage_context, + stat_loggers=stat_loggers, + ) + + @classmethod + def from_engine_args( + cls, + engine_args: EngineArgs, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + ) -> "LLMEngine": + """Creates an LLM engine from the engine arguments.""" + # Create the engine configs. + vllm_config = engine_args.create_engine_config(usage_context) + + engine_cls = cls + if envs.VLLM_USE_V1: + from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine + engine_cls = V1LLMEngine + + return engine_cls.from_vllm_config( + vllm_config=vllm_config, + usage_context=usage_context, + stat_loggers=stat_loggers, + disable_log_stats=engine_args.disable_log_stats, + ) + + def __reduce__(self): + # This is to ensure that the LLMEngine is not referenced in + # the closure used to initialize Ray worker actors + raise RuntimeError("LLMEngine should not be pickled!") + + def __del__(self): + # Shutdown model executor when engine is garbage collected + # Use getattr since __init__ can fail before the field is set + if model_executor := getattr(self, "model_executor", None): + model_executor.shutdown() + + def get_tokenizer_group( + self, + group_type: Type[_G] = BaseTokenizerGroup, + ) -> _G: + tokenizer_group = self.tokenizer + + if tokenizer_group is None: + raise ValueError("Unable to get tokenizer because " + "skip_tokenizer_init is True") + if not isinstance(tokenizer_group, group_type): + raise TypeError("Invalid type of tokenizer group. " + f"Expected type: {group_type}, but " + f"found type: {type(tokenizer_group)}") + + return tokenizer_group + + def get_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + return self.get_tokenizer_group().get_lora_tokenizer(lora_request) + + def _init_tokenizer(self) -> BaseTokenizerGroup: + return init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=self.scheduler_config, + parallel_config=self.parallel_config, + lora_config=self.lora_config) + + def _verify_args(self) -> None: + self.model_config.verify_with_parallel_config(self.parallel_config) + self.cache_config.verify_with_parallel_config(self.parallel_config) + if self.lora_config: + self.lora_config.verify_with_model_config(self.model_config) + self.lora_config.verify_with_scheduler_config( + self.scheduler_config) + if self.prompt_adapter_config: + self.prompt_adapter_config.verify_with_model_config( + self.model_config) + + def _add_processed_request( + self, + request_id: str, + processed_inputs: ProcessorInputs, + params: Union[SamplingParams, PoolingParams], + arrival_time: float, + lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], + trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, + ) -> Optional[SequenceGroup]: + """Add a processed request to the engine's request pool. + return the created sequence group. + """ + if isinstance(params, SamplingParams) and params.n > 1: + ParallelSampleSequenceGroup.add_request( + request_id, + self, + params, + processed_inputs=processed_inputs, + arrival_time=arrival_time, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + ) + return None + + self._validate_model_inputs(processed_inputs, lora_request) + # Create the sequences. + block_size = self.cache_config.block_size + seq_id = next(self.seq_counter) + eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) + + encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) + + seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id, + lora_request, prompt_adapter_request) + + encoder_seq = (None if encoder_inputs is None else Sequence( + seq_id, encoder_inputs, block_size, eos_token_id, lora_request, + prompt_adapter_request)) + + # Create a SequenceGroup based on SamplingParams or PoolingParams + if isinstance(params, SamplingParams): + seq_group = self._create_sequence_group_with_sampling( + request_id, + seq, + params, + arrival_time=arrival_time, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + encoder_seq=encoder_seq, + priority=priority) + elif isinstance(params, PoolingParams): + seq_group = self._create_sequence_group_with_pooling( + request_id, + seq, + params, + arrival_time=arrival_time, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + encoder_seq=encoder_seq, + priority=priority) + else: + raise ValueError( + "Either SamplingParams or PoolingParams must be provided.") + + # Add the sequence group to the scheduler with least unfinished seqs. + costs = [ + scheduler.get_num_unfinished_seq_groups() + for scheduler in self.scheduler + ] + min_cost_scheduler = self.scheduler[costs.index(min(costs))] + min_cost_scheduler.add_seq_group(seq_group) + + return seq_group + + def stop_remote_worker_execution_loop(self) -> None: + self.model_executor.stop_remote_worker_execution_loop() + + @overload + def add_request( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> None: + ... + + @overload + @deprecated("'inputs' will be renamed to 'prompt") + def add_request( + self, + request_id: str, + *, + inputs: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> None: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + def add_request( + self, + request_id: str, + prompt: Optional[PromptType] = None, + params: Optional[Union[SamplingParams, PoolingParams]] = None, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + *, + inputs: Optional[PromptType] = None, # DEPRECATED + ) -> None: + """Add a request to the engine's request pool. + + The request is added to the request pool and will be processed by the + scheduler as `engine.step()` is called. The exact scheduling policy is + determined by the scheduler. + + Args: + request_id: The unique ID of the request. + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` + for more details about the format of each input. + params: Parameters for sampling or pooling. + :class:`~vllm.SamplingParams` for text generation. + :class:`~vllm.PoolingParams` for pooling. + arrival_time: The arrival time of the request. If None, we use + the current monotonic time. + lora_request: The LoRA request to add. + trace_headers: OpenTelemetry trace headers. + prompt_adapter_request: The prompt adapter request to add. + priority: The priority of the request. + Only applicable with priority scheduling. + + Details: + - Set arrival_time to the current time if it is None. + - Set prompt_token_ids to the encoded prompt if it is None. + - Create `n` number of :class:`~vllm.Sequence` objects. + - Create a :class:`~vllm.SequenceGroup` object + from the list of :class:`~vllm.Sequence`. + - Add the :class:`~vllm.SequenceGroup` object to the scheduler. + + Example: + >>> # initialize engine + >>> engine = LLMEngine.from_engine_args(engine_args) + >>> # set request arguments + >>> example_prompt = "Who is the president of the United States?" + >>> sampling_params = SamplingParams(temperature=0.0) + >>> request_id = 0 + >>> + >>> # add the request to the engine + >>> engine.add_request( + >>> str(request_id), + >>> example_prompt, + >>> SamplingParams(temperature=0.0)) + >>> # continue the request processing + >>> ... + """ + if inputs is not None: + prompt = inputs + assert prompt is not None and params is not None + + if lora_request is not None and not self.lora_config: + raise ValueError(f"Got lora_request {lora_request} but LoRA is " + "not enabled!") + + if priority != 0 and not self.scheduler_config.policy == "priority": + raise ValueError(f"Got priority {priority} but " + "Priority scheduling is not enabled.") + + if isinstance(params, SamplingParams) \ + and (params.guided_decoding or params.logits_processors) \ + and self.scheduler_config.num_scheduler_steps > 1: + raise ValueError( + "Guided decoding and logits processors are not supported " + "in multi-step decoding") + + if arrival_time is None: + arrival_time = time.time() + + if self.tokenizer is not None: + self._validate_token_prompt( + prompt, + tokenizer=self.get_tokenizer(lora_request=lora_request)) + + preprocessed_inputs = self.input_preprocessor.preprocess( + prompt, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + processed_inputs = self.input_processor(preprocessed_inputs) + + self._add_processed_request( + request_id=request_id, + processed_inputs=processed_inputs, + params=params, + arrival_time=arrival_time, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + trace_headers=trace_headers, + priority=priority, + ) + + def _validate_token_prompt(self, prompt: PromptType, + tokenizer: AnyTokenizer): + # Guard against out-of-vocab tokens. + # For some tokenizers, tokenizer.decode will happily return empty text + # for token ids that are out of vocab, and we don't detect token ids + # that are greater than the max token id before running the model. + # However, these token ids will later crash a cuda kernel at runtime + # with an index out of bounds error. This will crash the entire engine. + # This needs to happen before multimodal input pre-processing, which + # may add dummy tokens that aren't part of the tokenizer's + # vocabulary. + if is_token_prompt(prompt): + prompt_ids = prompt["prompt_token_ids"] + if len(prompt_ids) == 0: + # Empty prompt check is handled later + return + max_input_id = max(prompt_ids) + if max_input_id > tokenizer.max_token_id: + raise ValueError( + "Token id {} is out of vocabulary".format(max_input_id)) + + def _create_sequence_group_with_sampling( + self, + request_id: str, + seq: Sequence, + sampling_params: SamplingParams, + arrival_time: float, + lora_request: Optional[LoRARequest], + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + encoder_seq: Optional[Sequence] = None, + priority: int = 0, + ) -> SequenceGroup: + """Creates a SequenceGroup with SamplingParams.""" + max_logprobs = self.get_model_config().max_logprobs + if (sampling_params.logprobs + and sampling_params.logprobs > max_logprobs) or ( + sampling_params.prompt_logprobs + and sampling_params.prompt_logprobs > max_logprobs): + raise ValueError(f"Cannot request more than " + f"{max_logprobs} logprobs.") + + sampling_params = self._build_logits_processors( + sampling_params, lora_request) + + # Defensive copy of SamplingParams, which are used by the sampler, + # this doesn't deep-copy LogitsProcessor objects + sampling_params = sampling_params.clone() + + sampling_params.update_from_generation_config( + self.generation_config_fields, seq.eos_token_id) + + # Create the sequence group. + draft_size = 1 + if self.vllm_config.speculative_config is not None: + draft_size = \ + self.vllm_config.speculative_config.num_speculative_tokens + 1 + seq_group = SequenceGroup( + request_id=request_id, + seqs=[seq], + arrival_time=arrival_time, + sampling_params=sampling_params, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + encoder_seq=encoder_seq, + priority=priority, + draft_size=draft_size) + + return seq_group + + def _create_sequence_group_with_pooling( + self, + request_id: str, + seq: Sequence, + pooling_params: PoolingParams, + arrival_time: float, + lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], + encoder_seq: Optional[Sequence] = None, + priority: int = 0, + ) -> SequenceGroup: + """Creates a SequenceGroup with PoolingParams.""" + # Defensive copy of PoolingParams, which are used by the pooler + pooling_params = pooling_params.clone() + # Create the sequence group. + seq_group = SequenceGroup( + request_id=request_id, + seqs=[seq], + arrival_time=arrival_time, + lora_request=lora_request, + pooling_params=pooling_params, + prompt_adapter_request=prompt_adapter_request, + encoder_seq=encoder_seq, + priority=priority) + return seq_group + + def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: + """Aborts a request(s) with the given ID. + + Args: + request_id: The ID(s) of the request to abort. + + Details: + - Refer to the + :meth:`~vllm.core.scheduler.Scheduler.abort_seq_group` + from class :class:`~vllm.core.scheduler.Scheduler`. + + Example: + >>> # initialize engine and add a request with request_id + >>> request_id = str(0) + >>> # abort the request + >>> engine.abort_request(request_id) + """ + for scheduler in self.scheduler: + scheduler.abort_seq_group( + request_id, seq_id_to_seq_group=self.seq_id_to_seq_group) + + def get_model_config(self) -> ModelConfig: + """Gets the model configuration.""" + return self.model_config + + def get_parallel_config(self) -> ParallelConfig: + """Gets the parallel configuration.""" + return self.parallel_config + + def get_decoding_config(self) -> DecodingConfig: + """Gets the decoding configuration.""" + return self.decoding_config + + def get_scheduler_config(self) -> SchedulerConfig: + """Gets the scheduler configuration.""" + return self.scheduler_config + + def get_lora_config(self) -> LoRAConfig: + """Gets the LoRA configuration.""" + return self.lora_config + + def get_num_unfinished_requests(self) -> int: + """Gets the number of unfinished requests.""" + return sum(scheduler.get_num_unfinished_seq_groups() + for scheduler in self.scheduler) + + def has_unfinished_requests(self) -> bool: + """Returns True if there are unfinished requests.""" + return any(scheduler.has_unfinished_seqs() + for scheduler in self.scheduler) + + def has_unfinished_requests_for_virtual_engine( + self, virtual_engine: int) -> bool: + """ + Returns True if there are unfinished requests for the virtual engine. + """ + return self.scheduler[virtual_engine].has_unfinished_seqs() + + def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: + """Reset prefix cache for all devices.""" + + success = True + for scheduler in self.scheduler: + success = success and scheduler.reset_prefix_cache(device) + return success + + @staticmethod + def _process_sequence_group_outputs( + seq_group: SequenceGroup, + outputs: List[PoolingSequenceGroupOutput], + ) -> None: + seq_group.pooled_data = outputs[0].data + + for seq in seq_group.get_seqs(): + seq.status = SequenceStatus.FINISHED_STOPPED + + return + + def _update_num_computed_tokens_for_multi_step_prefill( + self, seq_group: SequenceGroup, + seq_group_meta: SequenceGroupMetadata, + is_first_step_output: Optional[bool]): + """ + This function updates num_computed_tokens for prompt sequences + when Multi-Step is enabled. + + seq_group: SequenceGroup to update the num_computed_tokens for. + seq_group_meta: Metadata of the given SequenceGroup. + is_first_step_output: Optional[bool] - + When available, is_first_step_output indicates if the appended + output token is the output of the first-step in multi-step. + A value of None indicates that outputs from all steps in + in multi-step are submitted in a single burst. + """ + + assert self.scheduler_config.is_multi_step + + if not seq_group_meta.is_prompt: + # num_computed_token updates for multi-step decodes happen after + # the tokens are appended to the sequence. + return + + do_update: bool = False + if self.scheduler_config.chunked_prefill_enabled: + # In multi-step + chunked-prefill case, the prompt sequences + # that are scheduled are fully processed in the first step. + do_update = is_first_step_output is None or is_first_step_output + else: + # Normal multi-step decoding case. In this case prompt-sequences + # are actually single-stepped. Always update in this case. + assert seq_group.state.num_steps == 1 + do_update = True + + if do_update: + seq_group.update_num_computed_tokens( + seq_group_meta.token_chunk_size) + + def _process_model_outputs(self, + ctx: SchedulerContext, + request_id: Optional[str] = None) -> None: + """Apply the model output to the sequences in the scheduled seq groups + and return responses. + + ctx: The virtual engine context to work on + request_id: If provided, then only this request is going to be processed + """ + + now = time.time() + + if len(ctx.output_queue) == 0: + return None + + # Get pending async postprocessor + if request_id: + # When we process only one request, no pop is required + # (since later we will process all of the rest) + (outputs, seq_group_metadata_list, scheduler_outputs, is_async, + is_last_step, is_first_step_output, skip) = ctx.output_queue[0] + else: + (outputs, seq_group_metadata_list, scheduler_outputs, is_async, + is_last_step, is_first_step_output, + skip) = ctx.output_queue.popleft() + + # Sanity check + assert len(seq_group_metadata_list) == len( + scheduler_outputs.scheduled_seq_groups) + + has_multiple_outputs: bool = len(outputs) > 1 + outputs_by_sequence_group: List[List[SequenceGroupOutput]] + if has_multiple_outputs: + assert self.scheduler_config.is_multi_step or \ + self.speculative_config + # Organize outputs by [step][sequence group] instead of + # [sequence group][step]. + if self.scheduler_config.is_multi_step: + outputs_by_sequence_group = create_output_by_sequence_group( + outputs, len(seq_group_metadata_list)) + elif self.speculative_config: + # Decodes are multi-steps while prefills are not, outputting at + # most 1 token. Separate them so that we can trigger chunk + # processing without having to pad or copy over prompts K times + # to match decodes structure (costly with prompt_logprobs). + num_prefills = sum(sg.is_prompt + for sg in seq_group_metadata_list) + prefills, decodes = outputs[:num_prefills], outputs[ + num_prefills:] + outputs_by_sequence_group = create_output_by_sequence_group( + decodes, + num_seq_groups=len(seq_group_metadata_list) - num_prefills) + outputs_by_sequence_group = [p.outputs for p in prefills + ] + outputs_by_sequence_group + # We have outputs for multiple steps submitted in a single burst, + # so invalidate is_first_step_output. + is_first_step_output = None + else: + outputs_by_sequence_group = outputs + + # Determine the requests we need to operate on + if request_id: + indices = [] + for i, seq_group_meta in enumerate(seq_group_metadata_list): + if seq_group_meta.request_id == request_id: + assert i not in skip # Cannot be called twice + indices.append(i) + break + + # If the request_id was not found, then it means that + # this is a new request that has no pending async + # postprocessor + if not indices: + return + else: + indices = range(len(seq_group_metadata_list)) # type: ignore + + finished_before: List[int] = [] + finished_now: List[int] = [] + for i in indices: + if i in skip: + continue + + seq_group_meta = seq_group_metadata_list[i] + scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] + + seq_group: SequenceGroup = scheduled_seq_group.seq_group + + if seq_group.is_finished(): + finished_before.append(i) + continue + + output: List[SequenceGroupOutput] + if has_multiple_outputs: + output = outputs_by_sequence_group[i] + else: + output = [outputs_by_sequence_group[0][i]] + + if not is_async: + if self.scheduler_config.is_multi_step: + # Updates happen only if the sequence is prefill + self._update_num_computed_tokens_for_multi_step_prefill( + seq_group, seq_group_meta, is_first_step_output) + else: + seq_group.update_num_computed_tokens( + seq_group_meta.token_chunk_size or 0) + + if outputs: + for o in outputs: + if (isinstance(o, SamplerOutput) + and seq_group.metrics is not None): + if seq_group.metrics.model_forward_time is not None: + seq_group.metrics.model_forward_time += ( + o.model_forward_time or 0) + else: + seq_group.metrics.model_forward_time = ( + o.model_forward_time) + if seq_group.metrics.model_execute_time is not None: + seq_group.metrics.model_execute_time += ( + o.model_execute_time or 0) + else: + seq_group.metrics.model_execute_time = ( + o.model_execute_time) + + if self.model_config.runner_type == "pooling": + self._process_sequence_group_outputs(seq_group, output) + else: + self.output_processor.process_prompt_logprob(seq_group, output) + if seq_group_meta.do_sample: + self.output_processor.process_outputs( + seq_group, output, is_async) + + if seq_group.is_finished(): + finished_now.append(i) + + # Generate outputs for the requests that finished this iteration + for i in finished_now: + scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] + + seq_group = scheduled_seq_group.seq_group + seq_group.maybe_set_first_token_time(now) + if not seq_group.is_prefill(): + seq_group.set_last_token_time(now) + request_output = RequestOutputFactory.create( + seq_group, + self.seq_id_to_seq_group, + use_cache=self.use_cached_outputs) + if request_output: + ctx.request_outputs.append(request_output) + + # When we process a single request, we skip it for the next time, + # and invoke the request output callback (if there was final output) + if request_id: + assert len(indices) == 1 + skip.append(indices[0]) + + if (finished_now + and self.process_request_outputs_callback is not None): + self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() + return + + # Free currently finished requests + if finished_now: + for scheduler in self.scheduler: + scheduler.free_finished_seq_groups() + + # For multi-step without streaming, don't create outputs each iteration + if not is_last_step and not ctx.multi_step_stream_outputs: + # Immediately process request outputs here (if callback is given) + if (finished_now + and self.process_request_outputs_callback is not None): + self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() + return + + # Create the outputs + for i in indices: + if i in skip or i in finished_before or i in finished_now: + continue # Avoids double processing + + scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] + + seq_group = scheduled_seq_group.seq_group + seq_group.maybe_set_first_token_time(now) + if not seq_group.is_prefill(): + seq_group.set_last_token_time(now) + request_output = RequestOutputFactory.create( + seq_group, + self.seq_id_to_seq_group, + use_cache=self.use_cached_outputs) + if request_output: + ctx.request_outputs.append(request_output) + + # For multi-step with streaming, create outputs each iteration + if not is_last_step and ctx.multi_step_stream_outputs: + # Immediately process request outputs here (if callback is given) + if self.process_request_outputs_callback is not None: + self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() + return + + for seq_group in scheduler_outputs.ignored_seq_groups: + params = seq_group.sampling_params + if params is not None and params.output_kind == ( + RequestOutputKind.DELTA) and not seq_group.is_finished(): + continue + + request_output = RequestOutputFactory.create( + seq_group, + self.seq_id_to_seq_group, + use_cache=self.use_cached_outputs, + ) + if request_output: + ctx.request_outputs.append(request_output) + + # Immediately process request outputs here (if callback is given) + if (ctx.request_outputs + and self.process_request_outputs_callback is not None): + self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() + + # For async case, we need to record the stats here. + # For non-async case, the stats are done in the + # LLMEngine/AsyncLLMEngine directly + if is_async: + # Log stats. + self.do_log_stats(scheduler_outputs, outputs, finished_before, + skip) + + # Tracing + self.do_tracing(scheduler_outputs, finished_before) + + return None + + def _advance_to_next_step( + self, output: SamplerOutput, + seq_group_metadata_list: List[SequenceGroupMetadata], + scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None: + """Given model output from a single run, append the tokens to the + sequences. This is normally done inside output processor, but it is + required if the worker is to perform async forward pass to next step. + """ + for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \ + zip(seq_group_metadata_list, output, scheduled_seq_groups): + seq_group = scheduled_seq_group.seq_group + + if seq_group.is_finished(): + continue + + if self.scheduler_config.is_multi_step: + # Updates happen only if the sequence is prefill + self._update_num_computed_tokens_for_multi_step_prefill( + seq_group, seq_group_metadata, + seq_group.state.num_steps == 1) + else: + token_chunk_size = (seq_group_metadata.token_chunk_size + if seq_group_metadata.token_chunk_size + is not None else 0) + seq_group.update_num_computed_tokens(token_chunk_size) + + if seq_group_metadata.do_sample: + assert len(sequence_group_outputs.samples) == 1, ( + "Async output processor expects a single sample" + " (i.e sampling_params.n == 1)") + sample = sequence_group_outputs.samples[0] + + assert len(seq_group.seqs) == 1 + seq = seq_group.seqs[0] + + if self.scheduler_config.is_multi_step: + is_prefill_append = seq.data.get_num_uncomputed_tokens( + ) == 0 + seq.append_token_id(sample.output_token, sample.logprobs) + if not is_prefill_append: + seq_group.update_num_computed_tokens(1) + else: + seq.append_token_id(sample.output_token, sample.logprobs) + + def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: + """Performs one decoding iteration and returns newly generated results. + + .. figure:: https://i.imgur.com/sv2HssD.png + :alt: Overview of the step function + :align: center + + Overview of the step function. + + Details: + - Step 1: Schedules the sequences to be executed in the next + iteration and the token blocks to be swapped in/out/copy. + + - Depending on the scheduling policy, + sequences may be `preempted/reordered`. + - A Sequence Group (SG) refer to a group of sequences + that are generated from the same prompt. + + - Step 2: Calls the distributed executor to execute the model. + - Step 3: Processes the model output. This mainly includes: + + - Decodes the relevant outputs. + - Updates the scheduled sequence groups with model outputs + based on its `sampling parameters` (`use_beam_search` or not). + - Frees the finished sequence groups. + + - Finally, it creates and returns the newly generated results. + + Example: + >>> # Please see the example/ folder for more detailed examples. + >>> + >>> # initialize engine and request arguments + >>> engine = LLMEngine.from_engine_args(engine_args) + >>> example_inputs = [(0, "What is LLM?", + >>> SamplingParams(temperature=0.0))] + >>> + >>> # Start the engine with an event loop + >>> while True: + >>> if example_inputs: + >>> req_id, prompt, sampling_params = example_inputs.pop(0) + >>> engine.add_request(str(req_id),prompt,sampling_params) + >>> + >>> # continue the request processing + >>> request_outputs = engine.step() + >>> for request_output in request_outputs: + >>> if request_output.finished: + >>> # return or show the request output + >>> + >>> if not (engine.has_unfinished_requests() or example_inputs): + >>> break + """ + if self.parallel_config.pipeline_parallel_size > 1: + raise NotImplementedError( + "Pipeline parallelism is only supported through AsyncLLMEngine " + "as performance will be severely degraded otherwise.") + + # For llm_engine, there is no pipeline parallel support, so the engine + # used is always 0. + virtual_engine = 0 + + # These are cached outputs from previous iterations. None if on first + # iteration + cached_outputs = self.cached_scheduler_outputs[virtual_engine] + seq_group_metadata_list = cached_outputs.seq_group_metadata_list + scheduler_outputs = cached_outputs.scheduler_outputs + allow_async_output_proc = cached_outputs.allow_async_output_proc + + ctx = self.scheduler_contexts[virtual_engine] + + # Clear outputs for each new scheduler iteration + ctx.request_outputs.clear() + + # Skip the scheduler if there are any remaining steps in the seq groups. + # This ensures that the scheduler is only called again when the current + # batch has completed. + # The scheduler is also skipped if a single request caused the last + # engine step to fail, and the previous schedule needs to be rerun. + if not self._has_remaining_steps( + seq_group_metadata_list + ) and not self._skip_scheduling_next_step: + # Schedule iteration + (seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc + ) = self.scheduler[virtual_engine].schedule() + + ctx.seq_group_metadata_list = seq_group_metadata_list + ctx.scheduler_outputs = scheduler_outputs + + finished_requests_ids = self.scheduler[ + virtual_engine].get_and_reset_finished_requests_ids() + # When n>1, elements in self.seq_id_to_seq_group should be deleted + # here, otherwise memory leaks. + for finished_request_id in finished_requests_ids: + if finished_request_id in self.seq_id_to_seq_group: + del self.seq_id_to_seq_group[finished_request_id] + + # Maybe switch from async mode to sync mode + if not allow_async_output_proc and len(ctx.output_queue) > 0: + self._process_model_outputs(ctx=ctx) + + if (self.scheduler_config.is_multi_step + and scheduler_outputs.num_lookahead_slots > 0): + # cache the scheduler outputs for the next iteration if we have + # lookahead slots + self._cache_scheduler_outputs_for_multi_step( + virtual_engine, seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc) + else: + finished_requests_ids = list() + + assert seq_group_metadata_list is not None + assert scheduler_outputs is not None + + if not scheduler_outputs.is_empty(): + + # Check if we have a cached last_output from the previous iteration. + # For supporting PP this is probably the best way to pass the + # sampled_token_ids, as a separate broadcast over all the PP stages + # will cause one virtual engine's microbatch to block the pipeline. + last_sampled_token_ids = \ + self._get_last_sampled_token_ids(virtual_engine) + + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, + blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, + blocks_to_copy=scheduler_outputs.blocks_to_copy, + num_lookahead_slots=scheduler_outputs.num_lookahead_slots, + running_queue_size=scheduler_outputs.running_queue_size, + finished_requests_ids=finished_requests_ids, + # We use ExecuteModelRequest to pass the last sampled_token_ids + # to each of the non-last PP stages for in-place prepare_input. + last_sampled_token_ids=last_sampled_token_ids) + + if allow_async_output_proc: + execute_model_req.async_callback = self.async_callbacks[ + virtual_engine] + + try: + outputs = self.model_executor.execute_model( + execute_model_req=execute_model_req) + self._skip_scheduling_next_step = False + except InputProcessingError as e: + # The input for this request cannot be processed, so we must + # abort it. If there are remaining requests in the batch that + # have been scheduled, they will be retried on the next step. + invalid_request_id = e.request_id + self._abort_and_cache_schedule( + request_id=invalid_request_id, + virtual_engine=virtual_engine, + seq_group_metadata_list=seq_group_metadata_list, + scheduler_outputs=scheduler_outputs, + allow_async_output_proc=allow_async_output_proc) + # Raise so the caller is notified that this request failed + raise + + # We need to do this here so that last step's sampled_token_ids can + # be passed to the next iteration for PP. + if self.scheduler_config.is_multi_step: + self._update_cached_scheduler_output(virtual_engine, outputs) + else: + # Nothing scheduled => If there is pending async postprocessor, + # then finish it here. + if len(ctx.output_queue) > 0: + self._process_model_outputs(ctx=ctx) + # No outputs in this case + outputs = [] + + # Finish the current step for all the sequence groups. + if self.scheduler_config.is_multi_step: + for seq_group in seq_group_metadata_list: + seq_group.finish_step() + + if not self._has_remaining_steps(seq_group_metadata_list): + # clear the cache if we have finished all the steps. + if self.scheduler_config.is_multi_step: + self.cached_scheduler_outputs[0] = SchedulerOutputState() + + # is_first_step_output is True only when the num_steps of all + # the sequences are 1. When the num_steps > 1, + # multi_step_model_runner does the first-step output append. + is_first_step_output: bool = False if not seq_group_metadata_list \ + else seq_group_metadata_list[0].state.num_steps == 1 + + # Add results to the output_queue + ctx.append_output(outputs=outputs, + seq_group_metadata_list=seq_group_metadata_list, + scheduler_outputs=scheduler_outputs, + is_async=allow_async_output_proc, + is_last_step=True, + is_first_step_output=is_first_step_output) + + if outputs and allow_async_output_proc: + assert len(outputs) == 1, ( + "Async postprocessor expects only a single output set") + + self._advance_to_next_step( + outputs[0], seq_group_metadata_list, + scheduler_outputs.scheduled_seq_groups) + + # Check if need to run the usual non-async path + if not allow_async_output_proc: + self._process_model_outputs(ctx=ctx) + + # Log stats. + self.do_log_stats(scheduler_outputs, outputs) + + # Tracing + self.do_tracing(scheduler_outputs) + else: + # Multi-step case + return ctx.request_outputs + + if not self.has_unfinished_requests(): + # Drain async postprocessor (if exists) + if len(ctx.output_queue) > 0: + self._process_model_outputs(ctx=ctx) + assert len(ctx.output_queue) == 0 + + # Stop the execute model loop in parallel workers until there are + # more requests to process. This avoids waiting indefinitely in + # torch.distributed ops which may otherwise timeout, and unblocks + # the RPC thread in the workers so that they can process any other + # queued control plane messages, such as add/remove lora adapters. + logger.debug("Stopping remote worker execution loop.") + self.model_executor.stop_remote_worker_execution_loop() + + return ctx.request_outputs + + def _abort_and_cache_schedule( + self, request_id: str, virtual_engine: int, + seq_group_metadata_list: List[SequenceGroupMetadata], + scheduler_outputs: SchedulerOutputs, + allow_async_output_proc: bool) -> None: + """Aborts a single request, and caches the scheduler outputs minus that + request. This allows the next step to continue processing the remaining + requests without having to re-run the scheduler.""" + + # Abort the request and remove its sequence group from the current + # schedule + self.abort_request(request_id) + for i, metadata in enumerate(seq_group_metadata_list): + if metadata.request_id == request_id: + del seq_group_metadata_list[i] + break + for i, group in enumerate(scheduler_outputs.scheduled_seq_groups): + if group.seq_group.request_id == request_id: + del scheduler_outputs.scheduled_seq_groups[i] + break + + # If there are still other sequence groups left in the schedule, cache + # them and flag the engine to reuse the schedule. + if len(seq_group_metadata_list) > 0: + self._skip_scheduling_next_step = True + # Reuse multi-step caching logic + self._cache_scheduler_outputs_for_multi_step( + virtual_engine=virtual_engine, + scheduler_outputs=scheduler_outputs, + seq_group_metadata_list=seq_group_metadata_list, + allow_async_output_proc=allow_async_output_proc) + + def _has_remaining_steps( + self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] + ) -> bool: + if (not self.scheduler_config.is_multi_step + or not seq_group_metadata_list): + return False + + # TODO(will) this is a sanity check for nowto make sure that all the + # seqs are on the same steps. Eventually we will want to do some sort of + # dynamic scheduling when doing multi-step decoding. + ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps + if any([ + seq_group.state.remaining_steps != ref_remaining_steps + for seq_group in seq_group_metadata_list[1:] + ]): + raise AssertionError("All running sequence groups should " + "have the same remaining steps.") + + return ref_remaining_steps > 0 + + def _cache_scheduler_outputs_for_multi_step( + self, virtual_engine: int, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + scheduler_outputs: SchedulerOutputs, + allow_async_output_proc: bool) -> None: + co = self.cached_scheduler_outputs[virtual_engine] + + co.seq_group_metadata_list = seq_group_metadata_list + co.scheduler_outputs = scheduler_outputs + co.allow_async_output_proc = allow_async_output_proc + co.last_output = None + + def _update_cached_scheduler_output( + self, virtual_engine: int, + output: List[Optional[SamplerOutput]]) -> None: + if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0 + and output[0] is not None): + last_output = output[-1] + assert last_output is not None + assert last_output.sampled_token_ids_cpu is not None + assert last_output.sampled_token_ids is None + assert last_output.sampled_token_probs is None + self.cached_scheduler_outputs[ + virtual_engine].last_output = last_output + + def _get_last_sampled_token_ids( + self, virtual_engine: int) -> Optional[torch.Tensor]: + cached_last_output = self.cached_scheduler_outputs[ + virtual_engine].last_output + if (self.scheduler_config.is_multi_step + and self.parallel_config.pipeline_parallel_size > 1 + and cached_last_output is not None + and cached_last_output.sampled_token_ids_cpu is not None): + return cached_last_output.sampled_token_ids_cpu + return None + + def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None: + if not self.log_stats: + raise RuntimeError( + "Stat logging is disabled. Set `disable_log_stats=False` " + "argument to enable.") + if logger_name in self.stat_loggers: + raise KeyError(f"Logger with name {logger_name} already exists.") + self.stat_loggers[logger_name] = logger + + def remove_logger(self, logger_name: str) -> None: + if not self.log_stats: + raise RuntimeError( + "Stat logging is disabled. Set `disable_log_stats=False` " + "argument to enable.") + if logger_name not in self.stat_loggers: + raise KeyError(f"Logger with name {logger_name} does not exist.") + del self.stat_loggers[logger_name] + + def do_log_stats(self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[List[SamplerOutput]] = None, + finished_before: Optional[List[int]] = None, + skip: Optional[List[int]] = None) -> None: + """Forced log when no requests active.""" + if self.log_stats: + stats = self._get_stats(scheduler_outputs, model_output, + finished_before, skip) + for logger in self.stat_loggers.values(): + logger.log(stats) + + def _get_stats(self, + scheduler_outputs: Optional[SchedulerOutputs], + model_output: Optional[List[SamplerOutput]] = None, + finished_before: Optional[List[int]] = None, + skip: Optional[List[int]] = None) -> Stats: + """Get Stats to be Logged to Prometheus. + + Args: + scheduler_outputs: Optional, used to populate metrics related to + the scheduled batch, + model_output: Optional, used to emit speculative decoding metrics + which are created by the workers. + finished_before: Optional, indices of sequences that were finished + before. These sequences will be ignored. + skip: Optional, indices of sequences that were preempted. These + sequences will be ignored. + """ + now = time.time() + + # System State + # Scheduler State + num_running_sys = sum( + len(scheduler.running) for scheduler in self.scheduler) + num_swapped_sys = sum( + len(scheduler.swapped) for scheduler in self.scheduler) + num_waiting_sys = sum( + len(scheduler.waiting) for scheduler in self.scheduler) + + # KV Cache Usage in % + num_total_gpu = self.cache_config.num_gpu_blocks + gpu_cache_usage_sys = 0. + if num_total_gpu: # Guard against both None and 0 + num_free_gpu = sum( + scheduler.block_manager.get_num_free_gpu_blocks() + for scheduler in self.scheduler) + gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu) + + num_total_cpu = self.cache_config.num_cpu_blocks + cpu_cache_usage_sys = 0. + if num_total_cpu: # Guard against both None and 0 + num_free_cpu = sum( + scheduler.block_manager.get_num_free_cpu_blocks() + for scheduler in self.scheduler) + cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu) + + # Prefix Cache Hit Rate. Note that we always use + # the cache hit rate of the first virtual engine. + cpu_prefix_cache_hit_rate = self.scheduler[ + 0].get_prefix_cache_hit_rate(Device.CPU) + gpu_prefix_cache_hit_rate = self.scheduler[ + 0].get_prefix_cache_hit_rate(Device.GPU) + + # Iteration stats + num_prompt_tokens_iter = 0 + num_generation_tokens_iter = 0 + num_tokens_iter = 0 + time_to_first_tokens_iter: List[float] = [] + time_per_output_tokens_iter: List[float] = [] + num_preemption_iter = (0 if scheduler_outputs is None else + scheduler_outputs.preempted) + + # Request stats + # Latency + time_e2e_requests: List[float] = [] + time_queue_requests: List[float] = [] + time_inference_requests: List[float] = [] + time_prefill_requests: List[float] = [] + time_decode_requests: List[float] = [] + time_in_queue_requests: List[float] = [] + model_forward_time_requests: List[float] = [] + model_execute_time_requests: List[float] = [] + # Metadata + num_prompt_tokens_requests: List[int] = [] + num_generation_tokens_requests: List[int] = [] + n_requests: List[int] = [] + max_num_generation_tokens_requests: List[int] = [] + max_tokens_requests: List[int] = [] + finished_reason_requests: List[str] = [] + + # LoRA requests + running_lora_adapters = dict( + collectionsCounter([ + running_request.lora_request.lora_name + for scheduler in self.scheduler + for running_request in scheduler.running + if running_request.lora_request + ])) + waiting_lora_adapters = dict( + collectionsCounter([ + waiting_request.lora_request.lora_name + for scheduler in self.scheduler + for waiting_request in scheduler.waiting + if waiting_request.lora_request + ])) + max_lora_stat = "0" + if self.lora_config: + max_lora_stat = str(self.lora_config.max_loras) + + # NOTE: This loop assumes prefill seq_groups are before + # decode seq_groups in scheduled_seq_groups. + if scheduler_outputs is not None: + # For async postprocessor, already finished sequences need to be + # not counted (to avoid double counting) + actual_num_batched_tokens = scheduler_outputs.num_batched_tokens # type: ignore + + num_generation_tokens_from_prefill_groups = 0 + # NOTE: if scheduler_outputs.num_prefill_groups > 0 and + # the len of scheduler_outputs.scheduled_seq_groups is != + # scheduler_outputs.num_prefill_groups, this means that + # chunked prefills have been detected. + + for idx, scheduled_seq_group in enumerate( + scheduler_outputs.scheduled_seq_groups): + # Skip double logging when using async output proc + if finished_before and idx in finished_before: + actual_num_batched_tokens -= 1 + continue + + # Currently, skip == preempted sequences, so we need to skip + # their log stats + if skip and idx in skip: + continue + + group_was_prefill = idx < scheduler_outputs.num_prefill_groups + seq_group = scheduled_seq_group.seq_group + + # NOTE: a seq_group that completed all of its prefill tokens + # in the last iteration will have seq_group.is_prefill() = False + # with group_was_prefill = True + if group_was_prefill: + # Number of prompt tokens. + num_prompt_tokens_iter += ( + scheduled_seq_group.token_chunk_size) + + # If the seq_group just finished the prefill state + # get TTFT. + if not seq_group.is_prefill(): + latency = seq_group.get_last_token_latency() + time_to_first_tokens_iter.append(latency) + + # One generation token per finished prefill. + num_generation_tokens_from_prefill_groups += ( + seq_group.num_seqs()) + else: + # TPOTs. + latency = seq_group.get_last_token_latency() + time_per_output_tokens_iter.append(latency) + if seq_group.state.current_step == 0: + # For async_output_proc, the do_log_stats() + # is called following init_multi_step(), which + # sets the current_step to zero. + actual_num_batched_tokens +=\ + seq_group.state.num_steps - 1 + else: + actual_num_batched_tokens +=\ + seq_group.state.current_step - 1 + + # Because of chunked prefill, we can have a single sequence + # group that does multiple prompt_runs. To prevent logging + # the same metadata more than once per request, we standardize + # on logging request level information for finished requests, + # which can only happen once. + if seq_group.is_finished(): + # Latency timings + time_e2e_requests.append(now - + seq_group.metrics.arrival_time) + if (seq_group.metrics.first_scheduled_time is not None and + seq_group.metrics.first_token_time is not None): + time_queue_requests.append( + seq_group.metrics.first_scheduled_time - + seq_group.metrics.arrival_time) + time_prefill_requests.append( + seq_group.metrics.first_token_time - + seq_group.metrics.first_scheduled_time) + time_decode_requests.append( + now - seq_group.metrics.first_token_time) + time_inference_requests.append( + now - seq_group.metrics.first_scheduled_time) + if seq_group.metrics.time_in_queue is not None: + time_in_queue_requests.append( + seq_group.metrics.time_in_queue) + if seq_group.metrics.model_forward_time is not None: + model_forward_time_requests.append( + seq_group.metrics.model_forward_time) + if seq_group.metrics.model_execute_time is not None: + model_execute_time_requests.append( + seq_group.metrics.model_execute_time * 1000) + # Metadata + num_prompt_tokens_requests.append( + len(seq_group.prompt_token_ids)) + num_generation_tokens_requests.extend([ + seq.get_output_len() + for seq in seq_group.get_finished_seqs() + ]) + max_num_generation_tokens_requests.append( + max(seq.get_output_len() + for seq in seq_group.get_seqs())) + if seq_group.sampling_params is not None: + n_requests.append(seq_group.sampling_params.n) + max_tokens_requests.append( + seq_group.sampling_params.max_tokens) + finished_reason_requests.extend([ + SequenceStatus.get_finished_reason(seq.status) + for seq in seq_group.get_finished_seqs() + ]) + + # Number of generation tokens. + # num_batched_tokens equals the number of prompt_tokens plus the + # number of decode_tokens in a single iteration. So, + # num_generation_tokens = num_batched_tokens - num_prompt_tokens + # + num_generation_tokens_from_prefill_groups (since we generate + # one token on prefills on iters where the prefill finishes). + num_generation_tokens_iter = ( + actual_num_batched_tokens - num_prompt_tokens_iter + + num_generation_tokens_from_prefill_groups) + num_tokens_iter = (num_generation_tokens_iter + + num_prompt_tokens_iter) + # Spec decode, if enabled, emits specialized metrics from the worker in + # sampler output. + if model_output and isinstance(model_output[0], SamplerOutput) and ( + model_output[0].spec_decode_worker_metrics is not None): + spec_decode_metrics = model_output[0].spec_decode_worker_metrics + else: + spec_decode_metrics = None + + return Stats( + now=now, + # System stats + # Scheduler State + num_running_sys=num_running_sys, + num_swapped_sys=num_swapped_sys, + num_waiting_sys=num_waiting_sys, + # KV Cache Usage in % + gpu_cache_usage_sys=gpu_cache_usage_sys, + cpu_cache_usage_sys=cpu_cache_usage_sys, + # Prefix Cache Hit Rate + cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate, + gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate, + + # Iteration stats + num_prompt_tokens_iter=num_prompt_tokens_iter, + num_generation_tokens_iter=num_generation_tokens_iter, + num_tokens_iter=num_tokens_iter, + time_to_first_tokens_iter=time_to_first_tokens_iter, + time_per_output_tokens_iter=time_per_output_tokens_iter, + spec_decode_metrics=spec_decode_metrics, + num_preemption_iter=num_preemption_iter, + + # Request stats + # Latency + time_e2e_requests=time_e2e_requests, + time_queue_requests=time_queue_requests, + time_inference_requests=time_inference_requests, + time_prefill_requests=time_prefill_requests, + time_decode_requests=time_decode_requests, + time_in_queue_requests=time_in_queue_requests, + model_forward_time_requests=model_forward_time_requests, + model_execute_time_requests=model_execute_time_requests, + # Metadata + num_prompt_tokens_requests=num_prompt_tokens_requests, + num_generation_tokens_requests=num_generation_tokens_requests, + max_num_generation_tokens_requests= + max_num_generation_tokens_requests, + n_requests=n_requests, + max_tokens_requests=max_tokens_requests, + finished_reason_requests=finished_reason_requests, + max_lora=str(max_lora_stat), + waiting_lora_adapters=list(waiting_lora_adapters.keys()), + running_lora_adapters=list(running_lora_adapters.keys())) + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_executor.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.model_executor.remove_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.model_executor.list_loras() + + def pin_lora(self, lora_id: int) -> bool: + return self.model_executor.pin_lora(lora_id) + + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + return self.model_executor.add_prompt_adapter(prompt_adapter_request) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return self.model_executor.remove_prompt_adapter(prompt_adapter_id) + + def list_prompt_adapters(self) -> List[int]: + return self.model_executor.list_prompt_adapters() + + def start_profile(self) -> None: + self.model_executor.start_profile() + + def stop_profile(self) -> None: + self.model_executor.stop_profile() + + def sleep(self, level: int = 1) -> None: + assert self.vllm_config.model_config.enable_sleep_mode, ( + "Sleep mode is not enabled in the model config") + self.model_executor.sleep(level=level) + + def wake_up(self, tags: Optional[list[str]] = None) -> None: + assert self.vllm_config.model_config.enable_sleep_mode, ( + "Sleep mode is not enabled in the model config") + self.model_executor.wake_up(tags) + + def is_sleeping(self) -> bool: + return self.model_executor.is_sleeping + + def check_health(self) -> None: + if self.tokenizer: + self.tokenizer.check_health() + self.model_executor.check_health() + + def is_tracing_enabled(self) -> bool: + return self.tracer is not None + + def do_tracing(self, + scheduler_outputs: SchedulerOutputs, + finished_before: Optional[List[int]] = None) -> None: + if self.tracer is None: + return + + for idx, scheduled_seq_group in enumerate( + scheduler_outputs.scheduled_seq_groups): + # Skip double tracing when using async output proc + if finished_before and idx in finished_before: + continue + + seq_group = scheduled_seq_group.seq_group + if seq_group.is_finished(): + self.create_trace_span(seq_group) + + def create_trace_span(self, seq_group: SequenceGroup) -> None: + if self.tracer is None or seq_group.sampling_params is None: + return + arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9) + + trace_context = extract_trace_context(seq_group.trace_headers) + + with self.tracer.start_as_current_span( + "llm_request", + kind=SpanKind.SERVER, + context=trace_context, + start_time=arrival_time_nano_seconds) as seq_span: + metrics = seq_group.metrics + ttft = metrics.first_token_time - metrics.arrival_time + e2e_time = metrics.finished_time - metrics.arrival_time + seq_span.set_attribute(SpanAttributes.GEN_AI_RESPONSE_MODEL, + self.model_config.model) + seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID, + seq_group.request_id) + seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE, + seq_group.sampling_params.temperature) + seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, + seq_group.sampling_params.top_p) + seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS, + seq_group.sampling_params.max_tokens) + seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N, + seq_group.sampling_params.n) + seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_NUM_SEQUENCES, + seq_group.num_seqs()) + seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, + len(seq_group.prompt_token_ids)) + seq_span.set_attribute( + SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS, + sum([ + seq.get_output_len() + for seq in seq_group.get_finished_seqs() + ])) + seq_span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, + metrics.time_in_queue) + seq_span.set_attribute( + SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN, ttft) + seq_span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time) + if metrics.scheduler_time is not None: + seq_span.set_attribute( + SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER, + metrics.scheduler_time) + if metrics.model_forward_time is not None: + seq_span.set_attribute( + SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD, + metrics.model_forward_time / 1000.0) + if metrics.model_execute_time is not None: + seq_span.set_attribute( + SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE, + metrics.model_execute_time) + + def _validate_model_inputs(self, inputs: ProcessorInputs, + lora_request: Optional[LoRARequest]): + encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs) + + # For encoder-decoder multimodal models, the max_prompt_len + # restricts the decoder prompt length + if self.model_config.is_multimodal_model: + prompt_inputs = decoder_inputs + else: + prompt_inputs = encoder_inputs or decoder_inputs + + prompt_ids = prompt_inputs["prompt_token_ids"] + + if prompt_ids is None or len(prompt_ids) == 0: + raise ValueError("Prompt cannot be empty") + + if self.model_config.is_multimodal_model: + max_prompt_len = self.model_config.max_model_len + + if len(prompt_ids) > max_prompt_len: + raise ValueError( + f"The prompt (total length {len(prompt_ids)}) is too long " + f"to fit into the model (context length {max_prompt_len}). " + "Make sure that `max_model_len` is no smaller than the " + "number of text tokens plus multimodal tokens. For image " + "inputs, the number of image tokens depends on the number " + "of images, and possibly their aspect ratios as well.") + + # TODO: Find out how many placeholder tokens are there so we can + # check that chunked prefill does not truncate them + # max_batch_len = self.scheduler_config.max_num_batched_tokens + + def _build_logits_processors( + self, sampling_params: SamplingParams, + lora_request: Optional[LoRARequest]) -> SamplingParams: + """Constructs logits processors based on the guided_decoding, + logits_bias, and allowed_token_ids fields in sampling_params. Deletes + those fields and adds the constructed logits processors to the + logits_processors field. Returns the modified sampling params.""" + + logits_processors = [] + + if sampling_params.guided_decoding is not None: + # Defensively copy sampling params since guided decoding logits + # processors can have different state for each request + sampling_params = copy.copy(sampling_params) + guided_decoding = sampling_params.guided_decoding + + logger.debug( + "Building guided decoding logits processor in " + "LLMEngine. Params: %s", guided_decoding) + + tokenizer = self.get_tokenizer(lora_request=lora_request) + guided_decoding.backend = guided_decoding.backend or \ + self.decoding_config.guided_decoding_backend + + if self.decoding_config.reasoning_backend is not None: + logger.debug("Building with reasoning backend %s", + self.decoding_config.reasoning_backend) + + processor = get_local_guided_decoding_logits_processor( + guided_params=guided_decoding, + tokenizer=tokenizer, + model_config=self.model_config, + reasoning_backend=self.decoding_config.reasoning_backend, + ) + if processor: + logits_processors.append(processor) + + # Unset so this doesn't get passed down to the model + sampling_params.guided_decoding = None + + if (sampling_params.logit_bias or sampling_params.allowed_token_ids): + tokenizer = self.get_tokenizer(lora_request=lora_request) + + processors = get_openai_logits_processors( + logit_bias=sampling_params.logit_bias, + allowed_token_ids=sampling_params.allowed_token_ids, + tokenizer=tokenizer) + logits_processors.extend(processors) + + # Unset so these don't get passed down to the model + sampling_params.logit_bias = None + sampling_params.allowed_token_ids = None + + if len(sampling_params.bad_words) > 0: + tokenizer = self.get_tokenizer(lora_request) + processors = get_bad_words_logits_processors( + bad_words=sampling_params.bad_words, tokenizer=tokenizer) + logits_processors.extend(processors) + + if logits_processors: + if sampling_params.logits_processors is None: + sampling_params.logits_processors = logits_processors + else: + sampling_params.logits_processors.extend(logits_processors) + + return sampling_params + + def collective_rpc(self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + return self.model_executor.collective_rpc(method, timeout, args, + kwargs) + + +if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1: + from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine + LLMEngine = V1LLMEngine # type: ignore diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py new file mode 100644 index 0000000..9d0f604 --- /dev/null +++ b/vllm/engine/metrics.py @@ -0,0 +1,719 @@ +# SPDX-License-Identifier: Apache-2.0 + +import time +from typing import TYPE_CHECKING +from typing import Counter as CollectionsCounter +from typing import Dict, List, Optional, Type, Union, cast + +import numpy as np +import prometheus_client + +from vllm.config import SupportsMetricsInfo, VllmConfig +from vllm.engine.metrics_types import StatLoggerBase, Stats +from vllm.executor.ray_utils import ray +from vllm.logger import init_logger + +if ray is not None: + from ray.util import metrics as ray_metrics +else: + ray_metrics = None + +if TYPE_CHECKING: + from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics + +logger = init_logger(__name__) + +prometheus_client.disable_created_metrics() + +# The begin-* and end* here are used by the documentation generator +# to extract the metrics definitions. + + +# begin-metrics-definitions +class Metrics: + """ + vLLM uses a multiprocessing-based frontend for the OpenAI server. + This means that we need to run prometheus_client in multiprocessing mode + See https://prometheus.github.io/client_python/multiprocess/ for more + details on limitations. + """ + + labelname_finish_reason = "finished_reason" + labelname_waiting_lora_adapters = "waiting_lora_adapters" + labelname_running_lora_adapters = "running_lora_adapters" + labelname_max_lora = "max_lora" + _gauge_cls = prometheus_client.Gauge + _counter_cls = prometheus_client.Counter + _histogram_cls = prometheus_client.Histogram + + def __init__(self, labelnames: List[str], vllm_config: VllmConfig): + # Unregister any existing vLLM collectors (for CI/CD) + self._unregister_vllm_metrics() + + max_model_len = vllm_config.model_config.max_model_len + + # Use this flag to hide metrics that were deprecated in + # a previous release and which will be removed future + self.show_hidden_metrics = \ + vllm_config.observability_config.show_hidden_metrics + + # System stats + # Scheduler State + self.gauge_scheduler_running = self._gauge_cls( + name="vllm:num_requests_running", + documentation="Number of requests currently running on GPU.", + labelnames=labelnames, + multiprocess_mode="sum") + self.gauge_scheduler_waiting = self._gauge_cls( + name="vllm:num_requests_waiting", + documentation="Number of requests waiting to be processed.", + labelnames=labelnames, + multiprocess_mode="sum") + self.gauge_lora_info = self._gauge_cls( + name="vllm:lora_requests_info", + documentation="Running stats on lora requests.", + labelnames=[ + self.labelname_running_lora_adapters, + self.labelname_max_lora, + self.labelname_waiting_lora_adapters, + ], + multiprocess_mode="livemostrecent", + ) + + # Deprecated in 0.8 - KV cache offloading is not used in V1 + # Hidden in 0.9, due to be removed in 0.10 + if self.show_hidden_metrics: + self.gauge_scheduler_swapped = self._gauge_cls( + name="vllm:num_requests_swapped", + documentation=( + "Number of requests swapped to CPU. " + "DEPRECATED: KV cache offloading is not used in V1"), + labelnames=labelnames, + multiprocess_mode="sum") + + # KV Cache Usage in % + self.gauge_gpu_cache_usage = self._gauge_cls( + name="vllm:gpu_cache_usage_perc", + documentation="GPU KV-cache usage. 1 means 100 percent usage.", + labelnames=labelnames, + multiprocess_mode="sum") + + # Deprecated in 0.8 - KV cache offloading is not used in V1 + # Hidden in 0.9, due to be removed in 0.10 + if self.show_hidden_metrics: + self.gauge_cpu_cache_usage = self._gauge_cls( + name="vllm:cpu_cache_usage_perc", + documentation=( + "CPU KV-cache usage. 1 means 100 percent usage. " + "DEPRECATED: KV cache offloading is not used in V1"), + labelnames=labelnames, + multiprocess_mode="sum") + self.gauge_cpu_prefix_cache_hit_rate = self._gauge_cls( + name="vllm:cpu_prefix_cache_hit_rate", + documentation=( + "CPU prefix cache block hit rate. " + "DEPRECATED: KV cache offloading is not used in V1"), + labelnames=labelnames, + multiprocess_mode="sum") + + # Deprecated in 0.8 - replaced by queries+hits counters in V1 + # Hidden in 0.9, due to be removed in 0.10 + if self.show_hidden_metrics: + self.gauge_gpu_prefix_cache_hit_rate = self._gauge_cls( + name="vllm:gpu_prefix_cache_hit_rate", + documentation=("GPU prefix cache block hit rate. " + "DEPRECATED: use vllm:gpu_prefix_cache_queries " + "and vllm:gpu_prefix_cache_queries in V1"), + labelnames=labelnames, + multiprocess_mode="sum") + + # Iteration stats + self.counter_num_preemption = self._counter_cls( + name="vllm:num_preemptions_total", + documentation="Cumulative number of preemption from the engine.", + labelnames=labelnames) + self.counter_prompt_tokens = self._counter_cls( + name="vllm:prompt_tokens_total", + documentation="Number of prefill tokens processed.", + labelnames=labelnames) + self.counter_generation_tokens = self._counter_cls( + name="vllm:generation_tokens_total", + documentation="Number of generation tokens processed.", + labelnames=labelnames) + buckets = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096] + if not vllm_config.model_config.enforce_eager: + buckets = vllm_config.compilation_config.\ + cudagraph_capture_sizes.copy() + buckets.sort() + self.histogram_iteration_tokens = self._histogram_cls( + name="vllm:iteration_tokens_total", + documentation="Histogram of number of tokens per engine_step.", + labelnames=labelnames, + buckets=buckets) + self.histogram_time_to_first_token = self._histogram_cls( + name="vllm:time_to_first_token_seconds", + documentation="Histogram of time to first token in seconds.", + labelnames=labelnames, + buckets=[ + 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, + 0.75, 1.0, 2.5, 5.0, 7.5, 10.0 + ]) + self.histogram_time_per_output_token = self._histogram_cls( + name="vllm:time_per_output_token_seconds", + documentation="Histogram of time per output token in seconds.", + labelnames=labelnames, + buckets=[ + 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, + 1.0, 2.5 + ]) + + # Request stats + # Latency + request_latency_buckets = [ + 0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, + 40.0, 50.0, 60.0 + ] + self.histogram_e2e_time_request = self._histogram_cls( + name="vllm:e2e_request_latency_seconds", + documentation="Histogram of end to end request latency in seconds.", + labelnames=labelnames, + buckets=request_latency_buckets) + self.histogram_queue_time_request = self._histogram_cls( + name="vllm:request_queue_time_seconds", + documentation= + "Histogram of time spent in WAITING phase for request.", + labelnames=labelnames, + buckets=request_latency_buckets) + self.histogram_inference_time_request = self._histogram_cls( + name="vllm:request_inference_time_seconds", + documentation= + "Histogram of time spent in RUNNING phase for request.", + labelnames=labelnames, + buckets=request_latency_buckets) + self.histogram_prefill_time_request = self._histogram_cls( + name="vllm:request_prefill_time_seconds", + documentation= + "Histogram of time spent in PREFILL phase for request.", + labelnames=labelnames, + buckets=request_latency_buckets) + self.histogram_decode_time_request = self._histogram_cls( + name="vllm:request_decode_time_seconds", + documentation= + "Histogram of time spent in DECODE phase for request.", + labelnames=labelnames, + buckets=request_latency_buckets) + # Deprecated in 0.8 - duplicates vllm:request_queue_time_seconds: + # Hidden in 0.9, due to be removed in 0.10 + if self.show_hidden_metrics: + self.histogram_time_in_queue_request = self._histogram_cls( + name="vllm:time_in_queue_requests", + documentation= + ("Histogram of time the request spent in the queue in seconds. " + "DEPRECATED: use vllm:request_queue_time_seconds instead."), + labelnames=labelnames, + buckets=request_latency_buckets) + + # Deprecated in 0.8 - use prefill/decode/inference time metrics + # Hidden in 0.9, due to be removed in 0.10 + if self.show_hidden_metrics: + self.histogram_model_forward_time_request = self._histogram_cls( + name="vllm:model_forward_time_milliseconds", + documentation= + ("Histogram of time spent in the model forward pass in ms. " + "DEPRECATED: use prefill/decode/inference time metrics instead" + ), + labelnames=labelnames, + buckets=build_1_2_3_5_8_buckets(3000)) + self.histogram_model_execute_time_request = self._histogram_cls( + name="vllm:model_execute_time_milliseconds", + documentation= + ("Histogram of time spent in the model execute function in ms." + "DEPRECATED: use prefill/decode/inference time metrics instead" + ), + labelnames=labelnames, + buckets=build_1_2_3_5_8_buckets(3000)) + + # Metadata + self.histogram_num_prompt_tokens_request = self._histogram_cls( + name="vllm:request_prompt_tokens", + documentation="Number of prefill tokens processed.", + labelnames=labelnames, + buckets=build_1_2_5_buckets(max_model_len), + ) + self.histogram_num_generation_tokens_request = \ + self._histogram_cls( + name="vllm:request_generation_tokens", + documentation="Number of generation tokens processed.", + labelnames=labelnames, + buckets=build_1_2_5_buckets(max_model_len), + ) + self.histogram_max_num_generation_tokens_request = self._histogram_cls( + name="vllm:request_max_num_generation_tokens", + documentation= + "Histogram of maximum number of requested generation tokens.", + labelnames=labelnames, + buckets=build_1_2_5_buckets(max_model_len)) + self.histogram_n_request = self._histogram_cls( + name="vllm:request_params_n", + documentation="Histogram of the n request parameter.", + labelnames=labelnames, + buckets=[1, 2, 5, 10, 20], + ) + self.histogram_max_tokens_request = self._histogram_cls( + name="vllm:request_params_max_tokens", + documentation="Histogram of the max_tokens request parameter.", + labelnames=labelnames, + buckets=build_1_2_5_buckets(max_model_len), + ) + self.counter_request_success = self._counter_cls( + name="vllm:request_success_total", + documentation="Count of successfully processed requests.", + labelnames=labelnames + [Metrics.labelname_finish_reason]) + + # Speculative decoding stats + self.gauge_spec_decode_draft_acceptance_rate = self._gauge_cls( + name="vllm:spec_decode_draft_acceptance_rate", + documentation="Speculative token acceptance rate.", + labelnames=labelnames, + multiprocess_mode="sum") + self.gauge_spec_decode_efficiency = self._gauge_cls( + name="vllm:spec_decode_efficiency", + documentation="Speculative decoding system efficiency.", + labelnames=labelnames, + multiprocess_mode="sum") + self.counter_spec_decode_num_accepted_tokens = (self._counter_cls( + name="vllm:spec_decode_num_accepted_tokens_total", + documentation="Number of accepted tokens.", + labelnames=labelnames)) + self.counter_spec_decode_num_draft_tokens = self._counter_cls( + name="vllm:spec_decode_num_draft_tokens_total", + documentation="Number of draft tokens.", + labelnames=labelnames) + self.counter_spec_decode_num_emitted_tokens = (self._counter_cls( + name="vllm:spec_decode_num_emitted_tokens_total", + documentation="Number of emitted tokens.", + labelnames=labelnames)) + + +# end-metrics-definitions + + def _unregister_vllm_metrics(self) -> None: + for collector in list(prometheus_client.REGISTRY._collector_to_names): + if hasattr(collector, "_name") and "vllm" in collector._name: + prometheus_client.REGISTRY.unregister(collector) + + +class _RayGaugeWrapper: + """Wraps around ray.util.metrics.Gauge to provide same API as + prometheus_client.Gauge""" + + def __init__(self, + name: str, + documentation: str = "", + labelnames: Optional[List[str]] = None, + multiprocess_mode: str = ""): + del multiprocess_mode + labelnames_tuple = tuple(labelnames) if labelnames else None + self._gauge = ray_metrics.Gauge(name=name, + description=documentation, + tag_keys=labelnames_tuple) + + def labels(self, **labels): + self._gauge.set_default_tags(labels) + return self + + def set(self, value: Union[int, float]): + return self._gauge.set(value) + + def set_to_current_time(self): + # ray metrics doesn't have set_to_current time, https://docs.ray.io/en/latest/_modules/ray/util/metrics.html + return self._gauge.set(time.time()) + + +class _RayCounterWrapper: + """Wraps around ray.util.metrics.Counter to provide same API as + prometheus_client.Counter""" + + def __init__(self, + name: str, + documentation: str = "", + labelnames: Optional[List[str]] = None): + labelnames_tuple = tuple(labelnames) if labelnames else None + self._counter = ray_metrics.Counter(name=name, + description=documentation, + tag_keys=labelnames_tuple) + + def labels(self, **labels): + self._counter.set_default_tags(labels) + return self + + def inc(self, value: Union[int, float] = 1.0): + if value == 0: + return + return self._counter.inc(value) + + +class _RayHistogramWrapper: + """Wraps around ray.util.metrics.Histogram to provide same API as + prometheus_client.Histogram""" + + def __init__(self, + name: str, + documentation: str = "", + labelnames: Optional[List[str]] = None, + buckets: Optional[List[float]] = None): + labelnames_tuple = tuple(labelnames) if labelnames else None + boundaries = buckets if buckets else [] + self._histogram = ray_metrics.Histogram(name=name, + description=documentation, + tag_keys=labelnames_tuple, + boundaries=boundaries) + + def labels(self, **labels): + self._histogram.set_default_tags(labels) + return self + + def observe(self, value: Union[int, float]): + return self._histogram.observe(value) + + +class RayMetrics(Metrics): + """ + RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics. + Provides the same metrics as Metrics but uses Ray's util.metrics library. + """ + _gauge_cls: Type[prometheus_client.Gauge] = cast( + Type[prometheus_client.Gauge], _RayGaugeWrapper) + _counter_cls: Type[prometheus_client.Counter] = cast( + Type[prometheus_client.Counter], _RayCounterWrapper) + _histogram_cls: Type[prometheus_client.Histogram] = cast( + Type[prometheus_client.Histogram], _RayHistogramWrapper) + + def __init__(self, labelnames: List[str], vllm_config: VllmConfig): + if ray_metrics is None: + raise ImportError("RayMetrics requires Ray to be installed.") + super().__init__(labelnames, vllm_config) + + def _unregister_vllm_metrics(self) -> None: + # No-op on purpose + pass + + +def build_buckets(mantissa_lst: List[int], max_value: int) -> List[int]: + """ + Builds a list of buckets with increasing powers of 10 multiplied by + mantissa values until the value exceeds the specified maximum. + + """ + exponent = 0 + buckets: List[int] = [] + while True: + for m in mantissa_lst: + value = m * 10**exponent + if value <= max_value: + buckets.append(value) + else: + return buckets + exponent += 1 + + +def build_1_2_5_buckets(max_value: int) -> List[int]: + """ + Example: + >>> build_1_2_5_buckets(100) + [1, 2, 5, 10, 20, 50, 100] + """ + return build_buckets([1, 2, 5], max_value) + + +def build_1_2_3_5_8_buckets(max_value: int) -> List[int]: + """ + Example: + >>> build_1_2_3_5_8_buckets(100) + [1, 2, 3, 5, 8, 10, 20, 30, 50, 80, 100] + """ + return build_buckets([1, 2, 3, 5, 8], max_value) + + +def local_interval_elapsed(now: float, last_log: float, + local_interval: float) -> bool: + elapsed_time = now - last_log + return elapsed_time > local_interval + + +def get_throughput(tracked_stats: List[int], now: float, + last_log: float) -> float: + return float(np.sum(tracked_stats) / (now - last_log)) + + +class LoggingStatLogger(StatLoggerBase): + """LoggingStatLogger is used in LLMEngine to log to Stdout.""" + + def __init__(self, local_interval: float, vllm_config: VllmConfig) -> None: + super().__init__(local_interval, vllm_config) + self.last_prompt_throughput: Optional[float] = None + self.last_generation_throughput: Optional[float] = None + + def log(self, stats: Stats) -> None: + """Called by LLMEngine. + Logs to Stdout every self.local_interval seconds.""" + + # Save tracked stats for token counters. + self.num_prompt_tokens.append(stats.num_prompt_tokens_iter) + self.num_generation_tokens.append(stats.num_generation_tokens_iter) + + # Update spec decode metrics + self.maybe_update_spec_decode_metrics(stats) + + # Log locally every local_interval seconds. + if local_interval_elapsed(stats.now, self.last_local_log, + self.local_interval): + # Compute summary metrics for tracked stats (and log them + # to promethus if applicable). + prompt_throughput = get_throughput(self.num_prompt_tokens, + now=stats.now, + last_log=self.last_local_log) + generation_throughput = get_throughput( + self.num_generation_tokens, + now=stats.now, + last_log=self.last_local_log) + + log_fn = logger.info + if not any((prompt_throughput, generation_throughput, + self.last_prompt_throughput, + self.last_generation_throughput)): + # Avoid log noise on an idle production system + log_fn = logger.debug + + log_fn( + "Avg prompt throughput: %.1f tokens/s, " + "Avg generation throughput: %.1f tokens/s, " + "Running: %d reqs, Swapped: %d reqs, " + "Pending: %d reqs, GPU KV cache usage: %.1f%%, " + "CPU KV cache usage: %.1f%%.", + prompt_throughput, + generation_throughput, + stats.num_running_sys, + stats.num_swapped_sys, + stats.num_waiting_sys, + stats.gpu_cache_usage_sys * 100, + stats.cpu_cache_usage_sys * 100, + ) + if (stats.cpu_prefix_cache_hit_rate >= 0 + or stats.gpu_prefix_cache_hit_rate >= 0): + log_fn( + "Prefix cache hit rate: GPU: %.2f%%, CPU: %.2f%%", + stats.gpu_prefix_cache_hit_rate * 100, + stats.cpu_prefix_cache_hit_rate * 100, + ) + if self.spec_decode_metrics is not None: + log_fn( + self._format_spec_decode_metrics_str( + self.spec_decode_metrics)) + + self._reset(stats, prompt_throughput, generation_throughput) + + def _reset(self, stats, prompt_throughput, generation_throughput) -> None: + # Reset tracked stats for next interval. + self.num_prompt_tokens = [] + self.num_generation_tokens = [] + self.last_local_log = stats.now + self.spec_decode_metrics = None + self.last_prompt_throughput = prompt_throughput + self.last_generation_throughput = generation_throughput + + def _format_spec_decode_metrics_str( + self, metrics: "SpecDecodeWorkerMetrics") -> str: + + return ("Speculative metrics: " + f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, " + f"System efficiency: {metrics.system_efficiency:.3f}, " + f"Number of speculative tokens: {metrics.num_spec_tokens}, " + f"Number of accepted tokens: {metrics.accepted_tokens}, " + f"Number of draft tokens: {metrics.draft_tokens}, " + f"Number of emitted tokens: {metrics.emitted_tokens}.") + + def info(self, type: str, obj: SupportsMetricsInfo) -> None: + raise NotImplementedError + + +class PrometheusStatLogger(StatLoggerBase): + """PrometheusStatLogger is used LLMEngine to log to Promethus.""" + _metrics_cls = Metrics + _gauge_cls = prometheus_client.Gauge + + def __init__(self, local_interval: float, labels: Dict[str, str], + vllm_config: VllmConfig) -> None: + super().__init__(local_interval, vllm_config) + # Prometheus metrics + self.labels = labels + self.metrics = self._metrics_cls(labelnames=list(labels.keys()), + vllm_config=vllm_config) + + def _log_gauge(self, gauge, data: Union[int, float]) -> None: + # Convenience function for logging to gauge. + gauge.labels(**self.labels).set(data) + + def _log_counter(self, counter, data: Union[int, float]) -> None: + # Convenience function for logging to counter. + # Prevent ValueError from negative increment + if data < 0: + logger.warning("Skipping negative increment of %g to %s", data, + counter) + return + counter.labels(**self.labels).inc(data) + + def _log_counter_labels(self, counter, data: CollectionsCounter, + label_key: str) -> None: + # Convenience function for collection counter of labels. + for label, count in data.items(): + counter.labels(**{**self.labels, label_key: label}).inc(count) + + def _log_histogram(self, histogram, data: Union[List[int], + List[float]]) -> None: + # Convenience function for logging list to histogram. + for datum in data: + histogram.labels(**self.labels).observe(datum) + + def _log_gauge_string(self, gauge, data: Dict[str, str]) -> None: + gauge.labels(**data).set_to_current_time() + + def _log_prometheus(self, stats: Stats) -> None: + # System state data + self._log_gauge(self.metrics.gauge_scheduler_running, + stats.num_running_sys) + if self.metrics.show_hidden_metrics: + self._log_gauge(self.metrics.gauge_scheduler_swapped, + stats.num_swapped_sys) + self._log_gauge(self.metrics.gauge_scheduler_waiting, + stats.num_waiting_sys) + self._log_gauge(self.metrics.gauge_gpu_cache_usage, + stats.gpu_cache_usage_sys) + if self.metrics.show_hidden_metrics: + self._log_gauge(self.metrics.gauge_cpu_cache_usage, + stats.cpu_cache_usage_sys) + self._log_gauge(self.metrics.gauge_cpu_prefix_cache_hit_rate, + stats.cpu_prefix_cache_hit_rate) + self._log_gauge(self.metrics.gauge_gpu_prefix_cache_hit_rate, + stats.gpu_prefix_cache_hit_rate) + # Including max-lora in metric, in future this property of lora + # config maybe extended to be dynamic. + lora_info = { + self.metrics.labelname_running_lora_adapters: + ",".join(stats.running_lora_adapters), + self.metrics.labelname_waiting_lora_adapters: + ",".join(stats.waiting_lora_adapters), + self.metrics.labelname_max_lora: + stats.max_lora, + } + self._log_gauge_string(self.metrics.gauge_lora_info, lora_info) + # Iteration level data + self._log_counter(self.metrics.counter_num_preemption, + stats.num_preemption_iter) + self._log_counter(self.metrics.counter_prompt_tokens, + stats.num_prompt_tokens_iter) + self._log_counter(self.metrics.counter_generation_tokens, + stats.num_generation_tokens_iter) + self._log_histogram(self.metrics.histogram_iteration_tokens, + [stats.num_tokens_iter]) + self._log_histogram(self.metrics.histogram_time_to_first_token, + stats.time_to_first_tokens_iter) + self._log_histogram(self.metrics.histogram_time_per_output_token, + stats.time_per_output_tokens_iter) + + # Request level data + # Latency + self._log_histogram(self.metrics.histogram_e2e_time_request, + stats.time_e2e_requests) + self._log_histogram(self.metrics.histogram_queue_time_request, + stats.time_queue_requests) + self._log_histogram(self.metrics.histogram_inference_time_request, + stats.time_inference_requests) + self._log_histogram(self.metrics.histogram_prefill_time_request, + stats.time_prefill_requests) + self._log_histogram(self.metrics.histogram_decode_time_request, + stats.time_decode_requests) + if self.metrics.show_hidden_metrics: + self._log_histogram(self.metrics.histogram_time_in_queue_request, + stats.time_in_queue_requests) + self._log_histogram( + self.metrics.histogram_model_forward_time_request, + stats.model_forward_time_requests) + self._log_histogram( + self.metrics.histogram_model_execute_time_request, + stats.model_execute_time_requests) + # Metadata + finished_reason_counter = CollectionsCounter( + stats.finished_reason_requests) + self._log_counter_labels(self.metrics.counter_request_success, + finished_reason_counter, + Metrics.labelname_finish_reason) + self._log_histogram(self.metrics.histogram_num_prompt_tokens_request, + stats.num_prompt_tokens_requests) + self._log_histogram( + self.metrics.histogram_num_generation_tokens_request, + stats.num_generation_tokens_requests) + self._log_histogram(self.metrics.histogram_n_request, stats.n_requests) + self._log_histogram( + self.metrics.histogram_max_num_generation_tokens_request, + stats.max_num_generation_tokens_requests) + self._log_histogram(self.metrics.histogram_max_tokens_request, + stats.max_tokens_requests) + + def log(self, stats: Stats): + """Logs to prometheus and tracked stats every iteration.""" + # Log to prometheus. + self._log_prometheus(stats) + + # Save tracked stats for token counters. + self.num_prompt_tokens.append(stats.num_prompt_tokens_iter) + self.num_generation_tokens.append(stats.num_generation_tokens_iter) + + # Update spec decode metrics + self.maybe_update_spec_decode_metrics(stats) + + # Log locally every local_interval seconds. + if local_interval_elapsed(stats.now, self.last_local_log, + self.local_interval): + if self.spec_decode_metrics is not None: + self._log_gauge( + self.metrics.gauge_spec_decode_draft_acceptance_rate, + self.spec_decode_metrics.draft_acceptance_rate) + self._log_gauge(self.metrics.gauge_spec_decode_efficiency, + self.spec_decode_metrics.system_efficiency) + self._log_counter( + self.metrics.counter_spec_decode_num_accepted_tokens, + self.spec_decode_metrics.accepted_tokens) + self._log_counter( + self.metrics.counter_spec_decode_num_draft_tokens, + self.spec_decode_metrics.draft_tokens) + self._log_counter( + self.metrics.counter_spec_decode_num_emitted_tokens, + self.spec_decode_metrics.emitted_tokens) + + # Reset tracked stats for next interval. + self.num_prompt_tokens = [] + self.num_generation_tokens = [] + self.last_local_log = stats.now + self.spec_decode_metrics = None + + def info(self, type: str, obj: SupportsMetricsInfo) -> None: + # Info type metrics are syntactic sugar for a gauge permanently set to 1 + # Since prometheus multiprocessing mode does not support Info, emulate + # info here with a gauge. + if type == "cache_config": + metrics_info = obj.metrics_info() + info_gauge = self._gauge_cls( + name="vllm:cache_config_info", + documentation="Information of the LLMEngine CacheConfig", + labelnames=metrics_info.keys(), + multiprocess_mode="mostrecent") + info_gauge.labels(**metrics_info).set(1) + + +class RayPrometheusStatLogger(PrometheusStatLogger): + """RayPrometheusStatLogger uses Ray metrics instead.""" + _metrics_cls = RayMetrics + + def info(self, type: str, obj: SupportsMetricsInfo) -> None: + return None diff --git a/vllm/engine/metrics_types.py b/vllm/engine/metrics_types.py new file mode 100644 index 0000000..9e6d5ef --- /dev/null +++ b/vllm/engine/metrics_types.py @@ -0,0 +1,96 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +These types are defined in this file to avoid importing vllm.engine.metrics +and therefore importing prometheus_client. + +This is required due to usage of Prometheus multiprocess mode to enable +metrics after splitting out the uvicorn process from the engine process. + +Prometheus multiprocess mode requires setting PROMETHEUS_MULTIPROC_DIR +before prometheus_client is imported. Typically, this is done by setting +the env variable before launch, but since we are a library, we need to +do this in Python code and lazily import prometheus_client. +""" + +import time +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import List, Optional + +from vllm.config import SupportsMetricsInfo, VllmConfig +from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics + + +@dataclass +class Stats: + """Created by LLMEngine for use by StatLogger.""" + now: float + + # System stats (should have _sys suffix) + # Scheduler State + num_running_sys: int + num_waiting_sys: int + num_swapped_sys: int + # KV Cache Usage in % + gpu_cache_usage_sys: float + cpu_cache_usage_sys: float + # Prefix caching block hit rate + cpu_prefix_cache_hit_rate: float + gpu_prefix_cache_hit_rate: float + + # Iteration stats (should have _iter suffix) + num_prompt_tokens_iter: int + num_generation_tokens_iter: int + num_tokens_iter: int + time_to_first_tokens_iter: List[float] + time_per_output_tokens_iter: List[float] + num_preemption_iter: int + + # Request stats (should have _requests suffix) + # Latency + time_e2e_requests: List[float] + time_queue_requests: List[float] + time_inference_requests: List[float] + time_prefill_requests: List[float] + time_decode_requests: List[float] + time_in_queue_requests: List[float] + model_forward_time_requests: List[float] + model_execute_time_requests: List[float] + # Metadata + num_prompt_tokens_requests: List[int] + num_generation_tokens_requests: List[int] + n_requests: List[int] + max_num_generation_tokens_requests: List[int] + max_tokens_requests: List[int] + finished_reason_requests: List[str] + waiting_lora_adapters: List[str] + running_lora_adapters: List[str] + max_lora: str + + spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None + + +class StatLoggerBase(ABC): + """Base class for StatLogger.""" + + def __init__(self, local_interval: float, vllm_config: VllmConfig) -> None: + # Tracked stats over current local logging interval. + self.num_prompt_tokens: List[int] = [] + self.num_generation_tokens: List[int] = [] + self.last_local_log = time.time() + self.local_interval = local_interval + self.spec_decode_metrics: Optional[SpecDecodeWorkerMetrics] = None + + @abstractmethod + def log(self, stats: Stats) -> None: + raise NotImplementedError + + @abstractmethod + def info(self, type: str, obj: SupportsMetricsInfo) -> None: + raise NotImplementedError + + def maybe_update_spec_decode_metrics(self, stats: Stats): + """Save spec decode metrics (since they are unlikely + to be emitted at same time as log interval).""" + if stats.spec_decode_metrics is not None: + self.spec_decode_metrics = stats.spec_decode_metrics diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py new file mode 100644 index 0000000..cafd815 --- /dev/null +++ b/vllm/engine/multiprocessing/__init__.py @@ -0,0 +1,183 @@ +# SPDX-License-Identifier: Apache-2.0 + +import uuid +from dataclasses import dataclass, field +from enum import Enum +from typing import List, Mapping, Optional, Union, overload + +from typing_extensions import deprecated + +from vllm import PoolingParams +from vllm.inputs import PromptType +from vllm.lora.request import LoRARequest +from vllm.outputs import RequestOutput +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams +from vllm.utils import Device, deprecate_kwargs + +VLLM_RPC_SUCCESS_STR = "SUCCESS" + +IPC_INPUT_EXT = "_input_socket" +IPC_OUTPUT_EXT = "_output_socket" +IPC_HEALTH_EXT = "_health_socket" +IPC_DATA_EXT = "_data_socket" + + +class MQEngineDeadError(RuntimeError): + pass + + +@dataclass +class RPCProcessRequest: + prompt: PromptType + params: Union[SamplingParams, PoolingParams] + request_id: str + lora_request: Optional[LoRARequest] = None + trace_headers: Optional[Mapping[str, str]] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None + priority: int = 0 + + @overload + def __init__( + self, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> None: + ... + + @overload + @deprecated("'inputs' will be renamed to 'prompt") + def __init__( + self, + *, + inputs: PromptType, + params: Union[SamplingParams, PoolingParams], + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> None: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + def __init__( + self, + prompt: Optional[PromptType] = None, + params: Optional[Union[SamplingParams, PoolingParams]] = None, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + *, + inputs: Optional[PromptType] = None, # DEPRECATED + ) -> None: + if inputs is not None: + prompt = inputs + assert (prompt is not None and params is not None + and request_id is not None) + + super().__init__() + + self.prompt = prompt + self.params = params + self.request_id = request_id + self.lora_request = lora_request + self.trace_headers = trace_headers + self.prompt_adapter_request = prompt_adapter_request + self.priority = priority + + +@dataclass +class RPCError: + request_id: Optional[str] + is_engine_errored: bool + exception: BaseException + + +@dataclass +class RPCAbortRequest: + request_id: str + + +class RPCStartupRequest(Enum): + IS_SERVER_READY = 1 + + +@dataclass +class RPCStartupResponse: + tracing_enabled: bool + + +class RPCUProfileRequest(Enum): + START_PROFILE = 1 + STOP_PROFILE = 2 + + +@dataclass +class RPCResetPrefixCacheRequest: + device: Device + + +class RPCSleepRequest(Enum): + SLEEP_LEVEL_1 = 1 + SLEEP_LEVEL_2 = 2 + + +@dataclass +class RPCWakeUpRequest: + tags: Optional[list[str]] = None + + +@dataclass +class RPCIsSleepingRequest: + # Set the default value of request_id to a new UUID + request_id: str = field(default_factory=lambda: str(uuid.uuid4())) + + +@dataclass +class RPCIsSleepingResponse: + request_id: str + is_sleeping: bool + + +@dataclass +class RPCLoadAdapterRequest: + lora_request: LoRARequest + # Set the default value of request_id to a new UUID + request_id: str = field(default_factory=lambda: str(uuid.uuid4())) + + +@dataclass +class RPCAdapterLoadedResponse: + request_id: str + + +RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest, + RPCUProfileRequest, RPCLoadAdapterRequest, + RPCResetPrefixCacheRequest, RPCSleepRequest, + RPCWakeUpRequest, RPCIsSleepingRequest] + +REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCAdapterLoadedResponse, + RPCIsSleepingResponse, RPCError] + + +def ENGINE_DEAD_ERROR( + error: Optional[BaseException] = None) -> MQEngineDeadError: + if error is None: + return MQEngineDeadError( + "Engine loop is not running. Inspect the stacktrace to " + "find the original error") + + return MQEngineDeadError( + "Engine loop is not running. Inspect the stacktrace to " + f"find the original error: {repr(error)}.") diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py new file mode 100644 index 0000000..f058b13 --- /dev/null +++ b/vllm/engine/multiprocessing/client.py @@ -0,0 +1,742 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import copy +import pickle +from contextlib import contextmanager, suppress +from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping, + Optional, Union, cast, overload) + +import cloudpickle +import psutil +import zmq +import zmq.asyncio +from typing_extensions import deprecated +from zmq import Frame # type: ignore[attr-defined] +from zmq.asyncio import Socket + +from vllm import PoolingParams +from vllm.config import DecodingConfig, ModelConfig, VllmConfig +from vllm.core.scheduler import SchedulerOutputs +# yapf conflicts with isort for this block +# yapf: disable +from vllm.engine.async_llm_engine import ( + build_guided_decoding_logits_processor_async) +from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, + IPC_HEALTH_EXT, IPC_INPUT_EXT, + IPC_OUTPUT_EXT, RPC_REQUEST_T, + VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCAdapterLoadedResponse, RPCError, + RPCIsSleepingRequest, + RPCIsSleepingResponse, + RPCLoadAdapterRequest, + RPCProcessRequest, + RPCResetPrefixCacheRequest, + RPCSleepRequest, RPCStartupRequest, + RPCStartupResponse, + RPCUProfileRequest, RPCWakeUpRequest) +from vllm.engine.protocol import EngineClient +# yapf: enable +from vllm.envs import VLLM_RPC_TIMEOUT +from vllm.inputs import PromptType +from vllm.inputs.preprocess import InputPreprocessor +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.outputs import PoolingRequestOutput, RequestOutput +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from vllm.utils import Device, deprecate_kwargs + +logger = init_logger(__name__) + + +class MQClientClosedError(Exception): + """Exception class raised when the client is used post-close. + + The client can be closed, which closes the ZMQ context. This normally + happens on server shutdown. In some cases, methods like abort and + do_log_stats will still be called and then try to open a socket, which + causes a ZMQError and creates a huge stack trace. + So, we throw this error such that we can suppress it. + """ + + +class MQLLMEngineClient(EngineClient): + """A client wrapper for MQLLMEngine that conforms to the + EngineClient protocol. + + MQLLMEngine and MQLLMEngineClient are intended to run in separate + processes communicating via zeromq ipc sockets. + + The entrypoint to MQLLMEngineClient is through the generate() + method. On generate() MQLLMEngine does three things: + - Creates an asyncio output queue + - Sends a RPCGenerateRequest to the MQLLMEngine via zmq + - Pulls RequestOutputs from its queue and yields them + + MQLLMEngine runs two background loops: + - output_loop: the output loop pulls List[RequestOutput] + from the MQLLMEngine via zmq (each list is the output + of one engine_step in the LLMEngine). It then parses + the list and pushes individual request_outputs into + the corresponding output_queue such that they can be + consumed by the .generate() method. + - health_loop: the health loop queries the health socket + every N seconds, confirming the engine is healthy + """ + + def __init__(self, ipc_path: str, engine_config: VllmConfig, + engine_pid: int): + self.context = zmq.asyncio.Context() + self._errored_with: Optional[BaseException] = None + + # Get the configs. + self.model_config = engine_config.model_config + self.decoding_config = engine_config.decoding_config + + # Create the tokenizer group. + self.tokenizer = init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=engine_config.scheduler_config, + parallel_config=engine_config.parallel_config, + lora_config=engine_config.lora_config) + self.input_preprocessor = InputPreprocessor(self.model_config, + self.tokenizer) + + # Send RPCGenerateRequest to the MQLLMEngine. + self.input_socket: Socket = self.context.socket(zmq.constants.PUSH) + self.input_socket.connect(f"{ipc_path}{IPC_INPUT_EXT}") + + # Receive streams of RequestOutput from the MQLLMEngine. + self.output_socket: Socket = self.context.socket(zmq.constants.PULL) + self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}") + + # IPC path for acking heartbeats. + self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL) + self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") + + # IPC path for the data socket. + self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" + + # Stream for each individual request. + self.output_queues: Dict[str, asyncio.Queue] = {} + + # Loop to handle output of the LLMEngine periodically. + # Started after the MQLLMEngine is ready so that we can + # build the Client in an executor to enable clean shutdown. + self.output_loop: Optional[asyncio.Task] = None + + # Loop to check health of the LLMEngine periodically. + # Started after the MQLLMEngine is ready. + self.health_loop: Optional[asyncio.Task] = None + self._engine_process = psutil.Process(engine_pid) + + @staticmethod + def is_unsupported_config(vllm_config: VllmConfig): + # Pipeline parallel not yet supported + return vllm_config.parallel_config.pipeline_parallel_size > 1 + + @contextmanager + def get_data_socket(self) -> Iterator[Socket]: + socket = self.context.socket(zmq.constants.DEALER) + try: + socket.connect(self.data_ipc_path) + yield socket + finally: + socket.close(linger=0) + + async def run_heartbeat_loop(self, timeout: int): + """Background loop that continually checks to ensure the engine process + is still alive. + """ + try: + while True: + # Check if the engine process is running: + if not self._engine_process.is_running() or ( + self._engine_process.status() == psutil.STATUS_ZOMBIE): + # NB: is_running() returns True for zombies + self._set_errored( + RuntimeError( + f"Engine process (pid {self._engine_process.pid}) " + "died.")) + break + + if await self.heartbeat_socket.poll(timeout=timeout): + # Heartbeat received- check the message + await self._check_success( + error_message="Heartbeat failed.", + socket=self.heartbeat_socket) + + logger.debug("Heartbeat successful.") + + except asyncio.CancelledError: + logger.debug("Shutting down MQLLMEngineClient check health loop.") + + except psutil.NoSuchProcess: + self._set_errored( + RuntimeError( + f"Engine process (pid {self._engine_process.pid}) died.")) + + except Exception as e: + self._set_errored(e) + + async def run_output_handler_loop(self): + """Get RequestOutputs from Engine and stream to Request Queues""" + + try: + while True: + # Poll, checking for ENGINE_DEAD + while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT + ) == 0: + logger.debug("Waiting for output from MQLLMEngine.") + + # If errored, alert all running requests. + if self.errored: + for queue_j in tuple(self.output_queues.values()): + queue_j.put_nowait( + ENGINE_DEAD_ERROR(self._errored_with)) + return + + message: Frame = await self.output_socket.recv(copy=False) + request_outputs = pickle.loads(message.buffer) + + is_error = isinstance(request_outputs, + (BaseException, RPCError)) + if is_error: + if isinstance(request_outputs, RPCError): + rpc_error: RPCError = request_outputs + request_id = rpc_error.request_id + exception = rpc_error.exception + is_engine_errored = rpc_error.is_engine_errored + else: + # MPLLMEngine should always return an RPCError to + # the output_socket when an issue arises. + # If we are here, we are in a bad state and + # should shut down the server. + error: BaseException = request_outputs + logger.error( + "Received Exception %s rather than RPCError from " + "MPLLMEngine. This should never happen.", error) + request_id = None + exception = error + is_engine_errored = True + + # Set to error state only on engine critical error + # (and record only the first one) + if is_engine_errored and not self._errored_with: + self._errored_with = exception + # If engine is errored, no matter the type of exception + # it will no longer be able to receive new requests, + # therefore we have to inform that the current + # processed requests failed as well. Send back a dead + # engine error give this feedback and also give a + # 'hint' to the server to shutdown next. + exception = self.dead_error + + if request_id is None: + # If request_id is None, then the engine raised an + # exception for a batch, and we may not know the + # request that caused it, neither if it was actually + # caused by any of them (e.g. CUDA OOM). Therefore we + # broadcast the same exception for all requests. + for queue_i in tuple(self.output_queues.values()): + queue_i.put_nowait(exception) + else: + queue = self.output_queues.get(request_id) + if queue is not None: + queue.put_nowait(exception) + # Put each output into the appropriate queue. + elif isinstance( + request_outputs, + (RPCAdapterLoadedResponse, RPCIsSleepingResponse)): + self._add_output(request_outputs) + else: + for request_output in request_outputs: + self._add_output(request_output) + + except asyncio.CancelledError: + logger.debug("Shutting down MQLLMEngineClient output handler.") + + def _add_output(self, request_output: Union[RequestOutput, + RPCAdapterLoadedResponse, + RPCIsSleepingResponse]): + queue = self.output_queues.get(request_output.request_id) + if queue is not None: + queue.put_nowait(request_output) + + async def setup(self): + """Setup the client before it starts sending server requests.""" + + # Start output_loop + if self.output_loop is None: + # only generate once to avoid multiple concurrent output_loops + # this will lead to race conditions and wrong orders of tokens + # returned by the engine + # setup will be called multiple times during the startup of + # the engine + self.output_loop = asyncio.create_task( + self.run_output_handler_loop()) + + with self.get_data_socket() as socket: + # Wait until server is ready. + response = await self._wait_for_server_rpc(socket) + + self.tracing_flag = response.tracing_enabled + + # Start health_loop. + if self.health_loop is None: + self.health_loop = asyncio.create_task( + self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT)) + + def close(self): + """Destroy the ZeroMQ Context.""" + # Close all sockets and terminate the context. + self.context.destroy(linger=0) + + # Cancel background tasks. + if self.health_loop is not None: + self.health_loop.cancel() + if self.output_loop is not None: + self.output_loop.cancel() + + def _set_errored(self, e: BaseException): + logger.exception(repr(e)) + if self._errored_with is None: + self._errored_with = e + + @staticmethod + async def _send_get_data_rpc_request(request: RPCStartupRequest, + expected_type: Any, + error_message: str, + socket: Socket) -> Any: + """Send an RPC request that is expecting data back.""" + + # Ping RPCServer with a request. + await socket.send_multipart((pickle.dumps(request), ), copy=False) + + # Make sure the server responds in time. + if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0: + raise TimeoutError("RPCServer didn't reply within " + f"{VLLM_RPC_TIMEOUT} ms") + + # Await the data from the Server. + frame = await socket.recv(copy=False) + data = pickle.loads(frame.buffer) + + if isinstance(data, BaseException): + raise data + elif not isinstance(data, expected_type): + raise ValueError(error_message) + + return data + + @staticmethod + async def _send_one_way_rpc_request(request: RPC_REQUEST_T, + socket: Socket): + """Send one-way RPC request to trigger an action.""" + + if socket.closed: + raise MQClientClosedError() + + await socket.send_multipart((pickle.dumps(request), )) + + async def _await_ack(self, error_message: str, socket: Socket): + """Await acknowledgement that a request succeeded.""" + + if socket.closed: + raise MQClientClosedError() + + if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0: + raise TimeoutError("MQLLMEngine didn't reply within " + f"{VLLM_RPC_TIMEOUT}ms") + + await self._check_success(error_message, socket) + + @staticmethod + async def _check_success(error_message: str, socket: Socket): + """Confirm that socket has a VLLM_RPC_SUCCESS_STR message""" + + if socket.closed: + raise MQClientClosedError() + + frame = await socket.recv(copy=False) + response = pickle.loads(frame.buffer) + + # Raise error if unsuccessful + if isinstance(response, BaseException): + raise response + elif (not isinstance(response, str) + or response != VLLM_RPC_SUCCESS_STR): + raise ValueError(error_message) + + async def get_input_preprocessor(self) -> InputPreprocessor: + return self.input_preprocessor + + async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None): + return await self.tokenizer.get_lora_tokenizer_async(lora_request) + + async def get_decoding_config(self) -> DecodingConfig: + return self.decoding_config + + async def get_model_config(self) -> ModelConfig: + return self.model_config + + async def is_tracing_enabled(self) -> bool: + return self.tracing_flag + + async def _wait_for_server_rpc(self, socket: Socket) -> RPCStartupResponse: + """Wait for the RPCServer to start up.""" + + return await self._send_get_data_rpc_request( + request=RPCStartupRequest.IS_SERVER_READY, + expected_type=RPCStartupResponse, + error_message="Unable to start RPC Server", + socket=socket) + + async def abort(self, request_id: str): + """Send an ABORT_REQUEST signal to the RPC Server""" + + with suppress(MQClientClosedError): + await self._send_one_way_rpc_request( + request=RPCAbortRequest(request_id), socket=self.input_socket) + + async def do_log_stats( + self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[List[SamplerOutput]] = None, + ) -> None: + """ + Ignore do_log_stats (handled on MQLLMEngine polling) + """ + pass + + async def check_health(self): + """ + The check health loop probes the health status of the + Engine's health every N seconds and sets _errored_with + if the engine is unhealthy. + """ + if self._errored_with is not None: + raise self._errored_with + + @property + def is_running(self) -> bool: + return not self.errored + + @property + def is_stopped(self) -> bool: + return self.errored + + @property + def errored(self) -> bool: + return self._errored_with is not None + + @property + def dead_error(self) -> BaseException: + return ENGINE_DEAD_ERROR(self._errored_with) + + @overload + def generate( + self, + prompt: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> AsyncGenerator[RequestOutput, None]: + ... + + @overload + @deprecated("'inputs' will be renamed to 'prompt") + def generate( + self, + *, + inputs: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> AsyncGenerator[RequestOutput, None]: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + def generate( + self, + prompt: Optional[PromptType] = None, + sampling_params: Optional[SamplingParams] = None, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + *, + inputs: Optional[PromptType] = None # DEPRECATED + ) -> AsyncGenerator[RequestOutput, None]: + """Generate outputs for a request. + + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. + + Args: + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` + for more details about the format of each input. + sampling_params: The sampling parameters of the request. + request_id: The unique id of the request. + lora_request: LoRA request to use for generation, if any. + trace_headers: OpenTelemetry trace headers. + prompt_adapter_request: Prompt Adapter request to use + for generation, if any. + priority: Priority of the request (lower means earlier handling). + Any priority other than 0 will lead to an error if the + scheduling policy is not "priority". + """ + if inputs is not None: + prompt = inputs + assert (prompt is not None and sampling_params is not None + and request_id is not None) + + return self._process_request(prompt, sampling_params, request_id, + lora_request, trace_headers, + prompt_adapter_request, priority) + + @overload + def encode( + self, + prompt: PromptType, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, + ) -> AsyncGenerator[PoolingRequestOutput, None]: + ... + + @overload + @deprecated("'inputs' will be renamed to 'prompt") + def encode( + self, + *, + inputs: PromptType, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, + ) -> AsyncGenerator[PoolingRequestOutput, None]: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + def encode( + self, + prompt: Optional[PromptType] = None, + pooling_params: Optional[PoolingParams] = None, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, + *, + inputs: Optional[PromptType] = None # DEPRECATED + ) -> AsyncGenerator[PoolingRequestOutput, None]: + """Generate outputs for a request from a pooling model. + + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. + + Args: + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` + for more details about the format of each input. + pooling_params: The pooling parameters of the request. + request_id: The unique id of the request. + lora_request: LoRA request to use for generation, if any. + trace_headers: OpenTelemetry trace headers. + + Yields: + The output `PoolingRequestOutput` objects from the LLMEngine + for the request. + """ + if inputs is not None: + prompt = inputs + assert (prompt is not None and pooling_params is not None + and request_id is not None) + + return cast( + AsyncGenerator[PoolingRequestOutput, None], + self._process_request(prompt, + pooling_params, + request_id, + lora_request, + trace_headers, + priority=priority)) + + async def _process_request( + self, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[ + PoolingRequestOutput, None]]: + """Send an RPCGenerateRequest to the RPCServer and stream responses.""" + + # If already dead, error out. + if self._errored_with is not None: + raise ENGINE_DEAD_ERROR(self._errored_with) + + # Ensure the request id is unique among running requests + if request_id in self.output_queues: + raise ValueError(f"Request {request_id} already exists") + + # Constructing guided decoding logits processors is expensive, so we do + # it here to avoid contending with cpu resources and the GIL on the + # backend process. + if isinstance(params, SamplingParams) and \ + params.guided_decoding is not None: + params = await \ + build_guided_decoding_logits_processor_async( + sampling_params=params, + tokenizer=await self.get_tokenizer(lora_request), + default_guided_backend=(self.decoding_config.guided_decoding_backend + if self.decoding_config + else DecodingConfig.guided_decoding_backend), + model_config=self.model_config, + reasoning_backend=self.decoding_config.reasoning_backend, + ) + + # 1) Create output queue for this requests. + queue: asyncio.Queue[Union[RequestOutput, + BaseException]] = asyncio.Queue() + self.output_queues[request_id] = queue + + try: + # 2) Detach logits processors so that they can be pickled + # separately (may require cloudpickle which is slower) + if isinstance(params, SamplingParams) and params.logits_processors: + # Defensive shallow copy + params = copy.copy(params) + logits_processors = params.logits_processors + params.logits_processors = None + lp_bytes = cloudpickle.dumps(logits_processors) + else: + lp_bytes = None + + request_bytes = pickle.dumps( + RPCProcessRequest( + prompt=prompt, + params=params, + request_id=request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + )) + + # 3) Send the RPCGenerateRequest to the MQLLMEngine. + parts = (request_bytes, + lp_bytes) if lp_bytes else (request_bytes, ) + await self.input_socket.send_multipart(parts, copy=False) + + # 4) Stream the RequestOutputs from the output queue. Note + # that the output_loop pushes RequestOutput objects to this + # queue after pulling them from the zmq socket. + finished = False + try: + while not finished: + request_output = await queue.get() + + if isinstance(request_output, BaseException): + raise request_output + + finished = request_output.finished + yield request_output + finally: + # Request was canceled by the client. + if not finished and not self.errored: + await self.abort(request_id) + finally: + self.output_queues.pop(request_id) + + async def start_profile(self) -> None: + """Start profiling the engine""" + + await self._send_one_way_rpc_request( + request=RPCUProfileRequest.START_PROFILE, socket=self.input_socket) + + async def stop_profile(self) -> None: + """Stop profiling the engine""" + + await self._send_one_way_rpc_request( + request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket) + + async def reset_prefix_cache(self, + device: Optional[Device] = None) -> None: + """Reset the prefix cache""" + + await self._send_one_way_rpc_request( + request=RPCResetPrefixCacheRequest(device), + socket=self.input_socket) + + async def sleep(self, level: int = 1) -> None: + """Sleep the engine for a given level""" + return await self._send_one_way_rpc_request( + request=RPCSleepRequest(level), socket=self.input_socket) + + async def wake_up(self, tags: Optional[list[str]] = None) -> None: + """Wake up the engine""" + return await self._send_one_way_rpc_request( + request=RPCWakeUpRequest(tags), socket=self.input_socket) + + async def is_sleeping(self) -> bool: + """Check whether the engine is sleeping""" + request = RPCIsSleepingRequest() + + queue: asyncio.Queue[Union[BaseException, + RPCIsSleepingResponse]] = asyncio.Queue() + self.output_queues[request.request_id] = queue + + request_bytes = pickle.dumps(request) + await self.input_socket.send_multipart((request_bytes, ), copy=False) + + request_output = await queue.get() + self.output_queues.pop(request.request_id) + + if isinstance(request_output, BaseException): + raise request_output + return request_output.is_sleeping + + async def add_lora(self, lora_request: LoRARequest) -> None: + """Load a new LoRA adapter into the engine for future requests.""" + # Uses the same I/O as generate requests + request = RPCLoadAdapterRequest(lora_request) + + # Create output queue for this requests. + queue: asyncio.Queue[Union[None, BaseException]] = asyncio.Queue() + self.output_queues[request.request_id] = queue + + # Send the request + request_bytes = pickle.dumps(request) + await self.input_socket.send_multipart((request_bytes, ), copy=False) + + # Wait for the response + request_output = await queue.get() + self.output_queues.pop(request.request_id) + + # Raise on error, otherwise happily return None + if isinstance(request_output, BaseException): + raise request_output diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py new file mode 100644 index 0000000..6ed5ae0 --- /dev/null +++ b/vllm/engine/multiprocessing/engine.py @@ -0,0 +1,450 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pickle +import signal +from contextlib import contextmanager +from typing import Iterator, List, Optional, Union + +import cloudpickle +import zmq + +from vllm import AsyncEngineArgs, SamplingParams +from vllm.config import VllmConfig +from vllm.engine.llm_engine import LLMEngine +# yapf conflicts with isort for this block +# yapf: disable +from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, + IPC_HEALTH_EXT, IPC_INPUT_EXT, + IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, + VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCAdapterLoadedResponse, RPCError, + RPCIsSleepingRequest, + RPCIsSleepingResponse, + RPCLoadAdapterRequest, + RPCProcessRequest, + RPCResetPrefixCacheRequest, + RPCSleepRequest, RPCStartupRequest, + RPCStartupResponse, + RPCUProfileRequest, RPCWakeUpRequest) +# yapf: enable +from vllm.logger import init_logger +from vllm.outputs import RequestOutput +from vllm.transformers_utils.config import ( + maybe_register_config_serialize_by_value) +from vllm.usage.usage_lib import UsageContext +from vllm.worker.model_runner_base import InputProcessingError + +logger = init_logger(__name__) + +POLLING_TIMEOUT_MS = 10000 +HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), ) + + +class MQLLMEngine: + """A multiprocessing wrapper for :class:`LLMEngine`. + + This class is used to wrap the :class:`LLMEngine` class to enable use + in concurrnet manner. It runs a background loop and uses zeromq to + receive new requests and stream outputs incrementally via ipc. + + The :class:`LLMEngine` generate or encode process is kicked off when a new + RPCProcessRequest is received by the input_socket. + + The self.engine_loop checks the input_socket for new requests, + adds them to the LLMEngine if there are any, calls the internal + :class:`LLMEngine.step()`, and sends the RequestOutputs back over + the output_socket. + + If use_async_sockets is set, the logic associated with reading new + requests from the socket and sending data to the socket is passed + as a callback to the llm_engine, which calls the logic asynchronously + such that the IPC can be overlapped with the GPU. + + Args: + ipc_path: Base path for zeromq interprocess messaging + use_async_sockets: Whether to make send/recv async with GPU + log_requests: Whether to log the requests. + *args: Arguments for :class:`LLMEngine`. + **kwargs: Arguments for :class:`LLMEngine`. + """ + + def __init__(self, + ipc_path: str, + use_async_sockets: bool, + *args, + log_requests: bool = True, + **kwargs) -> None: + # For MQLLMEngine, we can use cached outputs, since each new request + # output is immediately pickled and send over the socket, which frees + # the python object to be reused again. + kwargs['use_cached_outputs'] = True + + self.engine = LLMEngine(*args, **kwargs) + self.log_requests = log_requests + + self.use_async_sockets = use_async_sockets + if self.use_async_sockets: + self.engine.process_request_outputs_callback = \ + self._async_socket_engine_callback + + self.ctx = zmq.Context() # type: ignore[attr-defined] + + # Receive input from the client. + self.input_socket = self.ctx.socket(zmq.constants.PULL) + self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}") + + # Send output stream back to client. + self.output_socket = self.ctx.socket(zmq.constants.PUSH) + self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}") + + # Send heartbeats back to client. + self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH) + self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") + + # IPC path for the data socket. + self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" + + # Error state. + self._errored_with: Optional[BaseException] = None + + @property + def dead_error(self) -> BaseException: + if self._errored_with is not None: + return ENGINE_DEAD_ERROR(self._errored_with) + else: + return ENGINE_DEAD_ERROR() + + @classmethod + def from_vllm_config(cls, vllm_config: VllmConfig, + usage_context: UsageContext, + disable_log_requests: bool, disable_log_stats: bool, + ipc_path: str) -> "MQLLMEngine": + # Setup plugins for each process + from vllm.plugins import load_general_plugins + load_general_plugins() + + use_async_sockets = vllm_config.model_config.use_async_output_proc + + return cls( + vllm_config=vllm_config, + executor_class=LLMEngine._get_executor_cls(vllm_config), + ipc_path=ipc_path, + usage_context=usage_context, + use_async_sockets=use_async_sockets, + log_requests=(not disable_log_requests), + log_stats=(not disable_log_stats), + ) + + @staticmethod + def from_engine_args(engine_args: AsyncEngineArgs, + usage_context: UsageContext, ipc_path: str): + """Creates an MQLLMEngine from the engine arguments.""" + + vllm_config = engine_args.create_engine_config(usage_context) + return MQLLMEngine.from_vllm_config( + ipc_path=ipc_path, + vllm_config=vllm_config, + usage_context=usage_context, + disable_log_requests=engine_args.disable_log_requests, + disable_log_stats=engine_args.disable_log_stats, + ) + + def start(self): + try: + try: + logger.debug("Starting Startup Loop.") + self.run_startup_loop() + logger.debug("Starting Engine Loop.") + self.run_engine_loop() + except Exception as e: + logger.exception(repr(e)) + except KeyboardInterrupt: + logger.debug("Shutting down MQLLMEngine.") + finally: + logger.debug("MQLLMEngine is shut down.") + self.cleanup() + + def cleanup(self): + """Cleanup zeromq state on shutdown.""" + # Closes all sockets and destroys context. + self.ctx.destroy(linger=0) + del self.engine + + @contextmanager + def make_data_socket( + self) -> Iterator[zmq.Socket]: # type: ignore[name-defined] + socket = self.ctx.socket(zmq.constants.ROUTER) + try: + socket.bind(self.data_ipc_path) + yield socket + finally: + socket.close(linger=0) + + def run_startup_loop(self) -> None: + """Startup loop for sending data from Engine -> Client.""" + + with self.make_data_socket() as socket: + response: Union[RPCStartupResponse, BaseException] + try: + identity, message = socket.recv_multipart(copy=False) + request: RPCStartupRequest = pickle.loads(message.buffer) + + # Handle the query from the Client. + if request == RPCStartupRequest.IS_SERVER_READY: + tracing_enabled = self.engine.is_tracing_enabled() + response = RPCStartupResponse( + tracing_enabled=tracing_enabled) + + except Exception as e: + response = e + + socket.send_multipart((identity, pickle.dumps(response)), + copy=False) + + def run_engine_loop(self): + """Core busy loop of the LLMEngine.""" + + while True: + if not self.engine.has_unfinished_requests(): + # Poll until there is work to do. + while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: + # When there's no work, check on engine health and send + # health status back to client + self._health_check() + self.engine.do_log_stats() + logger.debug("Waiting for new requests in engine loop.") + + # Handle any input from the client. + self.handle_new_input() + + # Engine step. + request_outputs = self.engine_step() + + # Send request outputs (if async, done in engine_step callback). + if not self.use_async_sockets: + self._send_outputs(request_outputs) + + def engine_step(self) -> List[RequestOutput]: + """Engine step wrapper with error handling.""" + try: + return self.engine.step() + except SystemExit: + raise + except InputProcessingError as e: + # Special case where we handle an error preparing the inputs for + # a single request in the batch + rpc_err = RPCError(request_id=e.request_id, + is_engine_errored=False, + exception=e.__cause__) + self._send_outputs(rpc_err) + return [] + except BaseException as e: + self._set_errored(e) + rpc_err = RPCError(request_id=None, + is_engine_errored=True, + exception=e) + self._send_outputs(rpc_err) + raise e + + def handle_new_input(self): + """Handle new input from the socket""" + try: + while self.input_socket.poll(timeout=0) != 0: + frames = self.input_socket.recv_multipart(copy=False) + request = pickle.loads(frames[0].buffer) + + if isinstance(request, RPCProcessRequest): + if len(frames) > 1: + # Use cloudpickle for logits processors + assert isinstance(request.params, SamplingParams) + lprocs = cloudpickle.loads(frames[1].buffer) + request.params.logits_processors = lprocs + self._handle_process_request(request) + elif isinstance(request, RPCAbortRequest): + self._handle_abort_request(request) + elif isinstance(request, RPCUProfileRequest): + if request == RPCUProfileRequest.START_PROFILE: + self.start_profile() + else: + self.stop_profile() + elif isinstance(request, RPCLoadAdapterRequest): + self._handle_load_adapter_request(request) + elif isinstance(request, RPCResetPrefixCacheRequest): + self.reset_prefix_cache() + elif isinstance(request, RPCSleepRequest): + self.sleep(request.value) + elif isinstance(request, RPCWakeUpRequest): + self.wake_up(request.tags) + elif isinstance(request, RPCIsSleepingRequest): + self._handle_is_sleeping_request(request) + else: + raise ValueError("Unknown RPCRequest Type: " + f"{type(request)}") + + except Exception as e: + self._set_errored(e) + self._send_unhealthy(e) + raise e + + def _handle_process_request(self, request: RPCProcessRequest): + """Handle RPCProcessRequest by adding it to the LLMEngine.""" + request_id = request.request_id + + if self._errored_with is not None: + rpc_err = RPCError(request_id=request_id, + is_engine_errored=True, + exception=ENGINE_DEAD_ERROR(self._errored_with)) + self._send_outputs(rpc_err) + + try: + self.engine.add_request( + request_id=request_id, + prompt=request.prompt, + params=request.params, + lora_request=request.lora_request, + trace_headers=request.trace_headers, + prompt_adapter_request=request.prompt_adapter_request, + priority=request.priority) + + if self.log_requests: + logger.info("Added request %s.", request.request_id) + + except Exception as e: + # We do not set self._errored = True here, since the error + # is due to an issue adding this request to the engine, + # rather than an issue with the engine itself. + logger.debug("Failed to add request %s to engine. %s", + request.request_id, e) + is_errored = self._errored_with is not None + rpc_err = RPCError(request_id=request_id, + is_engine_errored=is_errored, + exception=e) + self._send_outputs(rpc_err) + + # Remove request from the engine. + self.engine.abort_request(request_id) + + def _handle_abort_request(self, request: RPCAbortRequest): + self.engine.abort_request(request.request_id) + if self.log_requests: + logger.info("Aborted request %s.", request.request_id) + + def _handle_load_adapter_request(self, request: RPCLoadAdapterRequest): + try: + self.engine.add_lora(request.lora_request) + except BaseException as e: + # Send back an error if the adater fails to load + rpc_err = RPCError(request_id=request.request_id, + is_engine_errored=False, + exception=e) + self._send_outputs(rpc_err) + return + # Otherwise, send back the successful load message + self._send_outputs( + RPCAdapterLoadedResponse(request_id=request.request_id)) + + def _handle_is_sleeping_request(self, request: RPCIsSleepingRequest): + is_sleeping = self.is_sleeping() + self._send_outputs( + RPCIsSleepingResponse(request_id=request.request_id, + is_sleeping=is_sleeping)) + + def _health_check(self): + # Send unhealthy if engine has already errored + if self._errored_with is not None: + self._send_unhealthy(self._errored_with) + try: + self.engine.check_health() + self._send_healthy() + except Exception as e: + self._set_errored(e) + self._send_unhealthy(e) + + def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): + """Send outputs back to the engine client. These can be: + - Exceptions + - A list of generation outputs + - A response from loading a lora adapter + """ + if outputs: + try: + from ray.exceptions import RayTaskError + + # RayTaskError might not pickelable here. We need to unpack the + # underlying exception as the real exception in the output. + if (isinstance(outputs, RPCError) + and isinstance(outputs.exception, RayTaskError)): + outputs.exception = outputs.exception.cause + except ImportError: + pass + + output_bytes = pickle.dumps(outputs) + self.output_socket.send_multipart((output_bytes, ), copy=False) + + def _send_healthy(self): + """Send HEALTHY message to RPCClient.""" + if not self.heartbeat_socket.closed: + self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False) + + def _send_unhealthy(self, error: BaseException): + """Send UNHEALTHY message to RPCClient.""" + if not self.heartbeat_socket.closed: + error_bytes = pickle.dumps(error) + self.heartbeat_socket.send_multipart((error_bytes, ), copy=False) + + def _async_socket_engine_callback(self, + request_outputs: REQUEST_OUTPUTS_T): + """Callback used by engine to make socket handling async with GPU.""" + self._send_outputs(request_outputs) + self.handle_new_input() + + def _set_errored(self, e: BaseException): + """Log and set errored status if this is the first issue.""" + if self._errored_with is None: + self._errored_with = e + + def start_profile(self) -> None: + self.engine.start_profile() + + def stop_profile(self) -> None: + self.engine.stop_profile() + + def reset_prefix_cache(self) -> bool: + return self.engine.reset_prefix_cache() + + def sleep(self, level: int = 1) -> None: + self.engine.sleep(level) + + def wake_up(self, tags: Optional[list[str]] = None) -> None: + self.engine.wake_up(tags) + + def is_sleeping(self) -> bool: + return self.engine.is_sleeping() + + +def signal_handler(*_) -> None: + raise KeyboardInterrupt("MQLLMEngine terminated") + + +def run_mp_engine(vllm_config: VllmConfig, usage_context: UsageContext, + ipc_path: str, disable_log_stats: bool, + disable_log_requests: bool, engine_alive): + try: + # Ensure we can serialize transformer config before spawning + maybe_register_config_serialize_by_value() + + engine = MQLLMEngine.from_vllm_config( + vllm_config=vllm_config, + usage_context=usage_context, + disable_log_stats=disable_log_stats, + disable_log_requests=disable_log_requests, + ipc_path=ipc_path) + + signal.signal(signal.SIGTERM, signal_handler) + + engine.start() + + except BaseException as e: + logger.exception(e) + engine_alive.value = False + raise e diff --git a/vllm/engine/output_processor/__init__.py b/vllm/engine/output_processor/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py new file mode 100644 index 0000000..4c8e295 --- /dev/null +++ b/vllm/engine/output_processor/interfaces.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from typing import Callable, List + +from vllm.config import SchedulerConfig +from vllm.core.scheduler import Scheduler +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import Counter + + +class SequenceGroupOutputProcessor(ABC): + """Interface for logic that processes new token ids in sequence groups, + managing detokenization, stop checking, and freeing/forking sequences with + the scheduler. + + This is highly coupled with the LLMEngine and should be seen as an extension + of it. The logic is separated to simplify the LLMEngine class and allow + separate implementations for single-step decoding (which supports beam + search sequence forking) and multi-step decoding (which does not support + beam search, but does support speculative decoding). + """ + + @staticmethod + def create_output_processor( + scheduler_config: SchedulerConfig, + detokenizer: Detokenizer, + scheduler: List[Scheduler], + seq_counter: Counter, + get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer], + stop_checker: "StopChecker", + ): + """Create an output processor. + + This returns a single-step output processor if num_lookahead_slots is + zero, else returns a multi-step output processor. + """ + if scheduler_config.num_lookahead_slots == 0: + # Importing here to avoid cycle. + from vllm.engine.output_processor.single_step import ( + SingleStepOutputProcessor) + return SingleStepOutputProcessor(scheduler_config, detokenizer, + scheduler, seq_counter, + stop_checker) + else: + # Importing here to avoid cycle. + from vllm.engine.output_processor.multi_step import ( + MultiStepOutputProcessor) + return MultiStepOutputProcessor( + detokenizer, + scheduler, + seq_counter, + get_tokenizer_for_seq, + stop_checker, + ) + + @abstractmethod + def process_outputs(self, sequence_group: SequenceGroup, + outputs: List[SequenceGroupOutput], + is_async: bool) -> None: + """Process new token ids for the sequence group. Handles logic such as + detokenization, stop checking, and freeing/forking sequences in the + scheduler. + """ + pass + + @abstractmethod + def process_prompt_logprob(self, seq_group: SequenceGroup, + outputs: List[SequenceGroupOutput]) -> None: + """Update prompt logprobs received from outputs to seq_group.""" + pass diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py new file mode 100644 index 0000000..4c5d78a --- /dev/null +++ b/vllm/engine/output_processor/multi_step.py @@ -0,0 +1,210 @@ +# SPDX-License-Identifier: Apache-2.0 + +import functools +from typing import Callable, List, cast + +from vllm.core.scheduler import Scheduler +from vllm.engine.output_processor.interfaces import ( + SequenceGroupOutputProcessor) +from vllm.engine.output_processor.single_step import ( + single_step_process_prompt_logprob) +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.logger import init_logger +from vllm.sampling_params import SamplingParams +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, + CompletionSequenceGroupOutput, Sequence, + SequenceGroup, SequenceGroupOutput, SequenceOutput, + SequenceStatus) +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import Counter + +logger = init_logger(__name__) + + +class MultiStepOutputProcessor(SequenceGroupOutputProcessor): + """SequenceGroupOutputProcessor which handles logic related to + detokenization and stopping conditions. It specializes to "multi-step + decoding", where vLLM's worker may generate multiple tokens per invocation. + This is currently mutually exclusive with advanced sampling techniques like + beam search, which motivates the separation of this logic from the single + step output processor. + + This class is responsible for things such as correctly appending all new + token ids to their sequence, detokenizing new token ids, truncating new + output tokens after an eos token, and correctly handling the case where the + number of new output tokens per sequence differs in a single batch. + """ + + def __init__( + self, + detokenizer: Detokenizer, + scheduler: List[Scheduler], + seq_counter: Counter, + get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer], + stop_checker: StopChecker, + ): + self.detokenizer = detokenizer + self.scheduler = scheduler + self.seq_counter = seq_counter + self.get_tokenizer_for_seq = get_tokenizer_for_seq + self.stop_checker = stop_checker + + def process_prompt_logprob(self, seq_group: SequenceGroup, + outputs: List[SequenceGroupOutput]) -> None: + """Process prompt logprobs associated with each step of a multi-step- + scheduled computation. + + Args: + seq_group: the outputs are associated with this :class:`SequenceGroup` + outputs: the :class:`SequenceGroupOutput`s for all scheduler steps + """ + for output in outputs: + # Concatenate single-step prompt logprob processing results. + assert isinstance(output, CompletionSequenceGroupOutput) + single_step_process_prompt_logprob(self, seq_group, output) + + @staticmethod + @functools.lru_cache + def _log_prompt_logprob_unsupported_warning_once(): + # Reminder: Please update docs/source/features/compatibility_matrix.md + # If the feature combo become valid + logger.warning( + "Prompt logprob is not supported by multi step workers. " + "(e.g., speculative decode uses multi step workers).") + + def process_outputs(self, + sequence_group: SequenceGroup, + outputs: List[SequenceGroupOutput], + is_async: bool = False) -> None: + """Append new tokens in the outputs to sequences in the sequence group. + + This only supports sequence groups of size 1. It supports greater than + one new token per sequence. + + This applies logic like stop condition checking and detokenization. + It also handles cases where there are tokens emitted after + the EOS token. + + is_async - Indicates whether this postprocessor runs in + parallel with the GPU forward pass and is processing + tokens from the previous step. If this is true, then + no tokens need to be appended since it is already done + externally (before the next schedule() call) + """ + # Sequences can be in RUNNING or FINISHED_ABORTED state + # once scheduled, as a sequence is moved to FINSIHED_ABORTED + # if a client disconnects from the api server. + seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING) + if seqs is None: + seqs = sequence_group.get_seqs( + status=SequenceStatus.FINISHED_ABORTED) + + for output in outputs: + if output.samples[0].output_token != VLLM_INVALID_TOKEN_ID: + sequence_group.metrics.spec_token_acceptance_counts[ + output.step_index] += 1 + + assert seqs, "Expected RUNNING or FINISHED_ABORTED sequences" + assert len(seqs) == 1, ( + "Beam search not supported in multi-step decoding.") + seq = seqs[0] + seq_id = seq.seq_id + # This method is defined in the more generic + # SequenceGroupOutputProcessor, but here we assume that the outputs are + # of a more specific type. + assert all([ + isinstance(output, CompletionSequenceGroupOutput) + for output in outputs + ]) + compl_outputs = cast(List[CompletionSequenceGroupOutput], outputs) + assert all([ + seq_id == output.samples[0].parent_seq_id + for output in compl_outputs + ]) + + if is_async: + # Async case: We process tokens one by one. Here, we know the token + # was already appended, so we only need to do the rest of the + # postprocessor: Detokenization + stopping logic + self._process_decode_and_stop(seq, sequence_group.sampling_params) + else: + # Standard multi-step case + + # Since there's only one sequence per sequence group, + # we can take the first sample. + samples = [output.samples[0] for output in compl_outputs] + + # entries in sample tokens may be invalid (eg. due to spec decode + # rejecting tokens). + valid_samples = [ + sample for sample in samples + if sample.output_token != VLLM_INVALID_TOKEN_ID + ] + + # When both spec-decode and pre-fill chunking are enabled, we + # don't have guaranteed samples here (e.g. all -1s). + if valid_samples: + self._process_seq_outputs(seq, valid_samples, + sequence_group.sampling_params) + + def _process_decode_and_stop(self, seq: Sequence, + sampling_params: SamplingParams) -> None: + new_char_count = 0 + if sampling_params.detokenize and self.detokenizer: + new_char_count = self.detokenizer.decode_sequence_inplace( + seq, sampling_params) + + # TODO(sang): Support lora. + self.stop_checker.maybe_stop_sequence( + seq, + new_char_count=new_char_count, + sampling_params=sampling_params, + ) + + def _process_seq_outputs(self, seq: Sequence, + valid_samples: List[SequenceOutput], + sampling_params: SamplingParams) -> None: + output_token_ids = [sample.output_token for sample in valid_samples] + output_logprobs = [sample.logprobs for sample in valid_samples] + + # Truncate to max_tokens if necessary. + remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() + + len(output_token_ids)) + if remaining_tokens < 0: + output_token_ids = output_token_ids[:remaining_tokens] + + # Truncate any tokens after EOS. This is required as spec decode + # generates a fixed number of tokens without evaluating stopping + # conditions within the block. This can cause an eos token to be + # unintentionally ignored. + if not sampling_params.ignore_eos: + eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id + # Avoiding .index calls as exception throwing in the happy path + # is expensive. + for i in range(len(output_token_ids)): + if output_token_ids[i] == eos_token_id: + output_token_ids = output_token_ids[:i + 1] + break + + is_prefill_sampled_token = seq.data.get_num_uncomputed_tokens() == 0 + # Incrementally append tokens to the sequence, as if we had only one new + # token. + for output_token_id, output_logprob in zip(output_token_ids, + output_logprobs): + seq.append_token_id( + token_id=output_token_id, + logprobs=output_logprob, + ) + + if is_prefill_sampled_token: + is_prefill_sampled_token = False + else: + # Update num_computed_tokens iff the sampled token is not from + # a prefill step. + seq.data.update_num_computed_tokens(1) + + self._process_decode_and_stop(seq, sampling_params) + + if seq.is_finished(): + break diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py new file mode 100644 index 0000000..4d96791 --- /dev/null +++ b/vllm/engine/output_processor/single_step.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List + +from vllm.config import SchedulerConfig +from vllm.core.scheduler import Scheduler +from vllm.engine.output_processor.interfaces import ( + SequenceGroupOutputProcessor) +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.logger import init_logger +from vllm.sequence import (CompletionSequenceGroupOutput, SequenceGroup, + SequenceGroupOutput) +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.utils import Counter + +logger = init_logger(__name__) + + +def single_step_process_prompt_logprob( + sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup, + output: CompletionSequenceGroupOutput) -> None: + """Process prompt logprobs associated with the :class:`SequenceGroupOutput` + for a given step. + + Do nothing if the output has no prompt logprobs. + + Account for the fact that transformers do not compute first-token logprobs. + + Args: + sg_output_proc: :class:`SequenceGroupOutputProcessor` instance + seq_group: the output is associated with this :class:`SequenceGroup` + output: the :class:`SequenceGroupOutput` for a single scheduler step + """ + prompt_logprobs = output.prompt_logprobs + + # If this is the first (or only) "chunk" of the prefill, we need + # to prepend None to the list of prompt logprobs. The reason for this + # is that for N prompt tokens, the Sampler will generate N-1 total + # prompt logprobs during prefill since the token at idx 0 will not + # have a logprob associated with it. + if prompt_logprobs is not None: + if not seq_group.prompt_logprobs: + prompt_logprobs = [None] + prompt_logprobs + seq_group.prompt_logprobs = [] + + assert hasattr(sg_output_proc, 'detokenizer') + if (seq_group.sampling_params.detokenize + and sg_output_proc.detokenizer): + sg_output_proc.detokenizer.decode_prompt_logprobs_inplace( + seq_group, + prompt_logprobs, + position_offset=len(seq_group.prompt_logprobs)) + + seq_group.prompt_logprobs.extend(prompt_logprobs) + + +class SingleStepOutputProcessor(SequenceGroupOutputProcessor): + """SequenceGroupOutputProcessor which handles "output processing" logic, + which happens after the model returns generated token ids and before + scheduling of the next batch. Output processing logic includes + detokenization, and determining if a sequence is finished (e.g. via max len + or eos token). + + The SingleStepOutputProcessor is specialized to the case where the model + emits at most a single token per invocation, which precludes configurations + such as speculative decoding or multi-step decoding. This enables beam + search sampling, which requires forking/finishing/freeing sequences in a way + that is currently difficult to schedule multiple steps ahead of time. + """ + + def __init__(self, scheduler_config: SchedulerConfig, + detokenizer: Detokenizer, scheduler: List[Scheduler], + seq_counter: Counter, stop_checker: StopChecker): + self.scheduler_config = scheduler_config + self.detokenizer = detokenizer + self.scheduler = scheduler + self.seq_counter = seq_counter + self.stop_checker = stop_checker + + def process_outputs(self, sequence_group: SequenceGroup, + outputs: List[SequenceGroupOutput], + is_async: bool) -> None: + """Append all new tokens to sequences in the sequence group. Fork any + surviving beam candidates; free any unsurviving ones. + + Invokes detokenizer to detokenize new tokens, and also marks sequences + as finished if they meet stop conditions. + + is_async - Indicates whether this postprocessor runs in + parallel with the GPU forward pass and is processing + tokens from the previous step. If this is true, then + no tokens need to be appended since it is already done + externally (before the next schedule() call) + """ + assert (len(outputs) == 1 + ), f"{type(self)} does not support multiple outputs per step" + return self._process_sequence_group_outputs(sequence_group, outputs[0], + is_async) + + def process_prompt_logprob(self, seq_group: SequenceGroup, + outputs: List[SequenceGroupOutput]) -> None: + """Process prompt logprobs associated with one step of a single-step- + scheduled computation. + + Args: + seq_group: the output is associated with this :class:`SequenceGroup` + outputs: the :class:`SequenceGroupOutput` for a single scheduler step + """ + assert len(outputs) == 1, "Single step should only have 1 output." + output = outputs[0] + assert isinstance(output, CompletionSequenceGroupOutput) + single_step_process_prompt_logprob(self, seq_group, output) + + def _process_sequence_group_outputs(self, seq_group: SequenceGroup, + outputs: SequenceGroupOutput, + is_async: bool) -> None: + sampling_params = seq_group.sampling_params + + sample = outputs.samples[0] + seq = seq_group.first_seq + if not is_async: + seq.append_token_id(sample.output_token, sample.logprobs) + if sampling_params.detokenize and self.detokenizer: + new_char_count = self.detokenizer.decode_sequence_inplace( + seq, sampling_params) + else: + new_char_count = 0 + self.stop_checker.maybe_stop_sequence( + seq, + new_char_count, + sampling_params, + lora_req=seq_group.lora_request, + ) + if seq.is_finished(): + for scheduler in self.scheduler: + scheduler.free_seq(seq) diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py new file mode 100644 index 0000000..6cad9ec --- /dev/null +++ b/vllm/engine/output_processor/stop_checker.py @@ -0,0 +1,130 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable, List, Optional, Tuple + +from vllm.lora.request import LoRARequest +from vllm.sampling_params import SamplingParams +from vllm.sequence import Sequence, SequenceStatus +from vllm.transformers_utils.tokenizer import AnyTokenizer + + +class StopChecker: + """LLMEngine helper class which separates out the logic involving stop + checking. This checks things such as: whether the eos token was emitted, + whether the max_tokens has been consumed, whether a stop string has been + emitted, or if we have exceeded the max model len. + """ + + def __init__(self, max_model_len: int, + get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer]): + # Do not use it directly, but use `self._get_max_model_len`. + self._max_model_len = max_model_len + self.get_tokenizer_for_seq = get_tokenizer_for_seq + + def _get_max_model_len(self, lora_req: Optional[LoRARequest]): + if lora_req and lora_req.long_lora_max_len: + return lora_req.long_lora_max_len + else: + return self._max_model_len + + def maybe_stop_sequence( + self, + seq: Sequence, + new_char_count: int, + sampling_params: SamplingParams, + lora_req: Optional[LoRARequest] = None, + ) -> None: + """Stop the finished sequences. + + new_char_count is the number of chars added to the + sequence's output text for the newly generated token + """ + + # Check if the minimum number of tokens has been generated yet; + # skip the stop string/token checks if not + if seq.get_output_len() < sampling_params.min_tokens: + return + + # Check if the sequence has generated the EOS token. + if ((not sampling_params.ignore_eos) + and seq.get_last_token_id() == seq.eos_token_id): + # Remove the last EOS token unless explicitly specified + # This prevents unintended exposure of the EOS token + if new_char_count and ( + not sampling_params.include_stop_str_in_output): + seq.output_text = seq.output_text[:-new_char_count] + seq.status = SequenceStatus.FINISHED_STOPPED + return + + # Check if a stop token was encountered. + # This assumes a single token produced per step. + last_token_id = seq.get_last_token_id() + if last_token_id in (sampling_params.stop_token_ids or ()): + if new_char_count and ( + not sampling_params.include_stop_str_in_output): + # Remove last token + seq.output_text = seq.output_text[:-new_char_count] + seq.status = SequenceStatus.FINISHED_STOPPED + seq.stop_reason = last_token_id + return + + # Check if any stop strings are matched. + stop = self.check_stop_strings( + seq.output_text, new_char_count, sampling_params.stop, + sampling_params.include_stop_str_in_output) + if stop is not None: + stop_str, truncate_to = stop + if truncate_to != -1: + seq.output_text = seq.output_text[:truncate_to] + seq.status = SequenceStatus.FINISHED_STOPPED + seq.stop_reason = stop_str + return + + # Check if the sequence has reached max_model_len. + if seq.get_len() > self._get_max_model_len(lora_req): + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + # Check if the sequence has reached max_tokens. + if seq.get_output_len() == sampling_params.max_tokens: + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + @staticmethod + def check_stop_strings( + output_text: str, + new_char_count: int, + stop: List[str], + include_in_output: bool, + ) -> Optional[Tuple[str, int]]: + """Check if any stop strings are matched and truncate sequence + output text accordingly. + + Returns tuple (stop_string, offset) if matched or else None. + + Where stop_string is the matched stop string and offset is the + length to which output_text should be truncated, or -1 for no + truncation. + """ + if not new_char_count or not stop: + return None + + for stop_str in stop: + stop_string_len = len(stop_str) + # Avoid searching already-searched text. + stop_index = output_text.find(stop_str, + 1 - new_char_count - stop_string_len) + if stop_index == -1: + continue + + if include_in_output: + # Truncate to end of stop string. + stop_index += stop_string_len + if stop_index >= len(output_text): + # No truncation required. + return stop_str, -1 + + # Truncate the output text to either the beginning + # or end of the stop string. + return stop_str, stop_index + return None diff --git a/vllm/engine/output_processor/util.py b/vllm/engine/output_processor/util.py new file mode 100644 index 0000000..0d2b58c --- /dev/null +++ b/vllm/engine/output_processor/util.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List +from typing import Sequence as GenericSequence +from typing import cast + +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import CompletionSequenceGroupOutput, SequenceGroupOutput + + +def create_output_by_sequence_group( + outputs: GenericSequence[SamplerOutput], + num_seq_groups: int) -> List[List[SequenceGroupOutput]]: + """Helper method which transforms a 2d list organized by + [step][sequence group] into [sequence group][step]. + """ + output_by_sequence_group: List[List[CompletionSequenceGroupOutput]] = [ + [] for _ in range(num_seq_groups) + ] + for step in outputs: + sequence_group_output: CompletionSequenceGroupOutput + for i, sequence_group_output in enumerate(step): + output_by_sequence_group[i].append(sequence_group_output) + + # Cast to the more generic type that CompletionSequenceGroupOutput + # inherits from. + return cast(List[List[SequenceGroupOutput]], output_by_sequence_group) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py new file mode 100644 index 0000000..e2974b0 --- /dev/null +++ b/vllm/engine/protocol.py @@ -0,0 +1,297 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +from abc import ABC, abstractmethod +from typing import AsyncGenerator, List, Mapping, Optional + +from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function +from vllm.config import DecodingConfig, ModelConfig +from vllm.core.scheduler import SchedulerOutputs +from vllm.inputs.data import PromptType, TokensPrompt +from vllm.inputs.parse import is_explicit_encoder_decoder_prompt +from vllm.inputs.preprocess import InputPreprocessor +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import Device, collect_from_async_generator, random_uuid + +logger = init_logger(__name__) + + +class EngineClient(ABC): + """Protocol class for Clients to Engine""" + + @property + @abstractmethod + def is_running(self) -> bool: + ... + + @property + @abstractmethod + def is_stopped(self) -> bool: + ... + + @property + @abstractmethod + def errored(self) -> bool: + ... + + @property + @abstractmethod + def dead_error(self) -> BaseException: + ... + + @abstractmethod + def generate( + self, + prompt: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> AsyncGenerator[RequestOutput, None]: + """Generate outputs for a request.""" + ... + + async def beam_search( + self, + prompt: PromptType, + request_id: str, + params: BeamSearchParams, + ) -> AsyncGenerator[RequestOutput, None]: + + beam_width = params.beam_width + max_tokens = params.max_tokens + ignore_eos = params.ignore_eos + temperature = params.temperature + length_penalty = params.length_penalty + include_stop_str_in_output = params.include_stop_str_in_output + + preprocessor = await self.get_input_preprocessor() + tokenizer_group = preprocessor.get_tokenizer_group() + tokenizer = await tokenizer_group.get_lora_tokenizer_async() + + if is_explicit_encoder_decoder_prompt(prompt): + raise NotImplementedError + else: + processed_inputs = preprocessor._prompt_to_llm_inputs(prompt) + + prompt_token_ids = processed_inputs["prompt_token_ids"] + prompt_text = processed_inputs.get("prompt") + multi_modal_data = processed_inputs.get("multi_modal_data") + mm_processor_kwargs = processed_inputs.get("mm_processor_kwargs") + + tokenized_length = len(prompt_token_ids) + + sort_beams_key = create_sort_beams_key_function( + tokenizer.eos_token_id, length_penalty) + + beam_search_params = SamplingParams( + logprobs=2 * beam_width, + max_tokens=1, + temperature=temperature, + ) + all_beams = [ + BeamSearchSequence(tokens=prompt_token_ids, + cum_logprob=0, + logprobs=[], + multi_modal_data=multi_modal_data, + mm_processor_kwargs=mm_processor_kwargs) + ] + completed = [] + + for _ in range(max_tokens): + prompts_batch = [ + TokensPrompt(prompt_token_ids=beam.tokens, + multi_modal_data=beam.multi_modal_data, + mm_processor_kwargs=beam.mm_processor_kwargs) + for beam in all_beams + ] + + tasks = [] + + request_id = f"beam_search-{random_uuid()}" + for i, individual_prompt in enumerate(prompts_batch): + request_id_item = f"{request_id}-{i}" + task = asyncio.create_task( + collect_from_async_generator( + self.generate(individual_prompt, beam_search_params, + request_id_item))) + tasks.append(task) + + output = await asyncio.gather(*tasks) + + output = [x[0] for x in output] + + new_beams = [] + for i, current_beam in enumerate(all_beams): + result = output[i] + + if result.outputs[0].logprobs is not None: + logprobs = result.outputs[0].logprobs[0] + for token_id, logprob_obj in logprobs.items(): + if token_id == tokenizer.eos_token_id and \ + not ignore_eos: + completed.append( + BeamSearchSequence( + tokens=current_beam.tokens + + [token_id] if include_stop_str_in_output + else current_beam.tokens, + logprobs=current_beam.logprobs + + [logprobs], + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob, + finish_reason="stop", + stop_reason=tokenizer.eos_token_id)) + else: + new_beams.append( + BeamSearchSequence( + tokens=current_beam.tokens + [token_id], + logprobs=current_beam.logprobs + + [logprobs], + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob, + multi_modal_data=current_beam. + multi_modal_data, + mm_processor_kwargs=current_beam. + mm_processor_kwargs)) + + sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True) + all_beams = sorted_beams[:beam_width] + + completed.extend(all_beams) + sorted_completed = sorted(completed, key=sort_beams_key, reverse=True) + best_beams = sorted_completed[:beam_width] + + for beam in best_beams: + if (beam.tokens[-1] == tokenizer.eos_token_id and not ignore_eos): + # Skip the eos token in the text. + tokens = beam.tokens[tokenized_length:-1] + else: + tokens = beam.tokens[tokenized_length:] + beam.text = tokenizer.decode(tokens) + + beam_search_output = RequestOutput( + request_id=request_id, + prompt=prompt_text, + outputs=[ + CompletionOutput(text=beam.text, + cumulative_logprob=beam.cum_logprob, + token_ids=beam.tokens[tokenized_length:], + index=i, + logprobs=beam.logprobs, + finish_reason=beam.finish_reason if + beam.finish_reason is not None else "length", + stop_reason=beam.stop_reason) + for (i, beam) in enumerate(best_beams) + ], + finished=True, + prompt_token_ids=prompt_token_ids, + prompt_logprobs=None) + + yield beam_search_output + + @abstractmethod + def encode( + self, + prompt: PromptType, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, + ) -> AsyncGenerator[PoolingRequestOutput, None]: + """Generate outputs for a request from a pooling model.""" + ... + + @abstractmethod + async def abort(self, request_id: str) -> None: + """Abort a request. + + Args: + request_id: The unique id of the request. + """ + ... + + @abstractmethod + async def get_model_config(self) -> ModelConfig: + """Get the model configuration of the vLLM engine.""" + ... + + @abstractmethod + async def get_decoding_config(self) -> DecodingConfig: + """Get the decoding configuration of the vLLM engine.""" + ... + + @abstractmethod + async def get_input_preprocessor(self) -> InputPreprocessor: + """Get the input processor of the vLLM engine.""" + ... + + @abstractmethod + async def get_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + """Get the appropriate tokenizer for the request""" + ... + + @abstractmethod + async def is_tracing_enabled(self) -> bool: + ... + + @abstractmethod + async def do_log_stats( + self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[List[SamplerOutput]] = None, + ) -> None: + ... + + @abstractmethod + async def check_health(self) -> None: + """Raise if unhealthy""" + ... + + @abstractmethod + async def start_profile(self) -> None: + """Start profiling the engine""" + ... + + @abstractmethod + async def stop_profile(self) -> None: + """Start profiling the engine""" + ... + + @abstractmethod + async def reset_prefix_cache(self, + device: Optional[Device] = None) -> None: + """Reset the prefix cache""" + ... + + @abstractmethod + async def sleep(self, level: int = 1) -> None: + """Sleep the engine""" + ... + + @abstractmethod + async def wake_up(self, tags: Optional[list[str]] = None) -> None: + """Wake up the engine""" + ... + + @abstractmethod + async def is_sleeping(self) -> bool: + """Check whether the engine is sleeping""" + ... + + @abstractmethod + async def add_lora(self, lora_request: LoRARequest) -> None: + """Load a new LoRA adapter into the engine for future requests.""" + ... diff --git a/vllm/entrypoints/__init__.py b/vllm/entrypoints/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py new file mode 100644 index 0000000..c81ff95 --- /dev/null +++ b/vllm/entrypoints/api_server.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +NOTE: This API server is used only for demonstrating usage of AsyncEngine +and simple performance benchmarks. It is not intended for production use. +For production use, we recommend using our OpenAI compatible server. +We are also not going to accept PRs modifying this file, please +change `vllm/entrypoints/openai/api_server.py` instead. +""" +import asyncio +import json +import ssl +from argparse import Namespace +from collections.abc import AsyncGenerator +from typing import Any, Optional + +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, Response, StreamingResponse + +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.launcher import serve_http +from vllm.entrypoints.utils import with_cancellation +from vllm.logger import init_logger +from vllm.sampling_params import SamplingParams +from vllm.usage.usage_lib import UsageContext +from vllm.utils import FlexibleArgumentParser, random_uuid, set_ulimit +from vllm.version import __version__ as VLLM_VERSION + +logger = init_logger("vllm.entrypoints.api_server") + +TIMEOUT_KEEP_ALIVE = 5 # seconds. +app = FastAPI() +engine = None + + +@app.get("/health") +async def health() -> Response: + """Health check.""" + return Response(status_code=200) + + +@app.post("/generate") +async def generate(request: Request) -> Response: + """Generate completion for the request. + + The request should be a JSON object with the following fields: + - prompt: the prompt to use for the generation. + - stream: whether to stream the results or not. + - other fields: the sampling parameters (See `SamplingParams` for details). + """ + request_dict = await request.json() + return await _generate(request_dict, raw_request=request) + + +@with_cancellation +async def _generate(request_dict: dict, raw_request: Request) -> Response: + prompt = request_dict.pop("prompt") + stream = request_dict.pop("stream", False) + sampling_params = SamplingParams(**request_dict) + request_id = random_uuid() + + assert engine is not None + results_generator = engine.generate(prompt, sampling_params, request_id) + + # Streaming case + async def stream_results() -> AsyncGenerator[bytes, None]: + async for request_output in results_generator: + prompt = request_output.prompt + assert prompt is not None + text_outputs = [ + prompt + output.text for output in request_output.outputs + ] + ret = {"text": text_outputs} + yield (json.dumps(ret) + "\n").encode("utf-8") + + if stream: + return StreamingResponse(stream_results()) + + # Non-streaming case + final_output = None + try: + async for request_output in results_generator: + final_output = request_output + except asyncio.CancelledError: + return Response(status_code=499) + + assert final_output is not None + prompt = final_output.prompt + assert prompt is not None + text_outputs = [prompt + output.text for output in final_output.outputs] + ret = {"text": text_outputs} + return JSONResponse(ret) + + +def build_app(args: Namespace) -> FastAPI: + global app + + app.root_path = args.root_path + return app + + +async def init_app( + args: Namespace, + llm_engine: Optional[AsyncLLMEngine] = None, +) -> FastAPI: + app = build_app(args) + + global engine + + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = (llm_engine + if llm_engine is not None else AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.API_SERVER)) + + return app + + +async def run_server(args: Namespace, + llm_engine: Optional[AsyncLLMEngine] = None, + **uvicorn_kwargs: Any) -> None: + logger.info("vLLM API server version %s", VLLM_VERSION) + logger.info("args: %s", args) + + set_ulimit() + + app = await init_app(args, llm_engine) + assert engine is not None + + shutdown_task = await serve_http( + app, + sock=None, + enable_ssl_refresh=args.enable_ssl_refresh, + host=args.host, + port=args.port, + log_level=args.log_level, + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + **uvicorn_kwargs, + ) + + await shutdown_task + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + parser.add_argument("--host", type=str, default=None) + parser.add_argument("--port", type=parser.check_port, default=8000) + parser.add_argument("--ssl-keyfile", type=str, default=None) + parser.add_argument("--ssl-certfile", type=str, default=None) + parser.add_argument("--ssl-ca-certs", + type=str, + default=None, + help="The CA certificates file") + parser.add_argument( + "--enable-ssl-refresh", + action="store_true", + default=False, + help="Refresh SSL Context when SSL certificate files change") + parser.add_argument( + "--ssl-cert-reqs", + type=int, + default=int(ssl.CERT_NONE), + help="Whether client certificate is required (see stdlib ssl module's)" + ) + parser.add_argument( + "--root-path", + type=str, + default=None, + help="FastAPI root_path when app is behind a path based routing proxy") + parser.add_argument("--log-level", type=str, default="debug") + parser = AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() + + asyncio.run(run_server(args)) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py new file mode 100644 index 0000000..e48ffae --- /dev/null +++ b/vllm/entrypoints/chat_utils.py @@ -0,0 +1,1202 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import json +from abc import ABC, abstractmethod +from collections import defaultdict, deque +from collections.abc import Awaitable, Iterable +from functools import cache, lru_cache, partial +from pathlib import Path +from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union, + cast) + +import jinja2.nodes +import transformers.utils.chat_template_utils as hf_chat_utils +# yapf conflicts with isort for this block +# yapf: disable +from openai.types.chat import (ChatCompletionAssistantMessageParam, + ChatCompletionContentPartImageParam, + ChatCompletionContentPartInputAudioParam) +from openai.types.chat import ( + ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam) +from openai.types.chat import (ChatCompletionContentPartRefusalParam, + ChatCompletionContentPartTextParam) +from openai.types.chat import ( + ChatCompletionMessageParam as OpenAIChatCompletionMessageParam) +from openai.types.chat import (ChatCompletionMessageToolCallParam, + ChatCompletionToolMessageParam) +from openai.types.chat.chat_completion_content_part_input_audio_param import ( + InputAudio) +# yapf: enable +# pydantic needs the TypedDict from typing_extensions +from transformers import (PreTrainedTokenizer, PreTrainedTokenizerFast, + ProcessorMixin) +from typing_extensions import Required, TypeAlias, TypedDict + +from vllm.config import ModelConfig +from vllm.logger import init_logger +from vllm.multimodal import MultiModalDataDict +from vllm.multimodal.utils import MediaConnector +from vllm.transformers_utils.processor import cached_get_processor +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer + +logger = init_logger(__name__) + + +class AudioURL(TypedDict, total=False): + url: Required[str] + """ + Either a URL of the audio or a data URL with base64 encoded audio data. + """ + + +class ChatCompletionContentPartAudioParam(TypedDict, total=False): + audio_url: Required[AudioURL] + + type: Required[Literal["audio_url"]] + """The type of the content part.""" + + +class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False): + image_embeds: Required[Union[str, dict[str, str]]] + """ + The image embeddings. It can be either: + - A single base64 string. + - A dictionary where each value is a base64 string. + """ + type: Required[Literal["image_embeds"]] + """The type of the content part.""" + + +class VideoURL(TypedDict, total=False): + url: Required[str] + """ + Either a URL of the video or a data URL with base64 encoded video data. + """ + + +class ChatCompletionContentPartVideoParam(TypedDict, total=False): + video_url: Required[VideoURL] + + type: Required[Literal["video_url"]] + """The type of the content part.""" + + +class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False): + """A simpler version of the param that only accepts a plain image_url. + This is supported by OpenAI API, although it is not documented. + + Example: + { + "image_url": "https://example.com/image.jpg" + } + """ + image_url: Required[str] + + +class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False): + """A simpler version of the param that only accepts a plain audio_url. + + Example: + { + "audio_url": "https://example.com/audio.mp3" + } + """ + audio_url: Required[str] + + +class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False): + """A simpler version of the param that only accepts a plain audio_url. + + Example: + { + "video_url": "https://example.com/video.mp4" + } + """ + video_url: Required[str] + + +ChatCompletionContentPartParam: TypeAlias = Union[ + OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam, + ChatCompletionContentPartInputAudioParam, + ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam, + CustomChatCompletionContentSimpleImageParam, + ChatCompletionContentPartImageEmbedsParam, + CustomChatCompletionContentSimpleAudioParam, + CustomChatCompletionContentSimpleVideoParam, str] + + +class CustomChatCompletionMessageParam(TypedDict, total=False): + """Enables custom roles in the Chat Completion API.""" + role: Required[str] + """The role of the message's author.""" + + content: Union[str, list[ChatCompletionContentPartParam]] + """The contents of the message.""" + + name: str + """An optional name for the participant. + + Provides the model information to differentiate between participants of the + same role. + """ + + tool_call_id: Optional[str] + """Tool call that this message is responding to.""" + + tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]] + """The tool calls generated by the model, such as function calls.""" + + +ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam, + CustomChatCompletionMessageParam] + + +# TODO: Make fields ReadOnly once mypy supports it +class ConversationMessage(TypedDict, total=False): + role: Required[str] + """The role of the message's author.""" + + content: Union[Optional[str], list[dict[str, str]]] + """The contents of the message""" + + tool_call_id: Optional[str] + """Tool call that this message is responding to.""" + + name: Optional[str] + """The name of the function to call""" + + tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]] + """The tool calls generated by the model, such as function calls.""" + + +# Passed in by user +ChatTemplateContentFormatOption = Literal["auto", "string", "openai"] + +# Used internally +_ChatTemplateContentFormat = Literal["string", "openai"] + + +def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool: + if isinstance(node, jinja2.nodes.Name): + return node.ctx == "load" and node.name == varname + + return False + + +def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool: + if isinstance(node, jinja2.nodes.Getitem): + return (_is_var_access(node.node, varname) + and isinstance(node.arg, jinja2.nodes.Const) + and node.arg.value == key) + + if isinstance(node, jinja2.nodes.Getattr): + return _is_var_access(node.node, varname) and node.attr == key + + return False + + +def _is_var_or_elems_access( + node: jinja2.nodes.Node, + varname: str, + key: Optional[str] = None, +) -> bool: + if isinstance(node, jinja2.nodes.Filter): + return (node.node is not None + and _is_var_or_elems_access(node.node, varname, key)) + if isinstance(node, jinja2.nodes.Test): + return _is_var_or_elems_access(node.node, varname, key) + + if (isinstance(node, jinja2.nodes.Getitem) + and isinstance(node.arg, jinja2.nodes.Slice)): + return _is_var_or_elems_access(node.node, varname, key) + + # yapf: disable + return ( + _is_attr_access(node, varname, key) if key + else _is_var_access(node, varname) + ) # yapf: enable + + +def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str): + # Global variable that is implicitly defined at the root + yield root, varname + + # Iterative BFS + related_varnames = deque([varname]) + while related_varnames: + related_varname = related_varnames.popleft() + + for assign_ast in root.find_all(jinja2.nodes.Assign): + lhs = assign_ast.target + rhs = assign_ast.node + + if _is_var_or_elems_access(rhs, related_varname): + assert isinstance(lhs, jinja2.nodes.Name) + yield assign_ast, lhs.name + + # Avoid infinite looping for self-assignment + if lhs.name != related_varname: + related_varnames.append(lhs.name) + + +# NOTE: The proper way to handle this is to build a CFG so that we can handle +# the scope in which each variable is defined, but that is too complicated +def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node): + messages_varnames = [ + varname + for _, varname in _iter_nodes_assign_var_or_elems(root, "messages") + ] + + # Search for {%- for message in messages -%} loops + for loop_ast in root.find_all(jinja2.nodes.For): + loop_iter = loop_ast.iter + loop_target = loop_ast.target + + for varname in messages_varnames: + if _is_var_or_elems_access(loop_iter, varname): + assert isinstance(loop_target, jinja2.nodes.Name) + yield loop_ast, loop_target.name + break + + +def _iter_nodes_assign_content_item(root: jinja2.nodes.Node): + message_varnames = [ + varname for _, varname in _iter_nodes_assign_messages_item(root) + ] + + # Search for {%- for content in message['content'] -%} loops + for loop_ast in root.find_all(jinja2.nodes.For): + loop_iter = loop_ast.iter + loop_target = loop_ast.target + + for varname in message_varnames: + if _is_var_or_elems_access(loop_iter, varname, "content"): + assert isinstance(loop_target, jinja2.nodes.Name) + yield loop_ast, loop_target.name + break + + +def _try_extract_ast(chat_template: str) -> Optional[jinja2.nodes.Template]: + try: + jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template) + return jinja_compiled.environment.parse(chat_template) + except Exception: + logger.exception("Error when compiling Jinja template") + return None + + +def _detect_content_format( + chat_template: str, + *, + default: _ChatTemplateContentFormat, +) -> _ChatTemplateContentFormat: + jinja_ast = _try_extract_ast(chat_template) + if jinja_ast is None: + return default + + try: + next(_iter_nodes_assign_content_item(jinja_ast)) + except StopIteration: + return "string" + except Exception: + logger.exception("Error when parsing AST of Jinja template") + return default + else: + return "openai" + + +def resolve_mistral_chat_template( + chat_template: Optional[str], + **kwargs: Any, +) -> Optional[str]: + if chat_template is not None: + logger.warning_once( + "'chat_template' cannot be overridden for mistral tokenizer.") + if "add_generation_prompt" in kwargs: + logger.warning_once( + "'add_generation_prompt' is not supported for mistral tokenizer, " + "so it will be ignored.") + if "continue_final_message" in kwargs: + logger.warning_once( + "'continue_final_message' is not supported for mistral tokenizer, " + "so it will be ignored.") + return None + +def resolve_hf_chat_template( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + chat_template: Optional[str], + tools: Optional[list[dict[str, Any]]], + *, + trust_remote_code: bool, +) -> Optional[str]: + # 1st priority: The given chat template + if chat_template is not None: + return chat_template + + # 2nd priority: AutoProcessor chat template, unless tool calling is enabled + if tools is None: + try: + processor = cached_get_processor( + tokenizer.name_or_path, + processor_cls=(PreTrainedTokenizer, PreTrainedTokenizerFast, + ProcessorMixin), + trust_remote_code=trust_remote_code, + ) + if isinstance(processor, ProcessorMixin) and \ + processor.chat_template is not None: + return processor.chat_template + except Exception: + logger.debug("Failed to load AutoProcessor chat template for %s", + tokenizer.name_or_path, exc_info=True) + + # 3rd priority: AutoTokenizer chat template + try: + return tokenizer.get_chat_template(chat_template, tools=tools) + except Exception: + logger.debug("Failed to load AutoTokenizer chat template for %s", + tokenizer.name_or_path, exc_info=True) + + return None + + +def _resolve_chat_template_content_format( + chat_template: Optional[str], + tools: Optional[list[dict[str, Any]]], + given_format: ChatTemplateContentFormatOption, + tokenizer: AnyTokenizer, + *, + trust_remote_code: bool, +) -> _ChatTemplateContentFormat: + if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): + hf_chat_template = resolve_hf_chat_template( + tokenizer, + chat_template=chat_template, + trust_remote_code=trust_remote_code, + tools=tools, + ) + else: + hf_chat_template = None + + jinja_text = (hf_chat_template if isinstance(hf_chat_template, str) + else load_chat_template(chat_template, is_literal=True)) + + detected_format = ("string" if jinja_text is None else + _detect_content_format(jinja_text, default="string")) + + return detected_format if given_format == "auto" else given_format + + +@lru_cache +def _log_chat_template_content_format( + chat_template: Optional[str], + given_format: ChatTemplateContentFormatOption, + detected_format: ChatTemplateContentFormatOption, +): + logger.info( + "Detected the chat template content format to be '%s'. " + "You can set `--chat-template-content-format` to override this.", + detected_format, + ) + + if given_format != "auto" and given_format != detected_format: + logger.warning( + "You specified `--chat-template-content-format %s` " + "which is different from the detected format '%s'. " + "If our automatic detection is incorrect, please consider " + "opening a GitHub issue so that we can improve it: " + "https://github.com/vllm-project/vllm/issues/new/choose", + given_format, + detected_format, + ) + + +def resolve_chat_template_content_format( + chat_template: Optional[str], + tools: Optional[list[dict[str, Any]]], + given_format: ChatTemplateContentFormatOption, + tokenizer: AnyTokenizer, + *, + trust_remote_code: bool = False, +) -> _ChatTemplateContentFormat: + detected_format = _resolve_chat_template_content_format( + chat_template, + tools, + given_format, + tokenizer, + trust_remote_code=trust_remote_code, + ) + + _log_chat_template_content_format( + chat_template, + given_format=given_format, + detected_format=detected_format, + ) + + return detected_format + + +ModalityStr = Literal["image", "audio", "video", "image_embeds"] +_T = TypeVar("_T") + + +class BaseMultiModalItemTracker(ABC, Generic[_T]): + """ + Tracks multi-modal items in a given request and ensures that the number + of multi-modal items in a given request does not exceed the configured + maximum per prompt. + """ + + def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer): + super().__init__() + + self._model_config = model_config + self._tokenizer = tokenizer + self._allowed_items = (model_config.multimodal_config.limit_per_prompt + if model_config.multimodal_config else {}) + + self._items_by_modality = defaultdict[str, list[_T]](list) + + @property + def model_config(self) -> ModelConfig: + return self._model_config + + @property + def allowed_local_media_path(self): + return self._model_config.allowed_local_media_path + + @staticmethod + @cache + def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str: + return tokenizer.decode(token_index) + + def _placeholder_str(self, modality: ModalityStr, + current_count: int) -> Optional[str]: + # TODO: Let user specify how to insert image tokens into prompt + # (similar to chat template) + hf_config = self._model_config.hf_config + model_type = hf_config.model_type + + if modality in ("image", "image_embeds"): + if model_type == "chatglm": + return "<|begin_of_image|><|endoftext|><|end_of_image|>" + if model_type == "phi3_v": + # Workaround since this token is not defined in the tokenizer + return f"<|image_{current_count}|>" + if model_type == "phi4mm": + return "<|endoftext10|>" # 200010 (see vocab.json in hf model) + if model_type in ("minicpmo", "minicpmv"): + return "(./)" + if model_type in ("blip-2", "fuyu", "paligemma", "pixtral", + "mistral3"): + # These models do not use image tokens in the prompt + return None + if model_type == "qwen": + return f"Picture {current_count}: " + if model_type.startswith("llava"): + return self._cached_token_str(self._tokenizer, + hf_config.image_token_index) + if model_type in ("aya_vision", "chameleon", "deepseek_vl_v2", + "internvl_chat", "skywork_chat", "NVLM_D", + "h2ovl_chat"): + return "" + if model_type in ("mllama", "llama4"): + return "<|image|>" + if model_type in ("qwen2_vl", "qwen2_5_vl"): + return "<|vision_start|><|image_pad|><|vision_end|>" + if model_type == "molmo": + return "" + if model_type == "idefics3": + return "" + if model_type == "aria": + return "<|fim_prefix|><|img|><|fim_suffix|>" + if model_type == "gemma3": + return "" + + raise TypeError(f"Unknown {modality} model type: {model_type}") + elif modality == "audio": + if model_type == "ultravox": + return "<|audio|>" + if model_type == "phi4mm": + return "<|endoftext11|>" # 200011 (see vocab.json in hf model) + if model_type == "qwen2_audio": + return (f"Audio {current_count}: " + f"<|audio_bos|><|AUDIO|><|audio_eos|>") + if model_type == "minicpmo": + return "()" + raise TypeError(f"Unknown model type: {model_type}") + elif modality == "video": + if model_type in ("qwen2_vl", "qwen2_5_vl"): + return "<|vision_start|><|video_pad|><|vision_end|>" + if model_type in ("minicpmo", "minicpmv"): + return "()" + if model_type.startswith("llava"): + return self._cached_token_str(self._tokenizer, + hf_config.video_token_index) + raise TypeError(f"Unknown {modality} model type: {model_type}") + else: + raise TypeError(f"Unknown modality: {modality}") + + def add(self, modality: ModalityStr, item: _T) -> Optional[str]: + """ + Add a multi-modal item to the current prompt and returns the + placeholder string to use, if any. + """ + allowed_count = self._allowed_items.get(modality, 1) + current_count = len(self._items_by_modality[modality]) + 1 + if current_count > allowed_count: + raise ValueError( + f"At most {allowed_count} {modality}(s) may be provided in " + "one request.") + + self._items_by_modality[modality].append(item) + + return self._placeholder_str(modality, current_count) + + @abstractmethod + def create_parser(self) -> "BaseMultiModalContentParser": + raise NotImplementedError + + +class MultiModalItemTracker(BaseMultiModalItemTracker[object]): + + def all_mm_data(self) -> Optional[MultiModalDataDict]: + if not self._items_by_modality: + return None + mm_inputs = {} + items_by_modality = dict(self._items_by_modality) + if "image" in items_by_modality and "image_embeds" in items_by_modality: + raise ValueError(\ + "Mixing raw image and embedding inputs is not allowed") + + if "image_embeds" in items_by_modality: + image_embeds_lst = items_by_modality["image_embeds"] + if len(image_embeds_lst) > 1: + raise ValueError(\ + "Only one message can have {'type': 'image_embeds'}") + mm_inputs["image"] = image_embeds_lst[0] + if "image" in items_by_modality: + mm_inputs["image"] = items_by_modality["image"] # A list of images + if "audio" in items_by_modality: + mm_inputs["audio"] = items_by_modality["audio"] # A list of audios + if "video" in items_by_modality: + mm_inputs["video"] = items_by_modality["video"] # A list of videos + return mm_inputs + + def create_parser(self) -> "BaseMultiModalContentParser": + return MultiModalContentParser(self) + + +class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]): + + async def all_mm_data(self) -> Optional[MultiModalDataDict]: + if not self._items_by_modality: + return None + mm_inputs = {} + items_by_modality = { + modality: await asyncio.gather(*items) + for modality, items in self._items_by_modality.items() + } + + if "image" in items_by_modality and "image_embeds" in items_by_modality: + raise ValueError( + "Mixing raw image and embedding inputs is not allowed") + + if "image_embeds" in items_by_modality: + image_embeds_lst = items_by_modality["image_embeds"] + if len(image_embeds_lst) > 1: + raise ValueError( + "Only one message can have {'type': 'image_embeds'}") + mm_inputs["image"] = image_embeds_lst[0] + if "image" in items_by_modality: + mm_inputs["image"] = items_by_modality["image"] # A list of images + if "audio" in items_by_modality: + mm_inputs["audio"] = items_by_modality["audio"] # A list of audios + if "video" in items_by_modality: + mm_inputs["video"] = items_by_modality["video"] # A list of videos + return mm_inputs + + def create_parser(self) -> "BaseMultiModalContentParser": + return AsyncMultiModalContentParser(self) + + +class BaseMultiModalContentParser(ABC): + + def __init__(self) -> None: + super().__init__() + + # multimodal placeholder_string : count + self._placeholder_counts: dict[str, int] = defaultdict(lambda: 0) + + def _add_placeholder(self, placeholder: Optional[str]): + if placeholder: + self._placeholder_counts[placeholder] += 1 + + def mm_placeholder_counts(self) -> dict[str, int]: + return dict(self._placeholder_counts) + + @abstractmethod + def parse_image(self, image_url: str) -> None: + raise NotImplementedError + + @abstractmethod + def parse_image_embeds(self, + image_embeds: Union[str, dict[str, str]]) -> None: + raise NotImplementedError + + @abstractmethod + def parse_audio(self, audio_url: str) -> None: + raise NotImplementedError + + @abstractmethod + def parse_input_audio(self, input_audio: InputAudio) -> None: + raise NotImplementedError + + @abstractmethod + def parse_video(self, video_url: str) -> None: + raise NotImplementedError + + +class MultiModalContentParser(BaseMultiModalContentParser): + + def __init__(self, tracker: MultiModalItemTracker) -> None: + super().__init__() + + self._tracker = tracker + + self._connector = MediaConnector( + allowed_local_media_path=tracker.allowed_local_media_path, + ) + + def parse_image(self, image_url: str) -> None: + image = self._connector.fetch_image(image_url) + + placeholder = self._tracker.add("image", image) + self._add_placeholder(placeholder) + + def parse_image_embeds(self, + image_embeds: Union[str, dict[str, str]]) -> None: + if isinstance(image_embeds, dict): + embeds = { + k: self._connector.fetch_image_embedding(v) + for k, v in image_embeds.items() + } + placeholder = self._tracker.add("image_embeds", embeds) + + if isinstance(image_embeds, str): + embedding = self._connector.fetch_image_embedding(image_embeds) + placeholder = self._tracker.add("image_embeds", embedding) + + self._add_placeholder(placeholder) + + def parse_audio(self, audio_url: str) -> None: + audio = self._connector.fetch_audio(audio_url) + + placeholder = self._tracker.add("audio", audio) + self._add_placeholder(placeholder) + + def parse_input_audio(self, input_audio: InputAudio) -> None: + audio_data = input_audio.get("data", "") + audio_format = input_audio.get("format", "") + audio_url = f"data:audio/{audio_format};base64,{audio_data}" + + return self.parse_audio(audio_url) + + def parse_video(self, video_url: str) -> None: + video = self._connector.fetch_video(video_url) + + placeholder = self._tracker.add("video", video) + self._add_placeholder(placeholder) + + +class AsyncMultiModalContentParser(BaseMultiModalContentParser): + + def __init__(self, tracker: AsyncMultiModalItemTracker) -> None: + super().__init__() + + self._tracker = tracker + self._connector = MediaConnector( + allowed_local_media_path=tracker.allowed_local_media_path, + ) + + def parse_image(self, image_url: str) -> None: + image_coro = self._connector.fetch_image_async(image_url) + + placeholder = self._tracker.add("image", image_coro) + self._add_placeholder(placeholder) + + def parse_image_embeds(self, + image_embeds: Union[str, dict[str, str]]) -> None: + future: asyncio.Future[Union[str, dict[str, str]]] = asyncio.Future() + + if isinstance(image_embeds, dict): + embeds = { + k: self._connector.fetch_image_embedding(v) + for k, v in image_embeds.items() + } + future.set_result(embeds) + + if isinstance(image_embeds, str): + embedding = self._connector.\ + fetch_image_embedding(image_embeds) + future.set_result(embedding) + + placeholder = self._tracker.add("image_embeds", future) + self._add_placeholder(placeholder) + + def parse_audio(self, audio_url: str) -> None: + audio_coro = self._connector.fetch_audio_async(audio_url) + + placeholder = self._tracker.add("audio", audio_coro) + self._add_placeholder(placeholder) + + def parse_input_audio(self, input_audio: InputAudio) -> None: + audio_data = input_audio.get("data", "") + audio_format = input_audio.get("format", "") + audio_url = f"data:audio/{audio_format};base64,{audio_data}" + + return self.parse_audio(audio_url) + + def parse_video(self, video_url: str) -> None: + video = self._connector.fetch_video_async(video_url) + + placeholder = self._tracker.add("video", video) + self._add_placeholder(placeholder) + + +def validate_chat_template(chat_template: Optional[Union[Path, str]]): + """Raises if the provided chat template appears invalid.""" + if chat_template is None: + return + + elif isinstance(chat_template, Path) and not chat_template.exists(): + raise FileNotFoundError( + "the supplied chat template path doesn't exist") + + elif isinstance(chat_template, str): + JINJA_CHARS = "{}\n" + if not any(c in chat_template + for c in JINJA_CHARS) and not Path(chat_template).exists(): + raise ValueError( + f"The supplied chat template string ({chat_template}) " + f"appears path-like, but doesn't exist!") + + else: + raise TypeError( + f"{type(chat_template)} is not a valid chat template type") + + +def _load_chat_template( + chat_template: Optional[Union[Path, str]], + *, + is_literal: bool = False, +) -> Optional[str]: + if chat_template is None: + return None + + if is_literal: + if isinstance(chat_template, Path): + raise TypeError("chat_template is expected to be read directly " + "from its value") + + return chat_template + + try: + with open(chat_template) as f: + return f.read() + except OSError as e: + if isinstance(chat_template, Path): + raise + + JINJA_CHARS = "{}\n" + if not any(c in chat_template for c in JINJA_CHARS): + msg = (f"The supplied chat template ({chat_template}) " + f"looks like a file path, but it failed to be " + f"opened. Reason: {e}") + raise ValueError(msg) from e + + # If opening a file fails, set chat template to be args to + # ensure we decode so our escape are interpreted correctly + return _load_chat_template(chat_template, is_literal=True) + + +_cached_load_chat_template = lru_cache(_load_chat_template) + + +def load_chat_template( + chat_template: Optional[Union[Path, str]], + *, + is_literal: bool = False, +) -> Optional[str]: + return _cached_load_chat_template(chat_template, is_literal=is_literal) + + +# TODO: Let user specify how to insert multimodal tokens into prompt +# (similar to chat template) +def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int], + text_prompt: str) -> str: + """Combine multimodal prompts for a multimodal language model.""" + + # Look through the text prompt to check for missing placeholders + missing_placeholders: list[str] = [] + for placeholder in placeholder_counts: + + # For any existing placeholder in the text prompt, we leave it as is + placeholder_counts[placeholder] -= text_prompt.count(placeholder) + + if placeholder_counts[placeholder] < 0: + raise ValueError( + f"Found more '{placeholder}' placeholders in input prompt than " + "actual multimodal data items.") + + missing_placeholders.extend([placeholder] * + placeholder_counts[placeholder]) + + # NOTE: For now we always add missing placeholders at the front of + # the prompt. This may change to be customizable in the future. + return "\n".join(missing_placeholders + [text_prompt]) + + +# No need to validate using Pydantic again +_TextParser = partial(cast, ChatCompletionContentPartTextParam) +_ImageParser = partial(cast, ChatCompletionContentPartImageParam) +_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam) +_AudioParser = partial(cast, ChatCompletionContentPartAudioParam) +_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam) +_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) +_VideoParser = partial(cast, ChatCompletionContentPartVideoParam) + +_ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio] + +# Define a mapping from part types to their corresponding parsing functions. +MM_PARSER_MAP: dict[ + str, + Callable[[ChatCompletionContentPartParam], _ContentPart], +] = { + "text": + lambda part: _TextParser(part).get("text", ""), + "image_url": + lambda part: _ImageParser(part).get("image_url", {}).get("url", ""), + "image_embeds": + lambda part: _ImageEmbedsParser(part).get("image_embeds", {}), + "audio_url": + lambda part: _AudioParser(part).get("audio_url", {}).get("url", ""), + "input_audio": + lambda part: _InputAudioParser(part).get("input_audio", {}), + "refusal": + lambda part: _RefusalParser(part).get("refusal", ""), + "video_url": + lambda part: _VideoParser(part).get("video_url", {}).get("url", ""), +} + + +def _parse_chat_message_content_mm_part( + part: ChatCompletionContentPartParam) -> tuple[str, _ContentPart]: + """ + Parses a given multi-modal content part based on its type. + + Args: + part: A dict containing the content part, with a potential 'type' field. + + Returns: + A tuple (part_type, content) where: + - part_type: Type of the part (e.g., 'text', 'image_url'). + - content: Parsed content (e.g., text, image URL). + + Raises: + ValueError: If the 'type' field is missing and no direct URL is found. + """ + assert isinstance( + part, dict) # This is needed to avoid mypy errors: part.get() from str + part_type = part.get("type", None) + + if isinstance(part_type, str) and part_type in MM_PARSER_MAP: + content = MM_PARSER_MAP[part_type](part) + + # Special case for 'image_url.detail' + # We only support 'auto', which is the default + if part_type == "image_url" and part.get("detail", "auto") != "auto": + logger.warning("'image_url.detail' is currently not supported " + "and will be ignored.") + + return part_type, content + + # Handle missing 'type' but provided direct URL fields. + # 'type' is required field by pydantic + if part_type is None: + if part.get("image_url") is not None: + image_params = cast(CustomChatCompletionContentSimpleImageParam, + part) + return "image_url", image_params.get("image_url", "") + if part.get("audio_url") is not None: + audio_params = cast(CustomChatCompletionContentSimpleAudioParam, + part) + return "audio_url", audio_params.get("audio_url", "") + if part.get("input_audio") is not None: + input_audio_params = cast(dict[str, str], part) + return "input_audio", input_audio_params + if part.get("video_url") is not None: + video_params = cast(CustomChatCompletionContentSimpleVideoParam, + part) + return "video_url", video_params.get("video_url", "") + # Raise an error if no 'type' or direct URL is found. + raise ValueError("Missing 'type' field in multimodal part.") + + if not isinstance(part_type, str): + raise ValueError("Invalid 'type' field in multimodal part.") + return part_type, "unknown part_type content" + + +VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url", + "image_embeds", + "audio_url", "input_audio", "video_url") + + +def _parse_chat_message_content_parts( + role: str, + parts: Iterable[ChatCompletionContentPartParam], + mm_tracker: BaseMultiModalItemTracker, + *, + wrap_dicts: bool, +) -> list[ConversationMessage]: + content = list[_ContentPart]() + + mm_parser = mm_tracker.create_parser() + + for part in parts: + parse_res = _parse_chat_message_content_part( + part, + mm_parser, + wrap_dicts=wrap_dicts, + ) + if parse_res: + content.append(parse_res) + + if wrap_dicts: + # Parsing wraps images and texts as interleaved dictionaries + return [ConversationMessage(role=role, + content=content)] # type: ignore + texts = cast(list[str], content) + text_prompt = "\n".join(texts) + mm_placeholder_counts = mm_parser.mm_placeholder_counts() + if mm_placeholder_counts: + text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts, + text_prompt) + return [ConversationMessage(role=role, content=text_prompt)] + + +def _parse_chat_message_content_part( + part: ChatCompletionContentPartParam, + mm_parser: BaseMultiModalContentParser, + *, + wrap_dicts: bool, +) -> Optional[_ContentPart]: + """Parses a single part of a conversation. If wrap_dicts is True, + structured dictionary pieces for texts and images will be + wrapped in dictionaries, i.e., {"type": "text", "text", ...} and + {"type": "image"}, respectively. Otherwise multimodal data will be + handled by mm_parser, and texts will be returned as strings to be joined + with multimodal placeholders. + """ + if isinstance(part, str): # Handle plain text parts + return part + + # Handle structured dictionary parts + part_type, content = _parse_chat_message_content_mm_part(part) + + # if part_type is text/refusal/image_url/audio_url/video_url/input_audio but + # content is empty, log a warning and skip + if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content: + logger.warning( + "Skipping multimodal part (type: '%s') " + "with empty / unparsable content.", part_type) + return None + + if part_type in ("text", "refusal"): + str_content = cast(str, content) + if wrap_dicts: + return {'type': 'text', 'text': str_content} + else: + return str_content + + if part_type == "image_url": + str_content = cast(str, content) + mm_parser.parse_image(str_content) + return {'type': 'image'} if wrap_dicts else None + if part_type == "image_embeds": + content = cast(Union[str, dict[str, str]], content) + mm_parser.parse_image_embeds(content) + return {'type': 'image'} if wrap_dicts else None + if part_type == "audio_url": + str_content = cast(str, content) + mm_parser.parse_audio(str_content) + return {'type': 'audio'} if wrap_dicts else None + + if part_type == "input_audio": + dict_content = cast(InputAudio, content) + mm_parser.parse_input_audio(dict_content) + return {'type': 'audio'} if wrap_dicts else None + + if part_type == "video_url": + str_content = cast(str, content) + mm_parser.parse_video(str_content) + return {'type': 'video'} if wrap_dicts else None + + raise NotImplementedError(f"Unknown part type: {part_type}") + + +# No need to validate using Pydantic again +_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam) +_ToolParser = partial(cast, ChatCompletionToolMessageParam) + + +def _parse_chat_message_content( + message: ChatCompletionMessageParam, + mm_tracker: BaseMultiModalItemTracker, + content_format: _ChatTemplateContentFormat, +) -> list[ConversationMessage]: + role = message["role"] + content = message.get("content") + + if content is None: + content = [] + elif isinstance(content, str): + content = [ + ChatCompletionContentPartTextParam(type="text", text=content) + ] + result = _parse_chat_message_content_parts( + role, + content, # type: ignore + mm_tracker, + wrap_dicts=(content_format == "openai"), + ) + + for result_msg in result: + if role == 'assistant': + parsed_msg = _AssistantParser(message) + + if "tool_calls" in parsed_msg: + result_msg["tool_calls"] = list(parsed_msg["tool_calls"]) + elif role == "tool": + parsed_msg = _ToolParser(message) + if "tool_call_id" in parsed_msg: + result_msg["tool_call_id"] = parsed_msg["tool_call_id"] + + if "name" in message and isinstance(message["name"], str): + result_msg["name"] = message["name"] + + return result + + +def _postprocess_messages(messages: list[ConversationMessage]) -> None: + # per the Transformers docs & maintainers, tool call arguments in + # assistant-role messages with tool_calls need to be dicts not JSON str - + # this is how tool-use chat templates will expect them moving forwards + # so, for messages that have tool_calls, parse the string (which we get + # from openAI format) to dict + for message in messages: + if (message["role"] == "assistant" and "tool_calls" in message + and isinstance(message["tool_calls"], list)): + + for item in message["tool_calls"]: + item["function"]["arguments"] = json.loads( + item["function"]["arguments"]) + + +def parse_chat_messages( + messages: list[ChatCompletionMessageParam], + model_config: ModelConfig, + tokenizer: AnyTokenizer, + content_format: _ChatTemplateContentFormat, +) -> tuple[list[ConversationMessage], Optional[MultiModalDataDict]]: + conversation: list[ConversationMessage] = [] + mm_tracker = MultiModalItemTracker(model_config, tokenizer) + + for msg in messages: + sub_messages = _parse_chat_message_content( + msg, + mm_tracker, + content_format, + ) + + conversation.extend(sub_messages) + + _postprocess_messages(conversation) + + return conversation, mm_tracker.all_mm_data() + + +def parse_chat_messages_futures( + messages: list[ChatCompletionMessageParam], + model_config: ModelConfig, + tokenizer: AnyTokenizer, + content_format: _ChatTemplateContentFormat, +) -> tuple[list[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]: + conversation: list[ConversationMessage] = [] + mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer) + + for msg in messages: + sub_messages = _parse_chat_message_content( + msg, + mm_tracker, + content_format, + ) + + conversation.extend(sub_messages) + + _postprocess_messages(conversation) + + return conversation, mm_tracker.all_mm_data() + + +def apply_hf_chat_template( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + conversation: list[ConversationMessage], + chat_template: Optional[str], + tools: Optional[list[dict[str, Any]]], + *, + trust_remote_code: bool = False, + tokenize: bool = False, # Different from HF's default + **kwargs: Any, +) -> str: + hf_chat_template = resolve_hf_chat_template( + tokenizer, + chat_template=chat_template, + tools=tools, + trust_remote_code=trust_remote_code, + ) + + if hf_chat_template is None: + raise ValueError( + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the tokenizer " + "does not define one.") + + return tokenizer.apply_chat_template( + conversation=conversation, # type: ignore[arg-type] + tools=tools, # type: ignore[arg-type] + chat_template=hf_chat_template, + tokenize=tokenize, + **kwargs, + ) + + +def apply_mistral_chat_template( + tokenizer: MistralTokenizer, + messages: list[ChatCompletionMessageParam], + chat_template: Optional[str], + tools: Optional[list[dict[str, Any]]], + **kwargs: Any, +) -> list[int]: + # The return value of resolve_mistral_chat_template is always None, + # and we won't use it. + resolve_mistral_chat_template( + chat_template=chat_template, + **kwargs, + ) + + return tokenizer.apply_chat_template( + messages=messages, + tools=tools, + **kwargs, + ) diff --git a/vllm/entrypoints/cli/__init__.py b/vllm/entrypoints/cli/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/entrypoints/cli/benchmark/__init__.py b/vllm/entrypoints/cli/benchmark/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/entrypoints/cli/benchmark/base.py b/vllm/entrypoints/cli/benchmark/base.py new file mode 100644 index 0000000..c41b2c5 --- /dev/null +++ b/vllm/entrypoints/cli/benchmark/base.py @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: Apache-2.0 +import argparse + +from vllm.entrypoints.cli.types import CLISubcommand +from vllm.utils import FlexibleArgumentParser + + +class BenchmarkSubcommandBase(CLISubcommand): + """ The base class of subcommands for vllm bench. """ + + @property + def help(self) -> str: + """The help message of the subcommand.""" + raise NotImplementedError + + def add_cli_args(self, parser: argparse.ArgumentParser) -> None: + """Add the CLI arguments to the parser.""" + raise NotImplementedError + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + """Run the benchmark. + + Args: + args: The arguments to the command. + """ + raise NotImplementedError + + def subparser_init( + self, + subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + parser = subparsers.add_parser( + self.name, + help=self.help, + usage=f"vllm bench {self.name} [options]") + self.add_cli_args(parser) + return parser diff --git a/vllm/entrypoints/cli/benchmark/main.py b/vllm/entrypoints/cli/benchmark/main.py new file mode 100644 index 0000000..7583540 --- /dev/null +++ b/vllm/entrypoints/cli/benchmark/main.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 +import argparse + +import vllm.entrypoints.cli.benchmark.serve +from vllm.entrypoints.cli.types import CLISubcommand +from vllm.utils import FlexibleArgumentParser + +# TODO: Add the rest of the benchmark subcommands here, +# e.g., throughput, latency, etc. +BENCHMARK_CMD_MODULES = [ + vllm.entrypoints.cli.benchmark.serve, +] + + +class BenchmarkSubcommand(CLISubcommand): + """ The `bench` subcommand for the vLLM CLI. """ + + def __init__(self): + self.name = "bench" + super().__init__() + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + args.dispatch_function(args) + + def validate(self, args: argparse.Namespace) -> None: + if args.bench_type in self.cmds: + self.cmds[args.bench_type].validate(args) + + def subparser_init( + self, + subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + bench_parser = subparsers.add_parser( + "bench", + help="vLLM bench subcommand.", + usage="vllm bench [options]") + bench_subparsers = bench_parser.add_subparsers(required=True, + dest="bench_type") + self.cmds = {} + for cmd_module in BENCHMARK_CMD_MODULES: + new_cmds = cmd_module.cmd_init() + for cmd in new_cmds: + cmd.subparser_init(bench_subparsers).set_defaults( + dispatch_function=cmd.cmd) + self.cmds[cmd.name] = cmd + return bench_parser + + +def cmd_init() -> list[CLISubcommand]: + return [BenchmarkSubcommand()] diff --git a/vllm/entrypoints/cli/benchmark/serve.py b/vllm/entrypoints/cli/benchmark/serve.py new file mode 100644 index 0000000..d5a8589 --- /dev/null +++ b/vllm/entrypoints/cli/benchmark/serve.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 +import argparse + +from vllm.benchmarks.serve import add_cli_args, main +from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase +from vllm.entrypoints.cli.types import CLISubcommand + + +class BenchmarkServingSubcommand(BenchmarkSubcommandBase): + """ The `serve` subcommand for vllm bench. """ + + def __init__(self): + self.name = "serve" + super().__init__() + + @property + def help(self) -> str: + return "Benchmark the online serving throughput." + + def add_cli_args(self, parser: argparse.ArgumentParser) -> None: + add_cli_args(parser) + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + main(args) + + +def cmd_init() -> list[CLISubcommand]: + return [BenchmarkServingSubcommand()] diff --git a/vllm/entrypoints/cli/main.py b/vllm/entrypoints/cli/main.py new file mode 100644 index 0000000..aa54bd6 --- /dev/null +++ b/vllm/entrypoints/cli/main.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 + +# The CLI entrypoint to vLLM. +import signal +import sys + +import vllm.entrypoints.cli.benchmark.main +import vllm.entrypoints.cli.openai +import vllm.entrypoints.cli.serve +import vllm.version +from vllm.entrypoints.utils import cli_env_setup +from vllm.utils import FlexibleArgumentParser + +CMD_MODULES = [ + vllm.entrypoints.cli.openai, + vllm.entrypoints.cli.serve, + vllm.entrypoints.cli.benchmark.main, +] + + +def register_signal_handlers(): + + def signal_handler(sig, frame): + sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTSTP, signal_handler) + + +def main(): + cli_env_setup() + + parser = FlexibleArgumentParser(description="vLLM CLI") + parser.add_argument('-v', + '--version', + action='version', + version=vllm.version.__version__) + subparsers = parser.add_subparsers(required=False, dest="subparser") + cmds = {} + for cmd_module in CMD_MODULES: + new_cmds = cmd_module.cmd_init() + for cmd in new_cmds: + cmd.subparser_init(subparsers).set_defaults( + dispatch_function=cmd.cmd) + cmds[cmd.name] = cmd + args = parser.parse_args() + if args.subparser in cmds: + cmds[args.subparser].validate(args) + + if hasattr(args, "dispatch_function"): + args.dispatch_function(args) + else: + parser.print_help() + + +if __name__ == "__main__": + main() diff --git a/vllm/entrypoints/cli/openai.py b/vllm/entrypoints/cli/openai.py new file mode 100644 index 0000000..21a7d48 --- /dev/null +++ b/vllm/entrypoints/cli/openai.py @@ -0,0 +1,172 @@ +# SPDX-License-Identifier: Apache-2.0 +# Commands that act as an interactive OpenAI API client + +import argparse +import os +import signal +import sys +from typing import Optional + +from openai import OpenAI +from openai.types.chat import ChatCompletionMessageParam + +from vllm.entrypoints.cli.types import CLISubcommand +from vllm.utils import FlexibleArgumentParser + + +def _register_signal_handlers(): + + def signal_handler(sig, frame): + sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTSTP, signal_handler) + + +def _interactive_cli(args: argparse.Namespace) -> tuple[str, OpenAI]: + _register_signal_handlers() + + base_url = args.url + api_key = args.api_key or os.environ.get("OPENAI_API_KEY", "EMPTY") + openai_client = OpenAI(api_key=api_key, base_url=base_url) + + if args.model_name: + model_name = args.model_name + else: + available_models = openai_client.models.list() + model_name = available_models.data[0].id + + print(f"Using model: {model_name}") + + return model_name, openai_client + + +def chat(system_prompt: Optional[str], model_name: str, + client: OpenAI) -> None: + conversation: list[ChatCompletionMessageParam] = [] + if system_prompt is not None: + conversation.append({"role": "system", "content": system_prompt}) + + print("Please enter a message for the chat model:") + while True: + try: + input_message = input("> ") + except EOFError: + return + conversation.append({"role": "user", "content": input_message}) + + chat_completion = client.chat.completions.create(model=model_name, + messages=conversation) + + response_message = chat_completion.choices[0].message + output = response_message.content + + conversation.append(response_message) # type: ignore + print(output) + + +def _add_query_options( + parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + parser.add_argument( + "--url", + type=str, + default="http://localhost:8000/v1", + help="url of the running OpenAI-Compatible RESTful API server") + parser.add_argument( + "--model-name", + type=str, + default=None, + help=("The model name used in prompt completion, default to " + "the first model in list models API call.")) + parser.add_argument( + "--api-key", + type=str, + default=None, + help=( + "API key for OpenAI services. If provided, this api key " + "will overwrite the api key obtained through environment variables." + )) + return parser + + +class ChatCommand(CLISubcommand): + """The `chat` subcommand for the vLLM CLI. """ + + def __init__(self): + self.name = "chat" + super().__init__() + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + model_name, client = _interactive_cli(args) + system_prompt = args.system_prompt + conversation: list[ChatCompletionMessageParam] = [] + if system_prompt is not None: + conversation.append({"role": "system", "content": system_prompt}) + + print("Please enter a message for the chat model:") + while True: + try: + input_message = input("> ") + except EOFError: + return + conversation.append({"role": "user", "content": input_message}) + + chat_completion = client.chat.completions.create( + model=model_name, messages=conversation) + + response_message = chat_completion.choices[0].message + output = response_message.content + + conversation.append(response_message) # type: ignore + print(output) + + def subparser_init( + self, + subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + chat_parser = subparsers.add_parser( + "chat", + help="Generate chat completions via the running API server", + usage="vllm chat [options]") + _add_query_options(chat_parser) + chat_parser.add_argument( + "--system-prompt", + type=str, + default=None, + help=("The system prompt to be added to the chat template, " + "used for models that support system prompts.")) + return chat_parser + + +class CompleteCommand(CLISubcommand): + """The `complete` subcommand for the vLLM CLI. """ + + def __init__(self): + self.name = "complete" + super().__init__() + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + model_name, client = _interactive_cli(args) + print("Please enter prompt to complete:") + while True: + input_prompt = input("> ") + completion = client.completions.create(model=model_name, + prompt=input_prompt) + output = completion.choices[0].text + print(output) + + def subparser_init( + self, + subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + complete_parser = subparsers.add_parser( + "complete", + help=("Generate text completions based on the given prompt " + "via the running API server"), + usage="vllm complete [options]") + _add_query_options(complete_parser) + return complete_parser + + +def cmd_init() -> list[CLISubcommand]: + return [ChatCommand(), CompleteCommand()] diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py new file mode 100644 index 0000000..e89ac4e --- /dev/null +++ b/vllm/entrypoints/cli/serve.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse + +import uvloop + +from vllm.entrypoints.cli.types import CLISubcommand +from vllm.entrypoints.openai.api_server import run_server +from vllm.entrypoints.openai.cli_args import (make_arg_parser, + validate_parsed_serve_args) +from vllm.utils import FlexibleArgumentParser + + +class ServeSubcommand(CLISubcommand): + """The `serve` subcommand for the vLLM CLI. """ + + def __init__(self): + self.name = "serve" + super().__init__() + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + # If model is specified in CLI (as positional arg), it takes precedence + if hasattr(args, 'model_tag') and args.model_tag is not None: + args.model = args.model_tag + + uvloop.run(run_server(args)) + + def validate(self, args: argparse.Namespace) -> None: + validate_parsed_serve_args(args) + + def subparser_init( + self, + subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + serve_parser = subparsers.add_parser( + "serve", + help="Start the vLLM OpenAI Compatible API server", + usage="vllm serve [model_tag] [options]") + serve_parser.add_argument("model_tag", + type=str, + nargs='?', + help="The model tag to serve " + "(optional if specified in config)") + serve_parser.add_argument( + "--config", + type=str, + default='', + required=False, + help="Read CLI options from a config file." + "Must be a YAML with the following options:" + "https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#cli-reference" + ) + + return make_arg_parser(serve_parser) + + +def cmd_init() -> list[CLISubcommand]: + return [ServeSubcommand()] diff --git a/vllm/entrypoints/cli/types.py b/vllm/entrypoints/cli/types.py new file mode 100644 index 0000000..f739a68 --- /dev/null +++ b/vllm/entrypoints/cli/types.py @@ -0,0 +1,24 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse + +from vllm.utils import FlexibleArgumentParser + + +class CLISubcommand: + """Base class for CLI argument handlers.""" + + name: str + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + raise NotImplementedError("Subclasses should implement this method") + + def validate(self, args: argparse.Namespace) -> None: + # No validation by default + pass + + def subparser_init( + self, + subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + raise NotImplementedError("Subclasses should implement this method") diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py new file mode 100644 index 0000000..b09ee52 --- /dev/null +++ b/vllm/entrypoints/launcher.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import signal +import socket +from http import HTTPStatus +from typing import Any, Optional + +import uvicorn +from fastapi import FastAPI, Request, Response + +from vllm import envs +from vllm.engine.async_llm_engine import AsyncEngineDeadError +from vllm.engine.multiprocessing import MQEngineDeadError +from vllm.entrypoints.ssl import SSLCertRefresher +from vllm.logger import init_logger +from vllm.utils import find_process_using_port + +logger = init_logger(__name__) + + +async def serve_http(app: FastAPI, + sock: Optional[socket.socket], + enable_ssl_refresh: bool = False, + **uvicorn_kwargs: Any): + logger.info("Available routes are:") + for route in app.routes: + methods = getattr(route, "methods", None) + path = getattr(route, "path", None) + + if methods is None or path is None: + continue + + logger.info("Route: %s, Methods: %s", path, ', '.join(methods)) + + config = uvicorn.Config(app, **uvicorn_kwargs) + config.load() + server = uvicorn.Server(config) + _add_shutdown_handlers(app, server) + + loop = asyncio.get_running_loop() + + server_task = loop.create_task( + server.serve(sockets=[sock] if sock else None)) + + ssl_cert_refresher = None if not enable_ssl_refresh else SSLCertRefresher( + ssl_context=config.ssl, + key_path=config.ssl_keyfile, + cert_path=config.ssl_certfile, + ca_path=config.ssl_ca_certs) + + def signal_handler() -> None: + # prevents the uvicorn signal handler to exit early + server_task.cancel() + if ssl_cert_refresher: + ssl_cert_refresher.stop() + + async def dummy_shutdown() -> None: + pass + + loop.add_signal_handler(signal.SIGINT, signal_handler) + loop.add_signal_handler(signal.SIGTERM, signal_handler) + + try: + await server_task + return dummy_shutdown() + except asyncio.CancelledError: + port = uvicorn_kwargs["port"] + process = find_process_using_port(port) + if process is not None: + logger.debug( + "port %s is used by process %s launched with command:\n%s", + port, process, " ".join(process.cmdline())) + logger.info("Shutting down FastAPI HTTP server.") + return server.shutdown() + + +def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None: + """Adds handlers for fatal errors that should crash the server""" + + @app.exception_handler(RuntimeError) + async def runtime_error_handler(request: Request, __): + """On generic runtime error, check to see if the engine has died. + It probably has, in which case the server will no longer be able to + handle requests. Trigger a graceful shutdown with a SIGTERM.""" + engine = request.app.state.engine_client + if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored + and not engine.is_running): + logger.fatal("AsyncLLMEngine has failed, terminating server " + "process") + # See discussions here on shutting down a uvicorn server + # https://github.com/encode/uvicorn/discussions/1103 + # In this case we cannot await the server shutdown here because + # this handler must first return to close the connection for + # this request. + server.should_exit = True + + return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) + + @app.exception_handler(AsyncEngineDeadError) + async def async_engine_dead_handler(_, __): + """Kill the server if the async engine is already dead. It will + not handle any further requests.""" + if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: + logger.fatal("AsyncLLMEngine is already dead, terminating server " + "process") + server.should_exit = True + + return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) + + @app.exception_handler(MQEngineDeadError) + async def mq_engine_dead_handler(_, __): + """Kill the server if the mq engine is already dead. It will + not handle any further requests.""" + if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: + logger.fatal("MQLLMEngine is already dead, terminating server " + "process") + server.should_exit = True + + return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py new file mode 100644 index 0000000..f39b011 --- /dev/null +++ b/vllm/entrypoints/llm.py @@ -0,0 +1,1411 @@ +# SPDX-License-Identifier: Apache-2.0 + +import itertools +import warnings +from collections.abc import Sequence +from contextlib import contextmanager +from typing import Any, Callable, ClassVar, Optional, Union, cast, overload + +import cloudpickle +import torch.nn as nn +from tqdm import tqdm +from typing_extensions import TypeVar, deprecated + +from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, + BeamSearchSequence, get_beam_search_score) +from vllm.config import CompilationConfig +from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig, + TaskOption) +from vllm.engine.llm_engine import LLMEngine +from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, + ChatTemplateContentFormatOption, + apply_hf_chat_template, + apply_mistral_chat_template, + parse_chat_messages, + resolve_chat_template_content_format) +from vllm.entrypoints.score_utils import (_cosine_similarity, + _validate_score_input_lens) +from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt +from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.guided_decoding.guided_fields import ( + GuidedDecodingRequest, LLMGuidedOptions) +from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput, + PoolingRequestOutput, RequestOutput, + ScoringRequestOutput) +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, + RequestOutputKind, SamplingParams) +from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, + get_cached_tokenizer) +from vllm.transformers_utils.tokenizer_group import TokenizerGroup +from vllm.usage.usage_lib import UsageContext +from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs, + is_list_of) + +logger = init_logger(__name__) + +_R = TypeVar("_R", default=Any) + + +class LLM: + """An LLM for generating texts from given prompts and sampling parameters. + + This class includes a tokenizer, a language model (possibly distributed + across multiple GPUs), and GPU memory space allocated for intermediate + states (aka KV cache). Given a batch of prompts and sampling parameters, + this class generates texts from the model, using an intelligent batching + mechanism and efficient memory management. + + Args: + model: The name or path of a HuggingFace Transformers model. + tokenizer: The name or path of a HuggingFace Transformers tokenizer. + tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer + if available, and "slow" will always use the slow tokenizer. + skip_tokenizer_init: If true, skip initialization of tokenizer and + detokenizer. Expect valid prompt_token_ids and None for prompt + from the input. + trust_remote_code: Trust remote code (e.g., from HuggingFace) when + downloading the model and tokenizer. + allowed_local_media_path: Allowing API requests to read local images + or videos from directories specified by the server file system. + This is a security risk. Should only be enabled in trusted + environments. + tensor_parallel_size: The number of GPUs to use for distributed + execution with tensor parallelism. + dtype: The data type for the model weights and activations. Currently, + we support `float32`, `float16`, and `bfloat16`. If `auto`, we use + the `torch_dtype` attribute specified in the model config file. + However, if the `torch_dtype` in the config is `float32`, we will + use `float16` instead. + quantization: The method used to quantize the model weights. Currently, + we support "awq", "gptq", and "fp8" (experimental). + If None, we first check the `quantization_config` attribute in the + model config file. If that is None, we assume the model weights are + not quantized and use `dtype` to determine the data type of + the weights. + revision: The specific model version to use. It can be a branch name, + a tag name, or a commit id. + tokenizer_revision: The specific tokenizer version to use. It can be a + branch name, a tag name, or a commit id. + seed: The seed to initialize the random number generator for sampling. + gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to + reserve for the model weights, activations, and KV cache. Higher + values will increase the KV cache size and thus improve the model's + throughput. However, if the value is too high, it may cause out-of- + memory (OOM) errors. + swap_space: The size (GiB) of CPU memory per GPU to use as swap space. + This can be used for temporarily storing the states of the requests + when their `best_of` sampling parameters are larger than 1. If all + requests will have `best_of=1`, you can safely set this to 0. + Noting that `best_of` is only supported in V0. Otherwise, too small + values may cause out-of-memory (OOM) errors. + cpu_offload_gb: The size (GiB) of CPU memory to use for offloading + the model weights. This virtually increases the GPU memory space + you can use to hold the model weights, at the cost of CPU-GPU data + transfer for every forward pass. + enforce_eager: Whether to enforce eager execution. If True, we will + disable CUDA graph and always execute the model in eager mode. + If False, we will use CUDA graph and eager execution in hybrid. + max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. + When a sequence has context length larger than this, we fall back + to eager mode. Additionally for encoder-decoder models, if the + sequence length of the encoder input is larger than this, we fall + back to the eager mode. + disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig` + disable_async_output_proc: Disable async output processing. + This may result in lower performance. + hf_overrides: If a dictionary, contains arguments to be forwarded to the + HuggingFace config. If a callable, it is called to update the + HuggingFace config. + compilation_config: Either an integer or a dictionary. If it is an + integer, it is used as the level of compilation optimization. If it + is a dictionary, it can specify the full compilation configuration. + **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See + :ref:`engine-args`) + + Note: + This class is intended to be used for offline inference. For online + serving, use the :class:`~vllm.AsyncLLMEngine` class instead. + """ + + DEPRECATE_LEGACY: ClassVar[bool] = True + """A flag to toggle whether to deprecate the legacy generate/encode API.""" + + DEPRECATE_INIT_POSARGS: ClassVar[bool] = True + """ + A flag to toggle whether to deprecate positional arguments in + :meth:`LLM.__init__`. + """ + + @classmethod + @contextmanager + def deprecate_legacy_api(cls): + cls.DEPRECATE_LEGACY = True + + yield + + cls.DEPRECATE_LEGACY = False + + @deprecate_args( + start_index=2, # Ignore self and model + is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS, + additional_message=( + "All positional arguments other than `model` will be " + "replaced with keyword arguments in an upcoming version."), + ) + def __init__( + self, + model: str, + tokenizer: Optional[str] = None, + tokenizer_mode: str = "auto", + skip_tokenizer_init: bool = False, + trust_remote_code: bool = False, + allowed_local_media_path: str = "", + tensor_parallel_size: int = 1, + dtype: str = "auto", + quantization: Optional[str] = None, + revision: Optional[str] = None, + tokenizer_revision: Optional[str] = None, + seed: Optional[int] = None, + gpu_memory_utilization: float = 0.9, + swap_space: float = 4, + cpu_offload_gb: float = 0, + enforce_eager: Optional[bool] = None, + max_seq_len_to_capture: int = 8192, + disable_custom_all_reduce: bool = False, + disable_async_output_proc: bool = False, + hf_overrides: Optional[HfOverrides] = None, + mm_processor_kwargs: Optional[dict[str, Any]] = None, + # After positional args are removed, move this right below `model` + task: TaskOption = "auto", + override_pooler_config: Optional[PoolerConfig] = None, + compilation_config: Optional[Union[int, dict[str, Any]]] = None, + **kwargs, + ) -> None: + ''' + LLM constructor. + + Note: if enforce_eager is unset (enforce_eager is None) + it defaults to False. + ''' + + if "disable_log_stats" not in kwargs: + kwargs["disable_log_stats"] = True + + if "worker_cls" in kwargs: + worker_cls = kwargs["worker_cls"] + # if the worker_cls is not qualified string name, + # we serialize it using cloudpickle to avoid pickling issues + if isinstance(worker_cls, type): + kwargs["worker_cls"] = cloudpickle.dumps(worker_cls) + + if compilation_config is not None: + if isinstance(compilation_config, (int, dict)): + compilation_config_instance = CompilationConfig.from_cli( + str(compilation_config)) + else: + compilation_config_instance = compilation_config + else: + compilation_config_instance = None + + engine_args = EngineArgs( + model=model, + task=task, + tokenizer=tokenizer, + tokenizer_mode=tokenizer_mode, + skip_tokenizer_init=skip_tokenizer_init, + trust_remote_code=trust_remote_code, + allowed_local_media_path=allowed_local_media_path, + tensor_parallel_size=tensor_parallel_size, + dtype=dtype, + quantization=quantization, + revision=revision, + tokenizer_revision=tokenizer_revision, + seed=seed, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=swap_space, + cpu_offload_gb=cpu_offload_gb, + enforce_eager=enforce_eager, + max_seq_len_to_capture=max_seq_len_to_capture, + disable_custom_all_reduce=disable_custom_all_reduce, + disable_async_output_proc=disable_async_output_proc, + hf_overrides=hf_overrides, + mm_processor_kwargs=mm_processor_kwargs, + override_pooler_config=override_pooler_config, + compilation_config=compilation_config_instance, + **kwargs, + ) + + # Create the Engine (autoselects V0 vs V1) + self.llm_engine = LLMEngine.from_engine_args( + engine_args=engine_args, usage_context=UsageContext.LLM_CLASS) + self.engine_class = type(self.llm_engine) + + self.request_counter = Counter() + self.default_sampling_params: Union[dict[str, Any], None] = None + + def get_tokenizer(self) -> AnyTokenizer: + return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer + + def set_tokenizer(self, tokenizer: AnyTokenizer) -> None: + tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup) + + # While CachedTokenizer is dynamic, have no choice but + # compare class name. Misjudgment will arise from + # user-defined tokenizer started with 'Cached' + if tokenizer.__class__.__name__.startswith("Cached"): + tokenizer_group.tokenizer = tokenizer + else: + tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer) + + def get_default_sampling_params(self) -> SamplingParams: + if self.default_sampling_params is None: + self.default_sampling_params = ( + self.llm_engine.model_config.get_diff_sampling_param()) + if self.default_sampling_params: + return SamplingParams.from_optional(**self.default_sampling_params) + return SamplingParams() + + @overload + def generate( + self, + prompts: Union[PromptType, Sequence[PromptType]], + /, + sampling_params: Optional[Union[SamplingParams, + Sequence[SamplingParams]]] = None, + *, + use_tqdm: bool = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + guided_options_request: Optional[Union[LLMGuidedOptions, + GuidedDecodingRequest]] = None, + ) -> list[RequestOutput]: + ... + + @overload # LEGACY: single (prompt + optional token ids) + @deprecated("'prompt_token_ids' will become part of 'prompts'") + def generate( + self, + prompts: str, + sampling_params: Optional[Union[SamplingParams, + list[SamplingParams]]] = None, + prompt_token_ids: Optional[list[int]] = None, + use_tqdm: bool = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + guided_options_request: Optional[Union[LLMGuidedOptions, + GuidedDecodingRequest]] = None, + ) -> list[RequestOutput]: + ... + + @overload # LEGACY: multi (prompt + optional token ids) + @deprecated("'prompt_token_ids' will become part of 'prompts'") + def generate( + self, + prompts: list[str], + sampling_params: Optional[Union[SamplingParams, + list[SamplingParams]]] = None, + prompt_token_ids: Optional[list[list[int]]] = None, + use_tqdm: bool = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + guided_options_request: Optional[Union[LLMGuidedOptions, + GuidedDecodingRequest]] = None, + ) -> list[RequestOutput]: + ... + + @overload # LEGACY: single (token ids + optional prompt) + @deprecated("'prompt_token_ids' will become part of 'prompts'") + def generate( + self, + prompts: Optional[str] = None, + sampling_params: Optional[Union[SamplingParams, + list[SamplingParams]]] = None, + *, + prompt_token_ids: list[int], + use_tqdm: bool = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + guided_options_request: Optional[Union[LLMGuidedOptions, + GuidedDecodingRequest]] = None, + ) -> list[RequestOutput]: + ... + + @overload # LEGACY: multi (token ids + optional prompt) + @deprecated("'prompt_token_ids' will become part of 'prompts'") + def generate( + self, + prompts: Optional[list[str]] = None, + sampling_params: Optional[Union[SamplingParams, + list[SamplingParams]]] = None, + *, + prompt_token_ids: list[list[int]], + use_tqdm: bool = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + guided_options_request: Optional[Union[LLMGuidedOptions, + GuidedDecodingRequest]] = None, + ) -> list[RequestOutput]: + ... + + @overload # LEGACY: single or multi token ids [pos-only] + @deprecated("'prompt_token_ids' will become part of 'prompts'") + def generate( + self, + prompts: None, + sampling_params: None, + prompt_token_ids: Union[list[int], list[list[int]]], + use_tqdm: bool = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + guided_options_request: Optional[Union[LLMGuidedOptions, + GuidedDecodingRequest]] = None, + ) -> list[RequestOutput]: + ... + + @deprecate_kwargs( + "prompt_token_ids", + is_deprecated=lambda: LLM.DEPRECATE_LEGACY, + additional_message="Please use the 'prompts' parameter instead.", + ) + def generate( + self, + prompts: Union[Union[PromptType, Sequence[PromptType]], + Optional[Union[str, list[str]]]] = None, + sampling_params: Optional[Union[SamplingParams, + Sequence[SamplingParams]]] = None, + prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None, + use_tqdm: bool = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + guided_options_request: Optional[Union[LLMGuidedOptions, + GuidedDecodingRequest]] = None, + priority: Optional[list[int]] = None, + ) -> list[RequestOutput]: + """Generates the completions for the input prompts. + + This class automatically batches the given prompts, considering + the memory constraint. For the best performance, put all of your prompts + into a single list and pass it to this method. + + Args: + prompts: The prompts to the LLM. You may pass a sequence of prompts + for batch inference. See :class:`~vllm.inputs.PromptType` + for more details about the format of each prompts. + sampling_params: The sampling parameters for text generation. If + None, we use the default sampling parameters. + When it is a single value, it is applied to every prompt. + When it is a list, the list must have the same length as the + prompts and it is paired one by one with the prompt. + use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. + prompt_adapter_request: Prompt Adapter request to use for + generation, if any. + priority: The priority of the requests, if any. + Only applicable when priority scheduling policy is enabled. + + Returns: + A list of ``RequestOutput`` objects containing the + generated completions in the same order as the input prompts. + + Note: + Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is + considered legacy and may be deprecated in the future. You should + instead pass them via the ``inputs`` parameter. + """ + runner_type = self.llm_engine.model_config.runner_type + if runner_type not in ["generate", "transcription"]: + messages = [ + "LLM.generate() is only supported for (conditional) generation " + "models (XForCausalLM, XForConditionalGeneration).", + ] + + supported_runner_types = self.llm_engine.model_config \ + .supported_runner_types + if "generate" in supported_runner_types: + messages.append( + "Your model supports the 'generate' runner, but is " + f"currently initialized for the '{runner_type}' runner. " + "Please initialize vLLM using `--task generate`.") + + raise ValueError(" ".join(messages)) + + if prompt_token_ids is not None: + parsed_prompts = self._convert_v1_inputs( + prompts=cast(Optional[Union[str, list[str]]], prompts), + prompt_token_ids=prompt_token_ids, + ) + else: + parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], + prompts) + + if isinstance(guided_options_request, dict): + if len(guided_options_request) > 1: + raise ValueError( + "You can only use one guided decoding but multiple is " + f"specified: {guided_options_request}") + guided_options_request = GuidedDecodingRequest( + **guided_options_request) + + if sampling_params is None: + # Use default sampling params. + sampling_params = self.get_default_sampling_params() + + self._validate_and_add_requests( + prompts=parsed_prompts, + params=sampling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + guided_options=guided_options_request, + priority=priority) + + outputs = self._run_engine(use_tqdm=use_tqdm) + return self.engine_class.validate_outputs(outputs, RequestOutput) + + def collective_rpc(self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + """ + Execute an RPC call on all workers. + + Args: + method: Name of the worker method to execute, or a callable that + is serialized and sent to all workers to execute. + + If the method is a callable, it should accept an additional + `self` argument, in addition to the arguments passed in `args` + and `kwargs`. The `self` argument will be the worker object. + timeout: Maximum time in seconds to wait for execution. Raises a + :exc:`TimeoutError` on timeout. `None` means wait indefinitely. + args: Positional arguments to pass to the worker method. + kwargs: Keyword arguments to pass to the worker method. + + Returns: + A list containing the results from each worker. + + Note: + It is recommended to use this API to only pass control messages, + and set up data-plane communication to pass data. + """ + + return self.llm_engine.collective_rpc(method, timeout, args, kwargs) + + def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]: + """ + Run a function directly on the model inside each worker, + returning the result for each of them. + """ + executor = self.llm_engine.model_executor + return executor.apply_model(func) + + def beam_search( + self, + prompts: list[Union[TokensPrompt, TextPrompt]], + params: BeamSearchParams, + ) -> list[BeamSearchOutput]: + """ + Generate sequences using beam search. + + Args: + prompts: A list of prompts. Each prompt can be a string or a list + of token IDs. + params: The beam search parameters. + + TODO: how does beam search work together with length penalty, frequency + penalty, and stopping criteria, etc.? + """ + + beam_width = params.beam_width + max_tokens = params.max_tokens + temperature = params.temperature + ignore_eos = params.ignore_eos + length_penalty = params.length_penalty + + def sort_beams_key(x: BeamSearchSequence) -> float: + return get_beam_search_score(x.tokens, x.cum_logprob, + tokenizer.eos_token_id, + length_penalty) + + tokenizer = self.get_tokenizer() + # generate 2 * beam_width candidates at each step + # following the huggingface transformers implementation + # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa + beam_search_params = SamplingParams(logprobs=2 * beam_width, + max_tokens=1, + temperature=temperature) + instances: list[BeamSearchInstance] = [] + + for prompt in prompts: + if is_token_prompt(prompt): + prompt_tokens = prompt["prompt_token_ids"] + else: + prompt_tokens = tokenizer.encode(prompt["prompt"]) + instances.append(BeamSearchInstance(prompt_tokens)) + + for _ in range(max_tokens): + all_beams: list[BeamSearchSequence] = list( + sum((instance.beams for instance in instances), [])) + pos = [0] + list( + itertools.accumulate( + len(instance.beams) for instance in instances)) + instance_start_and_end: list[tuple[int, int]] = list( + zip(pos[:-1], pos[1:])) + + if len(all_beams) == 0: + break + + prompts_batch = [ + TokensPrompt(prompt_token_ids=beam.tokens) + for beam in all_beams + ] + + # only runs for one step + # we don't need to use tqdm here + output = self.generate(prompts_batch, + sampling_params=beam_search_params, + use_tqdm=False) + + for (start, end), instance in zip(instance_start_and_end, + instances): + instance_new_beams = [] + for i in range(start, end): + current_beam = all_beams[i] + result = output[i] + + if result.outputs[0].logprobs is not None: + # if `result.outputs[0].logprobs` is None, it means + # the sequence is completed because of the max-model-len + # or abortion. we don't need to add it to the new beams. + logprobs = result.outputs[0].logprobs[0] + for token_id, logprob_obj in logprobs.items(): + new_beam = BeamSearchSequence( + tokens=current_beam.tokens + [token_id], + logprobs=current_beam.logprobs + [logprobs], + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob) + + if token_id == tokenizer.eos_token_id and \ + not ignore_eos: + instance.completed.append(new_beam) + else: + instance_new_beams.append(new_beam) + sorted_beams = sorted(instance_new_beams, + key=sort_beams_key, + reverse=True) + instance.beams = sorted_beams[:beam_width] + + outputs = [] + for instance in instances: + instance.completed.extend(instance.beams) + sorted_completed = sorted(instance.completed, + key=sort_beams_key, + reverse=True) + best_beams = sorted_completed[:beam_width] + + for beam in best_beams: + beam.text = tokenizer.decode(beam.tokens) + outputs.append(BeamSearchOutput(sequences=best_beams)) + + return outputs + + def chat( + self, + messages: Union[list[ChatCompletionMessageParam], + list[list[ChatCompletionMessageParam]]], + sampling_params: Optional[Union[SamplingParams, + list[SamplingParams]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + chat_template: Optional[str] = None, + chat_template_content_format: ChatTemplateContentFormatOption = "auto", + add_generation_prompt: bool = True, + continue_final_message: bool = False, + tools: Optional[list[dict[str, Any]]] = None, + mm_processor_kwargs: Optional[dict[str, Any]] = None, + ) -> list[RequestOutput]: + """ + Generate responses for a chat conversation. + + The chat conversation is converted into a text prompt using the + tokenizer and calls the :meth:`generate` method to generate the + responses. + + Multi-modal inputs can be passed in the same way you would pass them + to the OpenAI API. + + Args: + messages: A list of conversations or a single conversation. + + - Each conversation is represented as a list of messages. + - Each message is a dictionary with 'role' and 'content' keys. + + sampling_params: The sampling parameters for text generation. + If None, we use the default sampling parameters. When it + is a single value, it is applied to every prompt. When it + is a list, the list must have the same length as the + prompts and it is paired one by one with the prompt. + use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. + chat_template: The template to use for structuring the chat. + If not provided, the model's default chat template will be used. + chat_template_content_format: The format to render message content. + + - "string" will render the content as a string. + Example: ``"Who are you?"`` + - "openai" will render the content as a list of dictionaries, + similar to OpenAI schema. + Example: ``[{"type": "text", "text": "Who are you?"}]`` + + add_generation_prompt: If True, adds a generation template + to each message. + continue_final_message: If True, continues the final message in + the conversation instead of starting a new one. Cannot be + ``True`` if ``add_generation_prompt`` is also ``True``. + mm_processor_kwargs: Multimodal processor kwarg overrides for this + chat request. Only used for offline requests. + + Returns: + A list of ``RequestOutput`` objects containing the generated + responses in the same order as the input messages. + """ + list_of_messages: list[list[ChatCompletionMessageParam]] + + # Handle multi and single conversations + if is_list_of(messages, list): + # messages is list[list[...]] + list_of_messages = cast(list[list[ChatCompletionMessageParam]], + messages) + else: + # messages is list[...] + list_of_messages = [ + cast(list[ChatCompletionMessageParam], messages) + ] + + tokenizer = self.get_tokenizer() + model_config = self.llm_engine.get_model_config() + resolved_content_format = resolve_chat_template_content_format( + chat_template, + tools, + chat_template_content_format, + tokenizer, + trust_remote_code=model_config.trust_remote_code, + ) + + prompts: list[Union[TokensPrompt, TextPrompt]] = [] + + for msgs in list_of_messages: + # NOTE: _parse_chat_message_content_parts() currently doesn't + # handle mm_processor_kwargs, since there is no implementation in + # the chat message parsing for it. + conversation, mm_data = parse_chat_messages( + msgs, + model_config, + tokenizer, + content_format=resolved_content_format, + ) + + prompt_data: Union[str, list[int]] + if isinstance(tokenizer, MistralTokenizer): + prompt_data = apply_mistral_chat_template( + tokenizer, + messages=msgs, + chat_template=chat_template, + tools=tools, + add_generation_prompt=add_generation_prompt, + continue_final_message=continue_final_message, + ) + else: + prompt_data = apply_hf_chat_template( + tokenizer, + trust_remote_code=model_config.trust_remote_code, + conversation=conversation, + chat_template=chat_template, + tools=tools, + add_generation_prompt=add_generation_prompt, + continue_final_message=continue_final_message, + ) + + prompt: Union[TokensPrompt, TextPrompt] + if is_list_of(prompt_data, int): + prompt = TokensPrompt(prompt_token_ids=prompt_data) + else: + prompt = TextPrompt(prompt=prompt_data) + + if mm_data is not None: + prompt["multi_modal_data"] = mm_data + + if mm_processor_kwargs is not None: + prompt["mm_processor_kwargs"] = mm_processor_kwargs + + prompts.append(prompt) + + return self.generate( + prompts, + sampling_params=sampling_params, + use_tqdm=use_tqdm, + lora_request=lora_request, + ) + + @overload + def encode( + self, + prompts: Union[PromptType, Sequence[PromptType]], + /, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + *, + use_tqdm: bool = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> list[PoolingRequestOutput]: + ... + + @overload # LEGACY: single (prompt + optional token ids) + @deprecated("'prompt_token_ids' will become part of 'prompts'") + def encode( + self, + prompts: str, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + prompt_token_ids: Optional[list[int]] = None, + use_tqdm: bool = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> list[PoolingRequestOutput]: + ... + + @overload # LEGACY: multi (prompt + optional token ids) + @deprecated("'prompt_token_ids' will become part of 'prompts'") + def encode( + self, + prompts: list[str], + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + prompt_token_ids: Optional[list[list[int]]] = None, + use_tqdm: bool = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> list[PoolingRequestOutput]: + ... + + @overload # LEGACY: single (token ids + optional prompt) + @deprecated("'prompt_token_ids' will become part of 'prompts'") + def encode( + self, + prompts: Optional[str] = None, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + *, + prompt_token_ids: list[int], + use_tqdm: bool = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> list[PoolingRequestOutput]: + ... + + @overload # LEGACY: multi (token ids + optional prompt) + @deprecated("'prompt_token_ids' will become part of 'prompts'") + def encode( + self, + prompts: Optional[list[str]] = None, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + *, + prompt_token_ids: list[list[int]], + use_tqdm: bool = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> list[PoolingRequestOutput]: + ... + + @overload # LEGACY: single or multi token ids [pos-only] + @deprecated("'prompt_token_ids' will become part of 'prompts'") + def encode( + self, + prompts: None, + pooling_params: None, + prompt_token_ids: Union[list[int], list[list[int]]], + use_tqdm: bool = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> list[PoolingRequestOutput]: + ... + + @deprecate_kwargs( + "prompt_token_ids", + is_deprecated=lambda: LLM.DEPRECATE_LEGACY, + additional_message="Please use the 'prompts' parameter instead.", + ) + def encode( + self, + prompts: Union[Union[PromptType, Sequence[PromptType]], + Optional[Union[str, list[str]]]] = None, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None, + use_tqdm: bool = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> list[PoolingRequestOutput]: + """Apply pooling to the hidden states corresponding to the input + prompts. + + This class automatically batches the given prompts, considering + the memory constraint. For the best performance, put all of your prompts + into a single list and pass it to this method. + + Args: + prompts: The prompts to the LLM. You may pass a sequence of prompts + for batch inference. See :class:`~vllm.inputs.PromptType` + for more details about the format of each prompts. + pooling_params: The pooling parameters for pooling. If None, we + use the default pooling parameters. + use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. + prompt_adapter_request: Prompt Adapter request to use for + generation, if any. + + Returns: + A list of ``PoolingRequestOutput`` objects containing the + pooled hidden states in the same order as the input prompts. + + Note: + Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is + considered legacy and may be deprecated in the future. You should + instead pass them via the ``inputs`` parameter. + """ + runner_type = self.llm_engine.model_config.runner_type + if runner_type != "pooling": + messages = ["LLM.encode() is only supported for pooling models."] + + supported_runner_types = self.llm_engine.model_config \ + .supported_runner_types + if "pooling" in supported_runner_types: + messages.append( + "Your model supports the 'pooling' runner, but is " + f"currently initialized for the '{runner_type}' runner. " + "Please initialize vLLM using `--task embed`, " + "`--task classify`, `--task score` etc.") + + raise ValueError(" ".join(messages)) + + if prompt_token_ids is not None: + parsed_prompts = self._convert_v1_inputs( + prompts=cast(Optional[Union[str, list[str]]], prompts), + prompt_token_ids=prompt_token_ids, + ) + else: + parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], + prompts) + + if pooling_params is None: + # Use default pooling params. + pooling_params = PoolingParams() + + self._validate_and_add_requests( + prompts=parsed_prompts, + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + + outputs = self._run_engine(use_tqdm=use_tqdm) + return self.engine_class.validate_outputs(outputs, + PoolingRequestOutput) + + def embed( + self, + prompts: Union[PromptType, Sequence[PromptType]], + /, + *, + use_tqdm: bool = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> list[EmbeddingRequestOutput]: + """ + Generate an embedding vector for each prompt. + + This class automatically batches the given prompts, considering + the memory constraint. For the best performance, put all of your prompts + into a single list and pass it to this method. + + Args: + prompts: The prompts to the LLM. You may pass a sequence of prompts + for batch inference. See :class:`~vllm.inputs.PromptType` + for more details about the format of each prompts. + use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. + prompt_adapter_request: Prompt Adapter request to use for + generation, if any. + + Returns: + A list of ``EmbeddingRequestOutput`` objects containing the + embedding vectors in the same order as the input prompts. + """ + if self.llm_engine.model_config.task != "embed": + raise ValueError( + "Embedding API is only enabled for `--task embed`") + + items = self.encode(prompts, + use_tqdm=use_tqdm, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + return [EmbeddingRequestOutput.from_base(item) for item in items] + + def classify( + self, + prompts: Union[PromptType, Sequence[PromptType]], + /, + *, + use_tqdm: bool = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> list[ClassificationRequestOutput]: + """ + Generate class logits for each prompt. + + This class automatically batches the given prompts, considering + the memory constraint. For the best performance, put all of your prompts + into a single list and pass it to this method. + + Args: + prompts: The prompts to the LLM. You may pass a sequence of prompts + for batch inference. See :class:`~vllm.inputs.PromptType` + for more details about the format of each prompts. + use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. + prompt_adapter_request: Prompt Adapter request to use for + generation, if any. + + Returns: + A list of ``ClassificationRequestOutput`` objects containing the + embedding vectors in the same order as the input prompts. + """ + if self.llm_engine.model_config.task != "classify": + raise ValueError( + "Classification API is only enabled for `--task classify`") + + items = self.encode(prompts, + use_tqdm=use_tqdm, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + return [ClassificationRequestOutput.from_base(item) for item in items] + + def _embedding_score( + self, + tokenizer: AnyTokenizer, + text_1: list[Union[str, TextPrompt, TokensPrompt]], + text_2: list[Union[str, TextPrompt, TokensPrompt]], + truncate_prompt_tokens: Optional[int] = None, + use_tqdm: bool = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> list[ScoringRequestOutput]: + + encoded_output: list[PoolingRequestOutput] = self.encode( + text_1 + text_2, + use_tqdm=use_tqdm, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + encoded_output_1: list[PoolingRequestOutput] = encoded_output[ + 0:len(text_1)] + encoded_output_2: list[PoolingRequestOutput] = encoded_output[ + len(text_1):] + + if len(encoded_output_1) == 1: + encoded_output_1 = encoded_output_1 * len(encoded_output_2) + + scores: list[PoolingRequestOutput] = [] + + scores = _cosine_similarity(tokenizer=tokenizer, + embed_1=encoded_output_1, + embed_2=encoded_output_2) + + items = self.engine_class.validate_outputs(scores, + PoolingRequestOutput) + return [ScoringRequestOutput.from_base(item) for item in items] + + def _cross_encoding_score( + self, + tokenizer: AnyTokenizer, + text_1: list[str], + text_2: list[str], + truncate_prompt_tokens: Optional[int] = None, + use_tqdm: bool = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> list[ScoringRequestOutput]: + + if isinstance(tokenizer, MistralTokenizer): + raise ValueError( + "Score API is only enabled for `--task embed or score`") + + if len(text_1) == 1: + text_1 = text_1 * len(text_2) + + input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)] + + pooling_params = PoolingParams() + + tokenization_kwargs: dict[str, Any] = {} + if truncate_prompt_tokens is not None: + tokenization_kwargs["truncation"] = True + tokenization_kwargs["max_length"] = truncate_prompt_tokens + + parsed_prompts = [] + + for q, t in input_pairs: + prompt_inputs = tokenizer(text=q, + text_pair=t, + **tokenization_kwargs) + engine_prompt = TokensPrompt( + prompt_token_ids=prompt_inputs["input_ids"], + token_type_ids=prompt_inputs.get("token_type_ids")) + parsed_prompts.append(engine_prompt) + + self._validate_and_add_requests( + prompts=parsed_prompts, + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + + outputs = self._run_engine(use_tqdm=use_tqdm) + items = self.engine_class.validate_outputs(outputs, + PoolingRequestOutput) + + return [ScoringRequestOutput.from_base(item) for item in items] + + def score( + self, + text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]], + text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]], + /, + *, + truncate_prompt_tokens: Optional[int] = None, + use_tqdm: bool = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> list[ScoringRequestOutput]: + """Generate similarity scores for all pairs ````. + + The inputs can be ``1 -> 1``, ``1 -> N`` or ``N -> N``. + In the ``1 - N`` case the ``text_1`` sentence will be replicated ``N`` + times to pair with the ``text_2`` sentences. + The input pairs are used to build a list of prompts for the + cross encoder model. This class automatically batches the prompts, + considering the memory constraint. For the best performance, put all + of your texts into a single list and pass it to this method. + + Args: + text_1: can be a single prompt or a list of prompts, in which + case it has to have the same length as the ``text_2`` list + text_2: The texts to pair with the query to form the input + to the LLM. See :class:`~vllm.inputs.PromptType` for + more details about the format of each prompts. + use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. + prompt_adapter_request: Prompt Adapter request to use for + generation, if any. + + Returns: + A list of ``ScoringRequestOutput`` objects containing the + generated scores in the same order as the input prompts. + """ + runner_type = self.llm_engine.model_config.runner_type + if runner_type != "pooling": + messages = ["LLM.score() is only supported for pooling models."] + + supported_runner_types = self.llm_engine.model_config \ + .supported_runner_types + if "pooling" in supported_runner_types: + messages.append( + "Your model supports the 'pooling' runner, but is " + f"currently initialized for the '{runner_type}' runner. " + "Please initialize vLLM using `--task embed`, " + "`--task classify`, `--task score` etc.") + + raise ValueError(" ".join(messages)) + + if self.llm_engine.model_config.task not in ("embed", "score"): + raise ValueError( + "Score API is only enabled for `--task embed or --task score`") + + # the tokenizer for models such as + # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing + # lists of tokens to the `text` and `text_pair` kwargs + tokenizer = self.llm_engine.get_tokenizer() + + def ensure_str(prompt: SingletonPrompt): + if isinstance(prompt, dict): + if "multi_modal_data" in prompt: + raise ValueError("Multi-modal prompt is not " + "supported for scoring") + elif "prompt_token_ids" in prompt: + prompt = tokenizer.decode( + cast(TokensPrompt, prompt)["prompt_token_ids"]) + elif "prompt" in prompt: + prompt = cast(TextPrompt, prompt)["prompt"] + assert type(prompt) is str + return prompt + + if isinstance(text_1, (str, dict)): + # Convert a single prompt to a list. + text_1 = [text_1] + input_text_1: list[str] = [ensure_str(t) for t in text_1] + + if isinstance(text_2, (str, dict)): + # Convert a single prompt to a list. + text_2 = [text_2] + input_text_2: list[str] = [ensure_str(t) for t in text_2] + + _validate_score_input_lens(input_text_1, input_text_2) + + if self.llm_engine.model_config.is_cross_encoder: + return self._cross_encoding_score(tokenizer, input_text_1, + input_text_2, + truncate_prompt_tokens, use_tqdm, + lora_request, + prompt_adapter_request) + else: + return self._embedding_score( + tokenizer, + input_text_1, # type: ignore[arg-type] + input_text_2, # type: ignore[arg-type] + truncate_prompt_tokens, + use_tqdm, + lora_request, + prompt_adapter_request) + + def start_profile(self) -> None: + self.llm_engine.start_profile() + + def stop_profile(self) -> None: + self.llm_engine.stop_profile() + + def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: + return self.llm_engine.reset_prefix_cache(device) + + def sleep(self, level: int = 1): + """ + Put the engine to sleep. The engine should not process any requests. + The caller should guarantee that no requests are being processed + during the sleep period, before `wake_up` is called. + + Args: + level: The sleep level. Level 1 sleep will offload the model + weights and discard the kv cache. The content of kv cache + is forgotten. Level 1 sleep is good for sleeping and waking + up the engine to run the same model again. The model weights + are backed up in CPU memory. Please make sure there's enough + CPU memory to store the model weights. Level 2 sleep will + discard both the model weights and the kv cache. The content + of both the model weights and kv cache is forgotten. Level 2 + sleep is good for sleeping and waking up the engine to run a + different model or update the model, where previous model + weights are not needed. It reduces CPU memory pressure. + """ + self.reset_prefix_cache() + self.llm_engine.sleep(level=level) + + def wake_up(self, tags: Optional[list[str]] = None): + """ + Wake up the engine from sleep mode. See the :meth:`sleep` method + for more details. + + Args: + tags: An optional list of tags to reallocate the engine memory + for specific memory allocations. Values must be in + ("weights", "kv_cache",). If None, all memory is reallocated. + wake_up should be called with all tags (or None) before the + engine is used again. + """ + self.llm_engine.wake_up(tags) + + # LEGACY + def _convert_v1_inputs( + self, + prompts: Optional[Union[str, list[str]]], + prompt_token_ids: Optional[Union[list[int], list[list[int]]]], + ): + # skip_tokenizer_init is now checked in engine + + if prompts is not None: + prompts = [p["content"] for p in parse_and_batch_prompt(prompts)] + if prompt_token_ids is not None: + prompt_token_ids = [ + p["content"] for p in parse_and_batch_prompt(prompt_token_ids) + ] + + num_requests = None + if prompts is not None: + num_requests = len(prompts) + if prompt_token_ids is not None: + if (num_requests is not None + and num_requests != len(prompt_token_ids)): + raise ValueError("The lengths of prompts and prompt_token_ids " + "must be the same.") + + num_requests = len(prompt_token_ids) + if num_requests is None: + raise ValueError("Either prompts or prompt_token_ids must be " + "provided.") + + parsed_prompts: list[PromptType] = [] + for i in range(num_requests): + item: PromptType + + if prompts is not None: + item = TextPrompt(prompt=prompts[i]) + elif prompt_token_ids is not None: + item = TokensPrompt(prompt_token_ids=prompt_token_ids[i]) + else: + raise AssertionError + + parsed_prompts.append(item) + + return parsed_prompts + + def _validate_and_add_requests( + self, + prompts: Union[PromptType, Sequence[PromptType]], + params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, + Sequence[PoolingParams]], + lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], + prompt_adapter_request: Optional[PromptAdapterRequest], + guided_options: Optional[GuidedDecodingRequest] = None, + priority: Optional[list[int]] = None, + ) -> None: + if guided_options is not None: + warnings.warn( + "guided_options_request is deprecated, use " + "SamplingParams.guided_decoding instead", + DeprecationWarning, + stacklevel=2, + ) + + if isinstance(prompts, (str, dict)): + # Convert a single prompt to a list. + prompts = [prompts] + + num_requests = len(prompts) + if isinstance(params, list) and len(params) != num_requests: + raise ValueError("The lengths of prompts and params " + "must be the same.") + if isinstance(lora_request, + list) and len(lora_request) != num_requests: + raise ValueError("The lengths of prompts and lora_request " + "must be the same.") + + for sp in params if isinstance(params, list) else (params, ): + if isinstance(sp, SamplingParams): + self._add_guided_params(sp, guided_options) + + # We only care about the final output + sp.output_kind = RequestOutputKind.FINAL_ONLY + + # Add requests to the engine. + for i, prompt in enumerate(prompts): + self._add_request( + prompt, + params[i] if isinstance(params, Sequence) else params, + lora_request=lora_request[i] if isinstance( + lora_request, Sequence) else lora_request, + prompt_adapter_request=prompt_adapter_request, + priority=priority[i] if priority else 0, + ) + + def _add_request( + self, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> None: + request_id = str(next(self.request_counter)) + self.llm_engine.add_request( + request_id, + prompt, + params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + ) + + def _add_guided_params( + self, + params: SamplingParams, + guided_options: Optional[GuidedDecodingRequest] = None): + if guided_options is None: + return params + + if params.guided_decoding is not None: + raise ValueError("Cannot set both guided_options_request and " + "params.guided_decoding.") + + params.guided_decoding = GuidedDecodingParams( + json=guided_options.guided_json, + regex=guided_options.guided_regex, + choice=guided_options.guided_choice, + grammar=guided_options.guided_grammar, + json_object=guided_options.guided_json_object, + backend=guided_options.guided_decoding_backend, + whitespace_pattern=guided_options.guided_whitespace_pattern) + return params + + def _run_engine( + self, *, use_tqdm: bool + ) -> list[Union[RequestOutput, PoolingRequestOutput]]: + # Initialize tqdm. + if use_tqdm: + num_requests = self.llm_engine.get_num_unfinished_requests() + pbar = tqdm( + total=num_requests, + desc="Processed prompts", + dynamic_ncols=True, + postfix=(f"est. speed input: {0:.2f} toks/s, " + f"output: {0:.2f} toks/s"), + ) + + # Run the engine. + outputs: list[Union[RequestOutput, PoolingRequestOutput]] = [] + total_in_toks = 0 + total_out_toks = 0 + while self.llm_engine.has_unfinished_requests(): + step_outputs = self.llm_engine.step() + for output in step_outputs: + if output.finished: + outputs.append(output) + if use_tqdm: + if isinstance(output, RequestOutput): + # Calculate tokens only for RequestOutput + n = len(output.outputs) + assert output.prompt_token_ids is not None + total_in_toks += len(output.prompt_token_ids) * n + in_spd = total_in_toks / pbar.format_dict["elapsed"] + total_out_toks += sum( + len(stp.token_ids) for stp in output.outputs) + out_spd = (total_out_toks / + pbar.format_dict["elapsed"]) + pbar.postfix = ( + f"est. speed input: {in_spd:.2f} toks/s, " + f"output: {out_spd:.2f} toks/s") + pbar.update(n) + else: + pbar.update(1) + + if use_tqdm: + pbar.close() + # Sort the outputs by request ID. + # This is necessary because some requests may be finished earlier than + # its previous requests. + return sorted(outputs, key=lambda x: int(x.request_id)) diff --git a/vllm/entrypoints/logger.py b/vllm/entrypoints/logger.py new file mode 100644 index 0000000..ea57591 --- /dev/null +++ b/vllm/entrypoints/logger.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Union + +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import BeamSearchParams, SamplingParams + +logger = init_logger(__name__) + + +class RequestLogger: + + def __init__(self, *, max_log_len: Optional[int]) -> None: + super().__init__() + + self.max_log_len = max_log_len + + def log_inputs( + self, + request_id: str, + prompt: Optional[str], + prompt_token_ids: Optional[list[int]], + params: Optional[Union[SamplingParams, PoolingParams, + BeamSearchParams]], + lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], + ) -> None: + max_log_len = self.max_log_len + if max_log_len is not None: + if prompt is not None: + prompt = prompt[:max_log_len] + + if prompt_token_ids is not None: + prompt_token_ids = prompt_token_ids[:max_log_len] + + logger.info( + "Received request %s: prompt: %r, " + "params: %s, prompt_token_ids: %s, " + "lora_request: %s, prompt_adapter_request: %s.", request_id, + prompt, params, prompt_token_ids, lora_request, + prompt_adapter_request) diff --git a/vllm/entrypoints/openai/__init__.py b/vllm/entrypoints/openai/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py new file mode 100644 index 0000000..6a8bdd0 --- /dev/null +++ b/vllm/entrypoints/openai/api_server.py @@ -0,0 +1,1121 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import atexit +import gc +import importlib +import inspect +import multiprocessing +import os +import re +import signal +import socket +import tempfile +import uuid +from argparse import Namespace +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from functools import partial +from http import HTTPStatus +from typing import Annotated, Optional, Union + +import uvloop +from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, Response, StreamingResponse +from starlette.concurrency import iterate_in_threadpool +from starlette.datastructures import State +from starlette.routing import Mount +from typing_extensions import assert_never + +import vllm.envs as envs +from vllm.config import ModelConfig +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore +from vllm.engine.multiprocessing.client import MQLLMEngineClient +from vllm.engine.multiprocessing.engine import run_mp_engine +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import (load_chat_template, + resolve_hf_chat_template, + resolve_mistral_chat_template) +from vllm.entrypoints.launcher import serve_http +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.cli_args import (make_arg_parser, + validate_parsed_serve_args) +# yapf conflicts with isort for this block +# yapf: disable +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + ChatCompletionResponse, + CompletionRequest, + CompletionResponse, + DetokenizeRequest, + DetokenizeResponse, + EmbeddingChatRequest, + EmbeddingCompletionRequest, + EmbeddingRequest, + EmbeddingResponse, + EmbeddingResponseData, + ErrorResponse, + LoadLoRAAdapterRequest, + PoolingChatRequest, + PoolingCompletionRequest, + PoolingRequest, PoolingResponse, + RerankRequest, RerankResponse, + ScoreRequest, ScoreResponse, + TokenizeRequest, + TokenizeResponse, + TranscriptionRequest, + TranscriptionResponse, + UnloadLoRAAdapterRequest) +# yapf: enable +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion +from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import (BaseModelPath, + OpenAIServingModels) +from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling +from vllm.entrypoints.openai.serving_score import ServingScores +from vllm.entrypoints.openai.serving_tokenization import ( + OpenAIServingTokenization) +from vllm.entrypoints.openai.serving_transcription import ( + OpenAIServingTranscription) +from vllm.entrypoints.openai.tool_parsers import ToolParserManager +from vllm.entrypoints.utils import (cli_env_setup, load_aware_call, + with_cancellation) +from vllm.logger import init_logger +from vllm.reasoning import ReasoningParserManager +from vllm.transformers_utils.config import ( + maybe_register_config_serialize_by_value) +from vllm.transformers_utils.tokenizer import MistralTokenizer +from vllm.usage.usage_lib import UsageContext +from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path, + is_valid_ipv6_address, set_ulimit) +from vllm.version import __version__ as VLLM_VERSION + +TIMEOUT_KEEP_ALIVE = 5 # seconds + +prometheus_multiproc_dir: tempfile.TemporaryDirectory + +# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) +logger = init_logger('vllm.entrypoints.openai.api_server') + +_running_tasks: set[asyncio.Task] = set() + + +@asynccontextmanager +async def lifespan(app: FastAPI): + try: + if app.state.log_stats: + engine_client: EngineClient = app.state.engine_client + + async def _force_log(): + while True: + await asyncio.sleep(10.) + await engine_client.do_log_stats() + + task = asyncio.create_task(_force_log()) + _running_tasks.add(task) + task.add_done_callback(_running_tasks.remove) + else: + task = None + + # Mark the startup heap as static so that it's ignored by GC. + # Reduces pause times of oldest generation collections. + gc.collect() + gc.freeze() + try: + yield + finally: + if task is not None: + task.cancel() + finally: + # Ensure app state including engine ref is gc'd + del app.state + + +@asynccontextmanager +async def build_async_engine_client( + args: Namespace) -> AsyncIterator[EngineClient]: + + # Context manager to handle engine_client lifecycle + # Ensures everything is shutdown and cleaned up on error/exit + engine_args = AsyncEngineArgs.from_cli_args(args) + + async with build_async_engine_client_from_engine_args( + engine_args, args.disable_frontend_multiprocessing) as engine: + yield engine + + +@asynccontextmanager +async def build_async_engine_client_from_engine_args( + engine_args: AsyncEngineArgs, + disable_frontend_multiprocessing: bool = False, +) -> AsyncIterator[EngineClient]: + """ + Create EngineClient, either: + - in-process using the AsyncLLMEngine Directly + - multiprocess using AsyncLLMEngine RPC + + Returns the Client or None if the creation failed. + """ + + # Create the EngineConfig (determines if we can use V1). + usage_context = UsageContext.OPENAI_API_SERVER + vllm_config = engine_args.create_engine_config(usage_context=usage_context) + + # V1 AsyncLLM. + if envs.VLLM_USE_V1: + if disable_frontend_multiprocessing: + logger.warning( + "V1 is enabled, but got --disable-frontend-multiprocessing. " + "To disable frontend multiprocessing, set VLLM_USE_V1=0.") + + from vllm.v1.engine.async_llm import AsyncLLM + async_llm: Optional[AsyncLLM] = None + try: + async_llm = AsyncLLM.from_vllm_config( + vllm_config=vllm_config, + usage_context=usage_context, + disable_log_requests=engine_args.disable_log_requests, + disable_log_stats=engine_args.disable_log_stats) + yield async_llm + finally: + if async_llm: + async_llm.shutdown() + + # V0 AsyncLLM. + elif (MQLLMEngineClient.is_unsupported_config(vllm_config) + or disable_frontend_multiprocessing): + + engine_client: Optional[EngineClient] = None + try: + engine_client = AsyncLLMEngine.from_vllm_config( + vllm_config=vllm_config, + usage_context=usage_context, + disable_log_requests=engine_args.disable_log_requests, + disable_log_stats=engine_args.disable_log_stats) + yield engine_client + finally: + if engine_client and hasattr(engine_client, "shutdown"): + engine_client.shutdown() + + # V0MQLLMEngine. + else: + if "PROMETHEUS_MULTIPROC_DIR" not in os.environ: + # Make TemporaryDirectory for prometheus multiprocessing + # Note: global TemporaryDirectory will be automatically + # cleaned up upon exit. + global prometheus_multiproc_dir + prometheus_multiproc_dir = tempfile.TemporaryDirectory() + os.environ[ + "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name + else: + logger.warning( + "Found PROMETHEUS_MULTIPROC_DIR was set by user. " + "This directory must be wiped between vLLM runs or " + "you will find inaccurate metrics. Unset the variable " + "and vLLM will properly handle cleanup.") + + # Select random path for IPC. + ipc_path = get_open_zmq_ipc_path() + logger.debug("Multiprocessing frontend to use %s for IPC Path.", + ipc_path) + + # Start RPCServer in separate process (holds the LLMEngine). + # the current process might have CUDA context, + # so we need to spawn a new process + context = multiprocessing.get_context("spawn") + + # Ensure we can serialize transformer config before spawning + maybe_register_config_serialize_by_value() + + # The Process can raise an exception during startup, which may + # not actually result in an exitcode being reported. As a result + # we use a shared variable to communicate the information. + engine_alive = multiprocessing.Value('b', True, lock=False) + engine_process = context.Process( + target=run_mp_engine, + args=(vllm_config, UsageContext.OPENAI_API_SERVER, ipc_path, + engine_args.disable_log_stats, + engine_args.disable_log_requests, engine_alive)) + engine_process.start() + engine_pid = engine_process.pid + assert engine_pid is not None, "Engine process failed to start." + logger.info("Started engine process with PID %d", engine_pid) + + def _cleanup_ipc_path(): + socket_path = ipc_path.replace("ipc://", "") + if os.path.exists(socket_path): + os.remove(socket_path) + + # Ensure we clean up the local IPC socket file on exit. + atexit.register(_cleanup_ipc_path) + + # Build RPCClient, which conforms to EngineClient Protocol. + build_client = partial(MQLLMEngineClient, ipc_path, vllm_config, + engine_pid) + mq_engine_client = await asyncio.get_running_loop().run_in_executor( + None, build_client) + try: + while True: + try: + await mq_engine_client.setup() + break + except TimeoutError: + if (not engine_process.is_alive() + or not engine_alive.value): + raise RuntimeError( + "Engine process failed to start. See stack " + "trace for the root cause.") from None + + yield mq_engine_client # type: ignore[misc] + finally: + # Ensure rpc server process was terminated + engine_process.terminate() + + # Close all open connections to the backend + mq_engine_client.close() + + # Wait for engine process to join + engine_process.join(4) + if engine_process.exitcode is None: + # Kill if taking longer than 5 seconds to stop + engine_process.kill() + + # Lazy import for prometheus multiprocessing. + # We need to set PROMETHEUS_MULTIPROC_DIR environment variable + # before prometheus_client is imported. + # See https://prometheus.github.io/client_python/multiprocess/ + from prometheus_client import multiprocess + multiprocess.mark_process_dead(engine_process.pid) + + +async def validate_json_request(raw_request: Request): + content_type = raw_request.headers.get("content-type", "").lower() + media_type = content_type.split(";", maxsplit=1)[0] + if media_type != "application/json": + raise HTTPException( + status_code=HTTPStatus.UNSUPPORTED_MEDIA_TYPE, + detail="Unsupported Media Type: Only 'application/json' is allowed" + ) + + +router = APIRouter() + + +def mount_metrics(app: FastAPI): + # Lazy import for prometheus multiprocessing. + # We need to set PROMETHEUS_MULTIPROC_DIR environment variable + # before prometheus_client is imported. + # See https://prometheus.github.io/client_python/multiprocess/ + from prometheus_client import (CollectorRegistry, make_asgi_app, + multiprocess) + from prometheus_fastapi_instrumentator import Instrumentator + + prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None) + if prometheus_multiproc_dir_path is not None: + logger.debug("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR", + prometheus_multiproc_dir_path) + registry = CollectorRegistry() + multiprocess.MultiProcessCollector(registry) + Instrumentator( + excluded_handlers=[ + "/metrics", + "/health", + "/load", + "/ping", + "/version", + ], + registry=registry, + ).add().instrument(app).expose(app) + + # Add prometheus asgi middleware to route /metrics requests + metrics_route = Mount("/metrics", make_asgi_app(registry=registry)) + else: + # Add prometheus asgi middleware to route /metrics requests + metrics_route = Mount("/metrics", make_asgi_app()) + + # Workaround for 307 Redirect for /metrics + metrics_route.path_regex = re.compile("^/metrics(?P.*)$") + app.routes.append(metrics_route) + + +def base(request: Request) -> OpenAIServing: + # Reuse the existing instance + return tokenization(request) + + +def models(request: Request) -> OpenAIServingModels: + return request.app.state.openai_serving_models + + +def chat(request: Request) -> Optional[OpenAIServingChat]: + return request.app.state.openai_serving_chat + + +def completion(request: Request) -> Optional[OpenAIServingCompletion]: + return request.app.state.openai_serving_completion + + +def pooling(request: Request) -> Optional[OpenAIServingPooling]: + return request.app.state.openai_serving_pooling + + +def embedding(request: Request) -> Optional[OpenAIServingEmbedding]: + return request.app.state.openai_serving_embedding + + +def score(request: Request) -> Optional[ServingScores]: + return request.app.state.openai_serving_scores + + +def rerank(request: Request) -> Optional[ServingScores]: + return request.app.state.openai_serving_scores + + +def tokenization(request: Request) -> OpenAIServingTokenization: + return request.app.state.openai_serving_tokenization + + +def transcription(request: Request) -> OpenAIServingTranscription: + return request.app.state.openai_serving_transcription + + +def engine_client(request: Request) -> EngineClient: + return request.app.state.engine_client + + +@router.get("/health") +async def health(raw_request: Request) -> Response: + """Health check.""" + await engine_client(raw_request).check_health() + return Response(status_code=200) + + +@router.get("/load") +async def get_server_load_metrics(request: Request): + # This endpoint returns the current server load metrics. + # It tracks requests utilizing the GPU from the following routes: + # - /v1/chat/completions + # - /v1/completions + # - /v1/audio/transcriptions + # - /v1/embeddings + # - /pooling + # - /score + # - /v1/score + # - /rerank + # - /v1/rerank + # - /v2/rerank + return JSONResponse( + content={'server_load': request.app.state.server_load_metrics}) + + +@router.api_route("/ping", methods=["GET", "POST"]) +async def ping(raw_request: Request) -> Response: + """Ping check. Endpoint required for SageMaker""" + return await health(raw_request) + + +@router.post("/tokenize", dependencies=[Depends(validate_json_request)]) +@with_cancellation +async def tokenize(request: TokenizeRequest, raw_request: Request): + handler = tokenization(raw_request) + + generator = await handler.create_tokenize(request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, TokenizeResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +@router.post("/detokenize", dependencies=[Depends(validate_json_request)]) +@with_cancellation +async def detokenize(request: DetokenizeRequest, raw_request: Request): + handler = tokenization(raw_request) + + generator = await handler.create_detokenize(request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, DetokenizeResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +@router.get("/v1/models") +async def show_available_models(raw_request: Request): + handler = models(raw_request) + + models_ = await handler.show_available_models() + return JSONResponse(content=models_.model_dump()) + + +@router.get("/version") +async def show_version(): + ver = {"version": VLLM_VERSION} + return JSONResponse(content=ver) + + +@router.post("/v1/chat/completions", + dependencies=[Depends(validate_json_request)]) +@with_cancellation +@load_aware_call +async def create_chat_completion(request: ChatCompletionRequest, + raw_request: Request): + handler = chat(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Chat Completions API") + + generator = await handler.create_chat_completion(request, raw_request) + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + + elif isinstance(generator, ChatCompletionResponse): + return JSONResponse(content=generator.model_dump()) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + +@router.post("/v1/completions", dependencies=[Depends(validate_json_request)]) +@with_cancellation +@load_aware_call +async def create_completion(request: CompletionRequest, raw_request: Request): + handler = completion(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Completions API") + + generator = await handler.create_completion(request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, CompletionResponse): + return JSONResponse(content=generator.model_dump()) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + +@router.post("/v1/embeddings", dependencies=[Depends(validate_json_request)]) +@with_cancellation +@load_aware_call +async def create_embedding(request: EmbeddingRequest, raw_request: Request): + handler = embedding(raw_request) + if handler is None: + fallback_handler = pooling(raw_request) + if fallback_handler is None: + return base(raw_request).create_error_response( + message="The model does not support Embeddings API") + + logger.warning( + "Embeddings API will become exclusive to embedding models " + "in a future release. To return the hidden states directly, " + "use the Pooling API (`/pooling`) instead.") + + res = await fallback_handler.create_pooling(request, raw_request) + + generator: Union[ErrorResponse, EmbeddingResponse] + if isinstance(res, PoolingResponse): + generator = EmbeddingResponse( + id=res.id, + object=res.object, + created=res.created, + model=res.model, + data=[ + EmbeddingResponseData( + index=d.index, + embedding=d.data, # type: ignore + ) for d in res.data + ], + usage=res.usage, + ) + else: + generator = res + else: + generator = await handler.create_embedding(request, raw_request) + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, EmbeddingResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +@router.post("/pooling", dependencies=[Depends(validate_json_request)]) +@with_cancellation +@load_aware_call +async def create_pooling(request: PoolingRequest, raw_request: Request): + handler = pooling(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Pooling API") + + generator = await handler.create_pooling(request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, PoolingResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +@router.post("/score", dependencies=[Depends(validate_json_request)]) +@with_cancellation +@load_aware_call +async def create_score(request: ScoreRequest, raw_request: Request): + handler = score(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Score API") + + generator = await handler.create_score(request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, ScoreResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +@router.post("/v1/score", dependencies=[Depends(validate_json_request)]) +@with_cancellation +@load_aware_call +async def create_score_v1(request: ScoreRequest, raw_request: Request): + logger.warning( + "To indicate that Score API is not part of standard OpenAI API, we " + "have moved it to `/score`. Please update your client accordingly.") + + return await create_score(request, raw_request) + + +@router.post("/v1/audio/transcriptions") +@with_cancellation +@load_aware_call +async def create_transcriptions(request: Annotated[TranscriptionRequest, + Form()], + raw_request: Request): + handler = transcription(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Transcriptions API") + + audio_data = await request.file.read() + generator = await handler.create_transcription(audio_data, request, + raw_request) + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + + elif isinstance(generator, TranscriptionResponse): + return JSONResponse(content=generator.model_dump()) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + +@router.post("/rerank", dependencies=[Depends(validate_json_request)]) +@with_cancellation +@load_aware_call +async def do_rerank(request: RerankRequest, raw_request: Request): + handler = rerank(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Rerank (Score) API") + generator = await handler.do_rerank(request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, RerankResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +@router.post("/v1/rerank", dependencies=[Depends(validate_json_request)]) +@with_cancellation +async def do_rerank_v1(request: RerankRequest, raw_request: Request): + logger.warning_once( + "To indicate that the rerank API is not part of the standard OpenAI" + " API, we have located it at `/rerank`. Please update your client " + "accordingly. (Note: Conforms to JinaAI rerank API)") + + return await do_rerank(request, raw_request) + + +@router.post("/v2/rerank", dependencies=[Depends(validate_json_request)]) +@with_cancellation +async def do_rerank_v2(request: RerankRequest, raw_request: Request): + return await do_rerank(request, raw_request) + + +TASK_HANDLERS: dict[str, dict[str, tuple]] = { + "generate": { + "messages": (ChatCompletionRequest, create_chat_completion), + "default": (CompletionRequest, create_completion), + }, + "embed": { + "messages": (EmbeddingChatRequest, create_embedding), + "default": (EmbeddingCompletionRequest, create_embedding), + }, + "score": { + "default": (RerankRequest, do_rerank) + }, + "rerank": { + "default": (RerankRequest, do_rerank) + }, + "reward": { + "messages": (PoolingChatRequest, create_pooling), + "default": (PoolingCompletionRequest, create_pooling), + }, + "classify": { + "messages": (PoolingChatRequest, create_pooling), + "default": (PoolingCompletionRequest, create_pooling), + }, +} + +if envs.VLLM_SERVER_DEV_MODE: + + @router.post("/reset_prefix_cache") + async def reset_prefix_cache(raw_request: Request): + """ + Reset the prefix cache. Note that we currently do not check if the + prefix cache is successfully reset in the API server. + """ + device = None + device_str = raw_request.query_params.get("device") + if device_str is not None: + device = Device[device_str.upper()] + logger.info("Resetting prefix cache with specific %s...", str(device)) + await engine_client(raw_request).reset_prefix_cache(device) + return Response(status_code=200) + + @router.post("/sleep") + async def sleep(raw_request: Request): + # get POST params + level = raw_request.query_params.get("level", "1") + await engine_client(raw_request).sleep(int(level)) + # FIXME: in v0 with frontend multiprocessing, the sleep command + # is sent but does not finish yet when we return a response. + return Response(status_code=200) + + @router.post("/wake_up") + async def wake_up(raw_request: Request): + tags = raw_request.query_params.getlist("tags") + if tags == []: + # set to None to wake up all tags if no tags are provided + tags = None + logger.info("wake up the engine with tags: %s", tags) + await engine_client(raw_request).wake_up(tags) + # FIXME: in v0 with frontend multiprocessing, the wake-up command + # is sent but does not finish yet when we return a response. + return Response(status_code=200) + + @router.get("/is_sleeping") + async def is_sleeping(raw_request: Request): + logger.info("check whether the engine is sleeping") + is_sleeping = await engine_client(raw_request).is_sleeping() + return JSONResponse(content={"is_sleeping": is_sleeping}) + + +@router.post("/invocations", dependencies=[Depends(validate_json_request)]) +async def invocations(raw_request: Request): + """ + For SageMaker, routes requests to other handlers based on model `task`. + """ + body = await raw_request.json() + task = raw_request.app.state.task + + if task not in TASK_HANDLERS: + raise HTTPException( + status_code=400, + detail=f"Unsupported task: '{task}' for '/invocations'. " + f"Expected one of {set(TASK_HANDLERS.keys())}") + + handler_config = TASK_HANDLERS[task] + if "messages" in body: + request_model, handler = handler_config["messages"] + else: + request_model, handler = handler_config["default"] + + # this is required since we lose the FastAPI automatic casting + request = request_model.model_validate(body) + return await handler(request, raw_request) + + +if envs.VLLM_TORCH_PROFILER_DIR: + logger.warning( + "Torch Profiler is enabled in the API server. This should ONLY be " + "used for local development!") + + @router.post("/start_profile") + async def start_profile(raw_request: Request): + logger.info("Starting profiler...") + await engine_client(raw_request).start_profile() + logger.info("Profiler started.") + return Response(status_code=200) + + @router.post("/stop_profile") + async def stop_profile(raw_request: Request): + logger.info("Stopping profiler...") + await engine_client(raw_request).stop_profile() + logger.info("Profiler stopped.") + return Response(status_code=200) + + +if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: + logger.warning( + "LoRA dynamic loading & unloading is enabled in the API server. " + "This should ONLY be used for local development!") + + @router.post("/v1/load_lora_adapter", + dependencies=[Depends(validate_json_request)]) + async def load_lora_adapter(request: LoadLoRAAdapterRequest, + raw_request: Request): + handler = models(raw_request) + response = await handler.load_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + + return Response(status_code=200, content=response) + + @router.post("/v1/unload_lora_adapter", + dependencies=[Depends(validate_json_request)]) + async def unload_lora_adapter(request: UnloadLoRAAdapterRequest, + raw_request: Request): + handler = models(raw_request) + response = await handler.unload_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + + return Response(status_code=200, content=response) + + +def build_app(args: Namespace) -> FastAPI: + if args.disable_fastapi_docs: + app = FastAPI(openapi_url=None, + docs_url=None, + redoc_url=None, + lifespan=lifespan) + else: + app = FastAPI(lifespan=lifespan) + app.include_router(router) + app.root_path = args.root_path + + mount_metrics(app) + + app.add_middleware( + CORSMiddleware, + allow_origins=args.allowed_origins, + allow_credentials=args.allow_credentials, + allow_methods=args.allowed_methods, + allow_headers=args.allowed_headers, + ) + + @app.exception_handler(RequestValidationError) + async def validation_exception_handler(_, exc): + err = ErrorResponse(message=str(exc), + type="BadRequestError", + code=HTTPStatus.BAD_REQUEST) + return JSONResponse(err.model_dump(), + status_code=HTTPStatus.BAD_REQUEST) + + # Ensure --api-key option from CLI takes precedence over VLLM_API_KEY + if token := args.api_key or envs.VLLM_API_KEY: + + @app.middleware("http") + async def authentication(request: Request, call_next): + if request.method == "OPTIONS": + return await call_next(request) + url_path = request.url.path + if app.root_path and url_path.startswith(app.root_path): + url_path = url_path[len(app.root_path):] + if not url_path.startswith("/v1"): + return await call_next(request) + if request.headers.get("Authorization") != "Bearer " + token: + return JSONResponse(content={"error": "Unauthorized"}, + status_code=401) + return await call_next(request) + + if args.enable_request_id_headers: + logger.warning( + "CAUTION: Enabling X-Request-Id headers in the API Server. " + "This can harm performance at high QPS.") + + @app.middleware("http") + async def add_request_id(request: Request, call_next): + request_id = request.headers.get( + "X-Request-Id") or uuid.uuid4().hex + response = await call_next(request) + response.headers["X-Request-Id"] = request_id + return response + + if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE: + logger.warning("CAUTION: Enabling log response in the API Server. " + "This can include sensitive information and should be " + "avoided in production.") + + @app.middleware("http") + async def log_response(request: Request, call_next): + response = await call_next(request) + response_body = [ + section async for section in response.body_iterator + ] + response.body_iterator = iterate_in_threadpool(iter(response_body)) + logger.info("response_body={%s}", response_body[0].decode()) + return response + + for middleware in args.middleware: + module_path, object_name = middleware.rsplit(".", 1) + imported = getattr(importlib.import_module(module_path), object_name) + if inspect.isclass(imported): + app.add_middleware(imported) # type: ignore[arg-type] + elif inspect.iscoroutinefunction(imported): + app.middleware("http")(imported) + else: + raise ValueError(f"Invalid middleware {middleware}. " + f"Must be a function or a class.") + + return app + + +async def init_app_state( + engine_client: EngineClient, + model_config: ModelConfig, + state: State, + args: Namespace, +) -> None: + if args.served_model_name is not None: + served_model_names = args.served_model_name + else: + served_model_names = [args.model] + + if args.disable_log_requests: + request_logger = None + else: + request_logger = RequestLogger(max_log_len=args.max_log_len) + + base_model_paths = [ + BaseModelPath(name=name, model_path=args.model) + for name in served_model_names + ] + + state.engine_client = engine_client + state.log_stats = not args.disable_log_stats + + resolved_chat_template = load_chat_template(args.chat_template) + if resolved_chat_template is not None: + # Get the tokenizer to check official template + tokenizer = await engine_client.get_tokenizer() + + if isinstance(tokenizer, MistralTokenizer): + # The warning is logged in resolve_mistral_chat_template. + resolved_chat_template = resolve_mistral_chat_template( + chat_template=resolved_chat_template) + else: + hf_chat_template = resolve_hf_chat_template( + tokenizer, + chat_template=None, + tools=None, + trust_remote_code=model_config.trust_remote_code) + + if hf_chat_template != resolved_chat_template: + logger.warning( + "Using supplied chat template: %s\n" + "It is different from official chat template '%s'. " + "This discrepancy may lead to performance degradation.", + resolved_chat_template, args.model) + + state.openai_serving_models = OpenAIServingModels( + engine_client=engine_client, + model_config=model_config, + base_model_paths=base_model_paths, + lora_modules=args.lora_modules, + prompt_adapters=args.prompt_adapters, + ) + await state.openai_serving_models.init_static_loras() + state.openai_serving_chat = OpenAIServingChat( + engine_client, + model_config, + state.openai_serving_models, + args.response_role, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_auto_tools=args.enable_auto_tool_choice, + tool_parser=args.tool_call_parser, + enable_reasoning=args.enable_reasoning, + reasoning_parser=args.reasoning_parser, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + ) if model_config.runner_type == "generate" else None + state.openai_serving_completion = OpenAIServingCompletion( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + ) if model_config.runner_type == "generate" else None + state.openai_serving_pooling = OpenAIServingPooling( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + ) if model_config.runner_type == "pooling" else None + state.openai_serving_embedding = OpenAIServingEmbedding( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + ) if model_config.task == "embed" else None + state.openai_serving_scores = ServingScores( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger) if model_config.task in ( + "score", "embed", "pooling") else None + state.jinaai_serving_reranking = ServingScores( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger + ) if model_config.task == "score" else None + state.openai_serving_tokenization = OpenAIServingTokenization( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + ) + state.openai_serving_transcription = OpenAIServingTranscription( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + ) if model_config.runner_type == "transcription" else None + state.task = model_config.task + + state.enable_server_load_tracking = args.enable_server_load_tracking + state.server_load_metrics = 0 + + +def create_server_socket(addr: tuple[str, int]) -> socket.socket: + family = socket.AF_INET + if is_valid_ipv6_address(addr[0]): + family = socket.AF_INET6 + + sock = socket.socket(family=family, type=socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + sock.bind(addr) + + return sock + + +async def run_server(args, **uvicorn_kwargs) -> None: + logger.info("vLLM API server version %s", VLLM_VERSION) + logger.info("args: %s", args) + + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + valid_tool_parses = ToolParserManager.tool_parsers.keys() + if args.enable_auto_tool_choice \ + and args.tool_call_parser not in valid_tool_parses: + raise KeyError(f"invalid tool call parser: {args.tool_call_parser} " + f"(chose from {{ {','.join(valid_tool_parses)} }})") + + valid_reasoning_parses = ReasoningParserManager.reasoning_parsers.keys() + if args.enable_reasoning \ + and args.reasoning_parser not in valid_reasoning_parses: + raise KeyError( + f"invalid reasoning parser: {args.reasoning_parser} " + f"(chose from {{ {','.join(valid_reasoning_parses)} }})") + + # workaround to make sure that we bind the port before the engine is set up. + # This avoids race conditions with ray. + # see https://github.com/vllm-project/vllm/issues/8204 + sock_addr = (args.host or "", args.port) + sock = create_server_socket(sock_addr) + + # workaround to avoid footguns where uvicorn drops requests with too + # many concurrent requests active + set_ulimit() + + def signal_handler(*_) -> None: + # Interrupt server on sigterm while initializing + raise KeyboardInterrupt("terminated") + + signal.signal(signal.SIGTERM, signal_handler) + + async with build_async_engine_client(args) as engine_client: + app = build_app(args) + + model_config = await engine_client.get_model_config() + await init_app_state(engine_client, model_config, app.state, args) + + def _listen_addr(a: str) -> str: + if is_valid_ipv6_address(a): + return '[' + a + ']' + return a or "0.0.0.0" + + is_ssl = args.ssl_keyfile and args.ssl_certfile + logger.info("Starting vLLM API server on http%s://%s:%d", + "s" if is_ssl else "", _listen_addr(sock_addr[0]), + sock_addr[1]) + + shutdown_task = await serve_http( + app, + sock=sock, + enable_ssl_refresh=args.enable_ssl_refresh, + host=args.host, + port=args.port, + log_level=args.uvicorn_log_level, + # NOTE: When the 'disable_uvicorn_access_log' value is True, + # no access log will be output. + access_log=not args.disable_uvicorn_access_log, + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + **uvicorn_kwargs, + ) + + # NB: Await server shutdown only after the backend context is exited + try: + await shutdown_task + finally: + sock.close() + + +if __name__ == "__main__": + # NOTE(simon): + # This section should be in sync with vllm/entrypoints/cli/main.py for CLI + # entrypoints. + cli_env_setup() + parser = FlexibleArgumentParser( + description="vLLM OpenAI-Compatible RESTful API server.") + parser = make_arg_parser(parser) + args = parser.parse_args() + validate_parsed_serve_args(args) + + uvloop.run(run_server(args)) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py new file mode 100644 index 0000000..218a8fb --- /dev/null +++ b/vllm/entrypoints/openai/cli_args.py @@ -0,0 +1,296 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This file contains the command line arguments for the vLLM's +OpenAI-compatible server. It is kept in a separate file for documentation +purposes. +""" + +import argparse +import json +import ssl +from collections.abc import Sequence +from typing import Optional, Union, get_args + +from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str +from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, + validate_chat_template) +from vllm.entrypoints.openai.serving_models import (LoRAModulePath, + PromptAdapterPath) +from vllm.entrypoints.openai.tool_parsers import ToolParserManager +from vllm.utils import FlexibleArgumentParser + + +class LoRAParserAction(argparse.Action): + + def __call__( + self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: Optional[Union[str, Sequence[str]]], + option_string: Optional[str] = None, + ): + if values is None: + values = [] + if isinstance(values, str): + raise TypeError("Expected values to be a list") + + lora_list: list[LoRAModulePath] = [] + for item in values: + if item in [None, '']: # Skip if item is None or empty string + continue + if '=' in item and ',' not in item: # Old format: name=path + name, path = item.split('=') + lora_list.append(LoRAModulePath(name, path)) + else: # Assume JSON format + try: + lora_dict = json.loads(item) + lora = LoRAModulePath(**lora_dict) + lora_list.append(lora) + except json.JSONDecodeError: + parser.error( + f"Invalid JSON format for --lora-modules: {item}") + except TypeError as e: + parser.error( + f"Invalid fields for --lora-modules: {item} - {str(e)}" + ) + setattr(namespace, self.dest, lora_list) + + +class PromptAdapterParserAction(argparse.Action): + + def __call__( + self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: Optional[Union[str, Sequence[str]]], + option_string: Optional[str] = None, + ): + if values is None: + values = [] + if isinstance(values, str): + raise TypeError("Expected values to be a list") + + adapter_list: list[PromptAdapterPath] = [] + for item in values: + name, path = item.split('=') + adapter_list.append(PromptAdapterPath(name, path)) + setattr(namespace, self.dest, adapter_list) + + +def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + parser.add_argument("--host", + type=nullable_str, + default=None, + help="Host name.") + parser.add_argument("--port", type=int, default=8000, help="Port number.") + parser.add_argument( + "--uvicorn-log-level", + type=str, + default="info", + choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'], + help="Log level for uvicorn.") + parser.add_argument("--disable-uvicorn-access-log", + action="store_true", + help="Disable uvicorn access log.") + parser.add_argument("--allow-credentials", + action="store_true", + help="Allow credentials.") + parser.add_argument("--allowed-origins", + type=json.loads, + default=["*"], + help="Allowed origins.") + parser.add_argument("--allowed-methods", + type=json.loads, + default=["*"], + help="Allowed methods.") + parser.add_argument("--allowed-headers", + type=json.loads, + default=["*"], + help="Allowed headers.") + parser.add_argument("--api-key", + type=nullable_str, + default=None, + help="If provided, the server will require this key " + "to be presented in the header.") + parser.add_argument( + "--lora-modules", + type=nullable_str, + default=None, + nargs='+', + action=LoRAParserAction, + help="LoRA module configurations in either 'name=path' format" + "or JSON format. " + "Example (old format): ``'name=path'`` " + "Example (new format): " + "``{\"name\": \"name\", \"path\": \"lora_path\", " + "\"base_model_name\": \"id\"}``") + parser.add_argument( + "--prompt-adapters", + type=nullable_str, + default=None, + nargs='+', + action=PromptAdapterParserAction, + help="Prompt adapter configurations in the format name=path. " + "Multiple adapters can be specified.") + parser.add_argument("--chat-template", + type=nullable_str, + default=None, + help="The file path to the chat template, " + "or the template in single-line form " + "for the specified model.") + parser.add_argument( + '--chat-template-content-format', + type=str, + default="auto", + choices=get_args(ChatTemplateContentFormatOption), + help='The format to render message content within a chat template.' + '\n\n' + '* "string" will render the content as a string. ' + 'Example: ``"Hello World"``\n' + '* "openai" will render the content as a list of dictionaries, ' + 'similar to OpenAI schema. ' + 'Example: ``[{"type": "text", "text": "Hello world!"}]``') + parser.add_argument("--response-role", + type=nullable_str, + default="assistant", + help="The role name to return if " + "``request.add_generation_prompt=true``.") + parser.add_argument("--ssl-keyfile", + type=nullable_str, + default=None, + help="The file path to the SSL key file.") + parser.add_argument("--ssl-certfile", + type=nullable_str, + default=None, + help="The file path to the SSL cert file.") + parser.add_argument("--ssl-ca-certs", + type=nullable_str, + default=None, + help="The CA certificates file.") + parser.add_argument( + "--enable-ssl-refresh", + action="store_true", + default=False, + help="Refresh SSL Context when SSL certificate files change") + parser.add_argument( + "--ssl-cert-reqs", + type=int, + default=int(ssl.CERT_NONE), + help="Whether client certificate is required (see stdlib ssl module's)." + ) + parser.add_argument( + "--root-path", + type=nullable_str, + default=None, + help="FastAPI root_path when app is behind a path based routing proxy." + ) + parser.add_argument( + "--middleware", + type=nullable_str, + action="append", + default=[], + help="Additional ASGI middleware to apply to the app. " + "We accept multiple --middleware arguments. " + "The value should be an import path. " + "If a function is provided, vLLM will add it to the server " + "using ``@app.middleware('http')``. " + "If a class is provided, vLLM will add it to the server " + "using ``app.add_middleware()``. ") + parser.add_argument( + "--return-tokens-as-token-ids", + action="store_true", + help="When ``--max-logprobs`` is specified, represents single tokens " + " as strings of the form 'token_id:{token_id}' so that tokens " + "that are not JSON-encodable can be identified.") + parser.add_argument( + "--disable-frontend-multiprocessing", + action="store_true", + help="If specified, will run the OpenAI frontend server in the same " + "process as the model serving engine.") + parser.add_argument( + "--enable-request-id-headers", + action="store_true", + help="If specified, API server will add X-Request-Id header to " + "responses. Caution: this hurts performance at high QPS.") + parser.add_argument( + "--enable-auto-tool-choice", + action="store_true", + default=False, + help="Enable auto tool choice for supported models. Use " + "``--tool-call-parser`` to specify which parser to use.") + + valid_tool_parsers = ToolParserManager.tool_parsers.keys() + parser.add_argument( + "--tool-call-parser", + type=str, + metavar="{" + ",".join(valid_tool_parsers) + "} or name registered in " + "--tool-parser-plugin", + default=None, + help= + "Select the tool call parser depending on the model that you're using." + " This is used to parse the model-generated tool call into OpenAI API " + "format. Required for ``--enable-auto-tool-choice``.") + + parser.add_argument( + "--tool-parser-plugin", + type=str, + default="", + help= + "Special the tool parser plugin write to parse the model-generated tool" + " into OpenAI API format, the name register in this plugin can be used " + "in ``--tool-call-parser``.") + + parser = AsyncEngineArgs.add_cli_args(parser) + + parser.add_argument('--max-log-len', + type=int, + default=None, + help='Max number of prompt characters or prompt ' + 'ID numbers being printed in log.' + ' The default of None means unlimited.') + + parser.add_argument( + "--disable-fastapi-docs", + action='store_true', + default=False, + help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint." + ) + parser.add_argument( + "--enable-prompt-tokens-details", + action='store_true', + default=False, + help="If set to True, enable prompt_tokens_details in usage.") + parser.add_argument( + "--enable-server-load-tracking", + action='store_true', + default=False, + help= + "If set to True, enable tracking server_load_metrics in the app state." + ) + + return parser + + +def validate_parsed_serve_args(args: argparse.Namespace): + """Quick checks for model serve args that raise prior to loading.""" + if hasattr(args, "subparser") and args.subparser != "serve": + return + + # Ensure that the chat template is valid; raises if it likely isn't + validate_chat_template(args.chat_template) + + # Enable auto tool needs a tool call parser to be valid + if args.enable_auto_tool_choice and not args.tool_call_parser: + raise TypeError("Error: --enable-auto-tool-choice requires " + "--tool-call-parser") + + # Enable reasoning needs a reasoning parser to be valid + if args.enable_reasoning and not args.reasoning_parser: + raise TypeError("Error: --enable-reasoning requires " + "--reasoning-parser") + + +def create_parser_for_docs() -> FlexibleArgumentParser: + parser_for_docs = FlexibleArgumentParser( + prog="-m vllm.entrypoints.openai.api_server") + return make_arg_parser(parser_for_docs) diff --git a/vllm/entrypoints/openai/logits_processors.py b/vllm/entrypoints/openai/logits_processors.py new file mode 100644 index 0000000..04d5091 --- /dev/null +++ b/vllm/entrypoints/openai/logits_processors.py @@ -0,0 +1,89 @@ +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Iterable +from functools import lru_cache, partial +from typing import Optional, Union + +import torch + +from vllm.sampling_params import LogitsProcessor +from vllm.transformers_utils.tokenizer import AnyTokenizer + + +class AllowedTokenIdsLogitsProcessor: + """Logits processor for constraining generated tokens to a + specific set of token ids.""" + + def __init__(self, allowed_ids: Iterable[int]): + self.allowed_ids: Optional[list[int]] = list(allowed_ids) + self.mask: Optional[torch.Tensor] = None + + def __call__(self, token_ids: list[int], + logits: torch.Tensor) -> torch.Tensor: + if self.mask is None: + self.mask = torch.ones((logits.shape[-1], ), + dtype=torch.bool, + device=logits.device) + self.mask[self.allowed_ids] = False + self.allowed_ids = None + logits.masked_fill_(self.mask, float("-inf")) + return logits + + +@lru_cache(maxsize=32) +def _get_allowed_token_ids_logits_processor( + allowed_token_ids: frozenset[int], + vocab_size: int, +) -> LogitsProcessor: + if not allowed_token_ids: + raise ValueError("Empty allowed_token_ids provided") + if not all(0 <= tid < vocab_size for tid in allowed_token_ids): + raise ValueError("allowed_token_ids contains " + "out-of-vocab token id") + return AllowedTokenIdsLogitsProcessor(allowed_token_ids) + + +def logit_bias_logits_processor( + logit_bias: dict[int, float], + token_ids: list[int], + logits: torch.Tensor, +) -> torch.Tensor: + for token_id, bias in logit_bias.items(): + logits[token_id] += bias + return logits + + +def get_logits_processors( + logit_bias: Optional[Union[dict[int, float], dict[str, float]]], + allowed_token_ids: Optional[list[int]], + tokenizer: AnyTokenizer, +) -> list[LogitsProcessor]: + logits_processors: list[LogitsProcessor] = [] + if logit_bias: + try: + # Convert token_id to integer + # Clamp the bias between -100 and 100 per OpenAI API spec + clamped_logit_bias: dict[int, float] = { + int(token_id): min(100.0, max(-100.0, bias)) + for token_id, bias in logit_bias.items() + } + except ValueError as exc: + raise ValueError( + "Found token_id in logit_bias that is not " + "an integer or string representing an integer") from exc + + # Check if token_id is within the vocab size + for token_id, bias in clamped_logit_bias.items(): + if token_id < 0 or token_id >= len(tokenizer): + raise ValueError(f"token_id {token_id} in logit_bias contains " + "out-of-vocab token id") + + logits_processors.append( + partial(logit_bias_logits_processor, clamped_logit_bias)) + + if allowed_token_ids is not None: + logits_processors.append( + _get_allowed_token_ids_logits_processor( + frozenset(allowed_token_ids), len(tokenizer))) + + return logits_processors diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py new file mode 100644 index 0000000..7cbd9d7 --- /dev/null +++ b/vllm/entrypoints/openai/protocol.py @@ -0,0 +1,1718 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py +import re +import time +from argparse import Namespace +from typing import Annotated, Any, ClassVar, Literal, Optional, Union + +import torch +from fastapi import UploadFile +from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter, + ValidationInfo, field_validator, model_validator) +from typing_extensions import TypeAlias + +from vllm.entrypoints.chat_utils import ChatCompletionMessageParam +from vllm.logger import init_logger +from vllm.pooling_params import PoolingParams +from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, + RequestOutputKind, SamplingParams) +from vllm.sequence import Logprob +from vllm.utils import random_uuid, resolve_obj_by_qualname + +logger = init_logger(__name__) + +# torch is mocked during docs generation, +# so we have to provide the values as literals +_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807) +_LONG_INFO: Union["torch.iinfo", Namespace] + +try: + from sphinx.ext.autodoc.mock import _MockModule + + if isinstance(torch, _MockModule): + _LONG_INFO = _MOCK_LONG_INFO + else: + _LONG_INFO = torch.iinfo(torch.long) +except ModuleNotFoundError: + _LONG_INFO = torch.iinfo(torch.long) + +assert _LONG_INFO.min == _MOCK_LONG_INFO.min +assert _LONG_INFO.max == _MOCK_LONG_INFO.max + + +class OpenAIBaseModel(BaseModel): + # OpenAI API does allow extra fields + model_config = ConfigDict(extra="allow") + + # Cache class field names + field_names: ClassVar[Optional[set[str]]] = None + + @model_validator(mode="wrap") + @classmethod + def __log_extra_fields__(cls, data, handler): + result = handler(data) + if not isinstance(data, dict): + return result + field_names = cls.field_names + if field_names is None: + # Get all class field names and their potential aliases + field_names = set() + for field_name, field in cls.model_fields.items(): + field_names.add(field_name) + if alias := getattr(field, "alias", None): + field_names.add(alias) + cls.field_names = field_names + + # Compare against both field names and aliases + if any(k not in field_names for k in data): + logger.warning( + "The following fields were present in the request " + "but ignored: %s", + data.keys() - field_names, + ) + return result + + +class ErrorResponse(OpenAIBaseModel): + object: str = "error" + message: str + type: str + param: Optional[str] = None + code: int + + +class ModelPermission(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}") + object: str = "model_permission" + created: int = Field(default_factory=lambda: int(time.time())) + allow_create_engine: bool = False + allow_sampling: bool = True + allow_logprobs: bool = True + allow_search_indices: bool = False + allow_view: bool = True + allow_fine_tuning: bool = False + organization: str = "*" + group: Optional[str] = None + is_blocking: bool = False + + +class ModelCard(OpenAIBaseModel): + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "vllm" + root: Optional[str] = None + parent: Optional[str] = None + max_model_len: Optional[int] = None + permission: list[ModelPermission] = Field(default_factory=list) + + +class ModelList(OpenAIBaseModel): + object: str = "list" + data: list[ModelCard] = Field(default_factory=list) + + +class PromptTokenUsageInfo(OpenAIBaseModel): + cached_tokens: Optional[int] = None + + +class UsageInfo(OpenAIBaseModel): + prompt_tokens: int = 0 + total_tokens: int = 0 + completion_tokens: Optional[int] = 0 + prompt_tokens_details: Optional[PromptTokenUsageInfo] = None + + +class RequestResponseMetadata(BaseModel): + request_id: str + final_usage_info: Optional[UsageInfo] = None + + +class JsonSchemaResponseFormat(OpenAIBaseModel): + name: str + description: Optional[str] = None + # schema is the field in openai but that causes conflicts with pydantic so + # instead use json_schema with an alias + json_schema: Optional[dict[str, Any]] = Field(default=None, alias='schema') + strict: Optional[bool] = None + + +class ResponseFormat(OpenAIBaseModel): + # type must be "json_schema", "json_object" or "text" + type: Literal["text", "json_object", "json_schema"] + json_schema: Optional[JsonSchemaResponseFormat] = None + + +class StreamOptions(OpenAIBaseModel): + include_usage: Optional[bool] = True + continuous_usage_stats: Optional[bool] = False + + +class FunctionDefinition(OpenAIBaseModel): + name: str + description: Optional[str] = None + parameters: Optional[dict[str, Any]] = None + + +class ChatCompletionToolsParam(OpenAIBaseModel): + type: Literal["function"] = "function" + function: FunctionDefinition + + +class ChatCompletionNamedFunction(OpenAIBaseModel): + name: str + + +class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel): + function: ChatCompletionNamedFunction + type: Literal["function"] = "function" + + +class LogitsProcessorConstructor(BaseModel): + qualname: str + args: Optional[list[Any]] = None + kwargs: Optional[dict[str, Any]] = None + + +LogitsProcessors = list[Union[str, LogitsProcessorConstructor]] + + +def get_logits_processors(processors: Optional[LogitsProcessors], + pattern: Optional[str]) -> Optional[list[Any]]: + if processors and pattern: + logits_processors = [] + for processor in processors: + qualname = processor if isinstance(processor, + str) else processor.qualname + if not re.match(pattern, qualname): + raise ValueError( + f"Logits processor '{qualname}' is not allowed by this " + "server. See --logits-processor-pattern engine argument " + "for more information.") + try: + logits_processor = resolve_obj_by_qualname(qualname) + except Exception as e: + raise ValueError( + f"Logits processor '{qualname}' could not be resolved: {e}" + ) from e + if isinstance(processor, LogitsProcessorConstructor): + logits_processor = logits_processor(*processor.args or [], + **processor.kwargs or {}) + logits_processors.append(logits_processor) + return logits_processors + elif processors: + raise ValueError( + "The `logits_processors` argument is not supported by this " + "server. See --logits-processor-pattern engine argugment " + "for more information.") + return None + + +class ChatCompletionRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/chat/create + messages: list[ChatCompletionMessageParam] + model: Optional[str] = None + frequency_penalty: Optional[float] = 0.0 + logit_bias: Optional[dict[str, float]] = None + logprobs: Optional[bool] = False + top_logprobs: Optional[int] = 0 + # TODO(#9845): remove max_tokens when field is removed from OpenAI API + max_tokens: Optional[int] = Field( + default=None, + deprecated= + 'max_tokens is deprecated in favor of the max_completion_tokens field') + max_completion_tokens: Optional[int] = None + n: Optional[int] = 1 + presence_penalty: Optional[float] = 0.0 + response_format: Optional[ResponseFormat] = None + seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) + stop: Optional[Union[str, list[str]]] = Field(default_factory=list) + stream: Optional[bool] = False + stream_options: Optional[StreamOptions] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + tools: Optional[list[ChatCompletionToolsParam]] = None + tool_choice: Optional[Union[ + Literal["none"], + Literal["auto"], + Literal["required"], + ChatCompletionNamedToolChoiceParam, + ]] = "none" + + # NOTE this will be ignored by vLLM -- the model determines the behavior + parallel_tool_calls: Optional[bool] = False + user: Optional[str] = None + + # doc: begin-chat-completion-sampling-params + best_of: Optional[int] = None + use_beam_search: bool = False + top_k: Optional[int] = None + min_p: Optional[float] = None + repetition_penalty: Optional[float] = None + length_penalty: float = 1.0 + stop_token_ids: Optional[list[int]] = Field(default_factory=list) + include_stop_str_in_output: bool = False + ignore_eos: bool = False + min_tokens: int = 0 + skip_special_tokens: bool = True + spaces_between_special_tokens: bool = True + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + prompt_logprobs: Optional[int] = None + # doc: end-chat-completion-sampling-params + + # doc: begin-chat-completion-extra-params + echo: bool = Field( + default=False, + description=( + "If true, the new message will be prepended with the last message " + "if they belong to the same role."), + ) + add_generation_prompt: bool = Field( + default=True, + description= + ("If true, the generation prompt will be added to the chat template. " + "This is a parameter used by chat template in tokenizer config of the " + "model."), + ) + continue_final_message: bool = Field( + default=False, + description= + ("If this is set, the chat will be formatted so that the final " + "message in the chat is open-ended, without any EOS tokens. The " + "model will continue this message rather than starting a new one. " + "This allows you to \"prefill\" part of the model's response for it. " + "Cannot be used at the same time as `add_generation_prompt`."), + ) + add_special_tokens: bool = Field( + default=False, + description=( + "If true, special tokens (e.g. BOS) will be added to the prompt " + "on top of what is added by the chat template. " + "For most models, the chat template takes care of adding the " + "special tokens so this should be set to false (as is the " + "default)."), + ) + documents: Optional[list[dict[str, str]]] = Field( + default=None, + description= + ("A list of dicts representing documents that will be accessible to " + "the model if it is performing RAG (retrieval-augmented generation)." + " If the template does not support RAG, this argument will have no " + "effect. We recommend that each document should be a dict containing " + "\"title\" and \"text\" keys."), + ) + chat_template: Optional[str] = Field( + default=None, + description=( + "A Jinja template to use for this conversion. " + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the tokenizer " + "does not define one."), + ) + chat_template_kwargs: Optional[dict[str, Any]] = Field( + default=None, + description=("Additional kwargs to pass to the template renderer. " + "Will be accessible by the chat template."), + ) + mm_processor_kwargs: Optional[dict[str, Any]] = Field( + default=None, + description=("Additional kwargs to pass to the HF processor."), + ) + guided_json: Optional[Union[str, dict, BaseModel]] = Field( + default=None, + description=("If specified, the output will follow the JSON schema."), + ) + guided_regex: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the regex pattern."), + ) + guided_choice: Optional[list[str]] = Field( + default=None, + description=( + "If specified, the output will be exactly one of the choices."), + ) + guided_grammar: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the context free grammar."), + ) + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be either " + "'outlines' / 'lm-format-enforcer'"), + ) + guided_whitespace_pattern: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding."), + ) + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling."), + ) + request_id: str = Field( + default_factory=lambda: f"{random_uuid()}", + description=( + "The request_id related to this request. If the caller does " + "not set it, a random_uuid will be generated. This id is used " + "through out the inference process and return in response."), + ) + logits_processors: Optional[LogitsProcessors] = Field( + default=None, + description=( + "A list of either qualified names of logits processors, or " + "constructor objects, to apply when sampling. A constructor is " + "a JSON object with a required 'qualname' field specifying the " + "qualified name of the processor class/factory, and optional " + "'args' and 'kwargs' fields containing positional and keyword " + "arguments. For example: {'qualname': " + "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': " + "{'param': 'value'}}.")) + return_tokens_as_token_ids: Optional[bool] = Field( + default=None, + description=( + "If specified with 'logprobs', tokens are represented " + " as strings of the form 'token_id:{token_id}' so that tokens " + "that are not JSON-encodable can be identified.")) + + # doc: end-chat-completion-extra-params + + # Default sampling parameters for chat completion requests + _DEFAULT_SAMPLING_PARAMS: dict = { + "repetition_penalty": 1.0, + "temperature": 1.0, + "top_p": 1.0, + "top_k": -1, + "min_p": 0.0, + } + + def to_beam_search_params( + self, + default_max_tokens: int, + default_sampling_params: Optional[dict] = None + ) -> BeamSearchParams: + # TODO(#9845): remove max_tokens when field is removed from OpenAI API + max_tokens = self.max_completion_tokens or self.max_tokens + + if default_sampling_params is None: + default_sampling_params = {} + n = self.n if self.n is not None else 1 + + # Use minimum of context window, user request & server limit. + max_tokens = min( + val for val in (default_max_tokens, max_tokens, + default_sampling_params.get("max_tokens", None)) + if val is not None) + + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get( + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + + return BeamSearchParams( + beam_width=n, + max_tokens=max_tokens, + ignore_eos=self.ignore_eos, + temperature=temperature, + length_penalty=self.length_penalty, + include_stop_str_in_output=self.include_stop_str_in_output, + ) + + def to_sampling_params( + self, + default_max_tokens: int, + logits_processor_pattern: Optional[str], + default_sampling_params: Optional[dict] = None, + ) -> SamplingParams: + # TODO(#9845): remove max_tokens when field is removed from OpenAI API + max_tokens = self.max_completion_tokens or self.max_tokens + + if default_sampling_params is None: + default_sampling_params = {} + + # Use minimum of context window, user request & server limit. + max_tokens = min( + val for val in (default_max_tokens, max_tokens, + default_sampling_params.get("max_tokens", None)) + if val is not None) + + # Default parameters + if (repetition_penalty := self.repetition_penalty) is None: + repetition_penalty = default_sampling_params.get( + "repetition_penalty", + self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"], + ) + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get( + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + if (top_p := self.top_p) is None: + top_p = default_sampling_params.get( + "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) + if (top_k := self.top_k) is None: + top_k = default_sampling_params.get( + "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]) + if (min_p := self.min_p) is None: + min_p = default_sampling_params.get( + "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]) + + prompt_logprobs = self.prompt_logprobs + if prompt_logprobs is None and self.echo: + prompt_logprobs = self.top_logprobs + + guided_json_object = None + if self.response_format is not None: + if self.response_format.type == "json_object": + guided_json_object = True + elif self.response_format.type == "json_schema": + json_schema = self.response_format.json_schema + assert json_schema is not None + self.guided_json = json_schema.json_schema + if self.guided_decoding_backend is None: + self.guided_decoding_backend = "xgrammar" + + guided_decoding = GuidedDecodingParams.from_optional( + json=self._get_guided_json_from_tool() or self.guided_json, + regex=self.guided_regex, + choice=self.guided_choice, + grammar=self.guided_grammar, + json_object=guided_json_object, + backend=self.guided_decoding_backend, + whitespace_pattern=self.guided_whitespace_pattern, + ) + + return SamplingParams.from_optional( + n=self.n, + best_of=self.best_of, + presence_penalty=self.presence_penalty, + frequency_penalty=self.frequency_penalty, + repetition_penalty=repetition_penalty, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + seed=self.seed, + stop=self.stop, + stop_token_ids=self.stop_token_ids, + logprobs=self.top_logprobs if self.logprobs else None, + prompt_logprobs=prompt_logprobs, + ignore_eos=self.ignore_eos, + max_tokens=max_tokens, + min_tokens=self.min_tokens, + skip_special_tokens=self.skip_special_tokens, + spaces_between_special_tokens=self.spaces_between_special_tokens, + logits_processors=get_logits_processors(self.logits_processors, + logits_processor_pattern), + include_stop_str_in_output=self.include_stop_str_in_output, + truncate_prompt_tokens=self.truncate_prompt_tokens, + output_kind=RequestOutputKind.DELTA if self.stream \ + else RequestOutputKind.FINAL_ONLY, + guided_decoding=guided_decoding, + logit_bias=self.logit_bias) + + def _get_guided_json_from_tool( + self) -> Optional[Union[str, dict, BaseModel]]: + # user has chosen to not use any tool + if self.tool_choice == "none" or self.tools is None: + return None + + # user has chosen to use a named tool + if type(self.tool_choice) is ChatCompletionNamedToolChoiceParam: + tool_name = self.tool_choice.function.name + tools = {tool.function.name: tool.function for tool in self.tools} + if tool_name not in tools: + raise ValueError( + f"Tool '{tool_name}' has not been passed in `tools`.") + tool = tools[tool_name] + return tool.parameters + + if self.tool_choice == "required": + # Pydantic schema generation cannot be used since the JSON schema + # has to be constructed for a specific instantiation of a tool list + # so that parameters of a function are correctly generated + # based on the chosen function name + def get_tool_schema(tool: ChatCompletionToolsParam) -> dict: + return { + "properties": { + "name": { + "type": "string", + "enum": [tool.function.name] + }, + # parameters are always generated as '{}' in the final + # output if they are missing from the request + # (i.e. are None or '{}') so the schema is + # updated to produce an empty object in that case + "parameters": tool.function.parameters + if tool.function.parameters else { + "type": "object", + "properties": {} + } + }, + "required": ["name", "parameters"] + } + + json_schema = { + "type": "array", + "minItems": 1, + "items": { + "type": "object", + "anyOf": [get_tool_schema(tool) for tool in self.tools] + } + } + return json_schema + + return None + + @model_validator(mode="before") + @classmethod + def validate_stream_options(cls, data): + if data.get("stream_options") and not data.get("stream"): + raise ValueError( + "Stream options can only be defined when `stream=True`.") + + return data + + @model_validator(mode="before") + @classmethod + def check_logprobs(cls, data): + if (prompt_logprobs := data.get("prompt_logprobs")) is not None: + if data.get("stream") and prompt_logprobs > 0: + raise ValueError( + "`prompt_logprobs` are not available when `stream=True`.") + + if prompt_logprobs < 0: + raise ValueError("`prompt_logprobs` must be a positive value.") + + if (top_logprobs := data.get("top_logprobs")) is not None: + if top_logprobs < 0: + raise ValueError("`top_logprobs` must be a positive value.") + + if top_logprobs > 0 and not data.get("logprobs"): + raise ValueError( + "when using `top_logprobs`, `logprobs` must be set to true." + ) + + return data + + @model_validator(mode="before") + @classmethod + def check_guided_decoding_count(cls, data): + if isinstance(data, ValueError): + raise data + + guide_count = sum([ + "guided_json" in data and data["guided_json"] is not None, + "guided_regex" in data and data["guided_regex"] is not None, + "guided_choice" in data and data["guided_choice"] is not None + ]) + # you can only use one kind of guided decoding + if guide_count > 1: + raise ValueError( + "You can only use one kind of guided decoding " + "('guided_json', 'guided_regex' or 'guided_choice').") + # you can only either use guided decoding or tools, not both + if guide_count > 1 and data.get("tool_choice", "none") not in ( + "none", + "auto", + "required", + ): + raise ValueError( + "You can only either use guided decoding or tools, not both.") + return data + + @model_validator(mode="before") + @classmethod + def check_tool_usage(cls, data): + + # if "tool_choice" is not specified but tools are provided, + # default to "auto" tool_choice + if "tool_choice" not in data and data.get("tools"): + data["tool_choice"] = "auto" + + # if "tool_choice" is "none" -- ignore tools if present + if "tool_choice" in data and data["tool_choice"] == "none": + # ensure that no tools are present + data.pop("tools", None) + return data + + # if "tool_choice" is specified -- validation + if "tool_choice" in data: + + # ensure that if "tool choice" is specified, tools are present + if "tools" not in data or data["tools"] is None: + raise ValueError( + "When using `tool_choice`, `tools` must be set.") + + # make sure that tool choice is either a named tool + # OR that it's set to "auto" or "required" + if data["tool_choice"] not in [ + "auto", "required" + ] and not isinstance(data["tool_choice"], dict): + raise NotImplementedError( + f'Invalid value for `tool_choice`: {data["tool_choice"]}! '\ + 'Only named tools, "none", "auto" or "required" '\ + 'are supported.' + ) + + # ensure that if "tool_choice" is specified as an object, + # it matches a valid tool + if isinstance(data["tool_choice"], dict): + valid_tool = False + specified_function = data["tool_choice"].get("function") + if not specified_function: + raise ValueError( + "Expected field `function` in `tool_choice`." + " Correct usage: `{\"type\": \"function\"," + " \"function\": {\"name\": \"my_function\"}}`") + specified_function_name = specified_function.get("name") + if not specified_function_name: + raise ValueError( + "Expected field `name` in `function` in `tool_choice`." + "Correct usage: `{\"type\": \"function\", " + "\"function\": {\"name\": \"my_function\"}}`") + for tool in data["tools"]: + if tool["function"]["name"] == specified_function_name: + valid_tool = True + break + if not valid_tool: + raise ValueError( + "The tool specified in `tool_choice` does not match any" + " of the specified `tools`") + return data + + @model_validator(mode="before") + @classmethod + def check_generation_prompt(cls, data): + if data.get("continue_final_message") and data.get( + "add_generation_prompt"): + raise ValueError("Cannot set both `continue_final_message` and " + "`add_generation_prompt` to True.") + return data + + +class CompletionRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/completions/create + model: Optional[str] = None + prompt: Union[list[int], list[list[int]], str, list[str]] + best_of: Optional[int] = None + echo: Optional[bool] = False + frequency_penalty: Optional[float] = 0.0 + logit_bias: Optional[dict[str, float]] = None + logprobs: Optional[int] = None + max_tokens: Optional[int] = 16 + n: int = 1 + presence_penalty: Optional[float] = 0.0 + seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) + stop: Optional[Union[str, list[str]]] = Field(default_factory=list) + stream: Optional[bool] = False + stream_options: Optional[StreamOptions] = None + suffix: Optional[str] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + user: Optional[str] = None + + # doc: begin-completion-sampling-params + use_beam_search: bool = False + top_k: Optional[int] = None + min_p: Optional[float] = None + repetition_penalty: Optional[float] = None + length_penalty: float = 1.0 + stop_token_ids: Optional[list[int]] = Field(default_factory=list) + include_stop_str_in_output: bool = False + ignore_eos: bool = False + min_tokens: int = 0 + skip_special_tokens: bool = True + spaces_between_special_tokens: bool = True + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + allowed_token_ids: Optional[list[int]] = None + prompt_logprobs: Optional[int] = None + # doc: end-completion-sampling-params + + # doc: begin-completion-extra-params + add_special_tokens: bool = Field( + default=True, + description=( + "If true (the default), special tokens (e.g. BOS) will be added to " + "the prompt."), + ) + response_format: Optional[ResponseFormat] = Field( + default=None, + description= + ("Similar to chat completion, this parameter specifies the format of " + "output. Only {'type': 'json_object'}, {'type': 'json_schema'} or " + "{'type': 'text' } is supported."), + ) + guided_json: Optional[Union[str, dict, BaseModel]] = Field( + default=None, + description="If specified, the output will follow the JSON schema.", + ) + guided_regex: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the regex pattern."), + ) + guided_choice: Optional[list[str]] = Field( + default=None, + description=( + "If specified, the output will be exactly one of the choices."), + ) + guided_grammar: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the context free grammar."), + ) + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be one of " + "'outlines' / 'lm-format-enforcer'"), + ) + guided_whitespace_pattern: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding."), + ) + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling."), + ) + logits_processors: Optional[LogitsProcessors] = Field( + default=None, + description=( + "A list of either qualified names of logits processors, or " + "constructor objects, to apply when sampling. A constructor is " + "a JSON object with a required 'qualname' field specifying the " + "qualified name of the processor class/factory, and optional " + "'args' and 'kwargs' fields containing positional and keyword " + "arguments. For example: {'qualname': " + "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': " + "{'param': 'value'}}.")) + + return_tokens_as_token_ids: Optional[bool] = Field( + default=None, + description=( + "If specified with 'logprobs', tokens are represented " + " as strings of the form 'token_id:{token_id}' so that tokens " + "that are not JSON-encodable can be identified.")) + + # doc: end-completion-extra-params + + # Default sampling parameters for completion requests + _DEFAULT_SAMPLING_PARAMS: dict = { + "repetition_penalty": 1.0, + "temperature": 1.0, + "top_p": 1.0, + "top_k": -1, + "min_p": 0.0, + } + + def to_beam_search_params( + self, + default_max_tokens: int, + default_sampling_params: Optional[dict] = None + ) -> BeamSearchParams: + max_tokens = self.max_tokens + + if default_sampling_params is None: + default_sampling_params = {} + n = self.n if self.n is not None else 1 + + # Use minimum of context window, user request & server limit. + max_tokens = min( + val for val in (default_max_tokens, max_tokens, + default_sampling_params.get("max_tokens", None)) + if val is not None) + + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get("temperature", 1.0) + + return BeamSearchParams( + beam_width=n, + max_tokens=max_tokens, + ignore_eos=self.ignore_eos, + temperature=temperature, + length_penalty=self.length_penalty, + include_stop_str_in_output=self.include_stop_str_in_output, + ) + + def to_sampling_params( + self, + default_max_tokens: int, + logits_processor_pattern: Optional[str], + default_sampling_params: Optional[dict] = None, + ) -> SamplingParams: + max_tokens = self.max_tokens + + if default_sampling_params is None: + default_sampling_params = {} + + # Use minimum of context window, user request & server limit. + max_tokens = min( + val for val in (default_max_tokens, max_tokens, + default_sampling_params.get("max_tokens", None)) + if val is not None) + + # Default parameters + if (repetition_penalty := self.repetition_penalty) is None: + repetition_penalty = default_sampling_params.get( + "repetition_penalty", + self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"], + ) + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get( + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + if (top_p := self.top_p) is None: + top_p = default_sampling_params.get( + "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) + if (top_k := self.top_k) is None: + top_k = default_sampling_params.get( + "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]) + if (min_p := self.min_p) is None: + min_p = default_sampling_params.get( + "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]) + + prompt_logprobs = self.prompt_logprobs + if prompt_logprobs is None and self.echo: + prompt_logprobs = self.logprobs + + echo_without_generation = self.echo and self.max_tokens == 0 + + guided_json_object = None + if (self.response_format is not None + and self.response_format.type == "json_object"): + guided_json_object = True + + guided_decoding = GuidedDecodingParams.from_optional( + json=self.guided_json, + regex=self.guided_regex, + choice=self.guided_choice, + grammar=self.guided_grammar, + json_object=guided_json_object, + backend=self.guided_decoding_backend, + whitespace_pattern=self.guided_whitespace_pattern, + ) + + return SamplingParams.from_optional( + n=self.n, + best_of=self.best_of, + presence_penalty=self.presence_penalty, + frequency_penalty=self.frequency_penalty, + repetition_penalty=repetition_penalty, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + seed=self.seed, + stop=self.stop, + stop_token_ids=self.stop_token_ids, + logprobs=self.logprobs, + ignore_eos=self.ignore_eos, + max_tokens=max_tokens if not echo_without_generation else 1, + min_tokens=self.min_tokens, + prompt_logprobs=prompt_logprobs, + skip_special_tokens=self.skip_special_tokens, + spaces_between_special_tokens=self.spaces_between_special_tokens, + include_stop_str_in_output=self.include_stop_str_in_output, + logits_processors=get_logits_processors(self.logits_processors, + logits_processor_pattern), + truncate_prompt_tokens=self.truncate_prompt_tokens, + output_kind=RequestOutputKind.DELTA if self.stream \ + else RequestOutputKind.FINAL_ONLY, + guided_decoding=guided_decoding, + logit_bias=self.logit_bias, + allowed_token_ids=self.allowed_token_ids) + + @model_validator(mode="before") + @classmethod + def check_guided_decoding_count(cls, data): + guide_count = sum([ + "guided_json" in data and data["guided_json"] is not None, + "guided_regex" in data and data["guided_regex"] is not None, + "guided_choice" in data and data["guided_choice"] is not None + ]) + if guide_count > 1: + raise ValueError( + "You can only use one kind of guided decoding " + "('guided_json', 'guided_regex' or 'guided_choice').") + return data + + @model_validator(mode="before") + @classmethod + def check_logprobs(cls, data): + if (prompt_logprobs := data.get("prompt_logprobs")) is not None: + if data.get("stream") and prompt_logprobs > 0: + raise ValueError( + "`prompt_logprobs` are not available when `stream=True`.") + + if prompt_logprobs < 0: + raise ValueError("`prompt_logprobs` must be a positive value.") + + if (logprobs := data.get("logprobs")) is not None and logprobs < 0: + raise ValueError("`logprobs` must be a positive value.") + + return data + + @model_validator(mode="before") + @classmethod + def validate_stream_options(cls, data): + if data.get("stream_options") and not data.get("stream"): + raise ValueError( + "Stream options can only be defined when `stream=True`.") + + return data + + +class EmbeddingCompletionRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/embeddings + model: Optional[str] = None + input: Union[list[int], list[list[int]], str, list[str]] + encoding_format: Literal["float", "base64"] = "float" + dimensions: Optional[int] = None + user: Optional[str] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + + # doc: begin-embedding-pooling-params + additional_data: Optional[Any] = None + # doc: end-embedding-pooling-params + + # doc: begin-embedding-extra-params + add_special_tokens: bool = Field( + default=True, + description=( + "If true (the default), special tokens (e.g. BOS) will be added to " + "the prompt."), + ) + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling."), + ) + + # doc: end-embedding-extra-params + + def to_pooling_params(self): + return PoolingParams(additional_data=self.additional_data) + + +class EmbeddingChatRequest(OpenAIBaseModel): + model: Optional[str] = None + messages: list[ChatCompletionMessageParam] + + encoding_format: Literal["float", "base64"] = "float" + dimensions: Optional[int] = None + user: Optional[str] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + + # doc: begin-chat-embedding-pooling-params + additional_data: Optional[Any] = None + # doc: end-chat-embedding-pooling-params + + # doc: begin-chat-embedding-extra-params + add_special_tokens: bool = Field( + default=False, + description=( + "If true, special tokens (e.g. BOS) will be added to the prompt " + "on top of what is added by the chat template. " + "For most models, the chat template takes care of adding the " + "special tokens so this should be set to false (as is the " + "default)."), + ) + chat_template: Optional[str] = Field( + default=None, + description=( + "A Jinja template to use for this conversion. " + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the tokenizer " + "does not define one."), + ) + chat_template_kwargs: Optional[dict[str, Any]] = Field( + default=None, + description=("Additional kwargs to pass to the template renderer. " + "Will be accessible by the chat template."), + ) + mm_processor_kwargs: Optional[dict[str, Any]] = Field( + default=None, + description=("Additional kwargs to pass to the HF processor."), + ) + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling."), + ) + # doc: end-chat-embedding-extra-params + + @model_validator(mode="before") + @classmethod + def check_generation_prompt(cls, data): + if data.get("continue_final_message") and data.get( + "add_generation_prompt"): + raise ValueError("Cannot set both `continue_final_message` and " + "`add_generation_prompt` to True.") + return data + + def to_pooling_params(self): + return PoolingParams(additional_data=self.additional_data) + + +EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest] + +PoolingCompletionRequest = EmbeddingCompletionRequest +PoolingChatRequest = EmbeddingChatRequest +PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest] + + +class ScoreRequest(OpenAIBaseModel): + model: Optional[str] = None + text_1: Union[list[str], str] + text_2: Union[list[str], str] + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + + # doc: begin-score-pooling-params + additional_data: Optional[Any] = None + # doc: end-score-pooling-params + + # doc: begin-score-extra-params + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling."), + ) + + # doc: end-score-extra-params + + def to_pooling_params(self): + return PoolingParams(additional_data=self.additional_data) + + +class RerankRequest(OpenAIBaseModel): + model: Optional[str] = None + query: str + documents: list[str] + top_n: int = Field(default_factory=lambda: 0) + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + + # doc: begin-rerank-pooling-params + additional_data: Optional[Any] = None + # doc: end-rerank-pooling-params + + # doc: begin-rerank-extra-params + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling."), + ) + + # doc: end-rerank-extra-params + + def to_pooling_params(self): + return PoolingParams(additional_data=self.additional_data) + + +class RerankDocument(BaseModel): + text: str + + +class RerankResult(BaseModel): + index: int + document: RerankDocument + relevance_score: float + + +class RerankUsage(BaseModel): + total_tokens: int + + +class RerankResponse(OpenAIBaseModel): + id: str + model: str + usage: RerankUsage + results: list[RerankResult] + + +class CompletionLogProbs(OpenAIBaseModel): + text_offset: list[int] = Field(default_factory=list) + token_logprobs: list[Optional[float]] = Field(default_factory=list) + tokens: list[str] = Field(default_factory=list) + top_logprobs: list[Optional[dict[str, + float]]] = Field(default_factory=list) + + +class CompletionResponseChoice(OpenAIBaseModel): + index: int + text: str + logprobs: Optional[CompletionLogProbs] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = Field( + default=None, + description=( + "The stop string or token id that caused the completion " + "to stop, None if the completion finished for some other reason " + "including encountering the EOS token"), + ) + prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None + + +class CompletionResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: list[CompletionResponseChoice] + usage: UsageInfo + + +class CompletionResponseStreamChoice(OpenAIBaseModel): + index: int + text: str + logprobs: Optional[CompletionLogProbs] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = Field( + default=None, + description=( + "The stop string or token id that caused the completion " + "to stop, None if the completion finished for some other reason " + "including encountering the EOS token"), + ) + + +class CompletionStreamResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: list[CompletionResponseStreamChoice] + usage: Optional[UsageInfo] = Field(default=None) + + +class EmbeddingResponseData(OpenAIBaseModel): + index: int + object: str = "embedding" + embedding: Union[list[float], str] + + +class EmbeddingResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"embd-{random_uuid()}") + object: str = "list" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + data: list[EmbeddingResponseData] + usage: UsageInfo + + +class PoolingResponseData(OpenAIBaseModel): + index: int + object: str = "pooling" + data: Union[list[list[float]], list[float], str] + + +class PoolingResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"pool-{random_uuid()}") + object: str = "list" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + data: list[PoolingResponseData] + usage: UsageInfo + + +class ScoreResponseData(OpenAIBaseModel): + index: int + object: str = "score" + score: float + + +class ScoreResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"embd-{random_uuid()}") + object: str = "list" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + data: list[ScoreResponseData] + usage: UsageInfo + + +class FunctionCall(OpenAIBaseModel): + name: str + arguments: str + + +class ToolCall(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}") + type: Literal["function"] = "function" + function: FunctionCall + + +class DeltaFunctionCall(BaseModel): + name: Optional[str] = None + arguments: Optional[str] = None + + +# a tool call delta where everything is optional +class DeltaToolCall(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}") + type: Literal["function"] = "function" + index: int + function: Optional[DeltaFunctionCall] = None + + +class ExtractedToolCallInformation(BaseModel): + # indicate if tools were called + tools_called: bool + + # extracted tool calls + tool_calls: list[ToolCall] + + # content - per OpenAI spec, content AND tool calls can be returned rarely + # But some models will do this intentionally + content: Optional[str] = None + + +class ChatMessage(OpenAIBaseModel): + role: str + reasoning_content: Optional[str] = None + content: Optional[str] = None + tool_calls: list[ToolCall] = Field(default_factory=list) + + +class ChatCompletionLogProb(OpenAIBaseModel): + token: str + logprob: float = -9999.0 + bytes: Optional[list[int]] = None + + +class ChatCompletionLogProbsContent(ChatCompletionLogProb): + # Workaround: redefine fields name cache so that it's not + # shared with the super class. + field_names: ClassVar[Optional[set[str]]] = None + top_logprobs: list[ChatCompletionLogProb] = Field(default_factory=list) + + +class ChatCompletionLogProbs(OpenAIBaseModel): + content: Optional[list[ChatCompletionLogProbsContent]] = None + + +class ChatCompletionResponseChoice(OpenAIBaseModel): + index: int + message: ChatMessage + logprobs: Optional[ChatCompletionLogProbs] = None + # per OpenAI spec this is the default + finish_reason: Optional[str] = "stop" + # not part of the OpenAI spec but included in vLLM for legacy reasons + stop_reason: Optional[Union[int, str]] = None + + +class ChatCompletionResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") + object: Literal["chat.completion"] = "chat.completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: list[ChatCompletionResponseChoice] + usage: UsageInfo + prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None + + +class DeltaMessage(OpenAIBaseModel): + role: Optional[str] = None + content: Optional[str] = None + reasoning_content: Optional[str] = None + tool_calls: list[DeltaToolCall] = Field(default_factory=list) + + +class ChatCompletionResponseStreamChoice(OpenAIBaseModel): + index: int + delta: DeltaMessage + logprobs: Optional[ChatCompletionLogProbs] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = None + + +class ChatCompletionStreamResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") + object: Literal["chat.completion.chunk"] = "chat.completion.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: list[ChatCompletionResponseStreamChoice] + usage: Optional[UsageInfo] = Field(default=None) + + +class TranscriptionResponseStreamChoice(OpenAIBaseModel): + delta: DeltaMessage + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = None + + +class TranscriptionStreamResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"trsc-{random_uuid()}") + object: Literal["transcription.chunk"] = "transcription.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: list[TranscriptionResponseStreamChoice] + usage: Optional[UsageInfo] = Field(default=None) + + +class BatchRequestInput(OpenAIBaseModel): + """ + The per-line object of the batch input file. + + NOTE: Currently only the `/v1/chat/completions` endpoint is supported. + """ + + # A developer-provided per-request id that will be used to match outputs to + # inputs. Must be unique for each request in a batch. + custom_id: str + + # The HTTP method to be used for the request. Currently only POST is + # supported. + method: str + + # The OpenAI API relative URL to be used for the request. Currently + # /v1/chat/completions is supported. + url: str + + # The parameters of the request. + body: Union[ChatCompletionRequest, EmbeddingRequest, ScoreRequest] + + @field_validator('body', mode='plain') + @classmethod + def check_type_for_url(cls, value: Any, info: ValidationInfo): + # Use url to disambiguate models + url = info.data['url'] + if url == "/v1/chat/completions": + return ChatCompletionRequest.model_validate(value) + if url == "/v1/embeddings": + return TypeAdapter(EmbeddingRequest).validate_python(value) + if url == "/v1/score": + return ScoreRequest.model_validate(value) + return TypeAdapter(Union[ChatCompletionRequest, EmbeddingRequest, + ScoreRequest]).validate_python(value) + + +class BatchResponseData(OpenAIBaseModel): + # HTTP status code of the response. + status_code: int = 200 + + # An unique identifier for the API request. + request_id: str + + # The body of the response. + body: Optional[Union[ChatCompletionResponse, EmbeddingResponse, + ScoreResponse]] = None + + +class BatchRequestOutput(OpenAIBaseModel): + """ + The per-line object of the batch output and error files + """ + + id: str + + # A developer-provided per-request id that will be used to match outputs to + # inputs. + custom_id: str + + response: Optional[BatchResponseData] + + # For requests that failed with a non-HTTP error, this will contain more + # information on the cause of the failure. + error: Optional[Any] + + +class TokenizeCompletionRequest(OpenAIBaseModel): + model: Optional[str] = None + prompt: str + + add_special_tokens: bool = Field( + default=True, + description=( + "If true (the default), special tokens (e.g. BOS) will be added to " + "the prompt."), + ) + + +class TokenizeChatRequest(OpenAIBaseModel): + model: Optional[str] = None + messages: list[ChatCompletionMessageParam] + + add_generation_prompt: bool = Field( + default=True, + description= + ("If true, the generation prompt will be added to the chat template. " + "This is a parameter used by chat template in tokenizer config of the " + "model."), + ) + continue_final_message: bool = Field( + default=False, + description= + ("If this is set, the chat will be formatted so that the final " + "message in the chat is open-ended, without any EOS tokens. The " + "model will continue this message rather than starting a new one. " + "This allows you to \"prefill\" part of the model's response for it. " + "Cannot be used at the same time as `add_generation_prompt`."), + ) + add_special_tokens: bool = Field( + default=False, + description=( + "If true, special tokens (e.g. BOS) will be added to the prompt " + "on top of what is added by the chat template. " + "For most models, the chat template takes care of adding the " + "special tokens so this should be set to false (as is the " + "default)."), + ) + chat_template: Optional[str] = Field( + default=None, + description=( + "A Jinja template to use for this conversion. " + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the tokenizer " + "does not define one."), + ) + chat_template_kwargs: Optional[dict[str, Any]] = Field( + default=None, + description=("Additional kwargs to pass to the template renderer. " + "Will be accessible by the chat template."), + ) + mm_processor_kwargs: Optional[dict[str, Any]] = Field( + default=None, + description=("Additional kwargs to pass to the HF processor."), + ) + + @model_validator(mode="before") + @classmethod + def check_generation_prompt(cls, data): + if data.get("continue_final_message") and data.get( + "add_generation_prompt"): + raise ValueError("Cannot set both `continue_final_message` and " + "`add_generation_prompt` to True.") + return data + + +TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest] + + +class TokenizeResponse(OpenAIBaseModel): + count: int + max_model_len: int + tokens: list[int] + + +class DetokenizeRequest(OpenAIBaseModel): + model: Optional[str] = None + tokens: list[int] + + +class DetokenizeResponse(OpenAIBaseModel): + prompt: str + + +class LoadLoRAAdapterRequest(BaseModel): + lora_name: str + lora_path: str + + +class UnloadLoRAAdapterRequest(BaseModel): + lora_name: str + lora_int_id: Optional[int] = Field(default=None) + + +## Protocols for Audio +AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", + "vtt"] + + +class TranscriptionRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/audio/createTranscription + + file: UploadFile + """ + The audio file object (not file name) to transcribe, in one of these + formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. + """ + + model: Optional[str] = None + """ID of the model to use. + """ + + language: Optional[str] = None + """The language of the input audio. + + Supplying the input language in + [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format + will improve accuracy and latency. + """ + + prompt: str = Field(default="") + """An optional text to guide the model's style or continue a previous audio + segment. + + The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting) + should match the audio language. + """ + + response_format: AudioResponseFormat = Field(default="json") + """ + The format of the output, in one of these options: `json`, `text`, `srt`, + `verbose_json`, or `vtt`. + """ + + ## TODO (varun) : Support if set to 0, certain thresholds are met !! + temperature: float = Field(default=0.0) + """The sampling temperature, between 0 and 1. + + Higher values like 0.8 will make the output more random, while lower values + like 0.2 will make it more focused / deterministic. If set to 0, the model + will use [log probability](https://en.wikipedia.org/wiki/Log_probability) + to automatically increase the temperature until certain thresholds are hit. + """ + + timestamp_granularities: list[Literal["word", "segment"]] = Field( + alias="timestamp_granularities[]", default=[]) + """The timestamp granularities to populate for this transcription. + + `response_format` must be set `verbose_json` to use timestamp granularities. + Either or both of these options are supported: `word`, or `segment`. Note: + There is no additional latency for segment timestamps, but generating word + timestamps incurs additional latency. + """ + + stream: Optional[bool] = False + """Custom field not present in the original OpenAI definition. When set, + it will enable output to be streamed in a similar fashion as the Chat + Completion endpoint. + """ + # Flattened stream option to simplify form data. + stream_include_usage: Optional[bool] = False + stream_continuous_usage_stats: Optional[bool] = False + + # Default sampling parameters for transcription requests. + _DEFAULT_SAMPLING_PARAMS: dict = { + "temperature": 0, + } + + def to_sampling_params( + self, + default_max_tokens: int, + default_sampling_params: Optional[dict] = None) -> SamplingParams: + # TODO(#9845): remove max_tokens when field is removed from OpenAI API + max_tokens = default_max_tokens + + if default_sampling_params is None: + default_sampling_params = {} + # Default parameters + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get( + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + + return SamplingParams.from_optional(temperature=temperature, + max_tokens=max_tokens, + output_kind=RequestOutputKind.DELTA + if self.stream \ + else RequestOutputKind.FINAL_ONLY) + + @model_validator(mode="before") + @classmethod + def validate_stream_options(cls, data): + stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"] + stream = data.get("stream", False) + if any(bool(data.get(so, False)) for so in stream_opts) and not stream: + raise ValueError( + "Stream options can only be defined when `stream=True`.") + + return data + + +# Transcription response objects +class TranscriptionResponse(OpenAIBaseModel): + text: str + """The transcribed text.""" + + +class TranscriptionWord(OpenAIBaseModel): + end: float + """End time of the word in seconds.""" + + start: float + """Start time of the word in seconds.""" + + word: str + """The text content of the word.""" + + +class TranscriptionSegment(OpenAIBaseModel): + id: int + """Unique identifier of the segment.""" + + avg_logprob: float + """Average logprob of the segment. + + If the value is lower than -1, consider the logprobs failed. + """ + + compression_ratio: float + """Compression ratio of the segment. + + If the value is greater than 2.4, consider the compression failed. + """ + + end: float + """End time of the segment in seconds.""" + + no_speech_prob: float + """Probability of no speech in the segment. + + If the value is higher than 1.0 and the `avg_logprob` is below -1, consider + this segment silent. + """ + + seek: int + """Seek offset of the segment.""" + + start: float + """Start time of the segment in seconds.""" + + temperature: float + """Temperature parameter used for generating the segment.""" + + text: str + """Text content of the segment.""" + + tokens: list[int] + """Array of token IDs for the text content.""" + + +class TranscriptionResponseVerbose(OpenAIBaseModel): + duration: str + """The duration of the input audio.""" + + language: str + """The language of the input audio.""" + + text: str + """The transcribed text.""" + + segments: Optional[list[TranscriptionSegment]] = None + """Segments of the transcribed text and their corresponding details.""" + + words: Optional[list[TranscriptionWord]] = None + """Extracted words and their corresponding timestamps.""" diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py new file mode 100644 index 0000000..0d06ba3 --- /dev/null +++ b/vllm/entrypoints/openai/run_batch.py @@ -0,0 +1,439 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import tempfile +from collections.abc import Awaitable +from http import HTTPStatus +from io import StringIO +from typing import Callable, Optional + +import aiohttp +import torch +from prometheus_client import start_http_server +from tqdm import tqdm + +from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.logger import RequestLogger, logger +# yapf: disable +from vllm.entrypoints.openai.protocol import (BatchRequestInput, + BatchRequestOutput, + BatchResponseData, + ChatCompletionResponse, + EmbeddingResponse, ErrorResponse, + ScoreResponse) +# yapf: enable +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding +from vllm.entrypoints.openai.serving_models import (BaseModelPath, + OpenAIServingModels) +from vllm.entrypoints.openai.serving_score import ServingScores +from vllm.usage.usage_lib import UsageContext +from vllm.utils import FlexibleArgumentParser, random_uuid +from vllm.version import __version__ as VLLM_VERSION + + +def parse_args(): + parser = FlexibleArgumentParser( + description="vLLM OpenAI-Compatible batch runner.") + parser.add_argument( + "-i", + "--input-file", + required=True, + type=str, + help= + "The path or url to a single input file. Currently supports local file " + "paths, or the http protocol (http or https). If a URL is specified, " + "the file should be available via HTTP GET.") + parser.add_argument( + "-o", + "--output-file", + required=True, + type=str, + help="The path or url to a single output file. Currently supports " + "local file paths, or web (http or https) urls. If a URL is specified," + " the file should be available via HTTP PUT.") + parser.add_argument( + "--output-tmp-dir", + type=str, + default=None, + help="The directory to store the output file before uploading it " + "to the output URL.", + ) + parser.add_argument("--response-role", + type=nullable_str, + default="assistant", + help="The role name to return if " + "`request.add_generation_prompt=True`.") + + parser = AsyncEngineArgs.add_cli_args(parser) + + parser.add_argument('--max-log-len', + type=int, + default=None, + help='Max number of prompt characters or prompt ' + 'ID numbers being printed in log.' + '\n\nDefault: Unlimited') + + parser.add_argument("--enable-metrics", + action="store_true", + help="Enable Prometheus metrics") + parser.add_argument( + "--url", + type=str, + default="0.0.0.0", + help="URL to the Prometheus metrics server " + "(only needed if enable-metrics is set).", + ) + parser.add_argument( + "--port", + type=int, + default=8000, + help="Port number for the Prometheus metrics server " + "(only needed if enable-metrics is set).", + ) + parser.add_argument( + "--enable-prompt-tokens-details", + action='store_true', + default=False, + help="If set to True, enable prompt_tokens_details in usage.") + + return parser.parse_args() + + +# explicitly use pure text format, with a newline at the end +# this makes it impossible to see the animation in the progress bar +# but will avoid messing up with ray or multiprocessing, which wraps +# each line of output with some prefix. +_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501 + + +class BatchProgressTracker: + + def __init__(self): + self._total = 0 + self._pbar: Optional[tqdm] = None + + def submitted(self): + self._total += 1 + + def completed(self): + if self._pbar: + self._pbar.update() + + def pbar(self) -> tqdm: + enable_tqdm = not torch.distributed.is_initialized( + ) or torch.distributed.get_rank() == 0 + self._pbar = tqdm(total=self._total, + unit="req", + desc="Running batch", + mininterval=5, + disable=not enable_tqdm, + bar_format=_BAR_FORMAT) + return self._pbar + + +async def read_file(path_or_url: str) -> str: + if path_or_url.startswith("http://") or path_or_url.startswith("https://"): + async with aiohttp.ClientSession() as session, \ + session.get(path_or_url) as resp: + return await resp.text() + else: + with open(path_or_url, encoding="utf-8") as f: + return f.read() + + +async def write_local_file(output_path: str, + batch_outputs: list[BatchRequestOutput]) -> None: + """ + Write the responses to a local file. + output_path: The path to write the responses to. + batch_outputs: The list of batch outputs to write. + """ + # We should make this async, but as long as run_batch runs as a + # standalone program, blocking the event loop won't effect performance. + with open(output_path, "w", encoding="utf-8") as f: + for o in batch_outputs: + print(o.model_dump_json(), file=f) + + +async def upload_data(output_url: str, data_or_file: str, + from_file: bool) -> None: + """ + Upload a local file to a URL. + output_url: The URL to upload the file to. + data_or_file: Either the data to upload or the path to the file to upload. + from_file: If True, data_or_file is the path to the file to upload. + """ + # Timeout is a common issue when uploading large files. + # We retry max_retries times before giving up. + max_retries = 5 + # Number of seconds to wait before retrying. + delay = 5 + + for attempt in range(1, max_retries + 1): + try: + # We increase the timeout to 1000 seconds to allow + # for large files (default is 300). + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout( + total=1000)) as session: + if from_file: + with open(data_or_file, "rb") as file: + async with session.put(output_url, + data=file) as response: + if response.status != 200: + raise Exception(f"Failed to upload file.\n" + f"Status: {response.status}\n" + f"Response: {response.text()}") + else: + async with session.put(output_url, + data=data_or_file) as response: + if response.status != 200: + raise Exception(f"Failed to upload data.\n" + f"Status: {response.status}\n" + f"Response: {response.text()}") + + except Exception as e: + if attempt < max_retries: + logger.error( + f"Failed to upload data (attempt {attempt}). " + f"Error message: {str(e)}.\nRetrying in {delay} seconds..." + ) + await asyncio.sleep(delay) + else: + raise Exception(f"Failed to upload data (attempt {attempt}). " + f"Error message: {str(e)}.") from e + + +async def write_file(path_or_url: str, batch_outputs: list[BatchRequestOutput], + output_tmp_dir: str) -> None: + """ + Write batch_outputs to a file or upload to a URL. + path_or_url: The path or URL to write batch_outputs to. + batch_outputs: The list of batch outputs to write. + output_tmp_dir: The directory to store the output file before uploading it + to the output URL. + """ + if path_or_url.startswith("http://") or path_or_url.startswith("https://"): + if output_tmp_dir is None: + logger.info("Writing outputs to memory buffer") + output_buffer = StringIO() + for o in batch_outputs: + print(o.model_dump_json(), file=output_buffer) + output_buffer.seek(0) + logger.info("Uploading outputs to %s", path_or_url) + await upload_data( + path_or_url, + output_buffer.read().strip().encode("utf-8"), + from_file=False, + ) + else: + # Write responses to a temporary file and then upload it to the URL. + with tempfile.NamedTemporaryFile( + mode="w", + encoding="utf-8", + dir=output_tmp_dir, + prefix="tmp_batch_output_", + suffix=".jsonl", + ) as f: + logger.info("Writing outputs to temporary local file %s", + f.name) + await write_local_file(f.name, batch_outputs) + logger.info("Uploading outputs to %s", path_or_url) + await upload_data(path_or_url, f.name, from_file=True) + else: + logger.info("Writing outputs to local file %s", path_or_url) + await write_local_file(path_or_url, batch_outputs) + + +def make_error_request_output(request: BatchRequestInput, + error_msg: str) -> BatchRequestOutput: + batch_output = BatchRequestOutput( + id=f"vllm-{random_uuid()}", + custom_id=request.custom_id, + response=BatchResponseData( + status_code=HTTPStatus.BAD_REQUEST, + request_id=f"vllm-batch-{random_uuid()}", + ), + error=error_msg, + ) + return batch_output + + +async def make_async_error_request_output( + request: BatchRequestInput, error_msg: str) -> BatchRequestOutput: + return make_error_request_output(request, error_msg) + + +async def run_request(serving_engine_func: Callable, + request: BatchRequestInput, + tracker: BatchProgressTracker) -> BatchRequestOutput: + response = await serving_engine_func(request.body) + + if isinstance(response, + (ChatCompletionResponse, EmbeddingResponse, ScoreResponse)): + batch_output = BatchRequestOutput( + id=f"vllm-{random_uuid()}", + custom_id=request.custom_id, + response=BatchResponseData( + body=response, request_id=f"vllm-batch-{random_uuid()}"), + error=None, + ) + elif isinstance(response, ErrorResponse): + batch_output = BatchRequestOutput( + id=f"vllm-{random_uuid()}", + custom_id=request.custom_id, + response=BatchResponseData( + status_code=response.code, + request_id=f"vllm-batch-{random_uuid()}"), + error=response, + ) + else: + batch_output = make_error_request_output( + request, error_msg="Request must not be sent in stream mode") + + tracker.completed() + return batch_output + + +async def main(args): + if args.served_model_name is not None: + served_model_names = args.served_model_name + else: + served_model_names = [args.model] + + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER) + + model_config = await engine.get_model_config() + base_model_paths = [ + BaseModelPath(name=name, model_path=args.model) + for name in served_model_names + ] + + if args.disable_log_requests: + request_logger = None + else: + request_logger = RequestLogger(max_log_len=args.max_log_len) + + # Create the openai serving objects. + openai_serving_models = OpenAIServingModels( + engine_client=engine, + model_config=model_config, + base_model_paths=base_model_paths, + lora_modules=None, + prompt_adapters=None, + ) + openai_serving_chat = OpenAIServingChat( + engine, + model_config, + openai_serving_models, + args.response_role, + request_logger=request_logger, + chat_template=None, + chat_template_content_format="auto", + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + ) if model_config.runner_type == "generate" else None + openai_serving_embedding = OpenAIServingEmbedding( + engine, + model_config, + openai_serving_models, + request_logger=request_logger, + chat_template=None, + chat_template_content_format="auto", + ) if model_config.task == "embed" else None + openai_serving_scores = (ServingScores( + engine, + model_config, + openai_serving_models, + request_logger=request_logger, + ) if model_config.task == "score" else None) + + tracker = BatchProgressTracker() + logger.info("Reading batch from %s...", args.input_file) + + # Submit all requests in the file to the engine "concurrently". + response_futures: list[Awaitable[BatchRequestOutput]] = [] + for request_json in (await read_file(args.input_file)).strip().split("\n"): + # Skip empty lines. + request_json = request_json.strip() + if not request_json: + continue + + request = BatchRequestInput.model_validate_json(request_json) + + # Determine the type of request and run it. + if request.url == "/v1/chat/completions": + chat_handler_fn = (None if openai_serving_chat is None else + openai_serving_chat.create_chat_completion) + if chat_handler_fn is None: + response_futures.append( + make_async_error_request_output( + request, + error_msg= + "The model does not support Chat Completions API", + )) + continue + + response_futures.append( + run_request(chat_handler_fn, request, tracker)) + tracker.submitted() + elif request.url == "/v1/embeddings": + embed_handler_fn = (None if openai_serving_embedding is None else + openai_serving_embedding.create_embedding) + if embed_handler_fn is None: + response_futures.append( + make_async_error_request_output( + request, + error_msg="The model does not support Embeddings API", + )) + continue + + response_futures.append( + run_request(embed_handler_fn, request, tracker)) + tracker.submitted() + elif request.url == "/v1/score": + score_handler_fn = (None if openai_serving_scores is None else + openai_serving_scores.create_score) + if score_handler_fn is None: + response_futures.append( + make_async_error_request_output( + request, + error_msg="The model does not support Scores API", + )) + continue + + response_futures.append( + run_request(score_handler_fn, request, tracker)) + tracker.submitted() + else: + response_futures.append( + make_async_error_request_output( + request, + error_msg= + "Only /v1/chat/completions, /v1/embeddings, and /v1/score " + "are supported in the batch endpoint.", + )) + + with tracker.pbar(): + responses = await asyncio.gather(*response_futures) + + await write_file(args.output_file, responses, args.output_tmp_dir) + + +if __name__ == "__main__": + args = parse_args() + + logger.info("vLLM batch processing API version %s", VLLM_VERSION) + logger.info("args: %s", args) + + # Start the Prometheus metrics server. LLMEngine uses the Prometheus client + # to publish metrics at the /metrics endpoint. + if args.enable_metrics: + logger.info("Prometheus metrics enabled") + start_http_server(port=args.port, addr=args.url) + else: + logger.info("Prometheus metrics disabled") + + asyncio.run(main(args)) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py new file mode 100644 index 0000000..d4d0cfa --- /dev/null +++ b/vllm/entrypoints/openai/serving_chat.py @@ -0,0 +1,1208 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import json +import re +import time +from collections.abc import AsyncGenerator, AsyncIterator +from collections.abc import Sequence as GenericSequence +from typing import Callable, Final, Optional, Union + +import jinja2 +import partial_json_parser +from fastapi import Request +from pydantic import TypeAdapter + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, + ConversationMessage) +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import ( + ChatCompletionLogProb, ChatCompletionLogProbs, + ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam, + ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, + DeltaToolCall, ErrorResponse, FunctionCall, FunctionDefinition, + PromptTokenUsageInfo, RequestResponseMetadata, ToolCall, UsageInfo) +from vllm.entrypoints.openai.serving_engine import (OpenAIServing, + clamp_prompt_logprobs) +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager +from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( + MistralToolCall) +from vllm.logger import init_logger +from vllm.outputs import CompletionOutput, RequestOutput +from vllm.reasoning import ReasoningParser, ReasoningParserManager +from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.sequence import Logprob +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.transformers_utils.tokenizers import (maybe_serialize_tool_calls, + truncate_tool_call_ids) + +logger = init_logger(__name__) + + +class OpenAIServingChat(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + response_role: str, + *, + request_logger: Optional[RequestLogger], + chat_template: Optional[str], + chat_template_content_format: ChatTemplateContentFormatOption, + return_tokens_as_token_ids: bool = False, + enable_reasoning: bool = False, + reasoning_parser: Optional[str] = None, + enable_auto_tools: bool = False, + tool_parser: Optional[str] = None, + enable_prompt_tokens_details: bool = False, + ) -> None: + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids) + + self.response_role = response_role + self.chat_template = chat_template + self.chat_template_content_format: Final = chat_template_content_format + + # set up tool use + self.enable_auto_tools: bool = enable_auto_tools + if self.enable_auto_tools: + logger.info( + "\"auto\" tool choice has been enabled please note that while" + " the parallel_tool_calls client option is preset for " + "compatibility reasons, it will be ignored.") + + self.enable_reasoning: bool = enable_reasoning + self.reasoning_parser: Optional[Callable[[AnyTokenizer], + ReasoningParser]] = None + if self.enable_reasoning: + try: + self.reasoning_parser = ( + ReasoningParserManager.get_reasoning_parser( + reasoning_parser)) + except Exception as e: + raise TypeError("Error: --enable-reasoning requires " + f"reasoning_parser:'{reasoning_parser}' " + "which has not been registered") from e + self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None + if self.enable_auto_tools: + try: + if (tool_parser == "pythonic" and + model_config.model.startswith("meta-llama/Llama-3.2")): + logger.warning( + "Llama3.2 models may struggle to emit valid pythonic" + " tool calls") + self.tool_parser = ToolParserManager.get_tool_parser( + tool_parser) + except Exception as e: + raise TypeError("Error: --enable-auto-tool-choice requires " + f"tool_parser:'{tool_parser}' which has not " + "been registered") from e + + self.enable_prompt_tokens_details = enable_prompt_tokens_details + self.default_sampling_params = ( + self.model_config.get_diff_sampling_param()) + if self.default_sampling_params: + source = self.model_config.generation_config + source = "model" if source == "auto" else source + logger.info("Using default chat sampling params from %s: %s", + source, self.default_sampling_params) + + async def create_chat_completion( + self, + request: ChatCompletionRequest, + raw_request: Optional[Request] = None, + ) -> Union[AsyncGenerator[str, None], ChatCompletionResponse, + ErrorResponse]: + """ + Chat Completion API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/chat/create + for the API specification. This API mimics the OpenAI + Chat Completion API. + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + logger.error("Error with model %s", error_check_ret) + return error_check_ret + + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.engine_client.errored: + raise self.engine_client.dead_error + + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + model_name = self._get_model_name(request.model, lora_request) + + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + tool_parser = self.tool_parser + + if isinstance(tokenizer, MistralTokenizer): + # because of issues with pydantic we need to potentially + # re-serialize the tool_calls field of the request + # for more info: see comment in `maybe_serialize_tool_calls` + maybe_serialize_tool_calls(request) + truncate_tool_call_ids(request) + + if (request.tool_choice == "auto" and + not (self.enable_auto_tools and tool_parser is not None) + and not isinstance(tokenizer, MistralTokenizer)): + # for hf tokenizers, "auto" tools requires + # --enable-auto-tool-choice and --tool-call-parser + return self.create_error_response( + "\"auto\" tool choice requires " + "--enable-auto-tool-choice and --tool-call-parser to be set" + ) + + tool_dicts = None if request.tools is None else [ + tool.model_dump() for tool in request.tools + ] + + ( + conversation, + request_prompts, + engine_prompts, + ) = await self._preprocess_chat( + request, + tokenizer, + request.messages, + chat_template=request.chat_template or self.chat_template, + chat_template_content_format=self.chat_template_content_format, + add_generation_prompt=request.add_generation_prompt, + continue_final_message=request.continue_final_message, + tool_dicts=tool_dicts, + documents=request.documents, + chat_template_kwargs=request.chat_template_kwargs, + tool_parser=tool_parser, + truncate_prompt_tokens=request.truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + ) + except (ValueError, TypeError, RuntimeError, + jinja2.TemplateError) as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + + request_id = "chatcmpl-" \ + f"{self._base_request_id(raw_request, request.request_id)}" + + request_metadata = RequestResponseMetadata(request_id=request_id) + if raw_request: + raw_request.state.request_metadata = request_metadata + + # Schedule the request and get the result generator. + generators: list[AsyncGenerator[RequestOutput, None]] = [] + try: + for i, engine_prompt in enumerate(engine_prompts): + sampling_params: Union[SamplingParams, BeamSearchParams] + default_max_tokens = self.max_model_len - len( + engine_prompt["prompt_token_ids"]) + if request.use_beam_search: + sampling_params = request.to_beam_search_params( + default_max_tokens, self.default_sampling_params) + else: + sampling_params = request.to_sampling_params( + default_max_tokens, + self.model_config.logits_processor_pattern, + self.default_sampling_params) + + self._log_inputs(request_id, + request_prompts[i], + params=sampling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + trace_headers = (None if raw_request is None else await + self._get_trace_headers(raw_request.headers)) + + if isinstance(sampling_params, BeamSearchParams): + generator = self.engine_client.beam_search( + prompt=engine_prompt, + request_id=request_id, + params=sampling_params, + ) + else: + generator = self.engine_client.generate( + engine_prompt, + sampling_params, + request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=request.priority, + ) + + generators.append(generator) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + assert len(generators) == 1 + result_generator, = generators + + # Streaming response + if request.stream: + return self.chat_completion_stream_generator( + request, result_generator, request_id, model_name, + conversation, tokenizer, request_metadata) + + try: + return await self.chat_completion_full_generator( + request, result_generator, request_id, model_name, + conversation, tokenizer, request_metadata) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + def get_chat_request_role(self, request: ChatCompletionRequest) -> str: + if request.add_generation_prompt: + return self.response_role + return request.messages[-1]["role"] + + @staticmethod + def _bracket_level(s: str, opening='{', closing='}') -> int: + """ + Calculate the current level of nested brackets in a given string. + """ + level = 0 + for char in s: + if char == opening: + level += 1 + elif char == closing: + level -= 1 + return level + + @staticmethod + def _filter_delta_text(delta_text: str, + previous_text: str) -> tuple[str, bool]: + # remove last '},' of the tool definition stemming from the + # "name"/"parameters" outer object or closing ']' of the tool list + # count occurrences of opening and closing curly braces and + # once level 0 is reached stop outputting text + # if 0 is reached while parsing the delta_text we know the current + # tool will finish in this current iteration + bracket_level = OpenAIServingChat._bracket_level(previous_text) + updated_delta, passed_zero = "", False + for c in delta_text: + if c == '{': + bracket_level += 1 + passed_zero = bracket_level == 0 + elif c == '}': + bracket_level -= 1 + passed_zero = bracket_level == 0 + + if bracket_level != 0: + updated_delta += c + else: + # if a comma is reached at level 0 we can stop + if c == ',': + break + return updated_delta, passed_zero + + def extract_tool_call_required_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + function_name_returned: bool, + ) -> tuple[Optional[DeltaMessage], bool]: + try: + obj = partial_json_parser.loads(current_text) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + obj = None + + # check if the current text is a valid array + # containing a partial tool calling object + # if not repeat + if obj is None or not isinstance(obj, list) or not len(obj) > 0: + function_name_returned = False + delta_message = None + else: + _, finishes_previous_tool = OpenAIServingChat._filter_delta_text( + delta_text, previous_text) + # take the last tool call from the generated list + current_tool_call = obj[-1] + + # once parameters have been generated the name is complete as well + if not finishes_previous_tool and ("name" not in current_tool_call + or "parameters" + not in current_tool_call): + function_name_returned = False + delta_message = None + else: + if not function_name_returned: + # get partly generated arguments from the latest tool call + param_match = re.search(r'.*"parameters":\s*(.*)', + current_text) + arguments = param_match.group(1) if param_match else "" + arguments, _ = OpenAIServingChat._filter_delta_text( + arguments, previous_text) + + # if this iteration finishes a previous tool call but a + # new incomplete tool is already generated, take the + # previous from the list + if (finishes_previous_tool + and "parameters" not in current_tool_call): + current_tool_call = obj[-2] + + function_name_returned = True + delta_message = DeltaMessage(tool_calls=[ + DeltaToolCall(function=DeltaFunctionCall( + name=current_tool_call["name"], + arguments=arguments), + index=len(obj) - 1, + type="function") + ]) + + else: + delta_text, _ = OpenAIServingChat._filter_delta_text( + delta_text, previous_text) + + if delta_text != "": + delta_message = DeltaMessage(tool_calls=[ + DeltaToolCall( + function=DeltaFunctionCall( + # OpenAI API returns None + # instead of name every time + name=None, + arguments=delta_text), + index=len(obj) - 1, + type="function") + ]) + else: + delta_message = None + + return delta_message, function_name_returned + + async def chat_completion_stream_generator( + self, + request: ChatCompletionRequest, + result_generator: AsyncIterator[RequestOutput], + request_id: str, + model_name: str, + conversation: list[ConversationMessage], + tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, + ) -> AsyncGenerator[str, None]: + created_time = int(time.time()) + chunk_object_type: Final = "chat.completion.chunk" + first_iteration = True + + # Send response for each token for each request.n (index) + num_choices = 1 if request.n is None else request.n + previous_num_tokens = [0] * num_choices + finish_reason_sent = [False] * num_choices + num_prompt_tokens = 0 + num_cached_tokens = None + + if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): + tool_choice_function_name = request.tool_choice.function.name + else: + tool_choice_function_name = None + + # Determine whether tools are in use with "auto" tool choice + tool_choice_auto = ( + not tool_choice_function_name + and self._should_stream_with_auto_tool_parsing(request)) + + should_stream_with_reasoning_parsing = ( + self._should_stream_with_reasoning_parsing(request)) + + all_previous_token_ids: Optional[list[list[int]]] + function_name_returned: Optional[list[bool]] = None + + # Only one of these will be used, thus previous_texts and + # all_previous_token_ids will not be used twice in the same iteration. + if tool_choice_auto or should_stream_with_reasoning_parsing: + # These are only required in "auto" tool choice case + previous_texts = [""] * num_choices + all_previous_token_ids = [[]] * num_choices + # For reasoning parser and tool call all enabled + added_content_delta_arr = [False] * num_choices + reasoning_end_arr = [False] * num_choices + elif request.tool_choice == "required": + previous_texts = [""] * num_choices + function_name_returned = [False] * num_choices + all_previous_token_ids = None + else: + previous_texts, all_previous_token_ids = None, None + + try: + # There is no need to check if the reasoning_parser is None + # because the should_stream_with_reasoning_parsing check + # already ensures that the reasoning_parser is not None. + # but the pre-commit hook requires it. + if should_stream_with_reasoning_parsing and \ + self.reasoning_parser is not None: + reasoning_parser = self.reasoning_parser(tokenizer) + except RuntimeError as e: + logger.exception("Error in reasoning parser creation.") + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" + yield "data: [DONE]\n\n" + return + + # Prepare the tool parser if it's needed + try: + if tool_choice_auto and self.tool_parser: + tool_parsers: list[Optional[ToolParser]] = [ + self.tool_parser(tokenizer) + ] * num_choices + else: + tool_parsers = [None] * num_choices + except Exception as e: + logger.exception("Error in tool parser creation.") + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" + yield "data: [DONE]\n\n" + return + + stream_options = request.stream_options + if stream_options: + include_usage = stream_options.include_usage + include_continuous_usage = include_usage and \ + stream_options.continuous_usage_stats + else: + include_usage, include_continuous_usage = False, False + + try: + async for res in result_generator: + if res.prompt_token_ids is not None: + num_prompt_tokens = len(res.prompt_token_ids) + if res.encoder_prompt_token_ids is not None: + num_prompt_tokens += len(res.encoder_prompt_token_ids) + + # We need to do it here, because if there are exceptions in + # the result_generator, it needs to be sent as the FIRST + # response (by the try...catch). + if first_iteration: + num_cached_tokens = res.num_cached_tokens + # Send first response for each request.n (index) with + # the role + role = self.get_chat_request_role(request) + + # NOTE num_choices defaults to 1 so this usually executes + # once per request + for i in range(num_choices): + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage( + role=role, + content="", + ), + logprobs=None, + finish_reason=None) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + + # if continuous usage stats are requested, add it + if include_continuous_usage: + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=0, + total_tokens=num_prompt_tokens) + + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + # Send response to echo the input portion of the + # last message + if request.echo: + last_msg_content: Union[str, list[dict[str, str]]] = "" + if conversation and "content" in conversation[ + -1] and conversation[-1].get("role") == role: + last_msg_content = conversation[-1]["content"] or "" + + if last_msg_content: + for i in range(num_choices): + choice_data = ( + ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage( + content=last_msg_content), + logprobs=None, + finish_reason=None)) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + if include_continuous_usage: + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=0, + total_tokens=num_prompt_tokens) + + data = chunk.model_dump_json( + exclude_unset=True) + yield f"data: {data}\n\n" + first_iteration = False + + for output in res.outputs: + i = output.index + tool_parser = tool_parsers[i] + + if finish_reason_sent[i]: + continue + + if request.logprobs and request.top_logprobs is not None: + assert output.logprobs is not None, ( + "Did not output logprobs") + logprobs = self._create_chat_logprobs( + token_ids=output.token_ids, + top_logprobs=output.logprobs, + tokenizer=tokenizer, + num_output_top_logprobs=request.top_logprobs, + return_as_token_id=request. + return_tokens_as_token_ids, + ) + else: + logprobs = None + + delta_text = output.text + + if not delta_text and not output.token_ids and \ + not previous_num_tokens[i]: + # Chunked prefill case, don't return empty chunks + continue + + delta_message: Optional[DeltaMessage] + + # just update previous_texts and previous_token_ids + if tool_choice_auto or should_stream_with_reasoning_parsing: + assert previous_texts is not None + assert all_previous_token_ids is not None + previous_text = previous_texts[i] + previous_token_ids = all_previous_token_ids[i] + current_text = previous_text + delta_text + current_token_ids = previous_token_ids + list( + output.token_ids) + + # handle streaming deltas for tools with named tool_choice + if tool_choice_function_name: + if (self.enable_reasoning + and not reasoning_parser.is_reasoning_end( + previous_token_ids)): + assert reasoning_parser is not None + delta_message = ( + reasoning_parser. + extract_reasoning_content_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output.token_ids, + )) + # When encountering think end id in delta_token_ids, + # process the `content`. Only keep 'content', + # remove 'reasoning_content' + if reasoning_parser.is_reasoning_end( + list(output.token_ids)): + if delta_message and delta_message.content: + # This need to be added to next `delta_text` + current_text = delta_message.content + delta_message.content = None + else: + current_text = "" + else: + # Just to add remaining `content` + if self.enable_reasoning: + delta_text = previous_text + delta_text + current_text = "" + + delta_message = DeltaMessage(tool_calls=[ + DeltaToolCall(function=DeltaFunctionCall( + name=tool_choice_function_name, + arguments=delta_text), + index=i) + ]) + + elif request.tool_choice == "required": + assert previous_texts is not None + assert function_name_returned is not None + previous_text = previous_texts[i] + current_text = previous_text + delta_text + fn_name_returned = function_name_returned[i] + + delta_message, function_name_returned[i] = ( + self.extract_tool_call_required_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + function_name_returned=fn_name_returned)) + + # update the previous values for the next iteration + previous_texts[i] = current_text + + # handle streaming deltas for tools with "auto" tool choice + # and reasoning parser + elif tool_choice_auto and self.enable_reasoning: + assert tool_parser is not None + assert reasoning_parser is not None + assert added_content_delta_arr is not None + assert reasoning_end_arr is not None + if not reasoning_end_arr[i]: + delta_message = ( + reasoning_parser. + extract_reasoning_content_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output.token_ids, + )) + + # When encountering think end id in delta_token_ids, + # set reasoning status to end. + # Remove the text and token ids related + # to 'reasoning_content'. + if reasoning_parser.is_reasoning_end( + list(output.token_ids)): + reasoning_end_arr[i] = True + current_token_ids = \ + reasoning_parser.extract_content_ids( + list(output.token_ids)) + if delta_message and delta_message.content: + current_text = delta_message.content + delta_message.content = None + else: + current_text = "" + + # handle tool calls only after reasoning is done, + else: + delta_token_ids = list(output.token_ids) + # First time to tool call, + # add the remaining text and token ids + # to delta from previous + if not added_content_delta_arr[i]: + added_content_delta_arr[i] = True + previous_text = "" + previous_token_ids = [] + delta_text = current_text + delta_token_ids = current_token_ids + + delta_message = ( + tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=delta_token_ids, + request=request)) + # when only tool calls + elif tool_choice_auto: + assert tool_parser is not None + delta_message = ( + tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=output.token_ids, + request=request)) + # when only reasoning + elif self.enable_reasoning: + assert reasoning_parser is not None + delta_message = (reasoning_parser. + extract_reasoning_content_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output.token_ids, + )) + # handle streaming just a content delta + else: + delta_message = DeltaMessage(content=delta_text) + + # update the previous values for the next iteration + if tool_choice_auto or should_stream_with_reasoning_parsing: + assert previous_texts is not None + assert all_previous_token_ids is not None + previous_texts[i] = current_text + all_previous_token_ids[i] = current_token_ids + + # set the previous values for the next iteration + previous_num_tokens[i] += len(output.token_ids) + + # if the message delta is None (e.g. because it was a + # "control token" for tool calls or the parser otherwise + # wasn't ready to send a token, then + # get the next token without streaming a chunk + if delta_message is None: + continue + + if output.finish_reason is None: + # Send token-by-token response for each request.n + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=delta_message, + logprobs=logprobs, + finish_reason=None) + + # if the model is finished generating + else: + # check to make sure we haven't "forgotten" to stream + # any tokens that were generated but previously + # matched by partial json parsing + # only happens if we are NOT using guided decoding + auto_tools_called = False + if tool_parser: + auto_tools_called = len( + tool_parser.prev_tool_call_arr) > 0 + index = len(tool_parser.prev_tool_call_arr + ) - 1 if auto_tools_called else 0 + else: + index = 0 + + if self._should_check_for_unstreamed_tool_arg_tokens( + delta_message, output) and tool_parser: + latest_delta_len = 0 + if ((isinstance( + delta_message.tool_calls[0].function, + DeltaFunctionCall)) and isinstance( + delta_message.tool_calls[0].function. + arguments, str)): + latest_delta_len = len( + delta_message.tool_calls[0].function. + arguments) + + # get the expected call based on partial JSON + # parsing which "autocompletes" the JSON + expected_call = json.dumps( + tool_parser.prev_tool_call_arr[index].get( + "arguments", {}), + ensure_ascii=False) + + # get what we've streamed so far for arguments + # for the current tool + actual_call = tool_parser.streamed_args_for_tool[ + index] + if (latest_delta_len > 0): + actual_call = actual_call[:-latest_delta_len] + + # check to see if there's anything left to stream + remaining_call = expected_call.replace( + actual_call, "", 1) + # set that as a delta message + delta_message = DeltaMessage(tool_calls=[ + DeltaToolCall(index=index, + function=DeltaFunctionCall( + arguments=remaining_call). + model_dump(exclude_none=True)) + ]) + + # Send the finish response for each request.n only once + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=delta_message, + logprobs=logprobs, + finish_reason=output.finish_reason + if not auto_tools_called else "tool_calls", + stop_reason=output.stop_reason) + + finish_reason_sent[i] = True + + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + + # handle usage stats if requested & if continuous + if include_continuous_usage: + completion_tokens = previous_num_tokens[i] + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens, + ) + + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + # once the final token is handled, if stream_options.include_usage + # is sent, send the usage + if include_usage: + completion_tokens = sum(previous_num_tokens) + final_usage = UsageInfo(prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + + completion_tokens) + if self.enable_prompt_tokens_details and num_cached_tokens: + final_usage.prompt_tokens_details = PromptTokenUsageInfo( + cached_tokens=num_cached_tokens) + + final_usage_chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[], + model=model_name, + usage=final_usage) + final_usage_data = (final_usage_chunk.model_dump_json( + exclude_unset=True, exclude_none=True)) + yield f"data: {final_usage_data}\n\n" + + # report to FastAPI middleware aggregate usage across all choices + num_completion_tokens = sum(previous_num_tokens) + request_metadata.final_usage_info = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_completion_tokens, + total_tokens=num_prompt_tokens + num_completion_tokens) + + except Exception as e: + # TODO: Use a vllm-specific Validation Error + logger.exception("Error in chat completion stream generator.") + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" + # Send the final done message after all response.n are finished + yield "data: [DONE]\n\n" + + async def chat_completion_full_generator( + self, + request: ChatCompletionRequest, + result_generator: AsyncIterator[RequestOutput], + request_id: str, + model_name: str, + conversation: list[ConversationMessage], + tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, + ) -> Union[ErrorResponse, ChatCompletionResponse]: + + created_time = int(time.time()) + final_res: Optional[RequestOutput] = None + + try: + async for res in result_generator: + final_res = res + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + assert final_res is not None + + choices: list[ChatCompletionResponseChoice] = [] + + role = self.get_chat_request_role(request) + for output in final_res.outputs: + token_ids = output.token_ids + out_logprobs = output.logprobs + + if request.logprobs and request.top_logprobs is not None: + assert out_logprobs is not None, "Did not output logprobs" + logprobs = self._create_chat_logprobs( + token_ids=token_ids, + top_logprobs=out_logprobs, + num_output_top_logprobs=request.top_logprobs, + tokenizer=tokenizer, + return_as_token_id=request.return_tokens_as_token_ids, + ) + else: + logprobs = None + + should_stream_with_reasoning_parsing = ( + self._should_stream_with_reasoning_parsing(request)) + + # In the OpenAI API the finish_reason is "tools_called" + # if the tool choice is auto and the model produced a tool + # call. The same is not true for named function calls + auto_tools_called = False + + if should_stream_with_reasoning_parsing and \ + self.reasoning_parser is not None: + try: + reasoning_parser = self.reasoning_parser(tokenizer) + except RuntimeError as e: + logger.exception("Error in reasoning parser creation.") + return self.create_error_response(str(e)) + # If the reasoning parser is enabled, + # tool calls are extracted exclusively from the content. + reasoning_content, content = ( + reasoning_parser.extract_reasoning_content( + output.text, request=request)) + else: + reasoning_content = None + content = output.text + + # if auto tools are not enabled, and a named tool choice using + # outlines is not being used + if (not self.enable_auto_tools or not self.tool_parser) and \ + (not isinstance(request.tool_choice, + ChatCompletionNamedToolChoiceParam + ) and request.tool_choice != "required"): + message = ChatMessage(role=role, + reasoning_content=reasoning_content, + content=content) + + # if the request uses tools and specified a tool choice + elif request.tool_choice and type( + request.tool_choice) is ChatCompletionNamedToolChoiceParam: + + tool_call_class = MistralToolCall if isinstance( + tokenizer, MistralTokenizer) else ToolCall + message = ChatMessage( + role=role, + reasoning_content=reasoning_content, + content="", + tool_calls=[ + tool_call_class(function=FunctionCall( + name=request.tool_choice.function.name, + arguments=content)) + ]) + + elif request.tool_choice and request.tool_choice == "required": + tool_call_class = MistralToolCall if isinstance( + tokenizer, MistralTokenizer) else ToolCall + + # the fields of FunctionDefinition are a superset of the + # tool call outputs and can be used for parsing + tool_calls = TypeAdapter( + list[FunctionDefinition]).validate_json(output.text) + message = ChatMessage( + role=role, + content="", + tool_calls=[ + tool_call_class(function=FunctionCall( + name=tool_call.name, + arguments=json.dumps(tool_call.parameters))) + for tool_call in tool_calls + ]) + + # if the request doesn't use tool choice + # OR specifies to not use a tool + elif not request.tool_choice or request.tool_choice == "none": + + message = ChatMessage(role=role, + reasoning_content=reasoning_content, + content=content) + + # handle when there are tools and tool choice is auto + elif request.tools and ( + request.tool_choice == "auto" + or request.tool_choice is None) and self.enable_auto_tools \ + and self.tool_parser: + + try: + tool_parser = self.tool_parser(tokenizer) + except RuntimeError as e: + logger.exception("Error in tool parser creation.") + return self.create_error_response(str(e)) + + tool_call_info = tool_parser.extract_tool_calls( + content if content is not None else "", request=request) + # In the OpenAI API the finish_reason is "tools_called" + # if the tool choice is auto and the model produced a tool + # call. The same is not true for named function calls + auto_tools_called = tool_call_info.tools_called + if tool_call_info.tools_called: + message = ChatMessage(role=role, + reasoning_content=reasoning_content, + content=tool_call_info.content, + tool_calls=tool_call_info.tool_calls) + + else: + # FOR NOW make it a chat message; we will have to detect + # the type to make it later. + message = ChatMessage(role=role, + reasoning_content=reasoning_content, + content=content) + + # undetermined case that is still important to handle + else: + logger.error( + "Error in chat_completion_full_generator - cannot determine" + " if tools should be extracted. Returning a standard chat " + "completion.") + message = ChatMessage(role=role, + reasoning_content=reasoning_content, + content=content) + + choice_data = ChatCompletionResponseChoice( + index=output.index, + message=message, + logprobs=logprobs, + finish_reason="tool_calls" if auto_tools_called else + output.finish_reason if output.finish_reason else "stop", + stop_reason=output.stop_reason) + choices.append(choice_data) + + if request.echo: + last_msg_content: Union[str, list[dict[str, str]]] = "" + if conversation and "content" in conversation[-1] and conversation[ + -1].get("role") == role: + last_msg_content = conversation[-1]["content"] or "" + if isinstance(last_msg_content, list): + last_msg_content = "\n".join(msg['text'] + for msg in last_msg_content) + + for choice in choices: + full_message = last_msg_content + (choice.message.content + or "") + choice.message.content = full_message + + assert final_res.prompt_token_ids is not None + num_prompt_tokens = len(final_res.prompt_token_ids) + if final_res.encoder_prompt_token_ids is not None: + num_prompt_tokens += len(final_res.encoder_prompt_token_ids) + num_generated_tokens = sum( + len(output.token_ids) for output in final_res.outputs) + usage = UsageInfo(prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + + num_generated_tokens) + if self.enable_prompt_tokens_details and final_res.num_cached_tokens: + usage.prompt_tokens_details = PromptTokenUsageInfo( + cached_tokens=final_res.num_cached_tokens) + + request_metadata.final_usage_info = usage + + response = ChatCompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs), + ) + + return response + + def _get_top_logprobs( + self, logprobs: dict[int, Logprob], top_logprobs: Optional[int], + tokenizer: AnyTokenizer, + should_return_as_token_id: bool) -> list[ChatCompletionLogProb]: + return [ + ChatCompletionLogProb(token=(token := self._get_decoded_token( + p[1], + p[0], + tokenizer, + return_as_token_id=should_return_as_token_id)), + logprob=max(p[1].logprob, -9999.0), + bytes=list( + token.encode("utf-8", errors="replace"))) + for i, p in enumerate(logprobs.items()) + if top_logprobs and i < top_logprobs + ] + + def _create_chat_logprobs( + self, + token_ids: GenericSequence[int], + top_logprobs: GenericSequence[Optional[dict[int, Logprob]]], + tokenizer: AnyTokenizer, + num_output_top_logprobs: Optional[int] = None, + return_as_token_id: Optional[bool] = None, + ) -> ChatCompletionLogProbs: + """Create OpenAI-style logprobs.""" + logprobs_content: list[ChatCompletionLogProbsContent] = [] + + should_return_as_token_id = return_as_token_id if \ + return_as_token_id is not None else self.return_tokens_as_token_ids + for i, token_id in enumerate(token_ids): + step_top_logprobs = top_logprobs[i] + if step_top_logprobs is None: + token = tokenizer.decode(token_id) + if should_return_as_token_id: + token = f"token_id:{token_id}" + + logprobs_content.append( + ChatCompletionLogProbsContent( + token=token, + bytes=list(token.encode("utf-8", errors="replace")), + )) + else: + step_token = step_top_logprobs[token_id] + step_decoded = step_token.decoded_token + + logprobs_content.append( + ChatCompletionLogProbsContent( + token=self._get_decoded_token( + step_token, + token_id, + tokenizer, + should_return_as_token_id, + ), + logprob=max(step_token.logprob, -9999.0), + bytes=None if step_decoded is None else list( + step_decoded.encode("utf-8", errors="replace")), + top_logprobs=self._get_top_logprobs( + step_top_logprobs, num_output_top_logprobs, + tokenizer, should_return_as_token_id), + )) + + return ChatCompletionLogProbs(content=logprobs_content) + + def _should_stream_with_auto_tool_parsing(self, + request: ChatCompletionRequest): + """ + Utility function to check if streamed tokens should go through the tool + call parser that was configured. + + We only want to do this IF user-provided tools are set, a tool parser + is configured, "auto" tool choice is enabled, and the request's tool + choice field indicates that "auto" tool choice should be used. + """ + return (request.tools and self.tool_parser and self.enable_auto_tools + and request.tool_choice in ['auto', None]) + + def _should_stream_with_reasoning_parsing(self, + request: ChatCompletionRequest): + """ + Utility function to check if streamed tokens should go through the + reasoning parser that was configured. + + We only want to do this IF reasoning is enabled and a reasoning + parser is configured. + """ + return self.enable_reasoning and self.reasoning_parser is not None + + def _should_check_for_unstreamed_tool_arg_tokens( + self, + delta_message: Optional[DeltaMessage], + output: CompletionOutput, + ) -> bool: + """ + Check to see if we should check for unstreamed tool arguments tokens. + This is only applicable when auto tool parsing is enabled, the delta + is a tool call with arguments. + """ + + # yapf: disable + return bool( + # if there is a delta message that includes tool calls which + # include a function that has arguments + output.finish_reason is not None + and self.enable_auto_tools and self.tool_parser and delta_message + and delta_message.tool_calls and delta_message.tool_calls[0] + and delta_message.tool_calls[0].function + and delta_message.tool_calls[0].function.arguments is not None + ) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py new file mode 100644 index 0000000..1067f35 --- /dev/null +++ b/vllm/entrypoints/openai/serving_completion.py @@ -0,0 +1,557 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import time +from collections.abc import AsyncGenerator, AsyncIterator +from collections.abc import Sequence as GenericSequence +from typing import Optional, Union, cast + +import jinja2 +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.logger import RequestLogger +# yapf conflicts with isort for this block +# yapf: disable +from vllm.entrypoints.openai.protocol import (CompletionLogProbs, + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + CompletionResponseStreamChoice, + CompletionStreamResponse, + ErrorResponse, + RequestResponseMetadata, + UsageInfo) +# yapf: enable +from vllm.entrypoints.openai.serving_engine import (OpenAIServing, + clamp_prompt_logprobs) +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.logger import init_logger +from vllm.outputs import RequestOutput +from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.sequence import Logprob +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import merge_async_iterators + +logger = init_logger(__name__) + + +class OpenAIServingCompletion(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + return_tokens_as_token_ids: bool = False, + ): + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids) + self.default_sampling_params = ( + self.model_config.get_diff_sampling_param()) + if self.default_sampling_params: + source = self.model_config.generation_config + source = "model" if source == "auto" else source + logger.info("Using default completion sampling params from %s: %s", + source, self.default_sampling_params) + + async def create_completion( + self, + request: CompletionRequest, + raw_request: Optional[Request] = None, + ) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]: + """Completion API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/completions/create + for the API specification. This API mimics the OpenAI Completion API. + + NOTE: Currently we do not support the following feature: + - suffix (the language models we currently support do not support + suffix) + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.engine_client.errored: + raise self.engine_client.dead_error + + # Return error for unsupported features. + if request.suffix is not None: + return self.create_error_response( + "suffix is not currently supported") + + request_id = f"cmpl-{self._base_request_id(raw_request)}" + created_time = int(time.time()) + + request_metadata = RequestResponseMetadata(request_id=request_id) + if raw_request: + raw_request.state.request_metadata = request_metadata + + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + request_prompts, engine_prompts = await self._preprocess_completion( + request, + tokenizer, + request.prompt, + truncate_prompt_tokens=request.truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + ) + except ValueError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + except TypeError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + except RuntimeError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + except jinja2.TemplateError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + + # Schedule the request and get the result generator. + generators: list[AsyncGenerator[RequestOutput, None]] = [] + try: + for i, engine_prompt in enumerate(engine_prompts): + sampling_params: Union[SamplingParams, BeamSearchParams] + default_max_tokens = self.max_model_len - len( + engine_prompt["prompt_token_ids"]) + if request.use_beam_search: + sampling_params = request.to_beam_search_params( + default_max_tokens, self.default_sampling_params) + else: + sampling_params = request.to_sampling_params( + default_max_tokens, + self.model_config.logits_processor_pattern, + self.default_sampling_params) + + request_id_item = f"{request_id}-{i}" + + self._log_inputs(request_id_item, + request_prompts[i], + params=sampling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + trace_headers = (None if raw_request is None else await + self._get_trace_headers(raw_request.headers)) + + if isinstance(sampling_params, BeamSearchParams): + generator = self.engine_client.beam_search( + prompt=engine_prompt, + request_id=request_id, + params=sampling_params, + ) + else: + generator = self.engine_client.generate( + engine_prompt, + sampling_params, + request_id_item, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + trace_headers=trace_headers, + priority=request.priority, + ) + + generators.append(generator) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + result_generator = merge_async_iterators(*generators) + + model_name = self._get_model_name(request.model, lora_request) + num_prompts = len(engine_prompts) + + # Similar to the OpenAI API, when n != best_of, we do not stream the + # results. Noting that best_of is only supported in V0. In addition, + # we do not stream the results when use beam search. + stream = (request.stream + and (request.best_of is None or request.n == request.best_of) + and not request.use_beam_search) + + # Streaming response + if stream: + return self.completion_stream_generator( + request, + result_generator, + request_id, + created_time, + model_name, + num_prompts=num_prompts, + tokenizer=tokenizer, + request_metadata=request_metadata) + + # Non-streaming response + final_res_batch: list[Optional[RequestOutput]] = [None] * num_prompts + try: + async for i, res in result_generator: + final_res_batch[i] = res + + for i, final_res in enumerate(final_res_batch): + assert final_res is not None + + # The output should contain the input text + # We did not pass it into vLLM engine to avoid being redundant + # with the inputs token IDs + if final_res.prompt is None: + final_res.prompt = request_prompts[i]["prompt"] + + final_res_batch_checked = cast(list[RequestOutput], + final_res_batch) + + response = self.request_output_to_completion_response( + final_res_batch_checked, + request, + request_id, + created_time, + model_name, + tokenizer, + request_metadata, + ) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + # When user requests streaming but we don't stream, we still need to + # return a streaming response with a single event. + if request.stream: + response_json = response.model_dump_json() + + async def fake_stream_generator() -> AsyncGenerator[str, None]: + yield f"data: {response_json}\n\n" + yield "data: [DONE]\n\n" + + return fake_stream_generator() + + return response + + async def completion_stream_generator( + self, + request: CompletionRequest, + result_generator: AsyncIterator[tuple[int, RequestOutput]], + request_id: str, + created_time: int, + model_name: str, + num_prompts: int, + tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, + ) -> AsyncGenerator[str, None]: + num_choices = 1 if request.n is None else request.n + previous_text_lens = [0] * num_choices * num_prompts + previous_num_tokens = [0] * num_choices * num_prompts + has_echoed = [False] * num_choices * num_prompts + num_prompt_tokens = [0] * num_prompts + + stream_options = request.stream_options + if stream_options: + include_usage = stream_options.include_usage + include_continuous_usage = include_usage and \ + stream_options.continuous_usage_stats + else: + include_usage, include_continuous_usage = False, False + + try: + async for prompt_idx, res in result_generator: + prompt_token_ids = res.prompt_token_ids + prompt_logprobs = res.prompt_logprobs + prompt_text = res.prompt + + # Prompt details are excluded from later streamed outputs + if res.prompt_token_ids is not None: + num_prompt_tokens[prompt_idx] = len(res.prompt_token_ids) + + delta_token_ids: GenericSequence[int] + out_logprobs: Optional[GenericSequence[Optional[dict[ + int, Logprob]]]] + + for output in res.outputs: + i = output.index + prompt_idx * num_choices + + assert request.max_tokens is not None + if request.echo and not has_echoed[i]: + assert prompt_token_ids is not None + assert prompt_text is not None + if request.max_tokens == 0: + # only return the prompt + delta_text = prompt_text + delta_token_ids = prompt_token_ids + out_logprobs = prompt_logprobs + else: + assert prompt_logprobs is not None + # echo the prompt and first token + delta_text = prompt_text + output.text + delta_token_ids = [ + *prompt_token_ids, *output.token_ids + ] + out_logprobs = [ + *prompt_logprobs, + *(output.logprobs or []), + ] + has_echoed[i] = True + else: + # return just the delta + delta_text = output.text + delta_token_ids = output.token_ids + out_logprobs = output.logprobs + + if not delta_text and not delta_token_ids \ + and not previous_num_tokens[i]: + # Chunked prefill case, don't return empty chunks + continue + + if request.logprobs is not None: + assert out_logprobs is not None, ( + "Did not output logprobs") + logprobs = self._create_completion_logprobs( + token_ids=delta_token_ids, + top_logprobs=out_logprobs, + num_output_top_logprobs=request.logprobs, + tokenizer=tokenizer, + initial_text_offset=previous_text_lens[i], + return_as_token_id=request. + return_tokens_as_token_ids, + ) + else: + logprobs = None + + previous_text_lens[i] += len(output.text) + previous_num_tokens[i] += len(output.token_ids) + finish_reason = output.finish_reason + stop_reason = output.stop_reason + + chunk = CompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[ + CompletionResponseStreamChoice( + index=i, + text=delta_text, + logprobs=logprobs, + finish_reason=finish_reason, + stop_reason=stop_reason, + ) + ]) + if include_continuous_usage: + prompt_tokens = num_prompt_tokens[prompt_idx] + completion_tokens = previous_num_tokens[i] + chunk.usage = UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + + response_json = chunk.model_dump_json(exclude_unset=False) + yield f"data: {response_json}\n\n" + + total_prompt_tokens = sum(num_prompt_tokens) + total_completion_tokens = sum(previous_num_tokens) + final_usage_info = UsageInfo( + prompt_tokens=total_prompt_tokens, + completion_tokens=total_completion_tokens, + total_tokens=total_prompt_tokens + total_completion_tokens) + + if include_usage: + final_usage_chunk = CompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[], + usage=final_usage_info, + ) + final_usage_data = (final_usage_chunk.model_dump_json( + exclude_unset=False, exclude_none=True)) + yield f"data: {final_usage_data}\n\n" + + # report to FastAPI middleware aggregate usage across all choices + request_metadata.final_usage_info = final_usage_info + + except Exception as e: + # TODO: Use a vllm-specific Validation Error + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" + yield "data: [DONE]\n\n" + + def request_output_to_completion_response( + self, + final_res_batch: list[RequestOutput], + request: CompletionRequest, + request_id: str, + created_time: int, + model_name: str, + tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, + ) -> CompletionResponse: + choices: list[CompletionResponseChoice] = [] + num_prompt_tokens = 0 + num_generated_tokens = 0 + + for final_res in final_res_batch: + prompt_token_ids = final_res.prompt_token_ids + assert prompt_token_ids is not None + prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs) + prompt_text = final_res.prompt + + token_ids: GenericSequence[int] + out_logprobs: Optional[GenericSequence[Optional[dict[int, + Logprob]]]] + + for output in final_res.outputs: + assert request.max_tokens is not None + if request.echo: + assert prompt_text is not None + if request.max_tokens == 0: + token_ids = prompt_token_ids + out_logprobs = prompt_logprobs + output_text = prompt_text + else: + token_ids = [*prompt_token_ids, *output.token_ids] + + if request.logprobs is None: + out_logprobs = None + else: + assert prompt_logprobs is not None + assert output.logprobs is not None + out_logprobs = [ + *prompt_logprobs, + *output.logprobs, + ] + + output_text = prompt_text + output.text + else: + token_ids = output.token_ids + out_logprobs = output.logprobs + output_text = output.text + + if request.logprobs is not None: + assert out_logprobs is not None, "Did not output logprobs" + logprobs = self._create_completion_logprobs( + token_ids=token_ids, + top_logprobs=out_logprobs, + tokenizer=tokenizer, + num_output_top_logprobs=request.logprobs, + return_as_token_id=request.return_tokens_as_token_ids, + ) + else: + logprobs = None + + choice_data = CompletionResponseChoice( + index=len(choices), + text=output_text, + logprobs=logprobs, + finish_reason=output.finish_reason, + stop_reason=output.stop_reason, + prompt_logprobs=final_res.prompt_logprobs, + ) + choices.append(choice_data) + + num_generated_tokens += len(output.token_ids) + + num_prompt_tokens += len(prompt_token_ids) + + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) + + request_metadata.final_usage_info = usage + + return CompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + ) + + def _create_completion_logprobs( + self, + token_ids: GenericSequence[int], + top_logprobs: GenericSequence[Optional[dict[int, Logprob]]], + num_output_top_logprobs: int, + tokenizer: AnyTokenizer, + initial_text_offset: int = 0, + return_as_token_id: Optional[bool] = None, + ) -> CompletionLogProbs: + """Create logprobs for OpenAI Completion API.""" + out_text_offset: list[int] = [] + out_token_logprobs: list[Optional[float]] = [] + out_tokens: list[str] = [] + out_top_logprobs: list[Optional[dict[str, float]]] = [] + + last_token_len = 0 + + should_return_as_token_id = return_as_token_id if \ + return_as_token_id is not None else self.return_tokens_as_token_ids + for i, token_id in enumerate(token_ids): + step_top_logprobs = top_logprobs[i] + if step_top_logprobs is None: + token = tokenizer.decode(token_id) + if should_return_as_token_id: + token = f"token_id:{token_id}" + + out_tokens.append(token) + out_token_logprobs.append(None) + out_top_logprobs.append(None) + else: + step_token = step_top_logprobs[token_id] + + token = self._get_decoded_token( + step_token, + token_id, + tokenizer, + return_as_token_id=should_return_as_token_id, + ) + token_logprob = max(step_token.logprob, -9999.0) + + out_tokens.append(token) + out_token_logprobs.append(token_logprob) + + # makes sure to add the top num_output_top_logprobs + 1 + # logprobs, as defined in the openai API + # (cf. https://github.com/openai/openai-openapi/blob/ + # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153) + out_top_logprobs.append({ + # Convert float("-inf") to the + # JSON-serializable float that OpenAI uses + self._get_decoded_token(top_lp[1], + top_lp[0], + tokenizer, + return_as_token_id=should_return_as_token_id): + max(top_lp[1].logprob, -9999.0) + for i, top_lp in enumerate(step_top_logprobs.items()) + if num_output_top_logprobs >= i + }) + + if len(out_text_offset) == 0: + out_text_offset.append(initial_text_offset) + else: + out_text_offset.append(out_text_offset[-1] + last_token_len) + last_token_len = len(token) + + return CompletionLogProbs( + text_offset=out_text_offset, + token_logprobs=out_token_logprobs, + tokens=out_tokens, + top_logprobs=out_top_logprobs, + ) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py new file mode 100644 index 0000000..0ee5867 --- /dev/null +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -0,0 +1,243 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import base64 +import time +from collections.abc import AsyncGenerator +from typing import Final, Literal, Optional, Union, cast + +import numpy as np +from fastapi import Request +from typing_extensions import assert_never + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest, + EmbeddingRequest, + EmbeddingResponse, + EmbeddingResponseData, + ErrorResponse, UsageInfo) +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.logger import init_logger +from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput, + PoolingRequestOutput) +from vllm.utils import merge_async_iterators + +logger = init_logger(__name__) + + +def _get_embedding( + output: EmbeddingOutput, + encoding_format: Literal["float", "base64"], +) -> Union[list[float], str]: + if encoding_format == "float": + return output.embedding + elif encoding_format == "base64": + # Force to use float32 for base64 encoding + # to match the OpenAI python client behavior + embedding_bytes = np.array(output.embedding, dtype="float32").tobytes() + return base64.b64encode(embedding_bytes).decode("utf-8") + + assert_never(encoding_format) + + +class OpenAIServingEmbedding(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + chat_template: Optional[str], + chat_template_content_format: ChatTemplateContentFormatOption, + ) -> None: + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger) + + self.chat_template = chat_template + self.chat_template_content_format: Final = chat_template_content_format + + async def create_embedding( + self, + request: EmbeddingRequest, + raw_request: Optional[Request] = None, + ) -> Union[EmbeddingResponse, ErrorResponse]: + """ + Embedding API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/embeddings/create + for the API specification. This API mimics the OpenAI Embedding API. + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + encoding_format = request.encoding_format + if request.dimensions is not None: + return self.create_error_response( + "dimensions is currently not supported") + + model_name = self._get_model_name(request.model) + request_id = f"embd-{self._base_request_id(raw_request)}" + created_time = int(time.time()) + + truncate_prompt_tokens = None + + if request.truncate_prompt_tokens is not None: + if request.truncate_prompt_tokens <= self.max_model_len: + truncate_prompt_tokens = request.truncate_prompt_tokens + else: + return self.create_error_response( + "truncate_prompt_tokens value is " + "greater than max_model_len." + " Please, select a smaller truncation size.") + + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + if prompt_adapter_request is not None: + raise NotImplementedError("Prompt adapter is not supported " + "for embedding models") + + if isinstance(request, EmbeddingChatRequest): + ( + _, + request_prompts, + engine_prompts, + ) = await self._preprocess_chat( + request, + tokenizer, + request.messages, + chat_template=request.chat_template or self.chat_template, + chat_template_content_format=self. + chat_template_content_format, + # In embedding requests, we are not generating tokens, + # so there is no need to append extra tokens to the input + add_generation_prompt=False, + continue_final_message=False, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + ) + else: + (request_prompts, + engine_prompts) = await self._preprocess_completion( + request, + tokenizer, + request.input, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + ) + except (ValueError, TypeError) as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + + # Schedule the request and get the result generator. + generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] + try: + pooling_params = request.to_pooling_params() + + for i, engine_prompt in enumerate(engine_prompts): + request_id_item = f"{request_id}-{i}" + + self._log_inputs(request_id_item, + request_prompts[i], + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + trace_headers = (None if raw_request is None else await + self._get_trace_headers(raw_request.headers)) + + generator = self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + ) + + generators.append(generator) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + result_generator = merge_async_iterators(*generators) + + num_prompts = len(engine_prompts) + + # Non-streaming response + final_res_batch: list[Optional[PoolingRequestOutput]] + final_res_batch = [None] * num_prompts + try: + async for i, res in result_generator: + final_res_batch[i] = res + + assert all(final_res is not None for final_res in final_res_batch) + + final_res_batch_checked = cast(list[PoolingRequestOutput], + final_res_batch) + + response = self.request_output_to_embedding_response( + final_res_batch_checked, + request_id, + created_time, + model_name, + encoding_format, + ) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + return response + + def request_output_to_embedding_response( + self, + final_res_batch: list[PoolingRequestOutput], + request_id: str, + created_time: int, + model_name: str, + encoding_format: Literal["float", "base64"], + ) -> EmbeddingResponse: + items: list[EmbeddingResponseData] = [] + num_prompt_tokens = 0 + + for idx, final_res in enumerate(final_res_batch): + embedding_res = EmbeddingRequestOutput.from_base(final_res) + + item = EmbeddingResponseData( + index=idx, + embedding=_get_embedding(embedding_res.outputs, + encoding_format), + ) + prompt_token_ids = final_res.prompt_token_ids + + items.append(item) + num_prompt_tokens += len(prompt_token_ids) + + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + total_tokens=num_prompt_tokens, + ) + + return EmbeddingResponse( + id=request_id, + created=created_time, + model=model_name, + data=items, + usage=usage, + ) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py new file mode 100644 index 0000000..bbc8edd --- /dev/null +++ b/vllm/entrypoints/openai/serving_engine.py @@ -0,0 +1,557 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json +from collections.abc import Iterable, Iterator, Mapping, Sequence +from concurrent.futures.thread import ThreadPoolExecutor +from http import HTTPStatus +from typing import Annotated, Any, Callable, Optional, TypedDict, Union + +from fastapi import Request +from pydantic import Field +from starlette.datastructures import Headers + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +# yapf conflicts with isort for this block +# yapf: disable +from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, + ChatTemplateContentFormatOption, + ConversationMessage, + apply_hf_chat_template, + apply_mistral_chat_template, + parse_chat_messages_futures, + resolve_chat_template_content_format) +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + CompletionRequest, + DetokenizeRequest, + EmbeddingChatRequest, + EmbeddingCompletionRequest, + ErrorResponse, RerankRequest, + ScoreRequest, + TokenizeChatRequest, + TokenizeCompletionRequest, + TranscriptionRequest) +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.openai.tool_parsers import ToolParser +# yapf: enable +from vllm.inputs import TokensPrompt +from vllm.inputs.parse import parse_and_batch_prompt +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.sequence import Logprob, PromptLogprobs +from vllm.tracing import (contains_trace_headers, extract_trace_headers, + log_tracing_disabled_warning) +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.utils import is_list_of, make_async, random_uuid + +logger = init_logger(__name__) + +CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest, + EmbeddingCompletionRequest, RerankRequest, + ScoreRequest, TokenizeCompletionRequest] + +ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest, + TokenizeChatRequest] + +AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, + TranscriptionRequest] + + +class TextTokensPrompt(TypedDict): + prompt: str + prompt_token_ids: list[int] + + +RequestPrompt = Union[list[int], str, TextTokensPrompt] + + +class OpenAIServing: + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + return_tokens_as_token_ids: bool = False, + ): + super().__init__() + + self.engine_client = engine_client + self.model_config = model_config + self.max_model_len = model_config.max_model_len + + self.models = models + + self.request_logger = request_logger + self.return_tokens_as_token_ids = return_tokens_as_token_ids + + self._tokenizer_executor = ThreadPoolExecutor(max_workers=1) + + self._tokenize_prompt_input_async = make_async( + self._tokenize_prompt_input, executor=self._tokenizer_executor) + self._tokenize_prompt_input_or_inputs_async = make_async( + self._tokenize_prompt_input_or_inputs, + executor=self._tokenizer_executor) + + def create_error_response( + self, + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: + return ErrorResponse(message=message, + type=err_type, + code=status_code.value) + + def create_streaming_error_response( + self, + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str: + json_str = json.dumps({ + "error": + self.create_error_response(message=message, + err_type=err_type, + status_code=status_code).model_dump() + }) + return json_str + + async def _check_model( + self, + request: AnyRequest, + ) -> Optional[ErrorResponse]: + if self._is_model_supported(request.model): + return None + if request.model in [ + lora.lora_name for lora in self.models.lora_requests + ]: + return None + if request.model in [ + prompt_adapter.prompt_adapter_name + for prompt_adapter in self.models.prompt_adapter_requests + ]: + return None + return self.create_error_response( + message=f"The model `{request.model}` does not exist.", + err_type="NotFoundError", + status_code=HTTPStatus.NOT_FOUND) + + def _maybe_get_adapters( + self, request: AnyRequest + ) -> Union[tuple[None, None], tuple[LoRARequest, None], tuple[ + None, PromptAdapterRequest]]: + if self._is_model_supported(request.model): + return None, None + for lora in self.models.lora_requests: + if request.model == lora.lora_name: + return lora, None + for prompt_adapter in self.models.prompt_adapter_requests: + if request.model == prompt_adapter.prompt_adapter_name: + return None, prompt_adapter + # if _check_model has been called earlier, this will be unreachable + raise ValueError(f"The model `{request.model}` does not exist.") + + def _normalize_prompt_text_to_input( + self, + request: AnyRequest, + tokenizer: AnyTokenizer, + prompt: str, + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]], + add_special_tokens: bool, + ) -> TextTokensPrompt: + if (self.model_config.encoder_config is not None + and self.model_config.encoder_config.get( + "do_lower_case", False)): + prompt = prompt.lower() + + if truncate_prompt_tokens is None: + encoded = tokenizer(prompt, add_special_tokens=add_special_tokens) + else: + encoded = tokenizer(prompt, + add_special_tokens=add_special_tokens, + truncation=True, + max_length=truncate_prompt_tokens) + + input_ids = encoded.input_ids + + input_text = prompt + + return self._validate_input(request, input_ids, input_text) + + def _normalize_prompt_tokens_to_input( + self, + request: AnyRequest, + tokenizer: AnyTokenizer, + prompt_ids: list[int], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]], + ) -> TextTokensPrompt: + if truncate_prompt_tokens is None: + input_ids = prompt_ids + else: + input_ids = prompt_ids[-truncate_prompt_tokens:] + + input_text = tokenizer.decode(input_ids) + + return self._validate_input(request, input_ids, input_text) + + def _validate_input( + self, + request: AnyRequest, + input_ids: list[int], + input_text: str, + ) -> TextTokensPrompt: + token_num = len(input_ids) + + # Note: EmbeddingRequest and ScoreRequest doesn't have max_tokens + if isinstance(request, + (EmbeddingChatRequest, EmbeddingCompletionRequest, + ScoreRequest, RerankRequest)): + + operation = "score" if isinstance(request, ScoreRequest) \ + else "embedding generation" + if token_num > self.max_model_len: + raise ValueError( + f"This model's maximum context length is " + f"{self.max_model_len} tokens. However, you requested " + f"{token_num} tokens in the input for {operation}. " + f"Please reduce the length of the input.") + return TextTokensPrompt(prompt=input_text, + prompt_token_ids=input_ids) + + # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens + # and does not require model context length validation + if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest, + DetokenizeRequest)): + return TextTokensPrompt(prompt=input_text, + prompt_token_ids=input_ids) + + # chat completion endpoint supports max_completion_tokens + if isinstance(request, ChatCompletionRequest): + # TODO(#9845): remove max_tokens when field dropped from OpenAI API + max_tokens = request.max_completion_tokens or request.max_tokens + else: + max_tokens = request.max_tokens + if max_tokens is None: + if token_num >= self.max_model_len: + raise ValueError( + f"This model's maximum context length is " + f"{self.max_model_len} tokens. However, you requested " + f"{token_num} tokens in the messages, " + f"Please reduce the length of the messages.") + elif token_num + max_tokens > self.max_model_len: + raise ValueError( + f"This model's maximum context length is " + f"{self.max_model_len} tokens. However, you requested " + f"{max_tokens + token_num} tokens " + f"({token_num} in the messages, " + f"{max_tokens} in the completion). " + f"Please reduce the length of the messages or completion.") + + return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) + + def _tokenize_prompt_input( + self, + request: AnyRequest, + tokenizer: AnyTokenizer, + prompt_input: Union[str, list[int]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, + add_special_tokens: bool = True, + ) -> TextTokensPrompt: + """ + A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs` + that assumes single input. + """ + return next( + self._tokenize_prompt_inputs( + request, + tokenizer, + [prompt_input], + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens, + )) + + def _tokenize_prompt_inputs( + self, + request: AnyRequest, + tokenizer: AnyTokenizer, + prompt_inputs: Iterable[Union[str, list[int]]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, + add_special_tokens: bool = True, + ) -> Iterator[TextTokensPrompt]: + """ + A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs` + that assumes multiple inputs. + """ + for text in prompt_inputs: + if isinstance(text, str): + yield self._normalize_prompt_text_to_input( + request, + tokenizer, + prompt=text, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens, + ) + else: + yield self._normalize_prompt_tokens_to_input( + request, + tokenizer, + prompt_ids=text, + truncate_prompt_tokens=truncate_prompt_tokens, + ) + + def _tokenize_prompt_input_or_inputs( + self, + request: AnyRequest, + tokenizer: AnyTokenizer, + input_or_inputs: Union[str, list[str], list[int], list[list[int]]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, + add_special_tokens: bool = True, + ) -> list[TextTokensPrompt]: + """ + Tokenize/detokenize depending on the input format. + + According to `OpenAI API `_ + , each input can be a string or array of tokens. Note that each request + can pass one or more inputs. + """ + # Although our type checking is based on mypy, + # VSCode Pyright extension should still work properly + # "is True" is required for Pyright to perform type narrowing + # See: https://github.com/microsoft/pyright/issues/7672 + return [ + self._normalize_prompt_text_to_input( + request, + tokenizer, + prompt=prompt_input["content"], + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens) + if prompt_input["is_tokens"] is False else + self._normalize_prompt_tokens_to_input( + request, + tokenizer, + prompt_ids=prompt_input["content"], + truncate_prompt_tokens=truncate_prompt_tokens) + for prompt_input in parse_and_batch_prompt(input_or_inputs) + ] + + async def _preprocess_completion( + self, + request: CompletionLikeRequest, + tokenizer: AnyTokenizer, + input_or_inputs: Union[str, list[str], list[int], list[list[int]]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, + add_special_tokens: bool = True, + ) -> tuple[list[TextTokensPrompt], list[TokensPrompt]]: + request_prompts = await self._tokenize_prompt_input_or_inputs_async( + request, + tokenizer, + input_or_inputs, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens, + ) + + engine_prompts = [ + TokensPrompt(prompt_token_ids=request_prompt["prompt_token_ids"]) + for request_prompt in request_prompts + ] + + return request_prompts, engine_prompts + + async def _preprocess_chat( + self, + request: ChatLikeRequest, + tokenizer: AnyTokenizer, + messages: list[ChatCompletionMessageParam], + chat_template: Optional[str], + chat_template_content_format: ChatTemplateContentFormatOption, + add_generation_prompt: bool = True, + continue_final_message: bool = False, + tool_dicts: Optional[list[dict[str, Any]]] = None, + documents: Optional[list[dict[str, str]]] = None, + chat_template_kwargs: Optional[dict[str, Any]] = None, + tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None, + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, + add_special_tokens: bool = False, + ) -> tuple[list[ConversationMessage], Sequence[RequestPrompt], + list[TokensPrompt]]: + model_config = self.model_config + + resolved_content_format = resolve_chat_template_content_format( + chat_template, + tool_dicts, + chat_template_content_format, + tokenizer, + trust_remote_code=model_config.trust_remote_code, + ) + conversation, mm_data_future = parse_chat_messages_futures( + messages, + model_config, + tokenizer, + content_format=resolved_content_format, + ) + + _chat_template_kwargs: dict[str, Any] = dict( + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + continue_final_message=continue_final_message, + tools=tool_dicts, + documents=documents, + ) + _chat_template_kwargs.update(chat_template_kwargs or {}) + + request_prompt: Union[str, list[int]] + if isinstance(tokenizer, MistralTokenizer): + request_prompt = apply_mistral_chat_template( + tokenizer, + messages=messages, + **_chat_template_kwargs, + ) + else: + request_prompt = apply_hf_chat_template( + tokenizer, + trust_remote_code=model_config.trust_remote_code, + conversation=conversation, + **_chat_template_kwargs, + ) + + mm_data = await mm_data_future + + # tool parsing is done only if a tool_parser has been set and if + # tool_choice is not "none" (if tool_choice is "none" but a tool_parser + # is set, we want to prevent parsing a tool_call hallucinated by the LLM + should_parse_tools = tool_parser is not None and (hasattr( + request, "tool_choice") and request.tool_choice != "none") + + if should_parse_tools: + if not isinstance(request, ChatCompletionRequest): + msg = "Tool usage is only supported for Chat Completions API" + raise NotImplementedError(msg) + + request = tool_parser(tokenizer).adjust_request( # type: ignore + request=request) + + if isinstance(request_prompt, str): + prompt_inputs = await self._tokenize_prompt_input_async( + request, + tokenizer, + request_prompt, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens, + ) + else: + # For MistralTokenizer + assert is_list_of(request_prompt, int), ( + "Prompt has to be either a string or a list of token ids") + prompt_inputs = TextTokensPrompt( + prompt=tokenizer.decode(request_prompt), + prompt_token_ids=request_prompt) + + engine_prompt = TokensPrompt( + prompt_token_ids=prompt_inputs["prompt_token_ids"]) + if mm_data is not None: + engine_prompt["multi_modal_data"] = mm_data + if request.mm_processor_kwargs is not None: + engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs + + return conversation, [request_prompt], [engine_prompt] + + def _log_inputs( + self, + request_id: str, + inputs: RequestPrompt, + params: Optional[Union[SamplingParams, PoolingParams, + BeamSearchParams]], + lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], + ) -> None: + if self.request_logger is None: + return + + if isinstance(inputs, str): + prompt = inputs + prompt_token_ids = None + elif isinstance(inputs, list): + prompt = None + prompt_token_ids = inputs + else: + prompt = inputs["prompt"] + prompt_token_ids = inputs["prompt_token_ids"] + + self.request_logger.log_inputs( + request_id, + prompt, + prompt_token_ids, + params=params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + + async def _get_trace_headers( + self, + headers: Headers, + ) -> Optional[Mapping[str, str]]: + is_tracing_enabled = await self.engine_client.is_tracing_enabled() + + if is_tracing_enabled: + return extract_trace_headers(headers) + + if contains_trace_headers(headers): + log_tracing_disabled_warning() + + return None + + @staticmethod + def _base_request_id(raw_request: Optional[Request], + default: Optional[str] = None) -> Optional[str]: + """Pulls the request id to use from a header, if provided""" + default = default or random_uuid() + if raw_request is None: + return default + + return raw_request.headers.get("X-Request-Id", default) + + @staticmethod + def _get_decoded_token(logprob: Logprob, + token_id: int, + tokenizer: AnyTokenizer, + return_as_token_id: bool = False) -> str: + if return_as_token_id: + return f"token_id:{token_id}" + + if logprob.decoded_token is not None: + return logprob.decoded_token + return tokenizer.decode(token_id) + + def _is_model_supported(self, model_name: Optional[str]) -> bool: + if not model_name: + return True + return self.models.is_base_model(model_name) + + def _get_model_name(self, + model_name: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> str: + if lora_request: + return lora_request.lora_name + if not model_name: + return self.models.base_model_paths[0].name + return model_name + + +def clamp_prompt_logprobs( + prompt_logprobs: Union[PromptLogprobs, + None]) -> Union[PromptLogprobs, None]: + if prompt_logprobs is None: + return prompt_logprobs + + for logprob_dict in prompt_logprobs: + if logprob_dict is None: + continue + for logprob_values in logprob_dict.values(): + if logprob_values.logprob == float('-inf'): + logprob_values.logprob = -9999.0 + return prompt_logprobs diff --git a/vllm/entrypoints/openai/serving_models.py b/vllm/entrypoints/openai/serving_models.py new file mode 100644 index 0000000..7a68452 --- /dev/null +++ b/vllm/entrypoints/openai/serving_models.py @@ -0,0 +1,244 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json +import pathlib +from dataclasses import dataclass +from http import HTTPStatus +from typing import Optional, Union + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.openai.protocol import (ErrorResponse, + LoadLoRAAdapterRequest, + ModelCard, ModelList, + ModelPermission, + UnloadLoRAAdapterRequest) +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.utils import AtomicCounter + +logger = init_logger(__name__) + + +@dataclass +class BaseModelPath: + name: str + model_path: str + + +@dataclass +class PromptAdapterPath: + name: str + local_path: str + + +@dataclass +class LoRAModulePath: + name: str + path: str + base_model_name: Optional[str] = None + + +class OpenAIServingModels: + """Shared instance to hold data about the loaded base model(s) and adapters. + + Handles the routes: + - /v1/models + - /v1/load_lora_adapter + - /v1/unload_lora_adapter + """ + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + base_model_paths: list[BaseModelPath], + *, + lora_modules: Optional[list[LoRAModulePath]] = None, + prompt_adapters: Optional[list[PromptAdapterPath]] = None, + ): + super().__init__() + + self.base_model_paths = base_model_paths + self.max_model_len = model_config.max_model_len + self.engine_client = engine_client + + self.static_lora_modules = lora_modules + self.lora_requests: list[LoRARequest] = [] + self.lora_id_counter = AtomicCounter(0) + + self.prompt_adapter_requests = [] + if prompt_adapters is not None: + for i, prompt_adapter in enumerate(prompt_adapters, start=1): + with pathlib.Path(prompt_adapter.local_path, + "adapter_config.json").open() as f: + adapter_config = json.load(f) + num_virtual_tokens = adapter_config["num_virtual_tokens"] + self.prompt_adapter_requests.append( + PromptAdapterRequest( + prompt_adapter_name=prompt_adapter.name, + prompt_adapter_id=i, + prompt_adapter_local_path=prompt_adapter.local_path, + prompt_adapter_num_virtual_tokens=num_virtual_tokens)) + + async def init_static_loras(self): + """Loads all static LoRA modules. + Raises if any fail to load""" + if self.static_lora_modules is None: + return + for lora in self.static_lora_modules: + load_request = LoadLoRAAdapterRequest(lora_path=lora.path, + lora_name=lora.name) + load_result = await self.load_lora_adapter( + request=load_request, base_model_name=lora.base_model_name) + if isinstance(load_result, ErrorResponse): + raise ValueError(load_result.message) + + def is_base_model(self, model_name) -> bool: + return any(model.name == model_name for model in self.base_model_paths) + + def model_name(self, lora_request: Optional[LoRARequest] = None) -> str: + """Returns the appropriate model name depending on the availability + and support of the LoRA or base model. + Parameters: + - lora: LoRARequest that contain a base_model_name. + Returns: + - str: The name of the base model or the first available model path. + """ + if lora_request is not None: + return lora_request.lora_name + return self.base_model_paths[0].name + + async def show_available_models(self) -> ModelList: + """Show available models. This includes the base model and all + adapters""" + model_cards = [ + ModelCard(id=base_model.name, + max_model_len=self.max_model_len, + root=base_model.model_path, + permission=[ModelPermission()]) + for base_model in self.base_model_paths + ] + lora_cards = [ + ModelCard(id=lora.lora_name, + root=lora.local_path, + parent=lora.base_model_name if lora.base_model_name else + self.base_model_paths[0].name, + permission=[ModelPermission()]) + for lora in self.lora_requests + ] + prompt_adapter_cards = [ + ModelCard(id=prompt_adapter.prompt_adapter_name, + root=self.base_model_paths[0].name, + permission=[ModelPermission()]) + for prompt_adapter in self.prompt_adapter_requests + ] + model_cards.extend(lora_cards) + model_cards.extend(prompt_adapter_cards) + return ModelList(data=model_cards) + + async def load_lora_adapter( + self, + request: LoadLoRAAdapterRequest, + base_model_name: Optional[str] = None + ) -> Union[ErrorResponse, str]: + error_check_ret = await self._check_load_lora_adapter_request(request) + if error_check_ret is not None: + return error_check_ret + + lora_name, lora_path = request.lora_name, request.lora_path + unique_id = self.lora_id_counter.inc(1) + lora_request = LoRARequest(lora_name=lora_name, + lora_int_id=unique_id, + lora_path=lora_path) + if base_model_name is not None and self.is_base_model(base_model_name): + lora_request.base_model_name = base_model_name + + # Validate that the adapter can be loaded into the engine + # This will also pre-load it for incoming requests + try: + await self.engine_client.add_lora(lora_request) + except BaseException as e: + error_type = "BadRequestError" + status_code = HTTPStatus.BAD_REQUEST + if "No adapter found" in str(e): + error_type = "NotFoundError" + status_code = HTTPStatus.NOT_FOUND + + return create_error_response(message=str(e), + err_type=error_type, + status_code=status_code) + + self.lora_requests.append(lora_request) + logger.info("Loaded new LoRA adapter: name '%s', path '%s'", lora_name, + lora_path) + return f"Success: LoRA adapter '{lora_name}' added successfully." + + async def unload_lora_adapter( + self, + request: UnloadLoRAAdapterRequest) -> Union[ErrorResponse, str]: + error_check_ret = await self._check_unload_lora_adapter_request(request + ) + if error_check_ret is not None: + return error_check_ret + + lora_name = request.lora_name + self.lora_requests = [ + lora_request for lora_request in self.lora_requests + if lora_request.lora_name != lora_name + ] + logger.info("Removed LoRA adapter: name '%s'", lora_name) + return f"Success: LoRA adapter '{lora_name}' removed successfully." + + async def _check_load_lora_adapter_request( + self, request: LoadLoRAAdapterRequest) -> Optional[ErrorResponse]: + # Check if both 'lora_name' and 'lora_path' are provided + if not request.lora_name or not request.lora_path: + return create_error_response( + message="Both 'lora_name' and 'lora_path' must be provided.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + # Check if the lora adapter with the given name already exists + if any(lora_request.lora_name == request.lora_name + for lora_request in self.lora_requests): + return create_error_response( + message= + f"The lora adapter '{request.lora_name}' has already been " + "loaded.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + return None + + async def _check_unload_lora_adapter_request( + self, + request: UnloadLoRAAdapterRequest) -> Optional[ErrorResponse]: + # Check if either 'lora_name' or 'lora_int_id' is provided + if not request.lora_name and not request.lora_int_id: + return create_error_response( + message= + "either 'lora_name' and 'lora_int_id' needs to be provided.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + # Check if the lora adapter with the given name exists + if not any(lora_request.lora_name == request.lora_name + for lora_request in self.lora_requests): + return create_error_response( + message= + f"The lora adapter '{request.lora_name}' cannot be found.", + err_type="NotFoundError", + status_code=HTTPStatus.NOT_FOUND) + + return None + + +def create_error_response( + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: + return ErrorResponse(message=message, + type=err_type, + code=status_code.value) diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py new file mode 100644 index 0000000..779a3ed --- /dev/null +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -0,0 +1,237 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import base64 +import time +from collections.abc import AsyncGenerator +from typing import Final, Literal, Optional, Union, cast + +import jinja2 +import numpy as np +from fastapi import Request +from typing_extensions import assert_never + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import (ErrorResponse, + PoolingChatRequest, + PoolingRequest, PoolingResponse, + PoolingResponseData, UsageInfo) +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.logger import init_logger +from vllm.outputs import PoolingOutput, PoolingRequestOutput +from vllm.utils import merge_async_iterators + +logger = init_logger(__name__) + + +def _get_data( + output: PoolingOutput, + encoding_format: Literal["float", "base64"], +) -> Union[list[float], str]: + if encoding_format == "float": + return output.data.tolist() + elif encoding_format == "base64": + # Force to use float32 for base64 encoding + # to match the OpenAI python client behavior + pooling_bytes = np.array(output.data, dtype="float32").tobytes() + return base64.b64encode(pooling_bytes).decode("utf-8") + + assert_never(encoding_format) + + +class OpenAIServingPooling(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + chat_template: Optional[str], + chat_template_content_format: ChatTemplateContentFormatOption, + ) -> None: + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger) + + self.chat_template = chat_template + self.chat_template_content_format: Final = chat_template_content_format + + async def create_pooling( + self, + request: PoolingRequest, + raw_request: Optional[Request] = None, + ) -> Union[PoolingResponse, ErrorResponse]: + """ + See https://platform.openai.com/docs/api-reference/embeddings/create + for the API specification. This API mimics the OpenAI Embedding API. + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + encoding_format = request.encoding_format + if request.dimensions is not None: + return self.create_error_response( + "dimensions is currently not supported") + + model_name = self._get_model_name(request.model) + request_id = f"pool-{self._base_request_id(raw_request)}" + created_time = int(time.time()) + + truncate_prompt_tokens = None + + if request.truncate_prompt_tokens is not None: + if request.truncate_prompt_tokens <= self.max_model_len: + truncate_prompt_tokens = request.truncate_prompt_tokens + else: + return self.create_error_response( + "truncate_prompt_tokens value is " + "greater than max_model_len." + " Please, select a smaller truncation size.") + + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + if prompt_adapter_request is not None: + raise NotImplementedError("Prompt adapter is not supported " + "for pooling models") + + if isinstance(request, PoolingChatRequest): + ( + _, + request_prompts, + engine_prompts, + ) = await self._preprocess_chat( + request, + tokenizer, + request.messages, + chat_template=request.chat_template or self.chat_template, + chat_template_content_format=self. + chat_template_content_format, + # In pooling requests, we are not generating tokens, + # so there is no need to append extra tokens to the input + add_generation_prompt=False, + continue_final_message=False, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + ) + else: + (request_prompts, + engine_prompts) = await self._preprocess_completion( + request, + tokenizer, + request.input, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + ) + except (ValueError, TypeError, jinja2.TemplateError) as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + + # Schedule the request and get the result generator. + generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] + try: + pooling_params = request.to_pooling_params() + + for i, engine_prompt in enumerate(engine_prompts): + request_id_item = f"{request_id}-{i}" + + self._log_inputs(request_id_item, + request_prompts[i], + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + trace_headers = (None if raw_request is None else await + self._get_trace_headers(raw_request.headers)) + + generator = self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + ) + + generators.append(generator) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + result_generator = merge_async_iterators(*generators) + + num_prompts = len(engine_prompts) + + # Non-streaming response + final_res_batch: list[Optional[PoolingRequestOutput]] + final_res_batch = [None] * num_prompts + try: + async for i, res in result_generator: + final_res_batch[i] = res + + assert all(final_res is not None for final_res in final_res_batch) + + final_res_batch_checked = cast(list[PoolingRequestOutput], + final_res_batch) + + response = self.request_output_to_pooling_response( + final_res_batch_checked, + request_id, + created_time, + model_name, + encoding_format, + ) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + return response + + def request_output_to_pooling_response( + self, + final_res_batch: list[PoolingRequestOutput], + request_id: str, + created_time: int, + model_name: str, + encoding_format: Literal["float", "base64"], + ) -> PoolingResponse: + items: list[PoolingResponseData] = [] + num_prompt_tokens = 0 + + for idx, final_res in enumerate(final_res_batch): + item = PoolingResponseData( + index=idx, + data=_get_data(final_res.outputs, encoding_format), + ) + prompt_token_ids = final_res.prompt_token_ids + + items.append(item) + num_prompt_tokens += len(prompt_token_ids) + + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + total_tokens=num_prompt_tokens, + ) + + return PoolingResponse( + id=request_id, + created=created_time, + model=model_name, + data=items, + usage=usage, + ) diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py new file mode 100644 index 0000000..73b4288 --- /dev/null +++ b/vllm/entrypoints/openai/serving_score.py @@ -0,0 +1,439 @@ +# SPDX-License-Identifier: Apache-2.0 +import asyncio +import time +from collections.abc import AsyncGenerator, Mapping +from typing import Any, Optional, Union + +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import (ErrorResponse, RerankDocument, + RerankRequest, RerankResponse, + RerankResult, RerankUsage, + ScoreRequest, ScoreResponse, + ScoreResponseData, UsageInfo) +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.score_utils import (_cosine_similarity, + _validate_score_input_lens) +from vllm.inputs.data import TokensPrompt +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, + PreTrainedTokenizer, + PreTrainedTokenizerFast) +from vllm.utils import make_async, merge_async_iterators + +logger = init_logger(__name__) + + +class ServingScores(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + ) -> None: + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger) + + async def _embedding_score( + self, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + texts_1: list[str], + texts_2: list[str], + request: Union[RerankRequest, ScoreRequest], + request_id=str, + tokenization_kwargs: Optional[dict[str, Any]] = None, + lora_request: Optional[Union[LoRARequest, None]] = None, + prompt_adapter_request: Optional[Union[PromptAdapterRequest, + None]] = None, + trace_headers: Optional[Mapping[str, str]] = None, + ) -> list[PoolingRequestOutput]: + + input_texts = texts_1 + texts_2 + + engine_prompts: list[TokensPrompt] = [] + tokenize_async = make_async(tokenizer.__call__, + executor=self._tokenizer_executor) + + tokenization_kwargs = tokenization_kwargs or {} + tokenized_prompts = await asyncio.gather( + *(tokenize_async(t, **tokenization_kwargs) for t in input_texts)) + + for tok_result, input_text in zip(tokenized_prompts, input_texts): + + text_token_prompt = \ + self._validate_input( + request, + tok_result["input_ids"], + input_text) + + engine_prompts.append( + TokensPrompt( + prompt_token_ids=text_token_prompt["prompt_token_ids"])) + + # Schedule the request and get the result generator. + generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] + pooling_params = request.to_pooling_params() + + for i, engine_prompt in enumerate(engine_prompts): + + request_id_item = f"{request_id}-{i}" + + self._log_inputs(request_id_item, + input_texts[i], + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + generators.append( + self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + )) + + result_generator = merge_async_iterators(*generators) + + # Non-streaming response + final_res_batch: list[PoolingRequestOutput] = [] + + embeddings: list[Optional[PoolingRequestOutput]] =\ + [None] * len(engine_prompts) + + async for i, res in result_generator: + embeddings[i] = res + + emb_texts_1: list[PoolingRequestOutput] = [] + emb_texts_2: list[PoolingRequestOutput] = [] + + for i in range(0, len(texts_1)): + assert (emb := embeddings[i]) is not None + emb_texts_1.append(emb) + + for i in range(len(texts_1), len(embeddings)): + assert (emb := embeddings[i]) is not None + emb_texts_2.append(emb) + + if len(emb_texts_1) == 1: + emb_texts_1 = emb_texts_1 * len(emb_texts_2) + + final_res_batch = _cosine_similarity(tokenizer=tokenizer, + embed_1=emb_texts_1, + embed_2=emb_texts_2) + + return final_res_batch + + async def _cross_encoding_score( + self, + tokenizer: Union[AnyTokenizer], + texts_1: list[str], + texts_2: list[str], + request: Union[RerankRequest, ScoreRequest], + request_id=str, + tokenization_kwargs: Optional[dict[str, Any]] = None, + lora_request: Optional[Union[LoRARequest, None]] = None, + prompt_adapter_request: Optional[Union[PromptAdapterRequest, + None]] = None, + trace_headers: Optional[Mapping[str, str]] = None, + ) -> list[PoolingRequestOutput]: + + request_prompts: list[str] = [] + engine_prompts: list[TokensPrompt] = [] + + if len(texts_1) == 1: + texts_1 = texts_1 * len(texts_2) + + input_pairs = [(t1, t2) for t1, t2 in zip(texts_1, texts_2)] + + if isinstance(tokenizer, MistralTokenizer): + raise ValueError( + "MistralTokenizer not supported for cross-encoding") + + tokenize_async = make_async(tokenizer.__call__, + executor=self._tokenizer_executor) + + tokenization_kwargs = tokenization_kwargs or {} + tokenized_prompts = await asyncio.gather( + *(tokenize_async(text=t1, text_pair=t2, **tokenization_kwargs) + for t1, t2 in input_pairs)) + + for prompt_inputs, (t1, t2) in zip(tokenized_prompts, input_pairs): + + request_prompt = f"{t1}{tokenizer.sep_token}{t2}" + + input_ids = prompt_inputs["input_ids"] + text_token_prompt = \ + self._validate_input(request, input_ids, request_prompt) + engine_prompt = TokensPrompt( + prompt_token_ids=text_token_prompt["prompt_token_ids"], + token_type_ids=prompt_inputs.get("token_type_ids")) + + request_prompts.append(request_prompt) + engine_prompts.append(engine_prompt) + + # Schedule the request and get the result generator. + generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] + + pooling_params = request.to_pooling_params() + + for i, engine_prompt in enumerate(engine_prompts): + request_id_item = f"{request_id}-{i}" + + self._log_inputs(request_id_item, + request_prompts[i], + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + generator = self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + ) + + generators.append(generator) + + result_generator = merge_async_iterators(*generators) + + # Non-streaming response + final_res_batch: list[ + Optional[PoolingRequestOutput]] = [None] * len(engine_prompts) + + async for i, res in result_generator: + final_res_batch[i] = res + + return [out for out in final_res_batch if out is not None] + + async def _run_scoring( + self, + texts_1: Union[str, list[str]], + texts_2: Union[str, list[str]], + request: Union[ScoreRequest, RerankRequest], + request_id: str, + raw_request: Optional[Request] = None, + truncate_prompt_tokens: Optional[int] = None, + ) -> list[PoolingRequestOutput]: + + tokenization_kwargs: dict[str, Any] = {} + if truncate_prompt_tokens is not None: + tokenization_kwargs["truncation"] = True + tokenization_kwargs["max_length"] = truncate_prompt_tokens + + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + if prompt_adapter_request is not None: + raise NotImplementedError("Prompt adapter is not supported " + "for scoring models") + + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + if truncate_prompt_tokens is not None and \ + truncate_prompt_tokens > self.max_model_len: + raise ValueError( + f"truncate_prompt_tokens value ({truncate_prompt_tokens}) " + f"is greater than max_model_len ({self.max_model_len})." + f" Please, select a smaller truncation size.") + + trace_headers = (None if raw_request is None else await + self._get_trace_headers(raw_request.headers)) + + if isinstance(texts_1, str): + texts_1 = [texts_1] + if isinstance(texts_2, str): + texts_2 = [texts_2] + + _validate_score_input_lens(texts_1, texts_2) + + if self.model_config.is_cross_encoder: + return await self._cross_encoding_score( + tokenizer=tokenizer, + texts_1=texts_1, + texts_2=texts_2, + request=request, + request_id=request_id, + tokenization_kwargs=tokenization_kwargs, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + trace_headers=trace_headers) + + else: + return await self._embedding_score( + tokenizer=tokenizer, + texts_1=texts_1, + texts_2=texts_2, + request=request, + request_id=request_id, + tokenization_kwargs=tokenization_kwargs, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + trace_headers=trace_headers) + + async def create_score( + self, + request: ScoreRequest, + raw_request: Optional[Request] = None, + ) -> Union[ScoreResponse, ErrorResponse]: + """ + Score API similar to Sentence Transformers cross encoder + + See https://sbert.net/docs/package_reference/cross_encoder + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + request_id = f"score-{self._base_request_id(raw_request)}" + created_time = int(time.time()) + + try: + final_res_batch = await self._run_scoring( + request.text_1, + request.text_2, + request, + request_id, + raw_request, + request.truncate_prompt_tokens, + ) + + return self.request_output_to_score_response( + final_res_batch, + request_id, + created_time, + self._get_model_name(request.model), + ) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + async def do_rerank( + self, + request: RerankRequest, + raw_request: Optional[Request] = None + ) -> Union[RerankResponse, ErrorResponse]: + """ + Rerank API based on JinaAI's rerank API; implements the same + API interface. Designed for compatibility with off-the-shelf + tooling, since this is a common standard for reranking APIs + + See example client implementations at + https://github.com/infiniflow/ragflow/blob/main/rag/llm/rerank_model.py + numerous clients use this standard. + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + request_id = f"rerank-{self._base_request_id(raw_request)}" + documents = request.documents + top_n = request.top_n if request.top_n > 0 else len(documents) + + try: + final_res_batch = await self._run_scoring( + request.query, + documents, + request, + request_id, + raw_request, + request.truncate_prompt_tokens, + ) + return self.request_output_to_rerank_response( + final_res_batch, + request_id, + self._get_model_name(request.model), + documents, + top_n, + ) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + def request_output_to_score_response( + self, + final_res_batch: list[PoolingRequestOutput], + request_id: str, + created_time: int, + model_name: str, + ) -> ScoreResponse: + items: list[ScoreResponseData] = [] + num_prompt_tokens = 0 + + for idx, final_res in enumerate(final_res_batch): + classify_res = ScoringRequestOutput.from_base(final_res) + + item = ScoreResponseData( + index=idx, + score=classify_res.outputs.score, + ) + prompt_token_ids = final_res.prompt_token_ids + + items.append(item) + num_prompt_tokens += len(prompt_token_ids) + + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + total_tokens=num_prompt_tokens, + ) + + return ScoreResponse( + id=request_id, + created=created_time, + model=model_name, + data=items, + usage=usage, + ) + + def request_output_to_rerank_response( + self, final_res_batch: list[PoolingRequestOutput], request_id: str, + model_name: str, documents: list[str], + top_n: int) -> RerankResponse: + """ + Convert the output of do_rank to a RerankResponse + """ + results: list[RerankResult] = [] + num_prompt_tokens = 0 + for idx, final_res in enumerate(final_res_batch): + classify_res = ScoringRequestOutput.from_base(final_res) + + result = RerankResult( + index=idx, + document=RerankDocument(text=documents[idx]), + relevance_score=classify_res.outputs.score, + ) + results.append(result) + prompt_token_ids = final_res.prompt_token_ids + num_prompt_tokens += len(prompt_token_ids) + + # sort by relevance, then return the top n if set + results.sort(key=lambda x: x.relevance_score, reverse=True) + if top_n < len(documents): + results = results[:top_n] + + return RerankResponse( + id=request_id, + model=model_name, + results=results, + usage=RerankUsage(total_tokens=num_prompt_tokens)) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py new file mode 100644 index 0000000..c642fc5 --- /dev/null +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Final, Optional, Union + +import jinja2 +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption +from vllm.entrypoints.logger import RequestLogger +# yapf conflicts with isort for this block +# yapf: disable +from vllm.entrypoints.openai.protocol import (DetokenizeRequest, + DetokenizeResponse, + ErrorResponse, + TokenizeChatRequest, + TokenizeRequest, + TokenizeResponse) +# yapf: enable +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class OpenAIServingTokenization(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + chat_template: Optional[str], + chat_template_content_format: ChatTemplateContentFormatOption, + ) -> None: + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger) + + self.chat_template = chat_template + self.chat_template_content_format: Final = chat_template_content_format + + async def create_tokenize( + self, + request: TokenizeRequest, + raw_request: Request, + ) -> Union[TokenizeResponse, ErrorResponse]: + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + request_id = f"tokn-{self._base_request_id(raw_request)}" + + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + if isinstance(request, TokenizeChatRequest): + ( + _, + request_prompts, + engine_prompts, + ) = await self._preprocess_chat( + request, + tokenizer, + request.messages, + chat_template=request.chat_template or self.chat_template, + chat_template_content_format=self. + chat_template_content_format, + add_generation_prompt=request.add_generation_prompt, + continue_final_message=request.continue_final_message, + chat_template_kwargs=request.chat_template_kwargs, + add_special_tokens=request.add_special_tokens, + ) + else: + (request_prompts, + engine_prompts) = await self._preprocess_completion( + request, + tokenizer, + request.prompt, + add_special_tokens=request.add_special_tokens, + ) + except (ValueError, TypeError, jinja2.TemplateError) as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + + input_ids: list[int] = [] + for i, engine_prompt in enumerate(engine_prompts): + self._log_inputs(request_id, + request_prompts[i], + params=None, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + # Silently ignore prompt adapter since it does not affect + # tokenization (Unlike in Embeddings API where an error is raised) + + input_ids.extend(engine_prompt["prompt_token_ids"]) + + return TokenizeResponse(tokens=input_ids, + count=len(input_ids), + max_model_len=self.max_model_len) + + async def create_detokenize( + self, + request: DetokenizeRequest, + raw_request: Request, + ) -> Union[DetokenizeResponse, ErrorResponse]: + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + request_id = f"tokn-{self._base_request_id(raw_request)}" + + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + self._log_inputs(request_id, + request.tokens, + params=None, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + # Silently ignore prompt adapter since it does not affect tokenization + # (Unlike in Embeddings API where an error is raised) + + prompt_input = await self._tokenize_prompt_input_async( + request, + tokenizer, + request.tokens, + ) + input_text = prompt_input["prompt"] + + return DetokenizeResponse(prompt=input_text) diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py new file mode 100644 index 0000000..13565d0 --- /dev/null +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -0,0 +1,421 @@ +# SPDX-License-Identifier: Apache-2.0 +import asyncio +import io +import time +from collections.abc import AsyncGenerator +from math import ceil +from typing import Final, Optional, Union, cast + +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import ( + DeltaMessage, ErrorResponse, RequestResponseMetadata, TranscriptionRequest, + TranscriptionResponse, TranscriptionResponseStreamChoice, + TranscriptionStreamResponse, UsageInfo) +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.inputs.data import PromptType +from vllm.logger import init_logger +from vllm.outputs import RequestOutput +from vllm.transformers_utils.processor import cached_get_processor +from vllm.utils import PlaceholderModule + +try: + import librosa +except ImportError: + librosa = PlaceholderModule("librosa") # type: ignore[assignment] + +logger = init_logger(__name__) + +# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages#supported-languages +# TODO these configs should live somewhere with the model so we can support +# additional ones + +ISO639_1_SUPPORTED_LANGS = { + "af": "Afrikaans", + "ar": "Arabic", + "hy": "Armenian", + "az": "Azerbaijani", + "be": "Belarusian", + "bs": "Bosnian", + "bg": "Bulgarian", + "ca": "Catalan", + "zh": "Chinese", + "hr": "Croatian", + "cs": "Czech", + "da": "Danish", + "nl": "Dutch", + "en": "English", + "et": "Estonian", + "fi": "Finnish", + "fr": "French", + "gl": "Galician", + "de": "German", + "el": "Greek", + "he": "Hebrew", + "hi": "Hindi", + "hu": "Hungarian", + "is": "Icelandic", + "id": "Indonesian", + "it": "Italian", + "ja": "Japanese", + "kn": "Kannada", + "kk": "Kazakh", + "ko": "Korean", + "lv": "Latvian", + "lt": "Lithuanian", + "mk": "Macedonian", + "ms": "Malay", + "mr": "Marathi", + "mi": "Maori", + "ne": "Nepali", + "no": "Norwegian", + "fa": "Persian", + "pl": "Polish", + "pt": "Portuguese", + "ro": "Romanian", + "ru": "Russian", + "sr": "Serbian", + "sk": "Slovak", + "sl": "Slovenian", + "es": "Spanish", + "sw": "Swahili", + "sv": "Swedish", + "tl": "Tagalog", + "ta": "Tamil", + "th": "Thai", + "tr": "Turkish", + "uk": "Ukrainian", + "ur": "Urdu", + "vi": "Vietnamese", + "cy": "Welsh" +} +ISO639_1_OTHER_LANGS = { + "lo": "Lao", + "jw": "Javanese", + "tk": "Turkmen", + "yi": "Yiddish", + "so": "Somali", + "bn": "Bengali", + "nn": "Norwegian Nynorsk", + "si": "Sinhala", + "yo": "Yoruba", + "sa": "Sanskrit", + "mi": "Māori", + "fo": "Faroese", # codespell:ignore + "mt": "Maltese", + "tg": "Tajik", + "mg": "Malagasy", + "haw": "Hawaiian", + "km": "Khmer", + "br": "Breton", + "ps": "Pashto", + "ln": "Lingala", + "la": "Latin", + "ml": "Malayalam", + "sq": "Albanian", + "su": "Sundanese", + "eu": "Basque", + "ka": "Georgian", + "uz": "Uzbek", + "sn": "Shona", + "ht": "Haitian", + "as": "Assamese", + "mn": "Mongolian", + "te": "Telugu", + "pa": "Panjabi", + "tt": "Tatar", + "gu": "Gujarati", + "oc": "Occitan", + "ha": "Hausa", + "ba": "Bashkir", + "my": "Burmese", + "sd": "Sindhi", + "am": "Amharic", + "lb": "Luxembourgish", + "bo": "Tibetan" +} + +# As per https://platform.openai.com/docs/guides/speech-to-text#overview. +# TODO configurable +MAX_AUDIO_CLIP_FILESIZE_MB = 25 + + +class OpenAIServingTranscription(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + return_tokens_as_token_ids: bool = False, + ): + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids) + + self.default_sampling_params = ( + self.model_config.get_diff_sampling_param()) + processor = cached_get_processor(model_config.model) + self.max_audio_clip_s = processor.feature_extractor.chunk_length + self.model_sr = processor.feature_extractor.sampling_rate + self.hop_length = processor.feature_extractor.hop_length + + if self.default_sampling_params: + logger.info( + "Overwriting default completion sampling param with: %s", + self.default_sampling_params) + + async def _preprocess_transcription( + self, + request: TranscriptionRequest, + audio_data: bytes, + ) -> tuple[PromptType, float]: + # Validate request + # TODO language should be optional and can be guessed. + # For now we default to en. See + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520 + lang_token = f"<|{request.language}|>" if request.language else "<|en|>" + if request.language: + if request.language in ISO639_1_SUPPORTED_LANGS: + pass + elif request.language in ISO639_1_OTHER_LANGS: + logger.warning( + "The selected language %s has limited accuracy with" + " reported WER>=0.5. Results may be less accurate " + "for this choice.", request.language) + else: + raise ValueError( + f"Unsupported language: {request.language}." + "Language should be one of:" + + f" {list(ISO639_1_SUPPORTED_LANGS.values())}" + + f"or {list(ISO639_1_OTHER_LANGS.values())}") + + if len(audio_data) / 1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB: + raise ValueError("Maximum file size exceeded.") + + with io.BytesIO(audio_data) as bytes_: + y, sr = librosa.load(bytes_) + + duration = librosa.get_duration(y=y, sr=sr) + if duration > self.max_audio_clip_s: + raise ValueError( + f"Maximum clip duration ({self.max_audio_clip_s}s) " + "exceeded.") + + prompt = { + "encoder_prompt": { + "prompt": "", + "multi_modal_data": { + "audio": (y, sr), + }, + }, + "decoder_prompt": + f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}" + } + return cast(PromptType, prompt), duration + + # TODO (varun) : Make verbose response work ! + async def create_transcription( + self, audio_data: bytes, request: TranscriptionRequest, + raw_request: Request + ) -> Union[TranscriptionResponse, AsyncGenerator[str, None], + ErrorResponse]: + """Transcription API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/audio/createTranscription + for the API specification. This API mimics the OpenAI transcription API. + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.engine_client.errored: + raise self.engine_client.dead_error + + if request.response_format not in ['text', 'json']: + return self.create_error_response( + "Currently only support response_format `text` or `json`") + + request_id = f"trsc-{self._base_request_id(raw_request)}" + + request_metadata = RequestResponseMetadata(request_id=request_id) + if raw_request: + raw_request.state.request_metadata = request_metadata + + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + if lora_request: + return self.create_error_response( + "Currently do not support LoRA for Transcription.") + if prompt_adapter_request: + return self.create_error_response( + "Currently do not support PromptAdapter for Transcription." + ) + + prompt, duration_s = await self._preprocess_transcription( + request=request, + audio_data=audio_data, + ) + + except ValueError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + + result_generator: Optional[AsyncGenerator[RequestOutput, None]] = None + try: + # TODO(rob): subtract len of tokenized prompt. + default_max_tokens = self.model_config.max_model_len + sampling_params = request.to_sampling_params( + default_max_tokens, self.default_sampling_params) + + self._log_inputs( + request_id, + prompt['decoder_prompt'], # type: ignore + params=sampling_params, + lora_request=None, + prompt_adapter_request=None) + + result_generator = self.engine_client.generate( + prompt, + sampling_params, + request_id, + ) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + if request.stream: + return self.transcription_stream_generator(request, + result_generator, + request_id, + request_metadata, + duration_s) + # Non-streaming response. + try: + assert result_generator is not None + async for op in result_generator: + result = op + return TranscriptionResponse(text=result.outputs[0].text) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + async def transcription_stream_generator( + self, request: TranscriptionRequest, + result_generator: AsyncGenerator[RequestOutput, None], + request_id: str, request_metadata: RequestResponseMetadata, + audio_duration_s: float) -> AsyncGenerator[str, None]: + created_time = int(time.time()) + model_name = request.model + chunk_object_type: Final = "transcription.chunk" + + completion_tokens = 0 + num_prompt_tokens = 0 + + include_usage = request.stream_include_usage \ + if request.stream_include_usage else False + include_continuous_usage = request.stream_continuous_usage_stats\ + if include_usage and request.stream_continuous_usage_stats\ + else False + + try: + async for res in result_generator: + # On first result. + if res.prompt_token_ids is not None: + # Do not account the 4-tokens `<|startoftranscript|>..` + # Could be negative when language token is not specified. + num_prompt_tokens = max(len(res.prompt_token_ids) - 4, 0) + # NOTE(NickLucche) user can't pass encoder prompts directly + # at least not to Whisper. One indicator of the encoder + # amount of processing is the log-mel spectogram length. + num_prompt_tokens += ceil(audio_duration_s * + self.model_sr / self.hop_length) + + # We need to do it here, because if there are exceptions in + # the result_generator, it needs to be sent as the FIRST + # response (by the try...catch). + + # Just one output (n=1) supported. + assert len(res.outputs) == 1 + output = res.outputs[0] + + delta_message = DeltaMessage(content=output.text) + completion_tokens += len(output.token_ids) + + if output.finish_reason is None: + # Still generating, send delta update. + choice_data = TranscriptionResponseStreamChoice( + delta=delta_message) + else: + # Model is finished generating. + choice_data = TranscriptionResponseStreamChoice( + delta=delta_message, + finish_reason=output.finish_reason, + stop_reason=output.stop_reason) + + chunk = TranscriptionStreamResponse(id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + + # handle usage stats if requested & if continuous + if include_continuous_usage: + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens, + ) + + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + # Once the final token is handled, if stream_options.include_usage + # is sent, send the usage. + if include_usage: + final_usage = UsageInfo(prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + + completion_tokens) + + final_usage_chunk = TranscriptionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[], + model=model_name, + usage=final_usage) + final_usage_data = (final_usage_chunk.model_dump_json( + exclude_unset=True, exclude_none=True)) + yield f"data: {final_usage_data}\n\n" + + # report to FastAPI middleware aggregate usage across all choices + request_metadata.final_usage_info = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens) + + except Exception as e: + # TODO: Use a vllm-specific Validation Error + logger.exception("Error in chat completion stream generator.") + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" + # Send the final done message after all response.n are finished + yield "data: [DONE]\n\n" diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py new file mode 100644 index 0000000..b81dc4e --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 + +from .abstract_tool_parser import ToolParser, ToolParserManager +from .granite_20b_fc_tool_parser import Granite20bFCToolParser +from .granite_tool_parser import GraniteToolParser +from .hermes_tool_parser import Hermes2ProToolParser +from .internlm2_tool_parser import Internlm2ToolParser +from .jamba_tool_parser import JambaToolParser +from .llama_tool_parser import Llama3JsonToolParser +from .mistral_tool_parser import MistralToolParser +from .phi4mini_tool_parser import Phi4MiniJsonToolParser +from .pythonic_tool_parser import PythonicToolParser + +__all__ = [ + "ToolParser", "ToolParserManager", "Granite20bFCToolParser", + "GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser", + "Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser", + "PythonicToolParser", "Phi4MiniJsonToolParser" +] diff --git a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py new file mode 100644 index 0000000..931d5aa --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py @@ -0,0 +1,163 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +from collections.abc import Sequence +from functools import cached_property +from typing import Callable, Optional, Union + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage, + ExtractedToolCallInformation) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import import_from_path, is_list_of + +logger = init_logger(__name__) + + +class ToolParser: + """ + Abstract ToolParser class that should not be used directly. Provided + properties and methods should be used in + derived classes. + """ + + def __init__(self, tokenizer: AnyTokenizer): + self.prev_tool_call_arr: list[dict] = [] + # the index of the tool call that is currently being parsed + self.current_tool_id: int = -1 + self.current_tool_name_sent: bool = False + self.streamed_args_for_tool: list[str] = [] + + self.model_tokenizer = tokenizer + + @cached_property + def vocab(self) -> dict[str, int]: + # NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab + # whereas all tokenizers have .get_vocab() + return self.model_tokenizer.get_vocab() + + def adjust_request( + self, request: ChatCompletionRequest) -> ChatCompletionRequest: + """ + Static method that used to adjust the request parameters. + """ + return request + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + """ + Static method that should be implemented for extracting tool calls from + a complete model-generated string. + Used for non-streaming responses where we have the entire model response + available before sending to the client. + Static because it's stateless. + """ + raise NotImplementedError( + "AbstractToolParser.extract_tool_calls has not been implemented!") + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + """ + Instance method that should be implemented for extracting tool calls + from an incomplete response; for use when handling tool calls and + streaming. Has to be an instance method because it requires state - + the current tokens/diffs, but also the information about what has + previously been parsed and extracted (see constructor) + """ + raise NotImplementedError( + "AbstractToolParser.extract_tool_calls_streaming has not been " + "implemented!") + + +class ToolParserManager: + tool_parsers: dict[str, type] = {} + + @classmethod + def get_tool_parser(cls, name) -> type: + """ + Get tool parser by name which is registered by `register_module`. + + Raise a KeyError exception if the name is not registered. + """ + if name in cls.tool_parsers: + return cls.tool_parsers[name] + + raise KeyError(f"tool helper: '{name}' not found in tool_parsers") + + @classmethod + def _register_module(cls, + module: type, + module_name: Optional[Union[str, list[str]]] = None, + force: bool = True) -> None: + if not issubclass(module, ToolParser): + raise TypeError( + f'module must be subclass of ToolParser, but got {type(module)}' + ) + if module_name is None: + module_name = module.__name__ + if isinstance(module_name, str): + module_name = [module_name] + for name in module_name: + if not force and name in cls.tool_parsers: + existed_module = cls.tool_parsers[name] + raise KeyError(f'{name} is already registered ' + f'at {existed_module.__module__}') + cls.tool_parsers[name] = module + + @classmethod + def register_module( + cls, + name: Optional[Union[str, list[str]]] = None, + force: bool = True, + module: Union[type, None] = None) -> Union[type, Callable]: + """ + Register module with the given name or name list. it can be used as a + decoder(with module as None) or normal function(with module as not + None). + """ + if not isinstance(force, bool): + raise TypeError(f'force must be a boolean, but got {type(force)}') + + # raise the error ahead of time + if not (name is None or isinstance(name, str) + or is_list_of(name, str)): + raise TypeError( + 'name must be None, an instance of str, or a sequence of str, ' + f'but got {type(name)}') + + # use it as a normal method: x.register_module(module=SomeClass) + if module is not None: + cls._register_module(module=module, module_name=name, force=force) + return module + + # use it as a decorator: @x.register_module() + def _register(module): + cls._register_module(module=module, module_name=name, force=force) + return module + + return _register + + @classmethod + def import_tool_parser(cls, plugin_path: str) -> None: + """ + Import a user-defined tool parser by the path of the tool parser define + file. + """ + module_name = os.path.splitext(os.path.basename(plugin_path))[0] + + try: + import_from_path(module_name, plugin_path) + except Exception: + logger.exception("Failed to load module '%s' from %s.", + module_name, plugin_path) + return diff --git a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py new file mode 100644 index 0000000..76da63c --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py @@ -0,0 +1,254 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json +import re +from collections.abc import Sequence +from json import JSONDecoder +from typing import Union + +import partial_json_parser +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.entrypoints.openai.tool_parsers.utils import (consume_space, + find_common_prefix, + is_complete_json, + partial_json_loads) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("granite-20b-fc") +class Granite20bFCToolParser(ToolParser): + """ + Tool call parser for the granite-20b-functioncalling model intended + for use with the examples/tool_chat_template_granite20b_fc.jinja + template. + + Used when --enable-auto-tool-choice --tool-call-parser granite-20-fc + are all set + """ + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + self.bot_token = "" + self.tool_start_token = self.bot_token + self.tool_call_regex = re.compile(r"\s*") + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + if self.tool_start_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + dec = JSONDecoder() + try: + matches = list(self.tool_call_regex.finditer(model_output)) + logger.debug("Found %d tool call matches", len(matches)) + + raw_function_calls = [] + + for i, match in enumerate(matches): + # position after the tag + start_of_json = match.end() + # end_index == the start of the next function call + # (if exists) + next_function_call_start = (matches[i + 1].start() if i + + 1 < len(matches) else None) + + raw_function_calls.append( + dec.raw_decode( + model_output[start_of_json:next_function_call_start]) + [0]) + + logger.debug("Extracted %d tool calls", len(raw_function_calls)) + tool_calls = [ + ToolCall( + type="function", + function=FunctionCall( + name=function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps(function_call["arguments"]), + ), + ) for function_call in raw_function_calls + ] + + content = model_output[:model_output.find(self.bot_token)] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if content else None, + ) + + except Exception as e: + logger.error("Error in extracting tool call from response %s", e) + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + if len(current_text) < len( + self.bot_token) and self.bot_token.startswith(current_text): + return None + + if not current_text.startswith(self.bot_token): + return DeltaMessage(content=delta_text) + + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + try: + tool_call_arr = [] + is_complete = [] + try: + start_idx = len(self.bot_token) + start_idx = consume_space(start_idx, current_text) + + while start_idx < len(current_text): + (obj, + end_idx) = partial_json_loads(current_text[start_idx:], + flags) + is_complete.append( + is_complete_json(current_text[start_idx:start_idx + + end_idx])) + start_idx += end_idx + start_idx = consume_space(start_idx, current_text) + start_idx += len(self.bot_token) + start_idx = consume_space(start_idx, current_text) + tool_call_arr.append(obj) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + + # select as the current tool call the one we're on the state at + current_tool_call: dict = tool_call_arr[self.current_tool_id] \ + if len(tool_call_arr) > 0 else {} + + # case -- if no tokens have been streamed for the tool, e.g. + # only the array brackets, stream nothing + if len(tool_call_arr) == 0: + return None + + # case: we are starting a new tool in the array + # -> array has > 0 length AND length has moved past cursor + elif (len(tool_call_arr) > 0 + and len(tool_call_arr) > self.current_tool_id + 1): + + # if we're moving on to a new call, first make sure we + # haven't missed anything in the previous one that was + # auto-generated due to JSON completions, but wasn't + # streamed to the client yet. + if self.current_tool_id >= 0: + cur_arguments = current_tool_call.get("arguments") + if cur_arguments: + cur_args_json = json.dumps(cur_arguments) + sent = len( + self.streamed_args_for_tool[self.current_tool_id]) + argument_diff = cur_args_json[sent:] + + logger.debug("got arguments diff: %s", argument_diff) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + else: + delta = None + else: + delta = None + # re-set stuff pertaining to progress in the current tool + self.current_tool_id = len(tool_call_arr) - 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("starting on new tool %d", self.current_tool_id) + return delta + + # if the current tool name hasn't been sent, send if available + # - otherwise send nothing + elif not self.current_tool_name_sent: + function_name = current_tool_call.get("name") + if function_name: + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=f"chatcmpl-tool-{random_uuid()}", + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + self.current_tool_name_sent = True + else: + delta = None + + # now we know we're on the same tool call and we're streaming + # arguments + else: + cur_arguments = current_tool_call.get("arguments") + delta = None + + if cur_arguments: + sent = len( + self.streamed_args_for_tool[self.current_tool_id]) + cur_args_json = json.dumps(cur_arguments) + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id].get("arguments") + + argument_diff = None + if is_complete[self.current_tool_id]: + argument_diff = cur_args_json[sent:] + elif prev_arguments: + prev_args_json = json.dumps(prev_arguments) + if cur_args_json != prev_args_json: + + prefix = find_common_prefix( + prev_args_json, cur_args_json) + argument_diff = prefix[sent:] + + if argument_diff is not None: + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + + self.prev_tool_call_arr = tool_call_arr + return delta + + except Exception as e: + logger.error("Error trying to handle streaming tool call: %s", e) + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None diff --git a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py new file mode 100644 index 0000000..91afc88 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py @@ -0,0 +1,232 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json +from collections.abc import Sequence +from typing import Union + +import partial_json_parser +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.entrypoints.openai.tool_parsers.utils import (consume_space, + find_common_prefix, + is_complete_json, + partial_json_loads) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("granite") +class GraniteToolParser(ToolParser): + """ + Tool call parser for the granite 3.0 models. Intended + for use with the examples/tool_chat_template_granite.jinja + template. + + Used when --enable-auto-tool-choice --tool-call-parser granite + are all set + """ + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + # for granite 3.0, the token `<|tool_call|>` + self.bot_token = "<|tool_call|>" + # for granite 3.1, the string `` + self.bot_string = "" + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + stripped = model_output.strip()\ + .removeprefix(self.bot_token)\ + .removeprefix(self.bot_string)\ + .lstrip() + if not stripped or stripped[0] != '[': + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + try: + raw_function_calls = json.loads(stripped) + if not isinstance(raw_function_calls, list): + raise Exception( + f"Expected dict or list, got {type(raw_function_calls)}") + + logger.debug("Extracted %d tool calls", len(raw_function_calls)) + tool_calls = [ + ToolCall( + type="function", + function=FunctionCall( + name=function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps(function_call["arguments"]), + ), + ) for function_call in raw_function_calls + ] + + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=None, + ) + + except Exception as e: + logger.error("Error in extracting tool call from response %s", e) + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + start_idx = consume_space(0, current_text) + if current_text[start_idx:].startswith(self.bot_token): + start_idx = consume_space(start_idx + len(self.bot_token), + current_text) + if current_text[start_idx:].startswith(self.bot_string): + start_idx = consume_space(start_idx + len(self.bot_string), + current_text) + if not current_text or start_idx >= len(current_text)\ + or current_text[start_idx] != '[': + return DeltaMessage(content=delta_text) + + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + try: + tool_call_arr = None + is_complete = None + try: + tool_calls, end_idx = partial_json_loads( + current_text[start_idx:], flags) + if type(tool_calls) is list: + tool_call_arr = tool_calls + else: + return DeltaMessage(content=delta_text) + + is_complete = [True] * len(tool_calls) + if not is_complete_json( + current_text[start_idx:start_idx + end_idx]): + is_complete[-1] = False + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + + # case -- if no tokens have been streamed for the tool, e.g. + # only the array brackets, stream nothing + if not tool_call_arr: + return None + + # select as the current tool call the one we're on the state at + current_tool_call: dict = tool_call_arr[self.current_tool_id] + + delta = None + # case: we are starting a new tool in the array + # -> array has > 0 length AND length has moved past cursor + if len(tool_call_arr) > self.current_tool_id + 1: + + # if we're moving on to a new call, first make sure we + # haven't missed anything in the previous one that was + # auto-generated due to JSON completions, but wasn't + # streamed to the client yet. + if self.current_tool_id >= 0: + cur_arguments = current_tool_call.get("arguments") + if cur_arguments: + cur_args_json = json.dumps(cur_arguments) + sent = len( + self.streamed_args_for_tool[self.current_tool_id]) + argument_diff = cur_args_json[sent:] + + logger.debug("got arguments diff: %s", argument_diff) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + + # re-set stuff pertaining to progress in the current tool + self.current_tool_id = len(tool_call_arr) - 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("starting on new tool %d", self.current_tool_id) + return delta + + # if the current tool name hasn't been sent, send if available + # - otherwise send nothing + elif not self.current_tool_name_sent: + function_name = current_tool_call.get("name") + if function_name: + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=f"chatcmpl-tool-{random_uuid()}", + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + self.current_tool_name_sent = True + + # now we know we're on the same tool call and we're streaming + # arguments + else: + cur_arguments = current_tool_call.get("arguments") + + if cur_arguments: + sent = len( + self.streamed_args_for_tool[self.current_tool_id]) + cur_args_json = json.dumps(cur_arguments) + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id].get("arguments") + + argument_diff = None + if is_complete[self.current_tool_id]: + argument_diff = cur_args_json[sent:] + elif prev_arguments: + prev_args_json = json.dumps(prev_arguments) + if cur_args_json != prev_args_json: + prefix = find_common_prefix( + prev_args_json, cur_args_json) + argument_diff = prefix[sent:] + + if argument_diff is not None: + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + + self.prev_tool_call_arr = tool_call_arr + return delta + + except Exception as e: + logger.error("Error trying to handle streaming tool call: %s", e) + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py new file mode 100644 index 0000000..4c39e9b --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -0,0 +1,370 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json +import re +from collections.abc import Sequence +from typing import Union + +import partial_json_parser +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("hermes") +class Hermes2ProToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + if isinstance(self.model_tokenizer, MistralTokenizer): + logger.error( + "Detected Mistral tokenizer when using a Hermes model") + self.model_tokenizer = self.model_tokenizer.tokenizer + + self.current_tool_name_sent: bool = False + self.prev_tool_call_arr: list[dict] = [] + self.current_tool_id: int = -1 + self.streamed_args_for_tool: list[str] = [ + ] # map what has been streamed for each tool so far to a list + + self.tool_call_start_token: str = "" + self.tool_call_end_token: str = "" + + self.tool_call_regex = re.compile( + r"(.*?)|(.*)", re.DOTALL) + self.scratch_pad_regex = re.compile( + r"(.*?)", re.DOTALL) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction.") + self.tool_call_start_token_id = self.vocab.get( + self.tool_call_start_token) + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) + if (self.tool_call_start_token_id is None + or self.tool_call_end_token_id is None): + raise RuntimeError( + "Hermes 2 Pro Tool parser could not locate tool call start/end " + "tokens in the tokenizer!") + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + + # sanity check; avoid unnecessary processing + if self.tool_call_start_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + else: + + try: + # there are two possible captures - between tags, or between a + # tag and end-of-string so the result of + # findall is an array of tuples where one is a function call and + # the other is None + function_call_tuples = ( + self.tool_call_regex.findall(model_output)) + + # load the JSON, and then use it to build the Function and + # Tool Call + raw_function_calls = [ + json.loads(match[0] if match[0] else match[1]) + for match in function_call_tuples + ] + tool_calls = [ + ToolCall( + type="function", + function=FunctionCall( + name=function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps(function_call["arguments"], + ensure_ascii=False))) + for function_call in raw_function_calls + ] + + content = model_output[:model_output. + find(self.tool_call_start_token)] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if content else None) + + except Exception: + logger.exception( + "Error in extracting tool call from response.") + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + logger.debug("delta_text: %s", delta_text) + logger.debug("delta_token_ids: %s", delta_token_ids) + # check to see if we should be streaming a tool call - is there a + if self.tool_call_start_token_id not in current_token_ids: + logger.debug("No tool call tokens found!") + return DeltaMessage(content=delta_text) + + try: + + # figure out where we are in the parsing by counting tool call + # start & end tags + prev_tool_start_count = previous_token_ids.count( + self.tool_call_start_token_id) + prev_tool_end_count = previous_token_ids.count( + self.tool_call_end_token_id) + cur_tool_start_count = current_token_ids.count( + self.tool_call_start_token_id) + cur_tool_end_count = current_token_ids.count( + self.tool_call_end_token_id) + tool_call_portion = None + text_portion = None + + # case: if we're generating text, OR rounding out a tool call + if (cur_tool_start_count == cur_tool_end_count + and prev_tool_end_count == cur_tool_end_count + and self.tool_call_end_token not in delta_text): + logger.debug("Generating text content! skipping tool parsing.") + return DeltaMessage(content=delta_text) + + if self.tool_call_end_token in delta_text: + logger.debug("tool_call_end_token in delta_text") + full_text = current_text + delta_text + tool_call_portion = full_text.split( + self.tool_call_start_token)[-1].split( + self.tool_call_end_token)[0].rstrip() + delta_text = delta_text.split( + self.tool_call_end_token)[0].rstrip() + text_portion = delta_text.split( + self.tool_call_end_token)[-1].lstrip() + + # case: if tool open & close tag counts don't match, we're doing + # imaginary "else" block here + # something with tools with this diff. + # flags for partial JSON parting. exported constants from + # "Allow" are handled via BIT MASK + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + + # case -- we're starting a new tool call + if (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count > prev_tool_start_count): + if len(delta_token_ids) > 1: + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + else: + tool_call_portion = None + delta = None + + text_portion = None + + # set cursors and state appropriately + self.current_tool_id += 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("Starting on a new tool %s", self.current_tool_id) + + # case -- we're updating an existing tool call + elif (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count == prev_tool_start_count): + + # get the portion of the text that's the tool call + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + text_portion = None + + # case -- the current tool call is being closed. + elif (cur_tool_start_count == cur_tool_end_count + and cur_tool_end_count >= prev_tool_end_count): + if (self.prev_tool_call_arr is None + or len(self.prev_tool_call_arr) == 0): + logger.debug( + "attempting to close tool call, but no tool call") + return None + diff = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments") + if diff: + diff = diff.encode('utf-8').decode( + 'unicode_escape') if diff is str else diff + if ('"}' not in delta_text): + return None + end_loc = delta_text.rindex('"}') + diff = delta_text[:end_loc] + '"}' + logger.debug( + "Finishing tool and found diff that had not " + "been streamed yet: %s", diff) + self.streamed_args_for_tool[self.current_tool_id] \ + += diff + return DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump( + exclude_none=True)) + ]) + + # case -- otherwise we're just generating text + else: + text = delta_text.replace(self.tool_call_start_token, "") + text = text.replace(self.tool_call_end_token, "") + delta = DeltaMessage(tool_calls=[], content=text) + return delta + + try: + + current_tool_call = partial_json_parser.loads( + tool_call_portion or "{}", + flags) if tool_call_portion else None + logger.debug("Parsed tool call %s", current_tool_call) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + except json.decoder.JSONDecodeError: + logger.debug("unable to parse JSON") + return None + + # case - we haven't sent the tool name yet. If it's available, send + # it. otherwise, wait until it's available. + if not self.current_tool_name_sent: + if (current_tool_call is None): + return None + function_name: Union[str, None] = current_tool_call.get("name") + if function_name: + self.current_tool_name_sent = True + return DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=f"chatcmpl-tool-{random_uuid()}", + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + else: + return None + # case -- otherwise, send the tool call delta + + # if the tool call portion is None, send the delta as text + if tool_call_portion is None: + # if there's text but not tool calls, send that - + # otherwise None to skip chunk + delta = DeltaMessage(content=delta_text) \ + if text_portion is not None else None + return delta + + # now, the nitty-gritty of tool calls + # now we have the portion to parse as tool call. + + logger.debug("Trying to parse current tool call with ID %s", + self.current_tool_id) + + # if we're starting a new tool call, push an empty object in as + # a placeholder for the arguments + if len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + + # main logic for tool parsing here - compare prev. partially-parsed + # JSON to the current partially-parsed JSON + prev_arguments = ( + self.prev_tool_call_arr[self.current_tool_id].get("arguments")) + cur_arguments = current_tool_call.get("arguments") + + logger.debug("diffing old arguments: %s", prev_arguments) + logger.debug("against new ones: %s", cur_arguments) + + # case -- no arguments have been created yet. skip sending a delta. + if not cur_arguments and not prev_arguments: + logger.debug("Skipping text %s - no arguments", delta_text) + delta = None + + # case -- prev arguments are defined, but non are now. + # probably impossible, but not a fatal error - just keep going + elif not cur_arguments and prev_arguments: + logger.error("should be impossible to have arguments reset " + "mid-call. skipping streaming anything.") + delta = None + + # case -- we now have the first info about arguments available from + # autocompleting the JSON + elif cur_arguments and not prev_arguments: + + cur_arguments_json = json.dumps(cur_arguments, + ensure_ascii=False) + logger.debug("finding %s in %s", delta_text, + cur_arguments_json) + + # get the location where previous args differ from current + if (delta_text not in cur_arguments_json[:-2]): + return None + args_delta_start_loc = cur_arguments_json[:-2]. \ + rindex(delta_text) + \ + len(delta_text) + + # use that to find the actual delta + arguments_delta = cur_arguments_json[:args_delta_start_loc] + logger.debug("First tokens in arguments received: %s", + arguments_delta) + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[self.current_tool_id] \ + += arguments_delta + + # last case -- we have an update to existing arguments. + elif cur_arguments and prev_arguments: + if isinstance(delta_text, str) and len(delta_text.rstrip( + )) >= 1 and delta_text.rstrip()[-1] == '}': + delta_text = delta_text.rstrip()[:-1] + + logger.debug("got diff %s", delta_text) + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=delta_text).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[self.current_tool_id] \ + += delta_text + + # handle saving the state for the current tool into + # the "prev" list for use in diffing for the next iteration + if self.current_tool_id == len(self.prev_tool_call_arr) - 1: + self.prev_tool_call_arr[self.current_tool_id] = \ + current_tool_call + else: + self.prev_tool_call_arr.append(current_tool_call) + + return delta + + except Exception: + logger.exception("Error trying to handle streaming tool call.") + return None # do not stream a delta. skip this token ID. diff --git a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py new file mode 100644 index 0000000..57d7c77 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py @@ -0,0 +1,211 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json +from collections.abc import Sequence +from typing import Union + +import partial_json_parser +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.entrypoints.openai.tool_parsers.utils import ( + extract_intermediate_diff) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +@ToolParserManager.register_module(["internlm"]) +class Internlm2ToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + self.position = 0 + + def adjust_request( + self, request: ChatCompletionRequest) -> ChatCompletionRequest: + if request.tools and request.tool_choice != 'none': + # do not skip special tokens because internlm use the special + # tokens to indicated the start and end of the tool calls + # information. + request.skip_special_tokens = False + return request + + def get_argments(self, obj): + if "parameters" in obj: + return obj.get("parameters") + elif "arguments" in obj: + return obj.get("arguments") + return None + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + if '<|action_start|>' not in current_text: + self.position = len(current_text) + return DeltaMessage(content=delta_text) + # if the tool call is sended, return a empty delta message + # to make sure the finish_reason will be send correctly. + if self.current_tool_id > 0: + return DeltaMessage(content='') + + last_pos = self.position + if '<|action_start|><|plugin|>' not in current_text[last_pos:]: + return None + + new_delta = current_text[last_pos:] + text, action = new_delta.split('<|action_start|><|plugin|>') + + if len(text) > 0: + self.position = self.position + len(text) + return DeltaMessage(content=text) + + action = action.strip() + action = action.split('<|action_end|>'.strip())[0] + + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + + try: + parsable_arr = action + + # tool calls are generated in an object in inernlm2 + # it's not support parallel tool calls + try: + tool_call_arr: dict = partial_json_parser.loads( + parsable_arr, flags) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + + # if the current tool name hasn't been sent, send if available + # - otherwise send nothing + if not self.current_tool_name_sent: + function_name = tool_call_arr.get("name") + if function_name: + self.current_tool_id = self.current_tool_id + 1 + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=f"chatcmpl-tool-{random_uuid()}", + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + self.current_tool_name_sent = True + self.streamed_args_for_tool.append("") + else: + delta = None + # now we know we're on the same tool call and we're streaming + # arguments + else: + prev_arguments = self.get_argments( + self.prev_tool_call_arr[self.current_tool_id]) + cur_arguments = self.get_argments(tool_call_arr) + + # not arguments generated + if not cur_arguments and not prev_arguments: + delta = None + # will never happen + elif not cur_arguments and prev_arguments: + logger.error( + "INVARIANT - impossible to have arguments reset " + "mid-arguments") + delta = None + # first time to get parameters + elif cur_arguments and not prev_arguments: + cur_arguments_json = json.dumps(cur_arguments) + + arguments_delta = cur_arguments_json[:cur_arguments_json. + index(delta_text) + + len(delta_text)] + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += arguments_delta + # both prev and cur parameters, send the increase parameters + elif cur_arguments and prev_arguments: + cur_args_json = json.dumps(cur_arguments) + prev_args_json = json.dumps(prev_arguments) + + argument_diff = extract_intermediate_diff( + cur_args_json, prev_args_json) + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + + # check to see if the name is defined and has been sent. if so, + # stream the name - otherwise keep waiting + # finish by setting old and returning None as base case + tool_call_arr["arguments"] = self.get_argments(tool_call_arr) + self.prev_tool_call_arr = [tool_call_arr] + return delta + except Exception: + logger.exception("Error trying to handle streaming tool call.") + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + text = model_output + tools = request.tools + if '<|action_start|><|plugin|>' in text: + text, action = text.split('<|action_start|><|plugin|>') + action = action.split('<|action_end|>'.strip())[0] + action = action[action.find('{'):] + action_dict = json.loads(action) + name, parameters = action_dict['name'], json.dumps( + action_dict.get('parameters', action_dict.get('arguments', + {}))) + + if not tools or name not in [t.function.name for t in tools]: + ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=text) + + tool_calls = [ + ToolCall( + function=FunctionCall(name=name, arguments=parameters)) + ] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=text if len(text) > 0 else None) + + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=text) diff --git a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py new file mode 100644 index 0000000..8df106b --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py @@ -0,0 +1,303 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json +import re +from collections.abc import Sequence +from typing import Union + +import partial_json_parser +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager +from vllm.entrypoints.openai.tool_parsers.utils import ( + extract_intermediate_diff) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.transformers_utils.tokenizers import MistralTokenizer +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("jamba") +class JambaToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + if isinstance(self.model_tokenizer, MistralTokenizer): + raise ValueError( + "Detected a MistralTokenizer tokenizer when using a Jamba model" + ) + + self.current_tool_name_sent: bool = False + self.prev_tool_call_arr: list[dict] = [] + self.current_tool_id: int = -1 + self.streamed_args_for_tool: list[str] = [ + ] # map what has been streamed for each tool so far to a list + + self.tool_calls_start_token: str = "" + self.tool_calls_end_token: str = "" + + self.tool_calls_regex = re.compile( + rf"{self.tool_calls_start_token}(.*?){self.tool_calls_end_token}", + re.DOTALL) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction.") + self.tool_calls_start_token_id = self.vocab.get( + self.tool_calls_start_token) + self.tool_calls_end_token_id = self.vocab.get( + self.tool_calls_end_token) + if (self.tool_calls_start_token_id is None + or self.tool_calls_end_token_id is None): + raise RuntimeError( + "Jamba Tool parser could not locate tool calls start/end " + "tokens in the tokenizer!") + + def adjust_request( + self, request: ChatCompletionRequest) -> ChatCompletionRequest: + if request.tools and request.tool_choice != 'none': + # do not skip special tokens because jamba use the special + # tokens to indicate the start and end of the tool calls + # information. + request.skip_special_tokens = False + return request + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + + # sanity check; avoid unnecessary processing + if self.tool_calls_start_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + else: + + try: + # use a regex to find the tool call between the tags + function_calls = self.tool_calls_regex.findall(model_output)[0] + + # load the JSON, and then use it to build the Function and + # Tool Call + raw_function_calls = json.loads(function_calls) + tool_calls = [ + ToolCall( + type="function", + function=FunctionCall( + name=function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps(function_call["arguments"]))) + for function_call in raw_function_calls + ] + + content = model_output[:model_output. + find(self.tool_calls_start_token)] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if + (len(content) > 0 and content != " ") else None) + + except Exception: + logger.exception( + "Error in extracting tool call from response.") + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + # if the tool call token is not in the tokens generated so far, append + # output to contents since it's not a tool + if self.tool_calls_start_token not in current_text: + return DeltaMessage(content=delta_text) + + # if the tool call token ID IS in the tokens generated so far, that + # means we're parsing as tool calls now + + # handle if we detected the start of tool calls token which means + # the start of tool calling + if (self.tool_calls_start_token_id in delta_token_ids + and len(delta_token_ids) == 1): + # if it's the only token, return None, so we don't send a chat + # completion and don't send a control token + return None + + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + try: + + # Extract the tool calls between the special tool call tokens + parsable_arr = current_text.split( + self.tool_calls_start_token)[-1].split( + self.tool_calls_end_token)[0] + + # tool calls are generated in an array, so do partial JSON + # parsing on the entire array + try: + tool_call_arr: list[dict] = partial_json_parser.loads( + parsable_arr, flags) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + + # select as the current tool call the one we're on the state at + + current_tool_call: dict = tool_call_arr[self.current_tool_id] \ + if len(tool_call_arr) > 0 else {} + + # case -- if no tokens have been streamed for the tool, e.g. + # only the array brackets, stream nothing + if len(tool_call_arr) == 0: + return None + + # case: we are starting a new tool in the array + # -> array has > 0 length AND length has moved past cursor + elif (len(tool_call_arr) > 0 + and len(tool_call_arr) > self.current_tool_id + 1): + + # if we're moving on to a new call, first make sure we + # haven't missed anything in the previous one that was + # auto-generated due to JSON completions, but wasn't + # streamed to the client yet. + if self.current_tool_id >= 0: + diff: Union[str, None] = current_tool_call.get("arguments") + + if diff: + diff = json.dumps(diff).replace( + self.streamed_args_for_tool[self.current_tool_id], + "") + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += diff + else: + delta = None + else: + delta = None + # re-set stuff pertaining to progress in the current tool + self.current_tool_id = len(tool_call_arr) - 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("starting on new tool %d", self.current_tool_id) + return delta + + # case: update an existing tool - this is handled below + + # if the current tool name hasn't been sent, send if available + # - otherwise send nothing + if not self.current_tool_name_sent: + function_name = current_tool_call.get("name") + if function_name: + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=f"chatcmpl-tool-{random_uuid()}", + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + self.current_tool_name_sent = True + else: + delta = None + + # now we know we're on the same tool call and we're streaming + # arguments + else: + + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id].get("arguments") + cur_arguments = current_tool_call.get("arguments") + + new_text = delta_text.replace("\'", "\"") + + if not cur_arguments and not prev_arguments: + + delta = None + elif not cur_arguments and prev_arguments: + logger.error( + "INVARIANT - impossible to have arguments reset " + "mid-arguments") + delta = None + elif cur_arguments and not prev_arguments: + cur_arguments_json = json.dumps(cur_arguments) + logger.debug("finding %s in %s", new_text, + cur_arguments_json) + + arguments_delta = cur_arguments_json[:cur_arguments_json. + index(new_text) + + len(new_text)] + logger.debug("First tokens in arguments received: %s", + arguments_delta) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += arguments_delta + + elif cur_arguments and prev_arguments: + cur_args_json = json.dumps(cur_arguments) + prev_args_json = json.dumps(prev_arguments) + logger.debug("Searching for diff between \n%s\n%s", + cur_args_json, prev_args_json) + + argument_diff = extract_intermediate_diff( + cur_args_json, prev_args_json) + logger.debug("got arguments diff: %s", argument_diff) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + else: + # try parsing it with regular JSON - if it works we're + # at the end, and we need to send the difference between + # tokens streamed so far and the valid JSON + delta = None + + # check to see if the name is defined and has been sent. if so, + # stream the name - otherwise keep waiting + # finish by setting old and returning None as base case + self.prev_tool_call_arr = tool_call_arr + return delta + + except Exception: + logger.exception("Error trying to handle streaming tool call.") + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None diff --git a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py new file mode 100644 index 0000000..20c3238 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py @@ -0,0 +1,261 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json +import re +from collections.abc import Sequence +from json import JSONDecoder +from typing import Union + +import partial_json_parser +from partial_json_parser.core.options import Allow +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.entrypoints.openai.tool_parsers.utils import (find_common_prefix, + is_complete_json, + partial_json_loads) +from vllm.logger import init_logger +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("llama3_json") +class Llama3JsonToolParser(ToolParser): + """ + Tool call parser for Llama 3.1 models intended for use with the + examples/tool_chat_template_llama.jinja template. + + Used when --enable-auto-tool-choice --tool-call-parser llama3_json + are all set + """ + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + + # initialize properties used for state when parsing tool calls in + # streaming mode + self.prev_tool_call_arr: list[dict] = [] + self.current_tool_id: int = -1 + self.current_tool_name_sent: bool = False + self.streamed_args_for_tool: list[str] = [ + ] # map what has been streamed for each tool so far to a list + self.bot_token = "<|python_tag|>" + self.bot_token_id = tokenizer.encode(self.bot_token, + add_special_tokens=False)[0] + self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL) + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + """ + Extract the tool calls from a complete model response. + """ + # case -- if a tool call token is not present, return a text response + if not (model_output.startswith(self.bot_token) + or model_output.startswith('{')): + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + try: + # load the JSON, and then use it to build the Function and + # Tool Call + dec = JSONDecoder() + function_call_arr = [] + + # depending on the prompt format the Llama model may or may not + # prefix the output with the <|python_tag|> token + start_idx = len(self.bot_token) if model_output.startswith( + self.bot_token) else 0 + while start_idx < len(model_output): + (obj, end_idx) = dec.raw_decode(model_output[start_idx:]) + start_idx += end_idx + len('; ') + function_call_arr.append(obj) + + tool_calls: list[ToolCall] = [ + ToolCall( + type="function", + function=FunctionCall( + name=raw_function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps(raw_function_call["arguments"] \ + if "arguments" in raw_function_call \ + else raw_function_call["parameters"]))) + for raw_function_call in function_call_arr + ] + + # get any content before the tool call + ret = ExtractedToolCallInformation(tools_called=True, + tool_calls=tool_calls, + content=None) + return ret + + except Exception: + logger.exception("Error in extracting tool call from response.") + # return information to just treat the tool call as regular JSON + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + if not (current_text.startswith(self.bot_token) + or current_text.startswith('{')): + return DeltaMessage(content=delta_text) + + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + try: + tool_call_arr = [] + is_complete = [] + try: + # depending on the prompt format the Llama model may or may not + # prefix the output with the <|python_tag|> token + start_idx = len(self.bot_token) if current_text.startswith( + self.bot_token) else 0 + while start_idx < len(current_text): + (obj, + end_idx) = partial_json_loads(current_text[start_idx:], + flags) + is_complete.append( + is_complete_json(current_text[start_idx:start_idx + + end_idx])) + start_idx += end_idx + len('; ') + # depending on the prompt Llama can use + # either arguments or parameters + if "parameters" in obj: + assert "arguments" not in obj, \ + "model generated both parameters and arguments" + obj["arguments"] = obj["parameters"] + tool_call_arr.append(obj) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + + # select as the current tool call the one we're on the state at + current_tool_call: dict = tool_call_arr[self.current_tool_id] \ + if len(tool_call_arr) > 0 else {} + + # case -- if no tokens have been streamed for the tool, e.g. + # only the array brackets, stream nothing + if len(tool_call_arr) == 0: + return None + + # case: we are starting a new tool in the array + # -> array has > 0 length AND length has moved past cursor + elif (len(tool_call_arr) > 0 + and len(tool_call_arr) > self.current_tool_id + 1): + + # if we're moving on to a new call, first make sure we + # haven't missed anything in the previous one that was + # auto-generated due to JSON completions, but wasn't + # streamed to the client yet. + if self.current_tool_id >= 0: + cur_arguments = current_tool_call.get("arguments") + if cur_arguments: + cur_args_json = json.dumps(cur_arguments) + sent = len( + self.streamed_args_for_tool[self.current_tool_id]) + argument_diff = cur_args_json[sent:] + + logger.debug("got arguments diff: %s", argument_diff) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + else: + delta = None + else: + delta = None + # re-set stuff pertaining to progress in the current tool + self.current_tool_id = len(tool_call_arr) - 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("starting on new tool %d", self.current_tool_id) + return delta + + # if the current tool name hasn't been sent, send if available + # - otherwise send nothing + elif not self.current_tool_name_sent: + function_name = current_tool_call.get("name") + if function_name: + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=f"chatcmpl-tool-{random_uuid()}", + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + self.current_tool_name_sent = True + else: + delta = None + + # now we know we're on the same tool call and we're streaming + # arguments + else: + cur_arguments = current_tool_call.get("arguments") + delta = None + + if cur_arguments: + sent = len( + self.streamed_args_for_tool[self.current_tool_id]) + cur_args_json = json.dumps(cur_arguments) + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id].get("arguments") + + argument_diff = None + if is_complete[self.current_tool_id]: + argument_diff = cur_args_json[sent:] + elif prev_arguments: + prev_args_json = json.dumps(prev_arguments) + if cur_args_json != prev_args_json: + + prefix = find_common_prefix( + prev_args_json, cur_args_json) + argument_diff = prefix[sent:] + + if argument_diff is not None: + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + + self.prev_tool_call_arr = tool_call_arr + return delta + + except Exception: + logger.exception("Error trying to handle streaming tool call.") + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py new file mode 100644 index 0000000..0661445 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -0,0 +1,325 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json +import re +from collections.abc import Sequence +from random import choices +from string import ascii_letters, digits +from typing import Union + +import partial_json_parser +from partial_json_parser.core.options import Allow +from pydantic import Field + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.entrypoints.openai.tool_parsers.utils import ( + extract_intermediate_diff) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer + +logger = init_logger(__name__) + +ALPHANUMERIC = ascii_letters + digits + + +class MistralToolCall(ToolCall): + id: str = Field( + default_factory=lambda: MistralToolCall.generate_random_id()) + + @staticmethod + def generate_random_id(): + # Mistral Tool Call Ids must be alphanumeric with a length of 9. + # https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299 + return "".join(choices(ALPHANUMERIC, k=9)) + + +@ToolParserManager.register_module("mistral") +class MistralToolParser(ToolParser): + """ + Tool call parser for Mistral 7B Instruct v0.3, intended for use with the + examples/tool_chat_template_mistral.jinja template. + + Used when --enable-auto-tool-choice --tool-call-parser mistral are all set + """ + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + if not isinstance(self.model_tokenizer, MistralTokenizer): + logger.info("Non-Mistral tokenizer detected when using a Mistral " + "model...") + + # initialize properties used for state when parsing tool calls in + # streaming mode + self.prev_tool_call_arr: list[dict] = [] + self.current_tool_id: int = -1 + self.current_tool_name_sent: bool = False + self.streamed_args_for_tool: list[str] = [ + ] # map what has been streamed for each tool so far to a list + self.bot_token = "[TOOL_CALLS]" + self.bot_token_id = self.vocab.get(self.bot_token) + self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL) + if self.bot_token_id is None: + raise RuntimeError( + "Mistral Tool Parser could not locate the tool call token in " + "the tokenizer!") + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + """ + Extract the tool calls from a complete model response. Requires + find-and-replacing single quotes with double quotes for JSON parsing, + make sure your tool call arguments don't ever include quotes! + """ + + # case -- if a tool call token is not present, return a text response + if self.bot_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + # first remove the BOT token + tool_content = model_output.replace(self.bot_token, "").strip() + + try: + + # we first try to directly load the json as parsing very nested + # jsons is difficult + try: + function_call_arr = json.loads(tool_content) + except json.JSONDecodeError: + # use a regex to find the part corresponding to the tool call. + # NOTE: This use case should not happen if the model is trained + # correctly. It's a easy possible fix so it's included, but + # can be brittle for very complex / highly nested tool calls + raw_tool_call = self.tool_call_regex.findall(tool_content)[0] + function_call_arr = json.loads(raw_tool_call) + + # Tool Call + tool_calls: list[MistralToolCall] = [ + MistralToolCall( + type="function", + function=FunctionCall( + name=raw_function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps(raw_function_call["arguments"], + ensure_ascii=False))) + for raw_function_call in function_call_arr + ] + + # get any content before the tool call + content = model_output.split(self.bot_token)[0] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if len(content) > 0 else None) + + except Exception: + logger.exception("Error in extracting tool call from response.") + # return information to just treat the tool call as regular JSON + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=tool_content) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + # if the tool call token is not in the tokens generated so far, append + # output to contents since it's not a tool + if self.bot_token not in current_text: + return DeltaMessage(content=delta_text) + + # if the tool call token ID IS in the tokens generated so far, that + # means we're parsing as tool calls now + + # handle if we detected the BOT token which means the start of tool + # calling + if (self.bot_token_id in delta_token_ids + and len(delta_token_ids) == 1): + # if it's the only token, return None, so we don't send a chat + # completion any don't send a control token + return None + + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + try: + + # replace BOT token with empty string, and convert single quotes + # to double to allow parsing as JSON since mistral uses single + # quotes instead of double for tool calls + parsable_arr = current_text.split(self.bot_token)[-1] + + # tool calls are generated in an array, so do partial JSON + # parsing on the entire array + try: + tool_call_arr: list[dict] = partial_json_parser.loads( + parsable_arr, flags) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + + # select as the current tool call the one we're on the state at + + current_tool_call: dict = tool_call_arr[self.current_tool_id] \ + if len(tool_call_arr) > 0 else {} + + # case -- if no tokens have been streamed for the tool, e.g. + # only the array brackets, stream nothing + if len(tool_call_arr) == 0: + return None + + # case: we are starting a new tool in the array + # -> array has > 0 length AND length has moved past cursor + elif (len(tool_call_arr) > 0 + and len(tool_call_arr) > self.current_tool_id + 1): + + # if we're moving on to a new call, first make sure we + # haven't missed anything in the previous one that was + # auto-generated due to JSON completions, but wasn't + # streamed to the client yet. + if self.current_tool_id >= 0: + diff: Union[str, None] = current_tool_call.get("arguments") + + if diff: + diff = json.dumps(diff, ensure_ascii=False).replace( + self.streamed_args_for_tool[self.current_tool_id], + "") + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += diff + else: + delta = None + else: + delta = None + # re-set stuff pertaining to progress in the current tool + self.current_tool_id = len(tool_call_arr) - 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("starting on new tool %d", self.current_tool_id) + return delta + + # case: update an existing tool - this is handled below + + # if the current tool name hasn't been sent, send if available + # - otherwise send nothing + if not self.current_tool_name_sent: + function_name = current_tool_call.get("name") + if function_name: + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=MistralToolCall.generate_random_id(), + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + self.current_tool_name_sent = True + else: + delta = None + + # now we know we're on the same tool call and we're streaming + # arguments + else: + + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id].get("arguments") + cur_arguments = current_tool_call.get("arguments") + + new_text = delta_text.replace("\'", "\"") + if ('"}' in new_text): + new_text = new_text[:new_text.rindex('"}')] + + if not cur_arguments and not prev_arguments: + + delta = None + elif not cur_arguments and prev_arguments: + logger.error( + "INVARIANT - impossible to have arguments reset " + "mid-arguments") + delta = None + elif cur_arguments and not prev_arguments: + cur_arguments_json = json.dumps(cur_arguments, + ensure_ascii=False)[:-2] + logger.debug("finding %s in %s", new_text, + cur_arguments_json) + + if (new_text not in cur_arguments_json): + return None + arguments_delta = cur_arguments_json[:cur_arguments_json. + rindex(new_text) + + len(new_text)] + logger.debug("First tokens in arguments received: %s", + arguments_delta) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += arguments_delta + + elif cur_arguments and prev_arguments: + cur_args_json = json.dumps(cur_arguments, + ensure_ascii=False) + prev_args_json = json.dumps(prev_arguments, + ensure_ascii=False) + logger.debug("Searching for diff between \n%s\n%s", + cur_args_json, prev_args_json) + + argument_diff = extract_intermediate_diff( + cur_args_json, prev_args_json) + logger.debug("got arguments diff: %s", argument_diff) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + else: + # try parsing it with regular JSON - if it works we're + # at the end, and we need to send the difference between + # tokens streamed so far and the valid JSON + delta = None + + # check to see if the name is defined and has been sent. if so, + # stream the name - otherwise keep waiting + # finish by setting old and returning None as base case + self.prev_tool_call_arr = tool_call_arr + return delta + + except Exception: + logger.exception("Error trying to handle streaming tool call.") + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None diff --git a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py new file mode 100644 index 0000000..668776a --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json +import re +from collections.abc import Sequence +from typing import Any, Optional + +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("phi4_mini_json") +class Phi4MiniJsonToolParser(ToolParser): + """ + Tool call parser for phi-4-mini models intended for use with the + examples/tool_chat_template_llama.jinja template. + + Used when --enable-auto-tool-choice --tool-call-parser phi4_mini_json + are all set + """ + + def __init__(self, tokenizer: PreTrainedTokenizerBase) -> None: + super().__init__(tokenizer) + + # initialize properties used for state when parsing tool calls in + # streaming mode + self.prev_tool_call_arr: list[dict[str, Any]] = [] + self.current_tool_id: int = -1 + self.current_tool_name_sent: bool = False + self.streamed_args_for_tool: list[str] = [ + ] # map what has been streamed for each tool so far to a list + self.bot_token: str = "functools" + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + """ + Extract the tool calls from a complete model response. + """ + logger.debug("Model output: %s", model_output) + + pattern = r'functools\[(.*?)\]' + matches = re.search(pattern, model_output, re.DOTALL) + + if not matches: + logger.debug("No function calls found") + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + try: + function_call_arr: list[dict[str, Any]] = [] + try: + json_content = '[' + matches.group(1) + ']' + + function_call_arr = json.loads(json_content) + logger.debug("Successfully extracted %d function calls", + len(function_call_arr)) + except json.JSONDecodeError as e: + logger.error( + "Failed to parse function calls from model output: %s. " + "Error: %s", model_output, str(e)) + + tool_calls: list[ToolCall] = [ + ToolCall( + id=f"chatcmpl-tool-{random_uuid()}", + type="function", + function=FunctionCall( + name=raw_function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps( + raw_function_call["arguments"] if "arguments" in + raw_function_call else + raw_function_call["parameters"]))) + for raw_function_call in function_call_arr + ] + + # get any content before the tool call + ret = ExtractedToolCallInformation(tools_called=True, + tool_calls=tool_calls, + content=None) + return ret + + except Exception: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Optional[DeltaMessage]: + + return None diff --git a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py new file mode 100644 index 0000000..1b9317f --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py @@ -0,0 +1,292 @@ +# SPDX-License-Identifier: Apache-2.0 + +import ast +import json +import re +from collections.abc import Sequence +from typing import Any, Union + +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class _UnexpectedAstError(Exception): + pass + + +@ToolParserManager.register_module("pythonic") +class PythonicToolParser(ToolParser): + """ + Tool call parser for models that produce tool calls in a pythonic style, + such as Llama 3.2 models. + + Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set + """ + # TODO(mdepinet): Possible future improvements: + # 1. Support text + tools separated by either <|python_tag|> or \n\n + # 2. Support tools outside of a list (or separated by a semicolon). + # This depends on item 1 for consistent streaming. + # Neither of these are necessary for e.g. ToolACE, but both would help make + # Llama3.2 models more reliable. + + TOOL_CALL_REGEX = re.compile( + r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]", + re.DOTALL) + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + + # Rename for readability. This is NOT a tool id. + @property + def current_tool_index(self) -> int: + return self.current_tool_id + + @current_tool_index.setter + def current_tool_index(self, value: int) -> None: + self.current_tool_id = value + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + """ + Extract the tool calls from a complete model response. + """ + + if not (self.TOOL_CALL_REGEX.match(model_output)): + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + try: + module = ast.parse(model_output) + parsed = getattr(module.body[0], "value", None) + if isinstance(parsed, ast.List) and all( + isinstance(e, ast.Call) for e in parsed.elts): + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=[ + _handle_single_tool(e) # type: ignore + for e in parsed.elts + ], + content=None) + else: + raise _UnexpectedAstError( + "Tool output must be a list of function calls") + except Exception: + logger.exception("Error in extracting tool call from response.") + # Treat as regular text + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + if not current_text.startswith("["): + return DeltaMessage(content=delta_text) + + try: + valid_and_added_text = _make_valid_python(current_text) + if valid_and_added_text is None: + return None + valid_text, added_text = valid_and_added_text + + module = ast.parse(valid_text) + parsed = getattr(module.body[0], "value", None) + if not isinstance(parsed, ast.List) or not all( + isinstance(e, ast.Call) for e in parsed.elts): + raise _UnexpectedAstError( + "Tool output must be a list of function calls") + tool_calls = [ + _handle_single_tool(e) # type: ignore + for e in parsed.elts + ] + + tool_deltas = [] + for index, new_call in enumerate(tool_calls): + if index < self.current_tool_index: + continue + + self.current_tool_index = index + if len(self.streamed_args_for_tool) == index: + self.streamed_args_for_tool.append("") + + new_call_complete = index < len( + tool_calls) - 1 or ")]" not in added_text + if new_call_complete: + self.current_tool_index += 1 + + withheld_suffix = (added_text[:-2] + if not new_call_complete else "") + if not new_call_complete and added_text[-2] == ")": + # Function call is incomplete. Withhold the closing bracket. + withheld_suffix = withheld_suffix + "}" + # Strings get single quotes in the model-produced string. + # JSON requires double quotes. + withheld_suffix = withheld_suffix.replace("'", '"') + delta = _compute_tool_delta(self.streamed_args_for_tool[index], + new_call, index, withheld_suffix) + + if delta is not None: + tool_deltas.append(delta) + if (delta.function is not None + and delta.function.arguments is not None): + self.streamed_args_for_tool[ + index] += delta.function.arguments + + # HACK: serving_chat.py inspects the internal state of tool parsers + # when determining it's final streaming delta, automatically + # adding autocompleted JSON. + # These two lines avoid that nonsense while ensuring finish_reason + # is set to tool_calls when at least one tool is called. + if tool_deltas and not self.prev_tool_call_arr: + self.prev_tool_call_arr = [{"arguments": {}}] + + if tool_deltas: + return DeltaMessage(tool_calls=tool_deltas) + elif not added_text and self.current_tool_id > 0: + # Return an empty DeltaMessage once the tool calls are all done + # so that finish_reason gets set. + return DeltaMessage(content='') + else: + return None + except Exception: + logger.exception("Error trying to handle streaming tool call.") + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None + + +def _get_parameter_value(val: ast.expr) -> Any: + if isinstance(val, ast.Constant): + return val.value + elif isinstance(val, ast.Dict): + if not all(isinstance(k, ast.Constant) for k in val.keys): + raise _UnexpectedAstError( + "Dict tool call arguments must have literal keys") + return { + k.value: _get_parameter_value(v) # type: ignore + for k, v in zip(val.keys, val.values) + } + elif isinstance(val, ast.List): + return [_get_parameter_value(v) for v in val.elts] + else: + raise _UnexpectedAstError("Tool call arguments must be literals") + + +def _handle_single_tool(call: ast.Call) -> ToolCall: + if not isinstance(call.func, ast.Name): + raise _UnexpectedAstError("Invalid tool call name") + function_name = call.func.id + arguments = {} + for keyword in call.keywords: + arguments[keyword.arg] = _get_parameter_value(keyword.value) + return ToolCall(type="function", + function=FunctionCall(name=function_name, + arguments=json.dumps(arguments))) + + +def _make_valid_python(text: str) -> Union[tuple[str, str], None]: + bracket_stack = [] + for index, char in enumerate(text): + if char in {"[", "(", "{"}: + bracket_stack.append(char) + elif char == "]": + if not bracket_stack or bracket_stack.pop() != "[": + raise _UnexpectedAstError("Mismatched square brackets") + elif char == ")": + if not bracket_stack or bracket_stack.pop() != "(": + raise _UnexpectedAstError("Mismatched parentheses") + elif char == "}": + if not bracket_stack or bracket_stack.pop() != "{": + raise _UnexpectedAstError("Mismatched curly braces") + elif char in {"'", '"'}: + if bracket_stack and bracket_stack[-1] == char: + if index > 0 and text[index - 1] == "\\": + # Treat an escaped quote as a regular character + pass + else: + bracket_stack.pop() + elif bracket_stack and bracket_stack[-1] in {"'", '"'}: + # Double quote within a single quote string or vice versa. + pass + else: + bracket_stack.append(char) + + text = text.rstrip() + if text.endswith("=") or text.endswith(":"): + # Since we have no type information for this property/parameter value, + # we can't fill in a valid value. + return None + if bracket_stack and bracket_stack[-1] == "{": + trailing_dict_text = text[:text.rfind("{")] + num_keys = trailing_dict_text.count(":") + num_values = trailing_dict_text.count(",") + if num_keys <= num_values: + return None # Incomplete property name within parameter value + if bracket_stack and bracket_stack[-1] == "(": + trailing_params_text = text[:text.rfind("(")] + num_full_param_names = trailing_params_text.count("=") + num_full_param_values = trailing_params_text.count(",") + if num_full_param_names <= num_full_param_values: + return None # Incomplete parameter name + if text.endswith(","): + text = text[:-1] + if bracket_stack and bracket_stack[-1] == "[" and not text.endswith( + "[") and not text.endswith(")"): + return None # Incomplete function name + + added_text = "" + for char in reversed(bracket_stack): + if char == "[": + added_text += "]" + elif char == "(": + added_text += ")" + elif char == "{": + added_text += "}" + elif char == "'": + added_text += "'" + elif char == '"': + added_text += '"' + + return text + added_text, added_text + + +def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall, + index: int, + withheld_suffix: str) -> Union[DeltaToolCall, None]: + new_call_args = new_call.function.arguments + if withheld_suffix: + assert new_call_args.endswith(withheld_suffix) + new_call_args = new_call_args[:-len(withheld_suffix)] + if not previously_sent_args: + return DeltaToolCall(id=new_call.id, + index=index, + function=DeltaFunctionCall( + name=new_call.function.name, + arguments=new_call_args, + )) + + arg_diff = new_call_args[len(previously_sent_args):] + return DeltaToolCall( + id="", index=index, function=DeltaFunctionCall( + arguments=arg_diff)) if arg_diff else None diff --git a/vllm/entrypoints/openai/tool_parsers/utils.py b/vllm/entrypoints/openai/tool_parsers/utils.py new file mode 100644 index 0000000..7997629 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/utils.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json +from json import JSONDecodeError, JSONDecoder +from typing import Any + +import partial_json_parser +from partial_json_parser.core.options import Allow + + +def find_common_prefix(s1: str, s2: str) -> str: + """ + Finds a common prefix that is shared between two strings, if there is one. + Order of arguments is NOT important. + + This function is provided as a UTILITY for extracting information from JSON + generated by partial_json_parser, to help in ensuring that the right tokens + are returned in streaming, so that close-quotes, close-brackets and + close-braces are not returned prematurely. + + e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') -> + '{"fruit": "ap' + """ + prefix = '' + min_length = min(len(s1), len(s2)) + for i in range(0, min_length): + if s1[i] == s2[i]: + prefix += s1[i] + else: + break + return prefix + + +def find_common_suffix(s1: str, s2: str) -> str: + """ + Finds a common suffix shared between two strings, if there is one. Order of + arguments is NOT important. + Stops when the suffix ends OR it hits an alphanumeric character + + e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}' + """ + suffix = '' + min_length = min(len(s1), len(s2)) + for i in range(1, min_length + 1): + if s1[-i] == s2[-i] and not s1[-i].isalnum(): + suffix = s1[-i] + suffix + else: + break + return suffix + + +def extract_intermediate_diff(curr: str, old: str) -> str: + """ + Given two strings, extract the difference in the middle between two strings + that are known to have a common prefix and/or suffix. + + This function is provided as a UTILITY for extracting information from JSON + generated by partial_json_parser, to help in ensuring that the right tokens + are returned in streaming, so that close-quotes, close-brackets and + close-braces are not returned prematurely. The order of arguments IS + important - the new version of the partially-parsed JSON must be the first + argument, and the secnod argument must be from the previous generation. + + What it returns, is tokens that should be streamed to the client. + + e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}') + -> 'ple' + + """ + suffix = find_common_suffix(curr, old) + + old = old[::-1].replace(suffix[::-1], '', 1)[::-1] + prefix = find_common_prefix(curr, old) + diff = curr + if len(suffix): + diff = diff[::-1].replace(suffix[::-1], '', 1)[::-1] + + if len(prefix): + # replace the prefix only once in case it's mirrored + diff = diff.replace(prefix, '', 1) + + return diff + + +def find_all_indices(string: str, substring: str) -> list[int]: + """ + Find all (starting) indices of a substring in a given string. Useful for + tool call extraction + """ + indices = [] + index = -1 + while True: + index = string.find(substring, index + 1) + if index == -1: + break + indices.append(index) + return indices + + +# partial_json_parser doesn't support extra data and +# JSONDecorder.raw_decode doesn't support partial JSON +def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]: + try: + return (partial_json_parser.loads(input_str, flags), len(input_str)) + except JSONDecodeError as e: + if "Extra data" in e.msg: + dec = JSONDecoder() + return dec.raw_decode(input_str) + raise + + +def is_complete_json(input_str: str) -> bool: + try: + json.loads(input_str) + return True + except JSONDecodeError: + return False + + +def consume_space(i: int, s: str) -> int: + while i < len(s) and s[i].isspace(): + i += 1 + return i diff --git a/vllm/entrypoints/score_utils.py b/vllm/entrypoints/score_utils.py new file mode 100644 index 0000000..53411a2 --- /dev/null +++ b/vllm/entrypoints/score_utils.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Union + +from torch.nn import CosineSimilarity + +from vllm.outputs import PoolingRequestOutput +from vllm.transformers_utils.tokenizer import (PreTrainedTokenizer, + PreTrainedTokenizerFast) + + +def _cosine_similarity( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + embed_1: list[PoolingRequestOutput], + embed_2: list[PoolingRequestOutput], +) -> list[PoolingRequestOutput]: + + scorer = CosineSimilarity(0) + scores: Union[list[PoolingRequestOutput]] = [] + + for emb_1, emb_2 in zip(embed_1, embed_2): + pair_score = scorer(emb_1.outputs.data, emb_2.outputs.data) + + padding = [] + if (pad_token_id := getattr(tokenizer, "pad_token_id", + None)) is not None: + padding = [pad_token_id] + + tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids + + scores.append( + PoolingRequestOutput( + request_id=f"{emb_1.request_id}_{emb_2.request_id}", + outputs=pair_score, + prompt_token_ids=tokens, + finished=True)) + + return scores + + +def _validate_score_input_lens( + texts_1: Union[list[str], list[dict]], + texts_2: Union[list[str], list[dict]], +): + if len(texts_1) > 1 and len(texts_1) != len(texts_2): + raise ValueError("Input lengths must be either 1:1, 1:N or N:N") + if len(texts_1) == 0: + raise ValueError("At least one text element must be given") + if len(texts_2) == 0: + raise ValueError("At least one text_pair element must be given") diff --git a/vllm/entrypoints/ssl.py b/vllm/entrypoints/ssl.py new file mode 100644 index 0000000..dba916b --- /dev/null +++ b/vllm/entrypoints/ssl.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +from ssl import SSLContext +from typing import Callable, Optional + +from watchfiles import Change, awatch + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class SSLCertRefresher: + """A class that monitors SSL certificate files and + reloads them when they change. + """ + + def __init__(self, + ssl_context: SSLContext, + key_path: Optional[str] = None, + cert_path: Optional[str] = None, + ca_path: Optional[str] = None) -> None: + self.ssl = ssl_context + self.key_path = key_path + self.cert_path = cert_path + self.ca_path = ca_path + + # Setup certification chain watcher + def update_ssl_cert_chain(change: Change, file_path: str) -> None: + logger.info("Reloading SSL certificate chain") + assert self.key_path and self.cert_path + self.ssl.load_cert_chain(self.cert_path, self.key_path) + + self.watch_ssl_cert_task = None + if self.key_path and self.cert_path: + self.watch_ssl_cert_task = asyncio.create_task( + self._watch_files([self.key_path, self.cert_path], + update_ssl_cert_chain)) + + # Setup CA files watcher + def update_ssl_ca(change: Change, file_path: str) -> None: + logger.info("Reloading SSL CA certificates") + assert self.ca_path + self.ssl.load_verify_locations(self.ca_path) + + self.watch_ssl_ca_task = None + if self.ca_path: + self.watch_ssl_ca_task = asyncio.create_task( + self._watch_files([self.ca_path], update_ssl_ca)) + + async def _watch_files(self, paths, fun: Callable[[Change, str], + None]) -> None: + """Watch multiple file paths asynchronously.""" + logger.info("SSLCertRefresher monitors files: %s", paths) + async for changes in awatch(*paths): + try: + for change, file_path in changes: + logger.info("File change detected: %s - %s", change.name, + file_path) + fun(change, file_path) + except Exception as e: + logger.error( + "SSLCertRefresher failed taking action on file change. " + "Error: %s", e) + + def stop(self) -> None: + """Stop watching files.""" + if self.watch_ssl_cert_task: + self.watch_ssl_cert_task.cancel() + self.watch_ssl_cert_task = None + if self.watch_ssl_ca_task: + self.watch_ssl_ca_task.cancel() + self.watch_ssl_ca_task = None diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py new file mode 100644 index 0000000..b88c2b3 --- /dev/null +++ b/vllm/entrypoints/utils.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import functools +import os + +from fastapi import Request +from fastapi.responses import JSONResponse, StreamingResponse +from starlette.background import BackgroundTask, BackgroundTasks + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +async def listen_for_disconnect(request: Request) -> None: + """Returns if a disconnect message is received""" + while True: + message = await request.receive() + if message["type"] == "http.disconnect": + break + + +def with_cancellation(handler_func): + """Decorator that allows a route handler to be cancelled by client + disconnections. + + This does _not_ use request.is_disconnected, which does not work with + middleware. Instead this follows the pattern from + starlette.StreamingResponse, which simultaneously awaits on two tasks- one + to wait for an http disconnect message, and the other to do the work that we + want done. When the first task finishes, the other is cancelled. + + A core assumption of this method is that the body of the request has already + been read. This is a safe assumption to make for fastapi handlers that have + already parsed the body of the request into a pydantic model for us. + This decorator is unsafe to use elsewhere, as it will consume and throw away + all incoming messages for the request while it looks for a disconnect + message. + + In the case where a `StreamingResponse` is returned by the handler, this + wrapper will stop listening for disconnects and instead the response object + will start listening for disconnects. + """ + + # Functools.wraps is required for this wrapper to appear to fastapi as a + # normal route handler, with the correct request type hinting. + @functools.wraps(handler_func) + async def wrapper(*args, **kwargs): + + # The request is either the second positional arg or `raw_request` + request = args[1] if len(args) > 1 else kwargs["raw_request"] + + handler_task = asyncio.create_task(handler_func(*args, **kwargs)) + cancellation_task = asyncio.create_task(listen_for_disconnect(request)) + + done, pending = await asyncio.wait([handler_task, cancellation_task], + return_when=asyncio.FIRST_COMPLETED) + for task in pending: + task.cancel() + + if handler_task in done: + return handler_task.result() + return None + + return wrapper + + +def decrement_server_load(request: Request): + request.app.state.server_load_metrics -= 1 + + +def load_aware_call(func): + + @functools.wraps(func) + async def wrapper(*args, **kwargs): + raw_request = kwargs.get("raw_request", + args[1] if len(args) > 1 else None) + + if raw_request is None: + raise ValueError( + "raw_request required when server load tracking is enabled") + + if not raw_request.app.state.enable_server_load_tracking: + return await func(*args, **kwargs) + + raw_request.app.state.server_load_metrics += 1 + try: + response = await func(*args, **kwargs) + except Exception: + raw_request.app.state.server_load_metrics -= 1 + raise + + if isinstance(response, (JSONResponse, StreamingResponse)): + if response.background is None: + response.background = BackgroundTask(decrement_server_load, + raw_request) + elif isinstance(response.background, BackgroundTasks): + response.background.add_task(decrement_server_load, + raw_request) + elif isinstance(response.background, BackgroundTask): + # Convert the single BackgroundTask to BackgroundTasks + # and chain the decrement_server_load task to it + tasks = BackgroundTasks() + tasks.add_task(response.background.func, + *response.background.args, + **response.background.kwargs) + tasks.add_task(decrement_server_load, raw_request) + response.background = tasks + else: + raw_request.app.state.server_load_metrics -= 1 + + return response + + return wrapper + + +def cli_env_setup(): + # The safest multiprocessing method is `spawn`, as the default `fork` method + # is not compatible with some accelerators. The default method will be + # changing in future versions of Python, so we should use it explicitly when + # possible. + # + # We only set it here in the CLI entrypoint, because changing to `spawn` + # could break some existing code using vLLM as a library. `spawn` will cause + # unexpected behavior if the code is not protected by + # `if __name__ == "__main__":`. + # + # References: + # - https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods + # - https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing + # - https://pytorch.org/docs/stable/multiprocessing.html#sharing-cuda-tensors + # - https://docs.habana.ai/en/latest/PyTorch/Getting_Started_with_PyTorch_and_Gaudi/Getting_Started_with_PyTorch.html?highlight=multiprocessing#torch-multiprocessing-for-dataloaders + if "VLLM_WORKER_MULTIPROC_METHOD" not in os.environ: + logger.debug("Setting VLLM_WORKER_MULTIPROC_METHOD to 'spawn'") + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" diff --git a/vllm/env_override.py b/vllm/env_override.py new file mode 100644 index 0000000..a351fc7 --- /dev/null +++ b/vllm/env_override.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 +import os + +import torch + +# set some common config/environment variables that should be set +# for all processes created by vllm and all processes +# that interact with vllm workers. +# they are executed whenever `import vllm` is called. + +# see https://github.com/NVIDIA/nccl/issues/1234 +os.environ['NCCL_CUMEM_ENABLE'] = '0' + +# see https://github.com/vllm-project/vllm/pull/15951 +# it avoids unintentional cuda initialization from torch.cuda.is_available() +os.environ['PYTORCH_NVML_BASED_CUDA_CHECK'] = '1' + +# see https://github.com/vllm-project/vllm/issues/10480 +os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1' +# see https://github.com/vllm-project/vllm/issues/10619 +# torch._inductor.config.compile_threads = 1 diff --git a/vllm/envs.py b/vllm/envs.py new file mode 100644 index 0000000..0d6d5d6 --- /dev/null +++ b/vllm/envs.py @@ -0,0 +1,800 @@ +# SPDX-License-Identifier: Apache-2.0 + +import hashlib +import os +import sys +import tempfile +from typing import TYPE_CHECKING, Any, Callable, Optional + +if TYPE_CHECKING: + VLLM_HOST_IP: str = "" + VLLM_PORT: Optional[int] = None + VLLM_RPC_BASE_PATH: str = tempfile.gettempdir() + VLLM_USE_MODELSCOPE: bool = False + VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60 + VLLM_NCCL_SO_PATH: Optional[str] = None + LD_LIBRARY_PATH: Optional[str] = None + VLLM_USE_TRITON_FLASH_ATTN: bool = False + VLLM_FLASH_ATTN_VERSION: Optional[int] = None + LOCAL_RANK: int = 0 + CUDA_VISIBLE_DEVICES: Optional[str] = None + VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60 + VLLM_API_KEY: Optional[str] = None + S3_ACCESS_KEY_ID: Optional[str] = None + S3_SECRET_ACCESS_KEY: Optional[str] = None + S3_ENDPOINT_URL: Optional[str] = None + VLLM_MODEL_REDIRECT_PATH: Optional[str] = None + VLLM_CACHE_ROOT: str = os.path.expanduser("~/.cache/vllm") + VLLM_CONFIG_ROOT: str = os.path.expanduser("~/.config/vllm") + VLLM_USAGE_STATS_SERVER: str = "https://stats.vllm.ai" + VLLM_NO_USAGE_STATS: bool = False + VLLM_DO_NOT_TRACK: bool = False + VLLM_USAGE_SOURCE: str = "" + VLLM_CONFIGURE_LOGGING: int = 1 + VLLM_LOGGING_LEVEL: str = "INFO" + VLLM_LOGGING_PREFIX: str = "" + VLLM_LOGGING_CONFIG_PATH: Optional[str] = None + VLLM_LOGITS_PROCESSOR_THREADS: Optional[int] = None + VLLM_TRACE_FUNCTION: int = 0 + VLLM_ATTENTION_BACKEND: Optional[str] = None + VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None + VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False + VLLM_PP_LAYER_PARTITION: Optional[str] = None + VLLM_CPU_KVCACHE_SPACE: int = 0 + VLLM_CPU_OMP_THREADS_BIND: str = "" + VLLM_CPU_MOE_PREPACK: bool = True + VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache") + VLLM_XLA_CHECK_RECOMPILATION: bool = False + VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024 + VLLM_USE_RAY_SPMD_WORKER: bool = False + VLLM_USE_RAY_COMPILED_DAG: bool = False + VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "auto" + VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False + VLLM_WORKER_MULTIPROC_METHOD: str = "spawn" + VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") + VLLM_IMAGE_FETCH_TIMEOUT: int = 5 + VLLM_VIDEO_FETCH_TIMEOUT: int = 30 + VLLM_AUDIO_FETCH_TIMEOUT: int = 10 + VLLM_MM_INPUT_CACHE_GIB: int = 8 + VLLM_TARGET_DEVICE: str = "cuda" + MAX_JOBS: Optional[str] = None + NVCC_THREADS: Optional[str] = None + VLLM_USE_PRECOMPILED: bool = False + VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL: bool = False + VLLM_NO_DEPRECATION_WARNING: bool = False + VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: bool = False + CMAKE_BUILD_TYPE: Optional[str] = None + VERBOSE: bool = False + VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False + VLLM_RPC_TIMEOUT: int = 10000 # ms + VLLM_PLUGINS: Optional[list[str]] = None + VLLM_TORCH_PROFILER_DIR: Optional[str] = None + VLLM_USE_TRITON_AWQ: bool = False + VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False + VLLM_SKIP_P2P_CHECK: bool = False + VLLM_DISABLED_KERNELS: list[str] = [] + VLLM_USE_V1: bool = False + VLLM_ROCM_USE_AITER: bool = False + VLLM_ROCM_USE_AITER_LINEAR: bool = True + VLLM_ROCM_USE_AITER_MOE: bool = True + VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE: bool = False + VLLM_ROCM_USE_AITER_RMSNORM: bool = True + VLLM_ROCM_FP8_PADDING: bool = True + VLLM_ROCM_MOE_PADDING: bool = True + VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True + VLLM_ENABLE_V1_MULTIPROCESSING: bool = True + VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 + VLLM_DISABLE_COMPILE_CACHE: bool = False + Q_SCALE_CONSTANT: int = 200 + K_SCALE_CONSTANT: int = 200 + V_SCALE_CONSTANT: int = 100 + VLLM_SERVER_DEV_MODE: bool = False + VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128 + VLLM_MLA_DISABLE: bool = False + VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False + VLLM_RAY_PER_WORKER_GPUS: float = 1.0 + VLLM_RAY_BUNDLE_INDICES: str = "" + VLLM_CUDART_SO_PATH: Optional[str] = None + VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True + VLLM_DP_RANK: int = 0 + VLLM_DP_RANK_LOCAL: int = -1 + VLLM_DP_SIZE: int = 1 + VLLM_DP_MASTER_IP: str = "" + VLLM_DP_MASTER_PORT: int = 0 + VLLM_MARLIN_USE_ATOMIC_ADD: bool = False + VLLM_V0_USE_OUTLINES_CACHE: bool = False + VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False + VLLM_TPU_BUCKET_PADDING_GAP: int = 0 + VLLM_USE_DEEP_GEMM: bool = False + # optional envs we add. + VLLM_W8A8_MOE_USE_W4A8: bool = False + VLLM_W4A8_FORMAT: str = "TN" + VLLM_W4A8_VERSION: int = 2 + VLLM_MIX_QUANTIZATION_TYPE: str = "" + VLLM_MLA_CUSTOMIZE: bool = True + VLLM_USE_INT8_MLA: bool = False + + +def get_default_cache_root(): + return os.getenv( + "XDG_CACHE_HOME", + os.path.join(os.path.expanduser("~"), ".cache"), + ) + + +def get_default_config_root(): + return os.getenv( + "XDG_CONFIG_HOME", + os.path.join(os.path.expanduser("~"), ".config"), + ) + + +def maybe_convert_int(value: Optional[str]) -> Optional[int]: + if value is None: + return None + return int(value) + + +# The begin-* and end* here are used by the documentation generator +# to extract the used env vars. + +# begin-env-vars-definition + +environment_variables: dict[str, Callable[[], Any]] = { + + # ================== Installation Time Env Vars ================== + + # Target device of vLLM, supporting [cuda (by default), + # rocm, neuron, cpu] + "VLLM_TARGET_DEVICE": + lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda"), + + # Maximum number of compilation jobs to run in parallel. + # By default this is the number of CPUs + "MAX_JOBS": + lambda: os.getenv("MAX_JOBS", None), + + # Number of threads to use for nvcc + # By default this is 1. + # If set, `MAX_JOBS` will be reduced to avoid oversubscribing the CPU. + "NVCC_THREADS": + lambda: os.getenv("NVCC_THREADS", None), + + # If set, vllm will use precompiled binaries (*.so) + "VLLM_USE_PRECOMPILED": + lambda: bool(os.environ.get("VLLM_USE_PRECOMPILED")) or bool( + os.environ.get("VLLM_PRECOMPILED_WHEEL_LOCATION")), + + # Whether to force using nightly wheel in python build. + # This is used for testing the nightly wheel in python build. + "VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL": + lambda: bool(int(os.getenv("VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL", "0")) + ), + + # CMake build type + # If not set, defaults to "Debug" or "RelWithDebInfo" + # Available options: "Debug", "Release", "RelWithDebInfo" + "CMAKE_BUILD_TYPE": + lambda: os.getenv("CMAKE_BUILD_TYPE"), + + # If set, vllm will print verbose logs during installation + "VERBOSE": + lambda: bool(int(os.getenv('VERBOSE', '0'))), + + # Root directory for vLLM configuration files + # Defaults to `~/.config/vllm` unless `XDG_CONFIG_HOME` is set + # Note that this not only affects how vllm finds its configuration files + # during runtime, but also affects how vllm installs its configuration + # files during **installation**. + "VLLM_CONFIG_ROOT": + lambda: os.path.expanduser( + os.getenv( + "VLLM_CONFIG_ROOT", + os.path.join(get_default_config_root(), "vllm"), + )), + + # ================== Runtime Env Vars ================== + + # Root directory for vLLM cache files + # Defaults to `~/.cache/vllm` unless `XDG_CACHE_HOME` is set + "VLLM_CACHE_ROOT": + lambda: os.path.expanduser( + os.getenv( + "VLLM_CACHE_ROOT", + os.path.join(get_default_cache_root(), "vllm"), + )), + + # used in distributed environment to determine the ip address + # of the current node, when the node has multiple network interfaces. + # If you are using multi-node inference, you should set this differently + # on each node. + 'VLLM_HOST_IP': + lambda: os.getenv('VLLM_HOST_IP', ""), + + # used in distributed environment to manually set the communication port + # Note: if VLLM_PORT is set, and some code asks for multiple ports, the + # VLLM_PORT will be used as the first port, and the rest will be generated + # by incrementing the VLLM_PORT value. + # '0' is used to make mypy happy + 'VLLM_PORT': + lambda: int(os.getenv('VLLM_PORT', '0')) + if 'VLLM_PORT' in os.environ else None, + + # path used for ipc when the frontend api server is running in + # multi-processing mode to communicate with the backend engine process. + 'VLLM_RPC_BASE_PATH': + lambda: os.getenv('VLLM_RPC_BASE_PATH', tempfile.gettempdir()), + + # If true, will load models from ModelScope instead of Hugging Face Hub. + # note that the value is true or false, not numbers + "VLLM_USE_MODELSCOPE": + lambda: os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true", + + # Interval in seconds to log a warning message when the ring buffer is full + "VLLM_RINGBUFFER_WARNING_INTERVAL": + lambda: int(os.environ.get("VLLM_RINGBUFFER_WARNING_INTERVAL", "60")), + + # path to cudatoolkit home directory, under which should be bin, include, + # and lib directories. + "CUDA_HOME": + lambda: os.environ.get("CUDA_HOME", None), + + # Path to the NCCL library file. It is needed because nccl>=2.19 brought + # by PyTorch contains a bug: https://github.com/NVIDIA/nccl/issues/1234 + "VLLM_NCCL_SO_PATH": + lambda: os.environ.get("VLLM_NCCL_SO_PATH", None), + + # when `VLLM_NCCL_SO_PATH` is not set, vllm will try to find the nccl + # library file in the locations specified by `LD_LIBRARY_PATH` + "LD_LIBRARY_PATH": + lambda: os.environ.get("LD_LIBRARY_PATH", None), + + # flag to control if vllm should use triton flash attention + "VLLM_USE_TRITON_FLASH_ATTN": + lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in + ("true", "1")), + + # Force vllm to use a specific flash-attention version (2 or 3), only valid + # when using the flash-attention backend. + "VLLM_FLASH_ATTN_VERSION": + lambda: maybe_convert_int(os.environ.get("VLLM_FLASH_ATTN_VERSION", None)), + + # Internal flag to enable Dynamo fullgraph capture + "VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE": + lambda: bool( + os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"), + + # local rank of the process in the distributed setting, used to determine + # the GPU device id + "LOCAL_RANK": + lambda: int(os.environ.get("LOCAL_RANK", "0")), + + # used to control the visible devices in the distributed setting + "CUDA_VISIBLE_DEVICES": + lambda: os.environ.get("CUDA_VISIBLE_DEVICES", None), + + # timeout for each iteration in the engine + "VLLM_ENGINE_ITERATION_TIMEOUT_S": + lambda: int(os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60")), + + # API key for vLLM API server + "VLLM_API_KEY": + lambda: os.environ.get("VLLM_API_KEY", None), + + # Whether to log responses from API Server for debugging + "VLLM_DEBUG_LOG_API_SERVER_RESPONSE": + lambda: os.environ.get("VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False"). + lower() == "true", + + # S3 access information, used for tensorizer to load model from S3 + "S3_ACCESS_KEY_ID": + lambda: os.environ.get("S3_ACCESS_KEY_ID", None), + "S3_SECRET_ACCESS_KEY": + lambda: os.environ.get("S3_SECRET_ACCESS_KEY", None), + "S3_ENDPOINT_URL": + lambda: os.environ.get("S3_ENDPOINT_URL", None), + + # Usage stats collection + "VLLM_USAGE_STATS_SERVER": + lambda: os.environ.get("VLLM_USAGE_STATS_SERVER", "https://stats.vllm.ai"), + "VLLM_NO_USAGE_STATS": + lambda: os.environ.get("VLLM_NO_USAGE_STATS", "0") == "1", + "VLLM_DO_NOT_TRACK": + lambda: (os.environ.get("VLLM_DO_NOT_TRACK", None) or os.environ.get( + "DO_NOT_TRACK", None) or "0") == "1", + "VLLM_USAGE_SOURCE": + lambda: os.environ.get("VLLM_USAGE_SOURCE", "production"), + + # Logging configuration + # If set to 0, vllm will not configure logging + # If set to 1, vllm will configure logging using the default configuration + # or the configuration file specified by VLLM_LOGGING_CONFIG_PATH + "VLLM_CONFIGURE_LOGGING": + lambda: int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")), + "VLLM_LOGGING_CONFIG_PATH": + lambda: os.getenv("VLLM_LOGGING_CONFIG_PATH"), + + # this is used for configuring the default logging level + "VLLM_LOGGING_LEVEL": + lambda: os.getenv("VLLM_LOGGING_LEVEL", "INFO").upper(), + + # if set, VLLM_LOGGING_PREFIX will be prepended to all log messages + "VLLM_LOGGING_PREFIX": + lambda: os.getenv("VLLM_LOGGING_PREFIX", ""), + + # if set, vllm will call logits processors in a thread pool with this many + # threads. This is useful when using custom logits processors that either + # (a) launch additional CUDA kernels or (b) do significant CPU-bound work + # while not holding the python GIL, or both. + "VLLM_LOGITS_PROCESSOR_THREADS": + lambda: int(os.getenv("VLLM_LOGITS_PROCESSOR_THREADS", "0")) + if "VLLM_LOGITS_PROCESSOR_THREADS" in os.environ else None, + + # Trace function calls + # If set to 1, vllm will trace function calls + # Useful for debugging + "VLLM_TRACE_FUNCTION": + lambda: int(os.getenv("VLLM_TRACE_FUNCTION", "0")), + + # Backend for attention computation + # Available options: + # - "TORCH_SDPA": use torch.nn.MultiheadAttention + # - "FLASH_ATTN": use FlashAttention + # - "XFORMERS": use XFormers + # - "ROCM_FLASH": use ROCmFlashAttention + # - "FLASHINFER": use flashinfer + # - "FLASHMLA": use FlashMLA + "VLLM_ATTENTION_BACKEND": + lambda: os.getenv("VLLM_ATTENTION_BACKEND", None), + + # If set, vllm will use flashinfer sampler + "VLLM_USE_FLASHINFER_SAMPLER": + lambda: bool(int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"])) + if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ else None, + + # If set, vllm will force flashinfer to use tensor cores; + # otherwise will use heuristic based on model architecture. + "VLLM_FLASHINFER_FORCE_TENSOR_CORES": + lambda: bool(int(os.getenv("VLLM_FLASHINFER_FORCE_TENSOR_CORES", "0"))), + + # Pipeline stage partition strategy + "VLLM_PP_LAYER_PARTITION": + lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None), + + # (CPU backend only) CPU key-value cache space. + # default is 4 GiB + "VLLM_CPU_KVCACHE_SPACE": + lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")), + + # (CPU backend only) CPU core ids bound by OpenMP threads, e.g., "0-31", + # "0,1,2", "0-31,33". CPU cores of different ranks are separated by '|'. + "VLLM_CPU_OMP_THREADS_BIND": + lambda: os.getenv("VLLM_CPU_OMP_THREADS_BIND", "all"), + + # (CPU backend only) whether to use prepack for MoE layer. This will be + # passed to ipex.llm.modules.GatedMLPMOE. On unsupported CPUs, you might + # need to set this to "0" (False). + "VLLM_CPU_MOE_PREPACK": + lambda: bool(int(os.getenv("VLLM_CPU_MOE_PREPACK", "1"))), + + # If the env var is set, then all workers will execute as separate + # processes from the engine, and we use the same mechanism to trigger + # execution on all workers. + # Run vLLM with VLLM_USE_RAY_SPMD_WORKER=1 to enable it. + "VLLM_USE_RAY_SPMD_WORKER": + lambda: bool(int(os.getenv("VLLM_USE_RAY_SPMD_WORKER", "0"))), + + # If the env var is set, it uses the Ray's Compiled Graph + # (previously known as ADAG) API which optimizes the + # control plane overhead. + # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. + # Note that this variable is set to 1 in V1 by default + # when ray distributed executor is used. + "VLLM_USE_RAY_COMPILED_DAG": + lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG", "0"))), + + # If the env var is set, Ray Compiled Graph uses the specified + # channel type to communicate between workers belonging to + # different pipeline-parallel stages. + # Available options: + # - "auto": use the default channel type + # - "nccl": use NCCL for communication + # - "shm": use shared memory and gRPC for communication + # This flag is ignored if VLLM_USE_RAY_COMPILED_DAG is not set. + "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE": + lambda: os.getenv("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "auto"), + + # If the env var is set, it enables GPU communication overlap + # (experimental feature) in Ray's Compiled Graph. This flag is ignored if + # VLLM_USE_RAY_COMPILED_DAG is not set. + "VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM": + lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM", "0")) + ), + + # Use dedicated multiprocess context for workers. + # Both spawn and fork work + "VLLM_WORKER_MULTIPROC_METHOD": + lambda: os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn"), + + # Path to the cache for storing downloaded assets + "VLLM_ASSETS_CACHE": + lambda: os.path.expanduser( + os.getenv( + "VLLM_ASSETS_CACHE", + os.path.join(get_default_cache_root(), "vllm", "assets"), + )), + + # Timeout for fetching images when serving multimodal models + # Default is 5 seconds + "VLLM_IMAGE_FETCH_TIMEOUT": + lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")), + + # Timeout for fetching videos when serving multimodal models + # Default is 30 seconds + "VLLM_VIDEO_FETCH_TIMEOUT": + lambda: int(os.getenv("VLLM_VIDEO_FETCH_TIMEOUT", "30")), + + # Timeout for fetching audio when serving multimodal models + # Default is 10 seconds + "VLLM_AUDIO_FETCH_TIMEOUT": + lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")), + + # Cache size (in GiB) for multimodal input cache + # Default is 4 GiB + "VLLM_MM_INPUT_CACHE_GIB": + lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "4")), + + # Path to the XLA persistent cache directory. + # Only used for XLA devices such as TPUs. + "VLLM_XLA_CACHE_PATH": + lambda: os.path.expanduser( + os.getenv( + "VLLM_XLA_CACHE_PATH", + os.path.join(get_default_cache_root(), "vllm", "xla_cache"), + )), + + # If set, assert on XLA recompilation after each execution step. + "VLLM_XLA_CHECK_RECOMPILATION": + lambda: bool(int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION", "0"))), + "VLLM_FUSED_MOE_CHUNK_SIZE": + lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")), + + # If set, vllm will skip the deprecation warnings. + "VLLM_NO_DEPRECATION_WARNING": + lambda: bool(int(os.getenv("VLLM_NO_DEPRECATION_WARNING", "0"))), + + # If set, the OpenAI API server will stay alive even after the underlying + # AsyncLLMEngine errors and stops serving requests + "VLLM_KEEP_ALIVE_ON_ENGINE_DEATH": + lambda: bool(os.getenv("VLLM_KEEP_ALIVE_ON_ENGINE_DEATH", 0)), + + # If the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN is set, it allows + # the user to specify a max sequence length greater than + # the max length derived from the model's config.json. + # To enable this, set VLLM_ALLOW_LONG_MAX_MODEL_LEN=1. + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": + lambda: + (os.environ.get("VLLM_ALLOW_LONG_MAX_MODEL_LEN", "0").strip().lower() in + ("1", "true")), + + # If set, forces FP8 Marlin to be used for FP8 quantization regardless + # of the hardware support for FP8 compute. + "VLLM_TEST_FORCE_FP8_MARLIN": + lambda: + (os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in + ("1", "true")), + "VLLM_TEST_FORCE_LOAD_FORMAT": + lambda: os.getenv("VLLM_TEST_FORCE_LOAD_FORMAT", "dummy"), + + # Time in ms for the zmq client to wait for a response from the backend + # server for simple data operations + "VLLM_RPC_TIMEOUT": + lambda: int(os.getenv("VLLM_RPC_TIMEOUT", "1000000")), + + # a list of plugin names to load, separated by commas. + # if this is not set, it means all plugins will be loaded + # if this is set to an empty string, no plugins will be loaded + "VLLM_PLUGINS": + lambda: None if "VLLM_PLUGINS" not in os.environ else os.environ[ + "VLLM_PLUGINS"].split(","), + + # Enables torch profiler if set. Path to the directory where torch profiler + # traces are saved. Note that it must be an absolute path. + "VLLM_TORCH_PROFILER_DIR": + lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os + .path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))), + + # If set, vLLM will use Triton implementations of AWQ. + "VLLM_USE_TRITON_AWQ": + lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))), + + # If set, allow loading or unloading lora adapters in runtime, + "VLLM_ALLOW_RUNTIME_LORA_UPDATING": + lambda: + (os.environ.get("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "0").strip().lower() in + ("1", "true")), + + # By default, vLLM will check the peer-to-peer capability itself, + # in case of broken drivers. See https://github.com/vllm-project/vllm/blob/a9b15c606fea67a072416ea0ea115261a2756058/vllm/distributed/device_communicators/custom_all_reduce_utils.py#L101-L108 for details. # noqa + # If this env var is set to 1, vLLM will skip the peer-to-peer check, + # and trust the driver's peer-to-peer capability report. + "VLLM_SKIP_P2P_CHECK": + lambda: os.getenv("VLLM_SKIP_P2P_CHECK", "0") == "1", + + # List of quantization kernels that should be disabled, used for testing + # and performance comparisons. Currently only affects MPLinearKernel + # selection + # (kernels: MacheteLinearKernel, MarlinLinearKernel, ExllamaLinearKernel) + "VLLM_DISABLED_KERNELS": + lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ[ + "VLLM_DISABLED_KERNELS"].split(","), + + # If set, use the V1 code path. + "VLLM_USE_V1": + lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))), + + # Disable aiter ops unless specifically enabled. + # Acts as a parent switch to enable the rest of the other operations. + "VLLM_ROCM_USE_AITER": + lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in + ("true", "1")), + + # use aiter linear op if aiter ops are enabled + # The following list of related ops + # - scaled_mm (per-tensor / rowwise) + "VLLM_ROCM_USE_AITER_LINEAR": + lambda: (os.getenv("VLLM_ROCM_USE_AITER_LINEAR", "True").lower() in + ("true", "1")), + + # Whether to use aiter moe ops. + # By default is enabled. + "VLLM_ROCM_USE_AITER_MOE": + lambda: (os.getenv("VLLM_ROCM_USE_AITER_MOE", "True").lower() in + ("true", "1")), + + # Whether to use aiter block scaled moe kernel. + # By default this is disabled. + "VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE": + lambda: + (os.getenv("VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE", "false").lower() in + ("true", "1")), + + # use aiter rms norm op if aiter ops are enabled. + "VLLM_ROCM_USE_AITER_RMSNORM": + lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in + ("true", "1")), + + # Pad the fp8 weights to 256 bytes for ROCm + "VLLM_ROCM_FP8_PADDING": + lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))), + + # Pad the weights for the moe kernel + "VLLM_ROCM_MOE_PADDING": + lambda: bool(int(os.getenv("VLLM_ROCM_MOE_PADDING", "1"))), + + # custom paged attention kernel for MI3* cards + "VLLM_ROCM_CUSTOM_PAGED_ATTN": + lambda: (os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in + ("true", "1")), + + # Divisor for dynamic query scale factor calculation for FP8 KV Cache + "Q_SCALE_CONSTANT": + lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")), + # Divisor for dynamic key scale factor calculation for FP8 KV Cache + "K_SCALE_CONSTANT": + lambda: int(os.getenv("K_SCALE_CONSTANT", "200")), + # Divisor for dynamic value scale factor calculation for FP8 KV Cache + "V_SCALE_CONSTANT": + lambda: int(os.getenv("V_SCALE_CONSTANT", "100")), + + # If set, enable multiprocessing in LLM for the V1 code path. + "VLLM_ENABLE_V1_MULTIPROCESSING": + lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1"))), + "VLLM_LOG_BATCHSIZE_INTERVAL": + lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")), + "VLLM_DISABLE_COMPILE_CACHE": + lambda: bool(int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0"))), + + # If set, vllm will run in development mode, which will enable + # some additional endpoints for developing and debugging, + # e.g. `/reset_prefix_cache` + "VLLM_SERVER_DEV_MODE": + lambda: bool(int(os.getenv("VLLM_SERVER_DEV_MODE", "0"))), + + # Controls the maximum number of requests to handle in a + # single asyncio task when processing per-token outputs in the + # V1 AsyncLLM interface. It is applicable when handling a high + # concurrency of streaming requests. + # Setting this too high can result in a higher variance of + # inter-message latencies. Setting it too low can negatively impact + # TTFT and overall throughput. + "VLLM_V1_OUTPUT_PROC_CHUNK_SIZE": + lambda: int(os.getenv("VLLM_V1_OUTPUT_PROC_CHUNK_SIZE", "128")), + + # If set, vLLM will disable the MLA attention optimizations. + "VLLM_MLA_DISABLE": + lambda: bool(int(os.getenv("VLLM_MLA_DISABLE", "0"))), + + # If set, vLLM will use the Triton implementation of moe_align_block_size, + # i.e. moe_align_block_size_triton in fused_moe.py. + "VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON": + lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0")) + ), + + # Number of GPUs per worker in Ray, if it is set to be a fraction, + # it allows ray to schedule multiple actors on a single GPU, + # so that users can colocate other actors on the same GPUs as vLLM. + "VLLM_RAY_PER_WORKER_GPUS": + lambda: float(os.getenv("VLLM_RAY_PER_WORKER_GPUS", "1.0")), + + # Bundle indices for Ray, if it is set, it can control precisely + # which indices are used for the Ray bundle, for every worker. + # Format: comma-separated list of integers, e.g. "0,1,2,3" + "VLLM_RAY_BUNDLE_INDICES": + lambda: os.getenv("VLLM_RAY_BUNDLE_INDICES", ""), + + # In some system, find_loaded_library() may not work. So we allow users to + # specify the path through environment variable VLLM_CUDART_SO_PATH. + "VLLM_CUDART_SO_PATH": + lambda: os.getenv("VLLM_CUDART_SO_PATH", None), + + # Contiguous cache fetching to avoid using costly gather operation on + # Gaudi3. This is only applicable to HPU contiguous cache. If set to true, + # contiguous cache fetch will be used. + "VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH": + lambda: os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() in + ("1", "true"), + + # Rank of the process in the data parallel setting + "VLLM_DP_RANK": + lambda: int(os.getenv("VLLM_DP_RANK", "0")), + + # Rank of the process in the data parallel setting. + # Defaults to VLLM_DP_RANK when not set. + "VLLM_DP_RANK_LOCAL": + lambda: int( + os.getenv("VLLM_DP_RANK_LOCAL", sys.modules[__name__].VLLM_DP_RANK)), + + # World size of the data parallel setting + "VLLM_DP_SIZE": + lambda: int(os.getenv("VLLM_DP_SIZE", "1")), + + # IP address of the master node in the data parallel setting + "VLLM_DP_MASTER_IP": + lambda: os.getenv("VLLM_DP_MASTER_IP", "127.0.0.1"), + + # Port of the master node in the data parallel setting + "VLLM_DP_MASTER_PORT": + lambda: int(os.getenv("VLLM_DP_MASTER_PORT", "0")), + + # Whether to use S3 path for model loading in CI via RunAI Streamer + "VLLM_CI_USE_S3": + lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1", + + # Use model_redirect to redirect the model name to a local folder. + "VLLM_MODEL_REDIRECT_PATH": + lambda: os.environ.get("VLLM_MODEL_REDIRECT_PATH", None), + + # Whether to use atomicAdd reduce in gptq/awq marlin kernel. + "VLLM_MARLIN_USE_ATOMIC_ADD": + lambda: os.environ.get("VLLM_MARLIN_USE_ATOMIC_ADD", "0") == "1", + + # Whether to turn on the outlines cache for V0 + # This cache is unbounded and on disk, so it's not safe to use in + # an environment with potentially malicious users. + "VLLM_V0_USE_OUTLINES_CACHE": + lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1", + + # If set, disables TPU-specific optimization for top-k & top-p sampling + "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION": + lambda: bool(int(os.environ["VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"])) + if "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION" in os.environ else None, + + # Gap between padding buckets for the forward pass. So we have + # 8, we will run forward pass with [16, 24, 32, ...]. + "VLLM_TPU_BUCKET_PADDING_GAP": + lambda: int(os.environ["VLLM_TPU_BUCKET_PADDING_GAP"]) + if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 0, + + # Allow use of DeepGemm kernels for fused moe ops. + "VLLM_USE_DEEP_GEMM": + lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))), + + # vLLM do not support W4A8 and W8A8, we add it. For MOE, we default use W8A8, + # If set to true, we use W4A8. + "VLLM_W8A8_MOE_USE_W4A8": + lambda: os.environ.get("VLLM_W8A8_MOE_USE_W4A8", "0").lower() in + ("1", "true"), + + # If set to true, we use int8 mla attention for decode stage. + + "VLLM_USE_INT8_MLA": + lambda: os.environ.get("VLLM_USE_INT8_MLA", "0").lower() in + ("1", "true"), + + # For W4A8 MOE, we default use TN gemm format, choices: [TN, NN]. + "VLLM_W4A8_FORMAT": + lambda: os.environ.get("VLLM_W4A8_FORMAT", "TN").upper(), + + "VLLM_W4A8_VERSION": + # For W4A8 MOE, we default use version 2, choices: [1, 2]. + lambda: int(os.environ.get("VLLM_W4A8_VERSION", "2")), + + # temp param to support compressed-tensor's multi-quantization + "VLLM_MIX_QUANTIZATION_TYPE": + lambda: os.environ.get("VLLM_MIX_QUANTIZATION_TYPE", "").upper(), + + # Use Customize mlp impl for faster speed and less gpu memory usage. + "VLLM_MLA_CUSTOMIZE": + lambda: os.environ.get("VLLM_MLA_CUSTOMIZE", "1").lower() in + ("1", "true"), +} + +# end-env-vars-definition + + +def __getattr__(name: str): + # lazy evaluation of environment variables + if name in environment_variables: + return environment_variables[name]() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__(): + return list(environment_variables.keys()) + + +def is_set(name: str): + """Check if an environment variable is explicitly set.""" + if name in environment_variables: + return name in os.environ + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def set_vllm_use_v1(use_v1: bool): + if is_set("VLLM_USE_V1"): + raise ValueError( + "Should not call set_vllm_use_v1() if VLLM_USE_V1 is set " + "explicitly by the user. Please raise this as a Github " + "Issue and explicitly set VLLM_USE_V1=0 or 1.") + os.environ["VLLM_USE_V1"] = "1" if use_v1 else "0" + + +def compute_hash() -> str: + """ + WARNING: Whenever a new key is added to this environment + variables, ensure that it is included in the factors list if + it affects the computation graph. For example, different values + of VLLM_PP_LAYER_PARTITION will generate different computation + graphs, so it is included in the factors list. The env vars that + affect the choice of different kernels or attention backends should + also be included in the factors list. + """ + factors: list[Any] = [] + + # summarize environment variables + def factorize(name: str): + if __getattr__(name): + factors.append(__getattr__(name)) + else: + factors.append("None") + + # The values of envs may affects the computation graph. + # TODO(DefTruth): hash all environment variables? + # for key in environment_variables: + # factorize(key) + environment_variables_to_hash = [ + "VLLM_PP_LAYER_PARTITION", + "VLLM_MLA_DISABLE", + "VLLM_USE_TRITON_FLASH_ATTN", + "VLLM_USE_TRITON_AWQ", + "VLLM_DP_RANK", + "VLLM_DP_SIZE", + ] + for key in environment_variables_to_hash: + if key in environment_variables: + factorize(key) + + hash_str = hashlib.md5(str(factors).encode()).hexdigest() + + return hash_str diff --git a/vllm/executor/__init__.py b/vllm/executor/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py new file mode 100644 index 0000000..58796e5 --- /dev/null +++ b/vllm/executor/executor_base.py @@ -0,0 +1,400 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import time +from abc import ABC, abstractmethod +from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, + Union) + +import torch.nn as nn +from typing_extensions import TypeVar + +import vllm.platforms +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sequence import ExecuteModelRequest, PoolerOutput +from vllm.utils import make_async +from vllm.worker.worker_base import WorkerBase + +logger = init_logger(__name__) + +_R = TypeVar("_R", default=Any) + + +class ExecutorBase(ABC): + """Base class for all executors. + + An executor is responsible for executing the model on one device, + or it can be a distributed executor + that can execute the model on multiple devices. + """ + + uses_ray: bool # whether the executor uses Ray for orchestration. + + def __init__( + self, + vllm_config: VllmConfig, + ) -> None: + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + self._init_executor() + self.is_sleeping = False + self.sleeping_tags: set[str] = set() + + @abstractmethod + def _init_executor(self) -> None: + raise NotImplementedError + + @abstractmethod + def collective_rpc(self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: Tuple = (), + kwargs: Optional[Dict[str, Any]] = None) -> List[_R]: + """ + Execute an RPC call on all workers. + + Args: + method: Name of the worker method to execute, or a callable that + is serialized and sent to all workers to execute. + + If the method is a callable, it should accept an additional + `self` argument, in addition to the arguments passed in `args` + and `kwargs`. The `self` argument will be the worker object. + timeout: Maximum time in seconds to wait for execution. Raises a + :exc:`TimeoutError` on timeout. `None` means wait indefinitely. + args: Positional arguments to pass to the worker method. + kwargs: Keyword arguments to pass to the worker method. + + Returns: + A list containing the results from each worker. + + Note: + It is recommended to use this API to only pass control messages, + and set up data-plane communication to pass data. + """ + raise NotImplementedError + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available blocks for the GPU KV cache and + swappable CPU KV cache. + + Normally, this should simply delegate to the underlying Worker. Some + ExecutorBase may require modification of the result, e.g. to ensure the + selected cache sizes are compatible with all workers. + + Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks + are blocks that are "active" on the device and can be appended to. + num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be + appended to. + """ + results = self.collective_rpc("determine_num_available_blocks") + a = min([r[0] for r in results]) + b = min([r[1] for r in results]) + return a, b + + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: + """Initialize the KV cache by invoking the underlying worker. + """ + # NOTE: This is logged in the executor because there can be >1 workers. + logger.info("# %s blocks: %d, # CPU blocks: %d", + vllm.platforms.current_platform.device_name, + num_gpu_blocks, num_cpu_blocks) + max_concurrency = (num_gpu_blocks * self.cache_config.block_size / + self.model_config.max_model_len) + logger.info("Maximum concurrency for %s tokens per request: %.2fx", + self.model_config.max_model_len, max_concurrency) + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + self.collective_rpc("initialize_cache", + args=(num_gpu_blocks, num_cpu_blocks)) + + def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]: + """ + Run a function directly on the model inside each worker, + returning the result for each of them. + """ + + def rpc_func(worker: WorkerBase) -> _R: + return func(worker.get_model()) + + return self.collective_rpc(rpc_func) + + def execute_model( + self, execute_model_req: ExecuteModelRequest + ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]: + output = self.collective_rpc("execute_model", + args=(execute_model_req, )) + return output[0] + + def stop_remote_worker_execution_loop(self) -> None: + """Releases parallel workers from model loop.""" + return + + def add_lora(self, lora_request: LoRARequest) -> bool: + assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." + return all(self.collective_rpc("add_lora", args=(lora_request, ))) + + def remove_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return all(self.collective_rpc("remove_lora", args=(lora_id, ))) + + def pin_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return all(self.collective_rpc("pin_lora", args=(lora_id, ))) + + def list_loras(self) -> Set[int]: + sets = self.collective_rpc("list_loras") + for s in sets: + assert s == sets[0], "All workers should have the same LORAs." + return sets[0] + + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + assert prompt_adapter_request.prompt_adapter_id > 0, \ + "prompt_adapter_id must be greater than 0." + return all( + self.collective_rpc("add_prompt_adapter", + args=(prompt_adapter_request, ))) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + assert prompt_adapter_id > 0, \ + "prompt_adapter_id must be greater than 0." + return all( + self.collective_rpc("remove_prompt_adapter", + args=(prompt_adapter_id, ))) + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + assert prompt_adapter_id > 0, \ + "prompt_adapter_id must be greater than 0." + return all( + self.collective_rpc("pin_prompt_adapter", + args=(prompt_adapter_id, ))) + + def list_prompt_adapters(self) -> Set[int]: + sets = self.collective_rpc("list_prompt_adapters") + for s in sets: + assert (s == sets[0] + ), "All workers should have the same prompt adapters." + return sets[0] + + def start_profile(self) -> None: + self.collective_rpc("start_profile") + + def stop_profile(self) -> None: + self.collective_rpc("stop_profile") + + def sleep(self, level: int = 1): + if self.is_sleeping: + logger.warning("Executor is already sleeping.") + return + time_before_sleep = time.perf_counter() + self.collective_rpc("sleep", kwargs=dict(level=level)) + time_after_sleep = time.perf_counter() + self.sleeping_tags = {"weights", "kv_cache"} + self.is_sleeping = True + logger.info("It took %.6f seconds to fall asleep.", + time_after_sleep - time_before_sleep) + + def wake_up(self, tags: Optional[list[str]] = None): + if not self.is_sleeping: + logger.warning("Executor is not sleeping.") + return + if tags: + for tag in tags: + if tag not in self.sleeping_tags: + logger.warning("Tag %s is not in sleeping tags %s", tag, + self.sleeping_tags) + return + time_before_wakeup = time.perf_counter() + self.collective_rpc("wake_up", kwargs=dict(tags=tags)) + time_after_wakeup = time.perf_counter() + logger.info("It took %.6f seconds to wake up tags %s.", + time_after_wakeup - time_before_wakeup, + tags if tags is not None else self.sleeping_tags) + if tags: + for tag in tags: + self.sleeping_tags.remove(tag) + else: + self.sleeping_tags.clear() + if not self.sleeping_tags: + self.is_sleeping = False + + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + self.collective_rpc("save_sharded_state", + kwargs=dict(path=path, + pattern=pattern, + max_size=max_size)) + + @abstractmethod + def check_health(self) -> None: + """Checks if the executor is healthy. If not, it should raise an + exception.""" + raise NotImplementedError + + def shutdown(self) -> None: + """Shutdown the executor.""" + return + + def __del__(self): + self.shutdown() + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + """Executes one model step on the given sequences.""" + output = await make_async(self.execute_model)(execute_model_req) + return output + + async def stop_remote_worker_execution_loop_async(self) -> None: + """Releases parallel workers from model loop.""" + return + + async def check_health_async(self) -> None: + """Checks if the executor is healthy. If not, it should raise an + exception.""" + self.check_health() + + +class DistributedExecutorBase(ExecutorBase): + """Abstract superclass of distributed executor implementations.""" + + def __init__(self, *args, **kwargs): + # This is non-None when the execute model loop is running + # in the parallel workers. It's a coroutine in the AsyncLLMEngine case. + self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None + + super().__init__(*args, **kwargs) + + def execute_model( + self, + execute_model_req: ExecuteModelRequest, + ) -> List[SamplerOutput]: + # TODO: unify into collective_rpc + if self.parallel_worker_tasks is None: + self.parallel_worker_tasks = self._run_workers( + "start_worker_execution_loop", + async_run_tensor_parallel_workers_only=True) + + # Only the driver worker returns the sampling results. + driver_outputs = self._driver_execute_model(execute_model_req) + assert driver_outputs is not None + return driver_outputs + + def stop_remote_worker_execution_loop(self) -> None: + if self.parallel_worker_tasks is None: + return + + self._driver_execute_model(execute_model_req=None) + parallel_worker_tasks = self.parallel_worker_tasks + self.parallel_worker_tasks = None + # Ensure that workers exit model loop cleanly + # (this will raise otherwise) + self._wait_for_tasks_completion(parallel_worker_tasks) + + @abstractmethod + def _driver_execute_model( + self, execute_model_req: Optional[ExecuteModelRequest] + ) -> Optional[List[SamplerOutput]]: + """Run execute_model in the driver worker. + + Passing None will cause the driver to stop the model execution loop + running in each of the remote workers. In this case, this method + returns None. Otherwise, this method returns the model output. + """ + raise NotImplementedError + + def collective_rpc(self, + method: Union[str, Callable], + timeout: Optional[float] = None, + args: Tuple = (), + kwargs: Optional[Dict] = None) -> List[Any]: + return self._run_workers(method, *args, **(kwargs or {})) + + @abstractmethod + def _run_workers( + self, + method: Union[str, Callable], + *args, + async_run_tensor_parallel_workers_only: bool = False, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers. + + Args: + async_run_tensor_parallel_workers_only: If True the method will be + run only in the remote TP workers, not the driver worker. + It will also be run asynchronously and return a list of futures + rather than blocking on the results. + + # TODO: simplify and merge with collective_rpc + """ + raise NotImplementedError + + @abstractmethod + def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: + """Wait for futures returned from _run_workers() with + async_run_remote_workers_only to complete.""" + raise NotImplementedError + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if self.parallel_worker_tasks is None: + # Start model execution loop running in the parallel workers + self.parallel_worker_tasks = asyncio.create_task( + self._start_worker_execution_loop()) + + # Only the driver worker returns the sampling results. + return await self._driver_execute_model_async(execute_model_req) + + async def stop_remote_worker_execution_loop_async(self) -> None: + if self.parallel_worker_tasks is None: + return + + await self._driver_execute_model_async() + parallel_worker_tasks = self.parallel_worker_tasks + self.parallel_worker_tasks = None + # Ensure that workers exit model loop cleanly + # (this will raise otherwise) + await parallel_worker_tasks + + @abstractmethod + async def _driver_execute_model_async( + self, + execute_model_req: Optional[ExecuteModelRequest] = None, + ) -> List[SamplerOutput]: + """Execute the model asynchronously in the driver worker. + + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """ + raise NotImplementedError + + @abstractmethod + async def _start_worker_execution_loop(self): + """Run execution loop on all workers. It guarantees all workers run + the loop or None of them is running the loop. Loop can be stopped by + `stop_remote_worker_execution_loop`. + The API is idempotent (guarantee only 1 loop run at any moment).""" + raise NotImplementedError diff --git a/vllm/executor/mp_distributed_executor.py b/vllm/executor/mp_distributed_executor.py new file mode 100644 index 0000000..d1f8c36 --- /dev/null +++ b/vllm/executor/mp_distributed_executor.py @@ -0,0 +1,243 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import os +from typing import Any, Callable, List, Optional, Union + +import cloudpickle + +from vllm.executor.executor_base import DistributedExecutorBase +from vllm.executor.multiproc_worker_utils import ( + ProcessWorkerWrapper, ResultHandler, WorkerMonitor, + set_multiprocessing_worker_envs) +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest +from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless, + get_distributed_init_method, get_ip, get_open_port, + make_async, run_method, update_environment_variables) +from vllm.worker.worker_base import WorkerWrapperBase + +logger = init_logger(__name__) + + +class MultiprocessingDistributedExecutor(DistributedExecutorBase): + """Python multiprocessing-based distributed executor""" + + uses_ray: bool = False + + def _check_cuda(self) -> None: + """Check that the number of GPUs is sufficient for the parallel + configuration. Separate from _init_executor to reduce the number of + indented blocks. + """ + parallel_config = self.parallel_config + world_size = parallel_config.world_size + tensor_parallel_size = parallel_config.tensor_parallel_size + + cuda_device_count = cuda_device_count_stateless() + # Use confusing message for more common TP-only case. + if tensor_parallel_size > cuda_device_count: + raise RuntimeError( + f"please set tensor_parallel_size ({tensor_parallel_size}) " + f"to less than max local gpu count ({cuda_device_count})") + + if world_size > cuda_device_count: + raise RuntimeError( + f"please ensure that world_size ({world_size}) " + f"is less than than max local gpu count ({cuda_device_count})") + + # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers + if "CUDA_VISIBLE_DEVICES" not in os.environ: + update_environment_variables({ + "CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size)))) + }) + + def _init_executor(self) -> None: + + from vllm.platforms import current_platform + if current_platform.is_cuda_alike(): + self._check_cuda() + + # Create the parallel GPU workers. + world_size = self.parallel_config.world_size + tensor_parallel_size = self.parallel_config.tensor_parallel_size + + # Set multiprocessing envs that are common to V0 and V1 + set_multiprocessing_worker_envs(self.parallel_config) + + # Multiprocessing-based executor does not support multi-node setting. + # Since it only works for single node, we can use the loopback address + # 127.0.0.1 for communication. + distributed_init_method = get_distributed_init_method( + "127.0.0.1", get_open_port()) + + self.workers: List[ProcessWorkerWrapper] = [] + # This is the list of workers that are rank 0 of each TP group EXCEPT + # global rank 0. These are the workers that will broadcast to the + # rest of the workers. + self.tp_driver_workers: List[ProcessWorkerWrapper] = [] + # This is the list of workers that are not drivers and not the first + # worker in a TP group. These are the workers that will be + # broadcasted to. + self.non_driver_workers: List[ProcessWorkerWrapper] = [] + + if world_size == 1: + self.worker_monitor = None + else: + result_handler = ResultHandler() + for rank in range(1, world_size): + worker = ProcessWorkerWrapper(result_handler, + WorkerWrapperBase, + self.vllm_config, rank) + self.workers.append(worker) + if rank % tensor_parallel_size == 0: + self.tp_driver_workers.append(worker) + else: + self.non_driver_workers.append(worker) + + self.worker_monitor = WorkerMonitor(self.workers, result_handler) + result_handler.start() + self.worker_monitor.start() + + # Set up signal handlers to shutdown the executor cleanly + # sometimes gc does not work well + + self.driver_worker = WorkerWrapperBase(self.vllm_config, 0) + + all_kwargs = [] + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + for i in range(world_size): + local_rank = i + rank = i + kwargs = dict( + vllm_config=self.vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + is_driver_worker=(not self.parallel_config) + or (rank % self.parallel_config.tensor_parallel_size == 0), + ) + all_kwargs.append(kwargs) + self._run_workers("init_worker", all_kwargs) + self._run_workers("init_device") + self._run_workers("load_model", + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers) + self.driver_exec_model = make_async(self.driver_worker.execute_model) + self.pp_locks: Optional[List[asyncio.Lock]] = None + + def shutdown(self): + if (worker_monitor := getattr(self, "worker_monitor", + None)) is not None: + worker_monitor.close() + + def _driver_execute_model( + self, execute_model_req: Optional[ExecuteModelRequest] + ) -> Optional[List[SamplerOutput]]: + """Run execute_model in the driver worker. + + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """ + return self.driver_worker.execute_model(execute_model_req) + + def _run_workers( + self, + method: Union[str, Callable], + *args, + async_run_tensor_parallel_workers_only: bool = False, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> List[Any]: + """Runs the given method on all workers. + + Args: + async_run_tensor_parallel_workers_only: If True the method will be + run only in the remote TP workers, not the driver worker. + It will also be run asynchronously and return a list of futures + rather than blocking on the results. + """ + if isinstance(method, str): + sent_method = method + else: + sent_method = cloudpickle.dumps(method) + del method + + if max_concurrent_workers: + raise NotImplementedError( + "max_concurrent_workers is not supported yet.") + + if async_run_tensor_parallel_workers_only: + # Run only non-driver workers and just return futures. + return [ + worker.execute_method(sent_method, *args, **kwargs) + for worker in self.non_driver_workers + ] + + # Start all remote workers first. + worker_outputs = [ + worker.execute_method(sent_method, *args, **kwargs) + for worker in self.workers + ] + + driver_worker_output = run_method(self.driver_worker, sent_method, + args, kwargs) + + # Get the results of the workers. + return [driver_worker_output + ] + [output.get() for output in worker_outputs] + + def check_health(self) -> None: + """Raises an error if engine is unhealthy.""" + if self.worker_monitor is not None and not self.worker_monitor.is_alive( + ): + raise RuntimeError("Worker processes are not running") + + def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: + """Wait for futures returned from _run_workers() with + async_run_remote_workers_only to complete.""" + for result in parallel_worker_tasks: + result.get() + + async def _driver_execute_model_async( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + if not self.tp_driver_workers: + return await self.driver_exec_model(execute_model_req) + + if self.pp_locks is None: + # This locks each pipeline parallel stage so multiple virtual + # engines can't execute on the same stage at the same time + # We create the locks here to avoid creating them in the constructor + # which uses a different asyncio loop. + self.pp_locks = [ + asyncio.Lock() + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + + tasks = [ + asyncio.create_task( + _run_task_with_lock(self.driver_exec_model, self.pp_locks[0], + execute_model_req)) + ] + for pp_rank, driver_worker in enumerate(self.tp_driver_workers, + start=1): + tasks.append( + asyncio.create_task( + _run_task_with_lock(driver_worker.execute_method_async, + self.pp_locks[pp_rank], + "execute_model", execute_model_req))) + results = await asyncio.gather(*tasks) + + # Only the last PP stage has the final results. + return results[-1] + + async def _start_worker_execution_loop(self): + coros = [ + worker.execute_method_async("start_worker_execution_loop") + for worker in self.non_driver_workers + ] + return await asyncio.gather(*coros) diff --git a/vllm/executor/msgspec_utils.py b/vllm/executor/msgspec_utils.py new file mode 100644 index 0000000..e680d53 --- /dev/null +++ b/vllm/executor/msgspec_utils.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 + +from array import array +from typing import Any, Type + +from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE + + +def encode_hook(obj: Any) -> Any: + """Custom msgspec enc hook that supports array types. + + See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder + """ + if isinstance(obj, array): + assert obj.typecode == VLLM_TOKEN_ID_ARRAY_TYPE, ( + f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. " + f"Given array has a type code of {obj.typecode}.") + return obj.tobytes() + + +def decode_hook(type: Type, obj: Any) -> Any: + """Custom msgspec dec hook that supports array types. + + See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder + """ + if type is array: + deserialized = array(VLLM_TOKEN_ID_ARRAY_TYPE) + deserialized.frombytes(obj) + return deserialized diff --git a/vllm/executor/multiproc_worker_utils.py b/vllm/executor/multiproc_worker_utils.py new file mode 100644 index 0000000..380b672 --- /dev/null +++ b/vllm/executor/multiproc_worker_utils.py @@ -0,0 +1,312 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import os +import sys +import threading +import uuid +from dataclasses import dataclass +from multiprocessing import Queue +from multiprocessing.connection import wait +from multiprocessing.process import BaseProcess +from typing import (Any, Callable, Dict, Generic, List, Optional, TextIO, + TypeVar, Union) + +import torch + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.utils import _maybe_force_spawn, get_mp_context, run_method + +logger = init_logger(__name__) + +T = TypeVar('T') + +_TERMINATE = "TERMINATE" # sentinel + +# ANSI color codes +CYAN = '\033[1;36m' +RESET = '\033[0;0m' + +JOIN_TIMEOUT_S = 2 + + +@dataclass +class Result(Generic[T]): + """Result of task dispatched to worker""" + + task_id: uuid.UUID + value: Optional[T] = None + exception: Optional[BaseException] = None + + +class ResultFuture(threading.Event, Generic[T]): + """Synchronous future for non-async case""" + + def __init__(self): + super().__init__() + self.result: Optional[Result[T]] = None + + def set_result(self, result: Result[T]): + self.result = result + self.set() + + def get(self) -> T: + self.wait() + assert self.result is not None + if self.result.exception is not None: + raise self.result.exception + return self.result.value # type: ignore[return-value] + + +def _set_future_result(future: Union[ResultFuture, asyncio.Future], + result: Result): + if isinstance(future, ResultFuture): + future.set_result(result) + return + loop = future.get_loop() + if not loop.is_closed(): + if result.exception is not None: + loop.call_soon_threadsafe(future.set_exception, result.exception) + else: + loop.call_soon_threadsafe(future.set_result, result.value) + + +class ResultHandler(threading.Thread): + """Handle results from all workers (in background thread)""" + + def __init__(self) -> None: + super().__init__(daemon=True) + self.result_queue = get_mp_context().Queue() + self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {} + + def run(self): + for result in iter(self.result_queue.get, _TERMINATE): + future = self.tasks.pop(result.task_id) + _set_future_result(future, result) + # Ensure that all waiters will receive an exception + for task_id, future in self.tasks.items(): + _set_future_result( + future, + Result(task_id=task_id, + exception=ChildProcessError("worker died"))) + + def close(self): + self.result_queue.put(_TERMINATE) + + +class WorkerMonitor(threading.Thread): + """Monitor worker status (in background thread)""" + + def __init__(self, workers: List['ProcessWorkerWrapper'], + result_handler: ResultHandler): + super().__init__(daemon=True) + self.workers = workers + self.result_handler = result_handler + self._close = False + + def run(self) -> None: + # Blocks until any worker exits + dead_sentinels = wait([w.process.sentinel for w in self.workers]) + if not self._close: + self._close = True + + # Kill / cleanup all workers + for worker in self.workers: + process = worker.process + if process.sentinel in dead_sentinels: + process.join(JOIN_TIMEOUT_S) + if process.exitcode is not None and process.exitcode != 0: + logger.error("Worker %s pid %s died, exit code: %s", + process.name, process.pid, process.exitcode) + # Cleanup any remaining workers + if logger: + logger.info("Killing local vLLM worker processes") + for worker in self.workers: + worker.kill_worker() + # Must be done after worker task queues are all closed + self.result_handler.close() + + for worker in self.workers: + worker.process.join(JOIN_TIMEOUT_S) + + def close(self): + if self._close: + return + self._close = True + logger.info("Terminating local vLLM worker processes") + for worker in self.workers: + worker.terminate_worker() + # Must be done after worker task queues are all closed + self.result_handler.close() + + +class ProcessWorkerWrapper: + """Local process wrapper for vllm.worker.Worker, + for handling single-node multi-GPU tensor parallel.""" + + def __init__(self, result_handler: ResultHandler, + worker_factory: Callable[[VllmConfig, int], Any], + vllm_config: VllmConfig, rank: int) -> None: + self.mp = get_mp_context() + self._task_queue = self.mp.Queue() + self.result_queue = result_handler.result_queue + self.tasks = result_handler.tasks + self.process: BaseProcess = self.mp.Process( # type: ignore[attr-defined] + target=_run_worker_process, + name="VllmWorkerProcess", + kwargs=dict( + worker_factory=worker_factory, + task_queue=self._task_queue, + result_queue=self.result_queue, + vllm_config=vllm_config, + rank=rank, + ), + daemon=True) + + self.process.start() + + def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future], + method: Union[str, bytes], args, kwargs): + task_id = uuid.uuid4() + self.tasks[task_id] = future + try: + self._task_queue.put((task_id, method, args, kwargs)) + except SystemExit: + raise + except BaseException as e: + del self.tasks[task_id] + raise ChildProcessError("worker died") from e + + def execute_method(self, method: Union[str, bytes], *args, **kwargs): + future: ResultFuture = ResultFuture() + self._enqueue_task(future, method, args, kwargs) + return future + + async def execute_method_async(self, method: Union[str, bytes], *args, + **kwargs): + future = asyncio.get_running_loop().create_future() + self._enqueue_task(future, method, args, kwargs) + return await future + + def terminate_worker(self): + try: + self._task_queue.put(_TERMINATE) + except ValueError: + self.process.kill() + self._task_queue.close() + + def kill_worker(self): + self._task_queue.close() + self.process.kill() + + +def _run_worker_process( + worker_factory: Callable[[VllmConfig, int], Any], + task_queue: Queue, + result_queue: Queue, + vllm_config: VllmConfig, + rank: int, +) -> None: + """Worker process event loop""" + + # Add process-specific prefix to stdout and stderr + process_name = get_mp_context().current_process().name + pid = os.getpid() + _add_prefix(sys.stdout, process_name, pid) + _add_prefix(sys.stderr, process_name, pid) + + # Initialize worker + worker = worker_factory(vllm_config, rank) + del worker_factory + + # Accept tasks from the engine in task_queue + # and return task output in result_queue + logger.info("Worker ready; awaiting tasks") + try: + for items in iter(task_queue.get, _TERMINATE): + output = None + exception = None + task_id, method, args, kwargs = items + try: + output = run_method(worker, method, args, kwargs) + except SystemExit: + raise + except KeyboardInterrupt: + break + except BaseException as e: + logger.exception( + "Exception in worker %s while processing method %s.", + process_name, method) + exception = e + result_queue.put( + Result(task_id=task_id, value=output, exception=exception)) + except KeyboardInterrupt: + pass + except Exception: + logger.exception("Worker failed") + + # Flush TunableOp results when TunableOp is enabled and + # online (in situ) tuning is enabled. + # Offline tuning API (record_untuned_is_enabled()) only + # available in PyTorch 2.6 or later. + if torch.cuda.is_available(): + import torch.cuda.tunable as tunable + if (tunable.is_enabled() and tunable.tuning_is_enabled() + and not tunable.record_untuned_is_enabled()): + tunable.write_file() + + logger.info("Worker exiting") + + +def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None: + """Prepend each output line with process-specific prefix""" + + prefix = f"{CYAN}({worker_name} pid={pid}){RESET} " + file_write = file.write + + def write_with_prefix(s: str): + if not s: + return + if file.start_new_line: # type: ignore[attr-defined] + file_write(prefix) + idx = 0 + while (next_idx := s.find('\n', idx)) != -1: + next_idx += 1 + file_write(s[idx:next_idx]) + if next_idx == len(s): + file.start_new_line = True # type: ignore[attr-defined] + return + file_write(prefix) + idx = next_idx + file_write(s[idx:]) + file.start_new_line = False # type: ignore[attr-defined] + + file.start_new_line = True # type: ignore[attr-defined] + file.write = write_with_prefix # type: ignore[method-assign] + + +def set_multiprocessing_worker_envs(parallel_config): + """ Set up environment variables that should be used when there are workers + in a multiprocessing environment. This should be called by the parent + process before worker processes are created""" + + _maybe_force_spawn() + + # Configure thread parallelism if OMP_NUM_THREADS isn't set + # + # Helps to avoid CPU contention. The default of spawning a thread per + # core combined with multiprocessing for each GPU can have a negative + # impact on performance. The contention is amplified when running in a + # container where CPU limits can cause throttling. + default_omp_num_threads = 1 + if "OMP_NUM_THREADS" not in os.environ and ( + current_parallelism := + torch.get_num_threads()) > default_omp_num_threads: + logger.warning( + "Reducing Torch parallelism from %d threads to %d to avoid " + "unnecessary CPU contention. Set OMP_NUM_THREADS in the " + "external environment to tune this value as needed.", + current_parallelism, default_omp_num_threads) + os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads) + torch.set_num_threads(default_omp_num_threads) diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py new file mode 100644 index 0000000..9b0b987 --- /dev/null +++ b/vllm/executor/ray_distributed_executor.py @@ -0,0 +1,700 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import json +import os +from collections import defaultdict +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union + +import cloudpickle +import msgspec + +import vllm.envs as envs +from vllm.executor.executor_base import ( + DistributedExecutorBase) # yapf: disable +from vllm.executor.msgspec_utils import encode_hook +from vllm.executor.ray_utils import (RayWorkerWrapper, initialize_ray_cluster, + ray) +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.platforms import current_platform +from vllm.sequence import ExecuteModelRequest +from vllm.utils import (_run_task_with_lock, get_distributed_init_method, + get_ip, get_open_port, make_async) + +if ray is not None: + from ray.actor import ActorHandle + from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +else: + ActorHandle = None + +if TYPE_CHECKING: + from ray.util.placement_group import PlacementGroup + +logger = init_logger(__name__) + + +@dataclass +class RayWorkerMetaData: + """ + Metadata for a Ray worker. + The order of ray worker creation can be random, + and we need to reset the rank after creating all workers. + """ + worker: ActorHandle + created_rank: int + adjusted_rank: int = -1 + ip: str = "" + + +class RayDistributedExecutor(DistributedExecutorBase): + """Ray-based distributed executor""" + + # These env vars are worker-specific, therefore are NOT copied + # from the driver to the workers + WORKER_SPECIFIC_ENV_VARS = { + "VLLM_HOST_IP", "VLLM_HOST_PORT", "LOCAL_RANK", "CUDA_VISIBLE_DEVICES" + } + + config_home = envs.VLLM_CONFIG_ROOT + # This file contains a list of env vars that should not be copied + # from the driver to the Ray workers. + non_carry_over_env_vars_file = os.path.join( + config_home, "ray_non_carry_over_env_vars.json") + if os.path.exists(non_carry_over_env_vars_file): + with open(non_carry_over_env_vars_file) as f: + non_carry_over_env_vars = set(json.load(f)) + else: + non_carry_over_env_vars = set() + + uses_ray: bool = True + + def _init_executor(self) -> None: + self.forward_dag: Optional[ray.dag.CompiledDAG] = None + if envs.VLLM_USE_V1: + # V1 uses SPMD worker and compiled DAG + os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1" + os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1" + + # For TPU, avoid compiling NVIDIA's NCCL + if current_platform.is_tpu(): + os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm" + + # If the env var is set, it uses the Ray's compiled DAG API + # which optimizes the control plane overhead. + # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. + # Currently, this requires USE_RAY_SPMD_WORKER=True. + self.use_ray_compiled_dag = envs.VLLM_USE_RAY_COMPILED_DAG + # If the env var is set, then we do not distinguish between the + # "driver worker" vs other workers. Also, the rank 0 worker will + # be executed in a remote Ray worker. Currently this requires + # USE_RAY_COMPILED_DAG=True. + self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER + if self.use_ray_compiled_dag: + assert self.use_ray_spmd_worker, ( + "VLLM_USE_RAY_COMPILED_DAG=1 requires " + "VLLM_USE_RAY_SPMD_WORKER=1") + if self.use_ray_spmd_worker: + # TODO: Support SPMD worker for non-DAG Ray executor. + assert self.use_ray_compiled_dag, ( + "VLLM_USE_RAY_SPMD_WORKER=1 requires " + "VLLM_USE_RAY_COMPILED_DAG=1") + + assert self.uses_ray + initialize_ray_cluster(self.parallel_config) + placement_group = self.parallel_config.placement_group + + # Disable Ray usage stats collection. + ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") + if ray_usage != "1": + os.environ["RAY_USAGE_STATS_ENABLED"] = "0" + + # Create the parallel GPU workers. + self._init_workers_ray(placement_group) + + self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook) + self.output_decoder = msgspec.msgpack.Decoder( + Optional[List[SamplerOutput]]) + self.use_v1 = envs.VLLM_USE_V1 + + self.pp_locks: Optional[List[asyncio.Lock]] = None + if not self.use_ray_compiled_dag: + self.driver_exec_method = make_async( + self.driver_worker.execute_method) + + def shutdown(self) -> None: + logger.info( + "Shutting down Ray distributed executor. If you see error log " + "from logging.cc regarding SIGTERM received, please ignore because " + "this is the expected termination process in Ray.") + if hasattr(self, "forward_dag") and self.forward_dag is not None: + self.forward_dag.teardown() + import ray + for worker in self.workers: + ray.kill(worker) + self.forward_dag = None + + def _configure_ray_workers_use_nsight(self, + ray_remote_kwargs) -> Dict[str, Any]: + # If nsight profiling is enabled, we need to set the profiling + # configuration for the ray workers as runtime env. + runtime_env = ray_remote_kwargs.setdefault("runtime_env", {}) + runtime_env.update({ + "nsight": { + "t": "cuda,cudnn,cublas", + "o": "'worker_process_%p'", + "cuda-graph-trace": "node", + } + }) + + return ray_remote_kwargs + + # child class could overwrite this to return actual env vars. + def _get_env_vars_to_be_updated(self): + return self._env_vars_for_all_workers + + def _init_workers_ray(self, placement_group: "PlacementGroup", + **ray_remote_kwargs): + num_gpus = envs.VLLM_RAY_PER_WORKER_GPUS + + # The driver dummy worker does not actually use any resources. + # It holds the resource for the driver worker. + self.driver_dummy_worker: Optional[RayWorkerWrapper] = None + # The remaining workers are the actual ray actors. + self.workers: List[RayWorkerWrapper] = [] + + # Used in ray compiled DAG: indexed first by PP rank, + # and then TP rank. In other words, the inner list is + # the TP group of workers for a PP rank. + self.pp_tp_workers: List[List[RayWorkerWrapper]] = [] + + if self.parallel_config.ray_workers_use_nsight: + ray_remote_kwargs = self._configure_ray_workers_use_nsight( + ray_remote_kwargs) + + logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker) + + # Create the workers. + bundle_indices: List[int] + if envs.VLLM_RAY_BUNDLE_INDICES: + # Use the bundle indices specified by the user. + bundle_indices = list( + map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(","))) + assert len(bundle_indices) == self.parallel_config.world_size, \ + ("VLLM_RAY_BUNDLE_INDICES must have the same size" + f" as the world size, but got {bundle_indices=} " + f"and {self.parallel_config.world_size=}") + assert len(set(bundle_indices)) == len(bundle_indices), \ + ("VLLM_RAY_BUNDLE_INDICES cannot have duplicate values," + f" but got {bundle_indices=}") + else: + # use the first N bundles that have GPU resources. + bundle_indices = [] + for bundle_id, bundle in enumerate(placement_group.bundle_specs): + if bundle.get(current_platform.ray_device_key, 0): + bundle_indices.append(bundle_id) + bundle_indices = bundle_indices[:self.parallel_config.world_size] + + worker_metadata: List[RayWorkerMetaData] = [] + driver_ip = get_ip() + for rank, bundle_id in enumerate(bundle_indices): + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=placement_group, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=bundle_id, + ) + + if current_platform.ray_device_key == "GPU": + # NV+AMD GPUs, and Intel XPUs + worker = ray.remote( + num_cpus=0, + num_gpus=num_gpus, + scheduling_strategy=scheduling_strategy, + **ray_remote_kwargs, + )(RayWorkerWrapper).remote(vllm_config=self.vllm_config, + rpc_rank=rank) + else: + worker = ray.remote( + num_cpus=0, + num_gpus=0, + resources={current_platform.ray_device_key: num_gpus}, + scheduling_strategy=scheduling_strategy, + **ray_remote_kwargs, + )(RayWorkerWrapper).remote(vllm_config=self.vllm_config, + rpc_rank=rank) + worker_metadata.append( + RayWorkerMetaData(worker=worker, created_rank=rank)) + + worker_ips = ray.get([ + each.worker.get_node_ip.remote() # type: ignore[attr-defined] + for each in worker_metadata + ]) + + for each, ip in zip(worker_metadata, worker_ips): + each.ip = ip + + if not self.use_ray_spmd_worker: + for i, each in enumerate(worker_metadata): + # find and remove the dummy worker from the list + worker = each.worker + worker_ip = each.ip + if self.driver_dummy_worker is None and worker_ip == driver_ip: + # If the worker is on the same node as the driver, we use it + # as the resource holder for the driver process. + self.driver_dummy_worker = worker + self.driver_worker = RayWorkerWrapper( + vllm_config=self.vllm_config, rpc_rank=0) + worker_metadata.pop(i) + break + + logger.debug("workers: %s", worker_metadata) + logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker) + if not self.use_ray_spmd_worker and self.driver_dummy_worker is None: + raise ValueError( + "Ray does not allocate any GPUs on the driver node." + f"Driver IP: {driver_ip}, worker IPs: {worker_ips}." + "Consider adjusting the Ray placement group or running " + "the driver on a GPU node.") + + ip_counts: Dict[str, int] = {} + for ip in worker_ips: + ip_counts[ip] = ip_counts.get(ip, 0) + 1 + + def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): + """ + Sort the workers based on 3 properties: + 1. If the worker is on the same node as the driver (vllm engine), + it should be placed first. + 2. Then, if the worker is on a node with fewer workers, it should + be placed first. + 3. Finally, if the work is on a node with smaller IP address, it + should be placed first. + """ + ip = item.ip + return (0 if ip == driver_ip else 1, ip_counts[ip], ip) + + # After sorting, the workers on the same node will be + # close to each other, and the workers on the driver + # node will be placed first. + sorted_worker_metadata = sorted(worker_metadata, + key=sort_by_driver_then_worker_ip) + start_rank = 0 if self.use_ray_spmd_worker else 1 + for i, item in enumerate(sorted_worker_metadata): + item.adjusted_rank = i + start_rank + self.workers = [item.worker for item in sorted_worker_metadata] + rerank_mapping = { + item.created_rank: item.adjusted_rank + for item in sorted_worker_metadata + } + self._run_workers("adjust_rank", rerank_mapping) + + # Get the set of GPU IDs used on each node. + worker_node_and_gpu_ids = [] + for worker in [self.driver_dummy_worker] + self.workers: + if worker is None: + # driver_dummy_worker can be None when using ray spmd worker. + continue + worker_node_and_gpu_ids.append( + ray.get(worker.get_node_and_gpu_ids.remote()) \ + ) # type: ignore + + node_workers = defaultdict(list) # node id -> list of worker ranks + node_gpus = defaultdict(list) # node id -> list of gpu ids + + for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids): + node_workers[node_id].append(i) + # `gpu_ids` can be a list of strings or integers. + # convert them to integers for consistency. + # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs), + # string sorting is not sufficient. + # see https://github.com/vllm-project/vllm/issues/5590 + gpu_ids = [int(x) for x in gpu_ids] + node_gpus[node_id].extend(gpu_ids) + for node_id, gpu_ids in node_gpus.items(): + node_gpus[node_id] = sorted(gpu_ids) + + all_ips = set(worker_ips + [driver_ip]) + n_ips = len(all_ips) + n_nodes = len(node_workers) + + if n_nodes != n_ips: + raise RuntimeError( + f"Every node should have a unique IP address. Got {n_nodes}" + f" nodes with node ids {list(node_workers.keys())} and " + f"{n_ips} unique IP addresses {all_ips}. Please check your" + " network configuration. If you set `VLLM_HOST_IP`" + " environment variable, make sure it is unique for" + " each node.") + + # Set environment variables for the driver and workers. + all_args_to_update_environment_variables = [{ + current_platform.device_control_env_var: + ",".join(map(str, node_gpus[node_id])), + } for (node_id, _) in worker_node_and_gpu_ids] + + # Environment variables to copy from driver to workers + env_vars_to_copy = [ + v for v in envs.environment_variables + if v not in self.WORKER_SPECIFIC_ENV_VARS + and v not in self.non_carry_over_env_vars + ] + + env_vars_to_copy.extend(current_platform.additional_env_vars) + + # Copy existing env vars to each worker's args + for args in all_args_to_update_environment_variables: + # TODO: refactor platform-specific env vars + for name in env_vars_to_copy: + if name in os.environ: + args[name] = os.environ[name] + + logger.info("non_carry_over_env_vars from config: %s", + self.non_carry_over_env_vars) + logger.info( + "Copying the following environment variables to workers: %s", + [v for v in env_vars_to_copy if v in os.environ]) + logger.info( + "If certain env vars should NOT be copied to workers, add them to " + "%s file", self.non_carry_over_env_vars_file) + + self._env_vars_for_all_workers = ( + all_args_to_update_environment_variables) + + self._run_workers("update_environment_variables", + self._get_env_vars_to_be_updated()) + + if len(node_gpus) == 1: + # in single node case, we don't need to get the IP address. + # the loopback address is sufficient + # NOTE: a node may have several IP addresses, one for each + # network interface. `get_ip()` might return any of them, + # while they might not work for communication inside the node + # if the network setup is complicated. Using the loopback address + # solves this issue, as it always works for communication inside + # the node. + driver_ip = "127.0.0.1" + distributed_init_method = get_distributed_init_method( + driver_ip, get_open_port()) + + # Initialize the actual workers inside worker wrapper. + all_kwargs = [] + for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids): + local_rank = node_workers[node_id].index(rank) + kwargs = dict( + vllm_config=self.vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + is_driver_worker=(not self.parallel_config) + or (rank % self.parallel_config.tensor_parallel_size == 0), + ) + all_kwargs.append(kwargs) + self._run_workers("init_worker", all_kwargs) + + self._run_workers("init_device") + self._run_workers("load_model", + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers) + + if self.use_ray_spmd_worker: + for pp_rank in range(self.parallel_config.pipeline_parallel_size): + self.pp_tp_workers.append([]) + for tp_rank in range( + self.parallel_config.tensor_parallel_size): + # PP=2, TP=4 + # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]] + rank = (pp_rank * self.parallel_config.tensor_parallel_size + ) + tp_rank + assert len(self.pp_tp_workers[pp_rank]) == tp_rank + assert pp_rank < len(self.pp_tp_workers) + self.pp_tp_workers[pp_rank].append(self.workers[rank]) + + # This is the list of workers that are rank 0 of each TP group EXCEPT + # global rank 0. These are the workers that will broadcast to the + # rest of the workers. + self.tp_driver_workers: List[RayWorkerWrapper] = [] + # This is the list of workers that are not drivers and not the first + # worker in a TP group. These are the workers that will be + # broadcasted to. + self.non_driver_workers: List[RayWorkerWrapper] = [] + + # Enforce rank order for correct rank to return final output. + for index, worker in enumerate(self.workers): + # The driver worker is rank 0 and not in self.workers. + rank = index + 1 + if rank % self.parallel_config.tensor_parallel_size == 0: + self.tp_driver_workers.append(worker) + else: + self.non_driver_workers.append(worker) + + def _driver_execute_model( + self, execute_model_req: Optional[ExecuteModelRequest] + ) -> Optional[List[SamplerOutput]]: + """Run execute_model in the driver worker. + + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """ + assert not self.use_ray_spmd_worker, ( + "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1") + return self.driver_worker.execute_method("execute_model", + execute_model_req) + + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if not self.use_ray_spmd_worker: + return super().execute_model(execute_model_req) + + if self.forward_dag is None: + self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) + + if self.use_v1: + serialized_data = execute_model_req + else: + serialized_data = self.input_encoder.encode(execute_model_req) + outputs = ray.get(self.forward_dag.execute(serialized_data)) + if self.use_v1: + output = outputs[0] + else: + output = self.output_decoder.decode(outputs[0]) + return output + + def _run_workers( + self, + method: Union[str, Callable], + *args, + async_run_tensor_parallel_workers_only: bool = False, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers. Can be used in the following + ways: + + Args: + - async_run_tensor_parallel_workers_only: If True the method will be + run only in the remote TP workers, not the driver worker. + It will also be run asynchronously and return a list of futures + rather than blocking on the results. + - args/kwargs: All workers share the same args/kwargs + """ + if isinstance(method, str): + sent_method = method + else: + sent_method = cloudpickle.dumps(method) + del method + if self.use_ray_spmd_worker: + assert not async_run_tensor_parallel_workers_only, ( + "async_run_tensor_parallel_workers_only is not supported for " + "spmd mode.") + + if max_concurrent_workers: + raise NotImplementedError( + "max_concurrent_workers is not supported yet.") + + # Start the ray workers first. + ray_workers = self.workers + if async_run_tensor_parallel_workers_only: + ray_workers = self.non_driver_workers + ray_worker_outputs = [ + worker.execute_method.remote(sent_method, *args, **kwargs) + for worker in ray_workers + ] + + if async_run_tensor_parallel_workers_only: + # Just return futures + return ray_worker_outputs + + driver_worker_output = [] + # In SPMD mode, the driver worker is the same as any other worker, + # so we only explicitly execute on the driver worker if using a + # non-SPMD worker class. + if not self.use_ray_spmd_worker: + # Start the driver worker after all the ray workers. + driver_worker_output = [ + self.driver_worker.execute_method(sent_method, *args, **kwargs) + ] + + # Get the results of the ray workers. + if self.workers: + ray_worker_outputs = ray.get(ray_worker_outputs) + + return driver_worker_output + ray_worker_outputs + + def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: + """Wait for futures returned from _run_workers() with + async_run_remote_workers_only to complete.""" + ray.get(parallel_worker_tasks) + + def _check_ray_cgraph_installation(self): + import pkg_resources + from packaging import version + + required_version = version.parse("2.43.0") + current_version = version.parse( + pkg_resources.get_distribution("ray").version) + if current_version < required_version: + raise ValueError(f"Ray version {required_version} is " + f"required, but found {current_version}") + + import importlib.util + cgraph_spec = importlib.util.find_spec( + "ray.experimental.compiled_dag_ref") + if cgraph_spec is None: + raise ValueError("Ray Compiled Graph is not installed. " + "Run `pip install ray[cgraph]` to install it.") + + cupy_spec = importlib.util.find_spec("cupy") + if (cupy_spec is None + and envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl"): + raise ValueError( + "cupy is not installed but required since " + "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE is set to 'nccl'. " + "Run `pip install ray[cgraph]` and check cupy installation.") + + def _compiled_ray_dag(self, enable_asyncio: bool): + assert self.parallel_config.use_ray + self._check_ray_cgraph_installation() + from ray.dag import InputNode, MultiOutputNode + + logger.info("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE = %s", + envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE) + logger.info("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s", + envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM) + + channel_type = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE + if channel_type not in ("auto", "nccl", "shm"): + raise ValueError( + "Invalid value for VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: " + f"{channel_type}. Valid values are: 'auto', 'nccl', or 'shm'.") + + # Enlarge the default value of "RAY_CGRAPH_get_timeout" to 300 seconds + # (it is 10 seconds by default). This is a Ray environment variable to + # control the timeout of getting result from a compiled graph execution, + # i.e., the distributed execution that includes model forward runs and + # intermediate tensor communications, in the case of vllm. + os.environ.setdefault("RAY_CGRAPH_get_timeout", "300") # noqa: SIM112 + logger.info("RAY_CGRAPH_get_timeout is set to %s", + os.environ["RAY_CGRAPH_get_timeout"]) # noqa: SIM112 + + with InputNode() as input_data: + # Example DAG: PP=2, TP=4 + # + # For V0: + # ExecuteModelRequest -> 0 -> (ExecuteModelReq, IntermediateTensors) -> 4 -> SamplerOutput # noqa: E501 + # ExecuteModelRequest -> 1 -> (ExecuteModelReq, IntermediateTensors) -> 5 -> SamplerOutput # noqa: E501 + # ExecuteModelRequest -> 2 -> (ExecuteModelReq, IntermediateTensors) -> 6 -> SamplerOutput # noqa: E501 + # ExecuteModelRequest -> 3 -> (ExecuteModelReq, IntermediateTensors) -> 7 -> SamplerOutput # noqa: E501 + # + # For V1: + # SchedulerOutput -> 0 -> (SchedulerOutput, IntermediateTensors) -> 4 -> ModelRunnerOutput # noqa: E501 + # SchedulerOutput -> 1 -> (SchedulerOutput, IntermediateTensors) -> 5 -> ModelRunnerOutput # noqa: E501 + # SchedulerOutput -> 2 -> (SchedulerOutput, IntermediateTensors) -> 6 -> ModelRunnerOutput # noqa: E501 + # SchedulerOutput -> 3 -> (SchedulerOutput, IntermediateTensors) -> 7 -> ModelRunnerOutput # noqa: E501 + + # All workers in the first TP group will take in the + # ExecuteModelRequest as input. + outputs = [input_data for _ in self.pp_tp_workers[0]] + for pp_rank, tp_group in enumerate(self.pp_tp_workers): + # Each PP worker takes in the output of the previous PP worker, + # and the TP group executes in SPMD fashion. + if self.use_v1: + outputs = [ + worker.execute_model_ray. + bind( # type: ignore[attr-defined] + outputs[i]) for i, worker in enumerate(tp_group) + ] + else: + outputs = [ + worker.execute_model_spmd. + bind( # type: ignore[attr-defined] + outputs[i]) for i, worker in enumerate(tp_group) + ] + + last_pp_rank = len(self.pp_tp_workers) - 1 + if (pp_rank < last_pp_rank and + envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE != "shm"): + # Specify how intermediate tensors should be passed + # between pp stages, no need to specify for the last + # pp stage or when using shared memory (the default). + transport = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE + outputs = [ + output.with_tensor_transport(transport=transport) + for output in outputs + ] + + forward_dag = MultiOutputNode(outputs) + + return forward_dag.experimental_compile( + enable_asyncio=enable_asyncio, + _overlap_gpu_communication=envs. + VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM) + + def __del__(self): + self.shutdown() + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if not self.use_ray_spmd_worker: + return await super().execute_model_async(execute_model_req) + + if self.forward_dag is None: + self.forward_dag = self._compiled_ray_dag(enable_asyncio=True) + + serialized_data = self.input_encoder.encode(execute_model_req) + dag_future = await self.forward_dag.execute_async(serialized_data) + output = await dag_future[0] + return self.output_decoder.decode(output) + + async def _driver_execute_model_async( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + assert not self.use_ray_spmd_worker, ( + "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1") + if not self.tp_driver_workers: + return await self.driver_exec_method("execute_model", + execute_model_req) + if self.pp_locks is None: + # This locks each pipeline parallel stage so multiple virtual + # engines can't execute on the same stage at the same time + # We create the locks here to avoid creating them in the constructor + # which uses a different asyncio loop. + self.pp_locks = [ + asyncio.Lock() + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + + tasks = [ + asyncio.create_task( + _run_task_with_lock(self.driver_exec_method, self.pp_locks[0], + "execute_model", execute_model_req)) + ] + for pp_rank, driver_worker in enumerate(self.tp_driver_workers, + start=1): + tasks.append( + asyncio.create_task( + _run_task_with_lock(driver_worker.execute_method.remote, + self.pp_locks[pp_rank], + "execute_model", execute_model_req))) + + results = await asyncio.gather(*tasks) + + # Only the last PP stage has the final results. + return results[-1] + + async def _start_worker_execution_loop(self): + assert not self.use_ray_spmd_worker, ( + "worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1") + coros = [ + worker.execute_method.remote("start_worker_execution_loop") + for worker in self.non_driver_workers + ] + return await asyncio.gather(*coros) + + def check_health(self) -> None: + # Assume that the Ray workers are healthy. + # TODO: check the health of the Ray workers + return diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py new file mode 100644 index 0000000..ef7df45 --- /dev/null +++ b/vllm/executor/ray_utils.py @@ -0,0 +1,433 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +import time +from collections import defaultdict +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import msgspec + +import vllm.platforms +from vllm.config import ParallelConfig +from vllm.executor.msgspec_utils import decode_hook, encode_hook +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.sequence import ExecuteModelRequest, IntermediateTensors +from vllm.utils import get_ip +from vllm.worker.worker_base import WorkerWrapperBase + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.outputs import ModelRunnerOutput + +logger = init_logger(__name__) +PG_WAIT_TIMEOUT = 1800 + +try: + import ray + from ray.util import placement_group_table + from ray.util.placement_group import PlacementGroup + try: + from ray._private.state import available_resources_per_node + except ImportError: + # Ray 2.9.x doesn't expose `available_resources_per_node` + from ray._private.state import state as _state + available_resources_per_node = _state._available_resources_per_node + + class RayWorkerWrapper(WorkerWrapperBase): + """Ray wrapper for vllm.worker.Worker, allowing Worker to be + lazily initialized after Ray sets CUDA_VISIBLE_DEVICES.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + # Since the compiled DAG runs a main execution + # in a different thread that calls cuda.set_device. + # The flag indicates is set_device is called on + # that thread. + self.compiled_dag_cuda_device_set = False + + self.input_decoder = msgspec.msgpack.Decoder(ExecuteModelRequest, + dec_hook=decode_hook) + self.output_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook) + + def get_node_ip(self) -> str: + return get_ip() + + def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: + node_id = ray.get_runtime_context().get_node_id() + device_key = vllm.platforms.current_platform.ray_device_key + if not device_key: + raise RuntimeError("current platform %s does not support ray.", + vllm.platforms.current_platform.device_name) + gpu_ids = ray.get_runtime_context().get_accelerator_ids( + )[device_key] + return node_id, gpu_ids + + def execute_model_spmd( + self, req_or_tuple: Union[bytes, + Tuple[bytes, + Optional[IntermediateTensors]]] + ) -> bytes: + """Execute model in SPMD fashion: used only when SPMD worker and + compiled DAG are both enabled. + + Args: + req_or_tuple: A request or a tuple containing the + request and intermediate tensors. Intermediate tensors are + None unless if it is provided because it is > 0 pipeline + stage. The request is serialized by msgspec. + """ + if isinstance(req_or_tuple, bytes): + serialized_req, intermediate_tensors = req_or_tuple, None + else: + serialized_req, intermediate_tensors = req_or_tuple + + execute_model_req = self.input_decoder.decode(serialized_req) + + # TODO(swang): This is needed right now because Ray Compiled Graph + # executes on a background thread, so we need to reset torch's + # current device. + import torch + if not self.compiled_dag_cuda_device_set: + torch.cuda.set_device(self.worker.device) + self.compiled_dag_cuda_device_set = True + + output = self.worker._execute_model_spmd(execute_model_req, + intermediate_tensors) + # Pipeline model request and output to the next pipeline stage. + if isinstance(output, IntermediateTensors): + output = serialized_req, output + else: + output = self.output_encoder.encode(output) + + return output + + def setup_device_if_necessary(self): + # TODO(swang): This is needed right now because Ray CG executes + # on a background thread, so we need to reset torch's current + # device. + # We can remove this API after it is fixed in compiled graph. + assert self.worker is not None, "Worker is not initialized" + if not self.compiled_dag_cuda_device_set: + if current_platform.is_tpu(): + # Not needed + pass + else: + import torch + torch.cuda.set_device(self.worker.device) + + self.compiled_dag_cuda_device_set = True + + def execute_model_ray( + self, + scheduler_output: Union["SchedulerOutput", + Tuple["SchedulerOutput", + "IntermediateTensors"]], + ) -> Union["ModelRunnerOutput", Tuple["SchedulerOutput", + "IntermediateTensors"]]: + # This method is used by Ray Compiled Graph to execute the model, + # and it needs a special logic of self.setup_device_if_necessary() + self.setup_device_if_necessary() + assert self.worker is not None, "Worker is not initialized" + if isinstance(scheduler_output, tuple): + scheduler_output, intermediate_tensors = scheduler_output + else: + scheduler_output, intermediate_tensors = scheduler_output, None + output = self.worker.model_runner.execute_model( + scheduler_output, intermediate_tensors) + if isinstance(output, IntermediateTensors): + output = scheduler_output, output + return output + + def override_env_vars(self, vars: Dict[str, str]): + os.environ.update(vars) + + ray_import_err = None + +except ImportError as e: + ray = None # type: ignore + ray_import_err = e + RayWorkerWrapper = None # type: ignore + + +def ray_is_available() -> bool: + """Returns True if Ray is available.""" + return ray is not None + + +def assert_ray_available(): + """Raise an exception if Ray is not available.""" + if ray is None: + raise ValueError("Failed to import Ray, please install Ray with " + "`pip install ray`.") from ray_import_err + + +def _verify_bundles(placement_group: "PlacementGroup", + parallel_config: ParallelConfig, device_str: str): + """Verify a given placement group has bundles located in the right place. + + There are 2 rules. + - Warn if all tensor parallel workers cannot fit in a single node. + - Fail if driver node is not included in a placement group. + """ + assert ray.is_initialized(), ( + "Ray is not initialized although distributed-executor-backend is ray.") + pg_data = placement_group_table(placement_group) + # bundle_idx -> node_id + bundle_to_node_ids = pg_data["bundles_to_node_id"] + # bundle_idx -> bundle (e.g., {"GPU": 1}) + bundles = pg_data["bundles"] + # node_id -> List of bundle (e.g., {"GPU": 1}) + node_id_to_bundle: Dict[str, List[Dict[str, float]]] = defaultdict(list) + + for bundle_idx, node_id in bundle_to_node_ids.items(): + node_id_to_bundle[node_id].append(bundles[bundle_idx]) + driver_node_id = ray.get_runtime_context().get_node_id() + + if driver_node_id not in node_id_to_bundle: + raise RuntimeError( + f"driver node id {driver_node_id} is not included in a placement " + f"group {placement_group.id}. Node id -> bundles " + f"{node_id_to_bundle}. " + "You don't have enough GPUs available in a current node. Check " + "`ray status` and `ray list nodes` to see if you have available " + "GPUs in a node `{driver_node_id}` before starting an vLLM engine." + ) + + for node_id, bundles in node_id_to_bundle.items(): + if len(bundles) < parallel_config.tensor_parallel_size: + logger.warning( + "tensor_parallel_size=%d " + "is bigger than a reserved number of %ss (%d " + "%ss) in a node %s. Tensor parallel workers can be " + "spread out to 2+ nodes which can degrade the performance " + "unless you have fast interconnect across nodes, like " + "Infiniband. To resolve this issue, make sure you have more " + "than %d GPUs available at each node.", + parallel_config.tensor_parallel_size, device_str, len(bundles), + device_str, node_id, parallel_config.tensor_parallel_size) + + +def _wait_until_pg_ready(current_placement_group: "PlacementGroup"): + """Wait until a placement group is ready. + + It prints the informative log messages if the placement group is + not created within time. + + """ + # Wait until PG is ready - this will block until all + # requested resources are available, and will timeout + # if they cannot be provisioned. + placement_group_specs = current_placement_group.bundle_specs + + s = time.time() + pg_ready_ref = current_placement_group.ready() + wait_interval = 10 + while time.time() - s < PG_WAIT_TIMEOUT: + ready, _ = ray.wait([pg_ready_ref], timeout=wait_interval) + if len(ready) > 0: + break + + # Exponential backoff for warning print. + wait_interval *= 2 + logger.info( + "Waiting for creating a placement group of specs for " + "%d seconds. specs=%s. Check `ray status` and " + "`ray list nodes` to see if you have enough resources," + " and make sure the IP addresses used by ray cluster" + " are the same as VLLM_HOST_IP environment variable" + " specified in each node if you are running on a multi-node.", + int(time.time() - s), placement_group_specs) + + try: + ray.get(pg_ready_ref, timeout=0) + except ray.exceptions.GetTimeoutError: + raise ValueError( + "Cannot provide a placement group of " + f"{placement_group_specs=} within {PG_WAIT_TIMEOUT} seconds. See " + "`ray status` and `ray list nodes` to make sure the cluster has " + "enough resources.") from None + + +def _wait_until_pg_removed(current_placement_group: "PlacementGroup"): + ray.util.remove_placement_group(current_placement_group) + s = time.time() + wait_interval = 10 + while time.time() - s < PG_WAIT_TIMEOUT: + pg = ray.util.get_current_placement_group() + if pg is None: + break + + # Exponential backoff for warning print. + wait_interval *= 2 + logger.info( + "Waiting for removing a placement group of specs for " + "%d seconds.", int(time.time() - s)) + time.sleep(wait_interval) + + +def initialize_ray_cluster( + parallel_config: ParallelConfig, + ray_address: Optional[str] = None, +): + """Initialize the distributed cluster with Ray. + + it will connect to the Ray cluster and create a placement group + for the workers, which includes the specification of the resources + for each distributed worker. + + Args: + parallel_config: The configurations for parallel execution. + ray_address: The address of the Ray cluster. If None, uses + the default Ray cluster address. + """ + assert_ray_available() + from vllm.platforms import current_platform + + if ray.is_initialized(): + logger.info("Ray is already initialized. Skipping Ray initialization.") + elif current_platform.is_rocm() or current_platform.is_xpu(): + # Try to connect existing ray instance and create a new one if not found + try: + ray.init("auto") + except ConnectionError: + logger.warning( + "No existing RAY instance detected. " + "A new instance will be launched with current node resources.") + ray.init(address=ray_address, num_gpus=parallel_config.world_size) + else: + import os + runtime_env = {} + import torch + device_count = torch.cuda.device_count() + import vllm.envs as envs + nccl_if_name = os.environ.get("NCCL_SOCKET_IFNAME",None) + vllm_nccl_comm = os.environ.get("VLLM_FORCE_NCCL_COMM",None) + if nccl_if_name is not None and vllm_nccl_comm is not None: + runtime_env = {"env_vars":{ + "NCCL_SOCKET_IFNAME":nccl_if_name, + "VLLM_FORCE_NCCL_COMM":vllm_nccl_comm}} + elif nccl_if_name is not None: + runtime_env = {"env_vars":{ + "NCCL_SOCKET_IFNAME":nccl_if_name}} + elif vllm_nccl_comm is not None: + runtime_env = {"env_vars":{ + "VLLM_FORCE_NCCL_COMM":vllm_nccl_comm}} + if "env_vars" not in runtime_env: + runtime_env = { + "env_vars":{"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES":"1"} + } + else: + runtime_env["env_vars"].update({"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES":"1"}) + all_envs = dict(os.environ) + all_vllm_envs = {k: v for k,v in all_envs.items() if "VLLM" in k} + runtime_env["env_vars"].update(all_vllm_envs) + # ray.init(address=ray_address, ignore_reinit_error=True, runtime_env=runtime_env) + if device_count >= parallel_config.world_size: + ray.init(address=ray_address, + ignore_reinit_error=True, + num_gpus=parallel_config.world_size, + runtime_env=runtime_env) + else: + ray.init(address=ray_address, ignore_reinit_error=True, runtime_env=runtime_env) + + device_str = current_platform.ray_device_key + if not device_str: + raise ValueError( + f"current platform {current_platform.device_name} does not " + "support ray.") + + # Create or get the placement group for worker processes + if parallel_config.placement_group: + current_placement_group = parallel_config.placement_group + else: + current_placement_group = ray.util.get_current_placement_group() + + if current_placement_group: + logger.info("Using the existing placement group") + + # We are in a placement group + bundles = current_placement_group.bundle_specs + # Verify that we can use the placement group. + device_bundles = 0 + for bundle in bundles: + bundle_devices = bundle.get(device_str, 0) + if bundle_devices > 1: + raise ValueError( + "Placement group bundle cannot have more than 1 " + f"{device_str}.") + if bundle_devices: + device_bundles += 1 + if parallel_config.world_size > device_bundles: + raise ValueError( + f"The number of required {device_str}s exceeds the total " + f"number of available {device_str}s in the placement group. " + f"Required number of devices: {parallel_config.world_size}. " + f"Total number of devices: {device_bundles}.") + else: + logger.info("No current placement group found. " + "Creating a new placement group.") + num_devices_in_cluster = ray.cluster_resources().get(device_str, 0) + # Log a warning message and delay resource allocation failure response. + # Avoid immediate rejection to allow user-initiated placement group + # created and wait cluster to be ready + if parallel_config.world_size > num_devices_in_cluster: + logger.warning( + "The number of required %ss exceeds the total " + "number of available %ss in the placement group.", device_str, + device_str) + # Create a new placement group + placement_group_specs: List[Dict[str, float]] = ([{ + device_str: 1.0 + } for _ in range(parallel_config.world_size)]) + + # vLLM engine is also a worker to execute model with an accelerator, + # so it requires to have the device in a current node. Check if + # the current node has at least one device. + current_ip = get_ip() + current_node_id = ray.get_runtime_context().get_node_id() + current_node_resource = available_resources_per_node()[current_node_id] + if current_node_resource.get(device_str, 0) < 1: + raise ValueError( + f"Current node has no {device_str} available. " + f"{current_node_resource=}. vLLM engine cannot start without " + f"{device_str}. Make sure you have at least 1 {device_str} " + f"available in a node {current_node_id=} {current_ip=}.") + # This way, at least bundle is required to be created in a current + # node. + placement_group_specs[0][f"node:{current_ip}"] = 0.001 + + # By default, Ray packs resources as much as possible. + current_placement_group = ray.util.placement_group( + placement_group_specs, strategy="PACK") + _wait_until_pg_ready(current_placement_group) + + assert current_placement_group is not None + _verify_bundles(current_placement_group, parallel_config, device_str) + # Set the placement group in the parallel config + parallel_config.placement_group = current_placement_group + + +def get_num_tpu_nodes() -> int: + from ray._private.accelerators import TPUAcceleratorManager + cluster_resources = ray.cluster_resources() + total_tpus = int(cluster_resources["TPU"]) + tpus_per_node = TPUAcceleratorManager.get_current_node_num_accelerators() + assert total_tpus % tpus_per_node == 0 + return total_tpus // tpus_per_node + + +def get_num_nodes_in_placement_group() -> int: + pg_table = ray.util.placement_group_table() + current_pg = ray.util.get_current_placement_group() + num_nodes = 0 + + if current_pg: + nodes_in_pg = set() + for pg_key, pg in pg_table.items(): + if pg_key == current_pg.id.hex(): + for _, node in pg["bundles_to_node_id"].items(): + nodes_in_pg.add(node) + num_nodes = len(nodes_in_pg) + + return num_nodes diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py new file mode 100644 index 0000000..8c004c7 --- /dev/null +++ b/vllm/executor/uniproc_executor.py @@ -0,0 +1,141 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.distributed as dist + +import vllm.envs as envs +from vllm.executor.executor_base import ExecutorBase +from vllm.logger import init_logger +from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, + run_method) +from vllm.worker.worker_base import WorkerWrapperBase + +logger = init_logger(__name__) + + +class UniProcExecutor(ExecutorBase): + + uses_ray: bool = False + + def _init_executor(self) -> None: + """Initialize the worker and load the model. + """ + self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, + rpc_rank=0) + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + local_rank = 0 + # set local rank as the device index if specified + device_info = self.vllm_config.device_config.device.__str__().split( + ":") + if len(device_info) > 1: + local_rank = int(device_info[1]) + rank = 0 + kwargs = dict( + vllm_config=self.vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + is_driver_worker=(not self.parallel_config) + or (rank % self.parallel_config.tensor_parallel_size == 0), + ) + self.collective_rpc("init_worker", args=([kwargs], )) + self.collective_rpc("init_device") + self.collective_rpc("load_model") + + def collective_rpc(self, + method: Union[str, Callable], + timeout: Optional[float] = None, + args: Tuple = (), + kwargs: Optional[Dict] = None) -> List[Any]: + if kwargs is None: + kwargs = {} + answer = run_method(self.driver_worker, method, args, kwargs) + return [answer] + + def check_health(self) -> None: + # UniProcExecutor will always be healthy as long as + # it's running. + return + + +UniProcExecutorAsync = UniProcExecutor + + +class ExecutorWithExternalLauncher(UniProcExecutor): + """An executor that uses external launchers to launch engines, + specially designed for torchrun-compatible launchers, for + offline inference with tensor parallelism. + + see https://github.com/vllm-project/vllm/issues/11400 for + the motivation, and examples/offline_inference/torchrun_example.py + for the usage example. + + The key idea: although it is tensor-parallel inference, we only + create one worker per executor, users will launch multiple + engines with torchrun-compatible launchers, and all these engines + work together to process the same prompts. When scheduling is + deterministic, all the engines will generate the same outputs, + and they don't need to synchronize the states with each other. + """ + uses_ray: bool = False + + def _init_executor(self) -> None: + """Initialize the worker and load the model. + """ + assert self.vllm_config.parallel_config.pipeline_parallel_size == 1, \ + ("ExecutorWithExternalLauncher does not " + "support pipeline parallelism.") + assert self.vllm_config.scheduler_config.delay_factor == 0.0, \ + ("ExecutorWithExternalLauncher needs deterministic " + "execution, so it" + "does not support delay_factor in scheduling") + if envs.VLLM_USE_V1: + assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, \ + ("To get deterministic execution in V1, " + "please set VLLM_ENABLE_V1_MULTIPROCESSING=0") + self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, + rpc_rank=0) + # engines are launched in torchrun-compatible launchers + # so we can use the env:// method. + # required env vars: + # - RANK + # - LOCAL_RANK + # - MASTER_ADDR + # - MASTER_PORT + distributed_init_method = "env://" + rank = int(os.environ["RANK"]) + local_rank = int(os.environ["LOCAL_RANK"]) + is_driver_worker = True + kwargs = dict( + vllm_config=self.vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + is_driver_worker=is_driver_worker, + ) + self.collective_rpc("init_worker", args=([kwargs], )) + self.collective_rpc("init_device") + self.collective_rpc("load_model") + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """ + Determine the number of available KV blocks. + Add an additional all_reduce to get the min across all ranks. + Note that even if we have the same `gpu_memory_utilization` and + `swap_space`, the available memory in every rank might still + differ because NCCL can take different amounts of memory in + different ranks. Therefore, it is necessary to test if all ranks + agree on the same KV cache configuration. + """ + a, b = super().determine_num_available_blocks() + from vllm.distributed.parallel_state import get_world_group + cpu_group = get_world_group().cpu_group + a_tensor = torch.tensor([a], device="cpu", dtype=torch.int64) + b_tensor = torch.tensor([b], device="cpu", dtype=torch.int64) + dist.all_reduce(a_tensor, group=cpu_group, op=dist.ReduceOp.MIN) + dist.all_reduce(b_tensor, group=cpu_group, op=dist.ReduceOp.MIN) + return a_tensor.item(), b_tensor.item() diff --git a/vllm/forward_context.py b/vllm/forward_context.py new file mode 100644 index 0000000..e195a03 --- /dev/null +++ b/vllm/forward_context.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: Apache-2.0 + +import time +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional + +import torch +import torch.distributed as dist + +import vllm.envs as envs +from vllm.config import VllmConfig +from vllm.logger import init_logger + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + +logger = init_logger(__name__) + +track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0 +last_logging_time: float = 0 +forward_start_time: float = 0 +batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL +batchsize_forward_time: defaultdict = defaultdict(list) + + +@dataclass +class DPMetadata: + cu_tokens_across_dp_cpu: torch.Tensor + + +@dataclass +class ForwardContext: + # copy from vllm_config.compilation_config.static_forward_context + no_compile_layers: dict[str, Any] + # TODO: extend to support per-layer dynamic forward context + attn_metadata: "AttentionMetadata" # set dynamically for each forward pass + # TODO: remove after making all virtual_engines share the same kv cache + virtual_engine: int # set dynamically for each forward pass + # set dynamically for each forward pass + dp_metadata: Optional[DPMetadata] = None + + +_forward_context: Optional[ForwardContext] = None + + +def get_forward_context() -> ForwardContext: + """Get the current forward context.""" + assert _forward_context is not None, ( + "Forward context is not set. " + "Please use `set_forward_context` to set the forward context.") + return _forward_context + + +@contextmanager +def set_forward_context(attn_metadata: Any, + vllm_config: VllmConfig, + virtual_engine: int = 0, + num_tokens: int = 0): + """A context manager that stores the current forward context, + can be attention metadata, etc. + Here we can inject common logic for every model forward pass. + """ + global forward_start_time + need_to_track_batchsize = track_batchsize and attn_metadata is not None + if need_to_track_batchsize: + forward_start_time = time.perf_counter() + dp_metadata: Optional[DPMetadata] = None + if vllm_config.parallel_config.data_parallel_size > 1: + dp_size = vllm_config.parallel_config.data_parallel_size + dp_rank = vllm_config.parallel_config.data_parallel_rank + if attn_metadata is not None: + if hasattr(attn_metadata, "num_prefill_tokens"): + # for v0 attention backends + batchsize = attn_metadata.num_prefill_tokens + \ + attn_metadata.num_decode_tokens + else: + # for v1 attention backends + batchsize = attn_metadata.num_input_tokens + else: + batchsize = num_tokens + num_tokens_across_dp = [0] * dp_size + num_tokens_across_dp[dp_rank] = batchsize + num_tokens_tensor = torch.tensor(num_tokens_across_dp, + device="cpu", + dtype=torch.int32) + from vllm.distributed.parallel_state import get_dp_group + dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) + cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0) + dp_metadata = DPMetadata(cu_tokens_across_dp_cpu) + + global _forward_context + prev_context = _forward_context + _forward_context = ForwardContext( + no_compile_layers=vllm_config.compilation_config. + static_forward_context, + virtual_engine=virtual_engine, + attn_metadata=attn_metadata, + dp_metadata=dp_metadata) + try: + yield + finally: + global last_logging_time, batchsize_logging_interval + if need_to_track_batchsize: + if hasattr(attn_metadata, "num_prefill_tokens"): + # for v0 attention backends + batchsize = attn_metadata.num_prefill_tokens + \ + attn_metadata.num_decode_tokens + else: + # for v1 attention backends + batchsize = attn_metadata.num_input_tokens + # we use synchronous scheduling right now, + # adding a sync point here should not affect + # scheduling of the next batch + torch.cuda.synchronize() + now = time.perf_counter() + # time measurement is in milliseconds + batchsize_forward_time[batchsize].append( + (now - forward_start_time) * 1000) + if now - last_logging_time > batchsize_logging_interval: + last_logging_time = now + forward_stats = [] + for bs, times in batchsize_forward_time.items(): + if len(times) <= 1: + # can be cudagraph / profiling run + continue + medium = torch.quantile(torch.tensor(times), q=0.5).item() + medium = round(medium, 2) + forward_stats.append((bs, len(times), medium)) + forward_stats.sort(key=lambda x: x[1], reverse=True) + if forward_stats: + logger.info(("Batchsize forward time stats " + "(batchsize, count, median_time(ms)): %s"), + forward_stats) + _forward_context = prev_context diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py new file mode 100644 index 0000000..6f8f2cd --- /dev/null +++ b/vllm/inputs/__init__.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 + +from .data import (DecoderOnlyInputs, EncoderDecoderInputs, + ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType, + SingletonInputs, SingletonInputsAdapter, SingletonPrompt, + TextPrompt, TokenInputs, TokensPrompt, + build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, + token_inputs, zip_enc_dec_prompts) +from .registry import (DummyData, InputContext, InputProcessingContext, + InputRegistry) + +INPUT_REGISTRY = InputRegistry() +""" +The global :class:`~InputRegistry` which is used by :class:`~vllm.LLMEngine` +to dispatch data processing according to the target model. +""" + +__all__ = [ + "TextPrompt", + "TokensPrompt", + "PromptType", + "SingletonPrompt", + "ExplicitEncoderDecoderPrompt", + "TokenInputs", + "token_inputs", + "DecoderOnlyInputs", + "EncoderDecoderInputs", + "ProcessorInputs", + "SingletonInputs", + "SingletonInputsAdapter", + "build_explicit_enc_dec_prompt", + "to_enc_dec_tuple_list", + "zip_enc_dec_prompts", + "INPUT_REGISTRY", + "DummyData", + "InputContext", + "InputProcessingContext", + "InputRegistry", +] diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py new file mode 100644 index 0000000..138a8f6 --- /dev/null +++ b/vllm/inputs/data.py @@ -0,0 +1,405 @@ +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Iterable +from dataclasses import dataclass +from functools import cached_property +from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast + +import torch +from typing_extensions import NotRequired, TypedDict, TypeVar, assert_never + +if TYPE_CHECKING: + from vllm.multimodal import (MultiModalDataDict, MultiModalKwargs, + MultiModalPlaceholderDict) + from vllm.multimodal.inputs import MultiModalInputs + + +class TextPrompt(TypedDict): + """Schema for a text prompt.""" + + prompt: str + """The input text to be tokenized before passing to the model.""" + + multi_modal_data: NotRequired["MultiModalDataDict"] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ + + mm_processor_kwargs: NotRequired[dict[str, Any]] + """ + Optional multi-modal processor kwargs to be forwarded to the + multimodal input mapper & processor. Note that if multiple modalities + have registered mappers etc for the model being considered, we attempt + to pass the mm_processor_kwargs to each of them. + """ + + +class TokensPrompt(TypedDict): + """Schema for a tokenized prompt.""" + + prompt_token_ids: list[int] + """A list of token IDs to pass to the model.""" + + token_type_ids: NotRequired[list[int]] + """A list of token type IDs to pass to the cross encoder model.""" + + multi_modal_data: NotRequired["MultiModalDataDict"] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ + + mm_processor_kwargs: NotRequired[dict[str, Any]] + """ + Optional multi-modal processor kwargs to be forwarded to the + multimodal input mapper & processor. Note that if multiple modalities + have registered mappers etc for the model being considered, we attempt + to pass the mm_processor_kwargs to each of them. + """ + + +SingletonPrompt = Union[str, TextPrompt, TokensPrompt] +""" +Set of possible schemas for a single prompt: + +- A text prompt (:class:`str` or :class:`TextPrompt`) +- A tokenized prompt (:class:`TokensPrompt`) + +Note that "singleton" is as opposed to a data structure +which encapsulates multiple prompts, i.e. of the sort +which may be utilized for encoder/decoder models when +the user desires to express both the encoder & decoder +prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt` + +A prompt of type :class:`SingletonPrompt` may be employed +as (1) input to a decoder-only model, (2) input to +the encoder of an encoder/decoder model, in the scenario +where the decoder-prompt is not specified explicitly, or +(3) as a member of a larger data structure encapsulating +more than one prompt, i.e. :class:`ExplicitEncoderDecoderPrompt` +""" + +_T1_co = TypeVar("_T1_co", + bound=SingletonPrompt, + default=SingletonPrompt, + covariant=True) +_T2_co = TypeVar("_T2_co", + bound=SingletonPrompt, + default=SingletonPrompt, + covariant=True) + + +# TODO: Make fields ReadOnly once mypy supports it +class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): + """ + Represents an encoder/decoder model input prompt, + comprising an explicit encoder prompt and a decoder prompt. + + The encoder and decoder prompts, respectively, may be formatted + according to any of the :class:`SingletonPrompt` schemas, + and are not required to have the same schema. + + Only the encoder prompt may have multi-modal data. mm_processor_kwargs + should be at the top-level, and should not be set in the encoder/decoder + prompts, since they are agnostic to the encoder/decoder. + + Note that an :class:`ExplicitEncoderDecoderPrompt` may not + be used as an input to a decoder-only model, + and that the :code:`encoder_prompt` and :code:`decoder_prompt` + fields of this data structure themselves must be + :class:`SingletonPrompt` instances. + """ + + encoder_prompt: _T1_co + + decoder_prompt: Optional[_T2_co] + + mm_processor_kwargs: NotRequired[dict[str, Any]] + + +PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt] +""" +Set of possible schemas for an LLM input, including +both decoder-only and encoder/decoder input types: + +- A text prompt (:class:`str` or :class:`TextPrompt`) +- A tokenized prompt (:class:`TokensPrompt`) +- A single data structure containing both an encoder and a decoder prompt + (:class:`ExplicitEncoderDecoderPrompt`) +""" + + +class TokenInputs(TypedDict): + """Represents token-based inputs.""" + + type: Literal["token"] + """The type of inputs.""" + + prompt_token_ids: list[int] + """The token IDs of the prompt.""" + + token_type_ids: NotRequired[list[int]] + """The token type IDs of the prompt.""" + + prompt: NotRequired[str] + """ + The original prompt text corresponding to the token IDs, if available. + """ + + multi_modal_data: NotRequired["MultiModalDataDict"] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ + + multi_modal_inputs: NotRequired["MultiModalKwargs"] + """ + Optional multi-modal inputs to pass to the model, + if the model supports it. + """ + + multi_modal_placeholders: NotRequired["MultiModalPlaceholderDict"] + """ + Placeholder ranges for the multi-modal data. + """ + + multi_modal_hashes: NotRequired[list[str]] + """ + The hashes of the multi-modal data. + """ + + mm_processor_kwargs: NotRequired[dict[str, Any]] + """ + Optional multi-modal processor kwargs to be forwarded to the + multimodal input mapper & processor. Note that if multiple modalities + have registered mappers etc for the model being considered, we attempt + to pass the mm_processor_kwargs to each of them. + """ + + +def token_inputs( + prompt_token_ids: list[int], + token_type_ids: Optional[list[int]] = None, + prompt: Optional[str] = None, + multi_modal_data: Optional["MultiModalDataDict"] = None, + multi_modal_inputs: Optional["MultiModalKwargs"] = None, + multi_modal_hashes: Optional[list[str]] = None, + multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None, + mm_processor_kwargs: Optional[dict[str, Any]] = None, +) -> TokenInputs: + """Construct :class:`TokenInputs` from optional values.""" + inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids) + + if prompt is not None: + inputs["prompt"] = prompt + if token_type_ids is not None: + inputs["token_type_ids"] = token_type_ids + if multi_modal_data is not None: + inputs["multi_modal_data"] = multi_modal_data + if multi_modal_inputs is not None: + inputs["multi_modal_inputs"] = multi_modal_inputs + if multi_modal_hashes is not None: + inputs["multi_modal_hashes"] = multi_modal_hashes + if multi_modal_placeholders is not None: + inputs["multi_modal_placeholders"] = multi_modal_placeholders + if mm_processor_kwargs is not None: + inputs["mm_processor_kwargs"] = mm_processor_kwargs + + return inputs + + +DecoderOnlyInputs = Union[TokenInputs, "MultiModalInputs"] +""" +The inputs in :class:`~vllm.LLMEngine` before they are +passed to the model executor. +This specifies the data required for decoder-only models. +""" + + +class EncoderDecoderInputs(TypedDict): + """ + The inputs in :class:`~vllm.LLMEngine` before they are + passed to the model executor. + + This specifies the required data for encoder-decoder models. + """ + encoder: Union[TokenInputs, "MultiModalInputs"] + """The inputs for the encoder portion.""" + + decoder: Union[TokenInputs, "MultiModalInputs"] + """The inputs for the decoder portion.""" + + +SingletonInputs = Union[TokenInputs, "MultiModalInputs"] +""" +A processed :class:`SingletonPrompt` which can be passed to +:class:`vllm.sequence.Sequence`. +""" + + +@dataclass +class SingletonInputsAdapter: + """ + Unified interface to access the components of :class:`SingletonInputs`. + """ + inputs: SingletonInputs + + @cached_property + def prompt(self) -> Optional[str]: + inputs = self.inputs + + if inputs["type"] == "token" or inputs["type"] == "multimodal": + return inputs.get("prompt") + + assert_never(inputs) # type: ignore[arg-type] + + @cached_property + def prompt_token_ids(self) -> list[int]: + inputs = self.inputs + + if inputs["type"] == "token" or inputs["type"] == "multimodal": + return inputs.get("prompt_token_ids", []) + + assert_never(inputs) # type: ignore[arg-type] + + @cached_property + def token_type_ids(self) -> list[int]: + inputs = self.inputs + + if inputs["type"] == "token" or inputs["type"] == "multimodal": + return inputs.get("token_type_ids", []) + + assert_never(inputs) # type: ignore[arg-type] + + @cached_property + def prompt_embeds(self) -> Optional[torch.Tensor]: + inputs = self.inputs + + if inputs["type"] == "token" or inputs["type"] == "multimodal": + return None + + assert_never(inputs) # type: ignore[arg-type] + + @cached_property + def multi_modal_data(self) -> "MultiModalDataDict": + inputs = self.inputs + + if inputs["type"] == "token": + return inputs.get("multi_modal_data", {}) + + if inputs["type"] == "multimodal": + return inputs.get("mm_kwargs", {}) + + assert_never(inputs) # type: ignore[arg-type] + + @cached_property + def multi_modal_inputs(self) -> Union[dict, "MultiModalKwargs"]: + inputs = self.inputs + + if inputs["type"] == "token": + return inputs.get("multi_modal_inputs", {}) + + if inputs["type"] == "multimodal": + return inputs.get("mm_kwargs", {}) + + assert_never(inputs) # type: ignore[arg-type] + + @cached_property + def multi_modal_hashes(self) -> list[str]: + inputs = self.inputs + + if inputs["type"] == "token": + return inputs.get("multi_modal_hashes", []) + + if inputs["type"] == "multimodal": + # only the case when we use MultiModalInputs + return inputs.get("mm_hashes", []) # type: ignore[return-value] + + assert_never(inputs) # type: ignore[arg-type] + + @cached_property + def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict": + inputs = self.inputs + + if inputs["type"] == "token": + return inputs.get("multi_modal_placeholders", {}) + + if inputs["type"] == "multimodal": + return inputs.get("mm_placeholders", {}) + + assert_never(inputs) # type: ignore[arg-type] + + @cached_property + def mm_processor_kwargs(self) -> dict[str, Any]: + inputs = self.inputs + + if inputs["type"] == "token": + return inputs.get("mm_processor_kwargs", {}) + + if inputs["type"] == "multimodal": + return {} + + assert_never(inputs) # type: ignore[arg-type] + + +ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs] +""" +The inputs to :data:`vllm.inputs.InputProcessor`. +""" + +_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt) +_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt) + + +def build_explicit_enc_dec_prompt( + encoder_prompt: _T1, + decoder_prompt: Optional[_T2], + mm_processor_kwargs: Optional[dict[str, Any]] = None, +) -> ExplicitEncoderDecoderPrompt[_T1, _T2]: + if mm_processor_kwargs is None: + mm_processor_kwargs = {} + return ExplicitEncoderDecoderPrompt( + encoder_prompt=encoder_prompt, + decoder_prompt=decoder_prompt, + mm_processor_kwargs=mm_processor_kwargs) + + +def zip_enc_dec_prompts( + enc_prompts: Iterable[_T1], + dec_prompts: Iterable[Optional[_T2]], + mm_processor_kwargs: Optional[Union[Iterable[dict[str, Any]], + dict[str, Any]]] = None, +) -> list[ExplicitEncoderDecoderPrompt[_T1, _T2]]: + """ + Zip encoder and decoder prompts together into a list of + :class:`ExplicitEncoderDecoderPrompt` instances. + + ``mm_processor_kwargs`` may also be provided; if a dict is passed, the same + dictionary will be used for every encoder/decoder prompt. If an iterable is + provided, it will be zipped with the encoder/decoder prompts. + """ + if mm_processor_kwargs is None: + mm_processor_kwargs = cast(dict[str, Any], {}) + if isinstance(mm_processor_kwargs, dict): + return [ + build_explicit_enc_dec_prompt( + encoder_prompt, decoder_prompt, + cast(dict[str, Any], mm_processor_kwargs)) + for (encoder_prompt, + decoder_prompt) in zip(enc_prompts, dec_prompts) + ] + return [ + build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt, + mm_proc_kwargs) + for (encoder_prompt, decoder_prompt, mm_proc_kwargs + ) in zip(enc_prompts, dec_prompts, mm_processor_kwargs) + ] + + +def to_enc_dec_tuple_list( + enc_dec_prompts: Iterable[ExplicitEncoderDecoderPrompt[_T1, _T2]], +) -> list[tuple[_T1, Optional[_T2]]]: + return [(enc_dec_prompt["encoder_prompt"], + enc_dec_prompt["decoder_prompt"]) + for enc_dec_prompt in enc_dec_prompts] diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py new file mode 100644 index 0000000..28e207d --- /dev/null +++ b/vllm/inputs/parse.py @@ -0,0 +1,121 @@ +# SPDX-License-Identifier: Apache-2.0 +from collections.abc import Sequence +from typing import Literal, Optional, TypedDict, Union, cast, overload + +from typing_extensions import TypeIs + +from vllm.utils import is_list_of + +from .data import (ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType, + SingletonInputs, SingletonPrompt, TextPrompt, TokensPrompt) + + +class ParsedText(TypedDict): + content: str + is_tokens: Literal[False] + + +class ParsedTokens(TypedDict): + content: list[int] + is_tokens: Literal[True] + + +@overload +def parse_and_batch_prompt( + prompt: Union[str, list[str]]) -> Sequence[ParsedText]: + ... + + +@overload +def parse_and_batch_prompt( + prompt: Union[list[int], list[list[int]]]) -> Sequence[ParsedTokens]: + ... + + +def parse_and_batch_prompt( + prompt: Union[str, list[str], list[int], list[list[int]]], +) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]: + if isinstance(prompt, str): + # case 1: a string + return [ParsedText(content=prompt, is_tokens=False)] + + if isinstance(prompt, list): + if len(prompt) == 0: + raise ValueError("please provide at least one prompt") + + if is_list_of(prompt, str): + # case 2: array of strings + prompt = cast(list[str], prompt) + return [ + ParsedText(content=elem, is_tokens=False) for elem in prompt + ] + if is_list_of(prompt, int): + # case 3: array of tokens + prompt = cast(list[int], prompt) + return [ParsedTokens(content=prompt, is_tokens=True)] + if is_list_of(prompt, list): + prompt = cast(list[list[int]], prompt) + if len(prompt[0]) == 0: + raise ValueError("please provide at least one prompt") + + if is_list_of(prompt[0], int): + # case 4: array of token arrays + return [ + ParsedTokens(content=elem, is_tokens=True) + for elem in prompt + ] + + raise TypeError("prompt must be a string, array of strings, " + "array of tokens, or array of token arrays") + + +class ParsedStrPrompt(TypedDict): + type: Literal["str"] + content: str + + +class ParsedTextPrompt(TypedDict): + type: Literal["text"] + content: TextPrompt + + +class ParsedTokensPrompt(TypedDict): + type: Literal["tokens"] + content: TokensPrompt + + +def parse_singleton_prompt( + prompt: SingletonPrompt, +) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]: + if isinstance(prompt, str): + return ParsedStrPrompt(type="str", content=prompt) + elif isinstance(prompt, dict): + if "prompt_token_ids" in prompt: + return ParsedTokensPrompt(type="tokens", + content=prompt) # type: ignore + elif "prompt" in prompt: + return ParsedTextPrompt(type="text", content=prompt) + + raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt") + + +def is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]: + return isinstance(prompt, dict) and "prompt_token_ids" in prompt + + +def is_explicit_encoder_decoder_prompt( + prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]: + return isinstance(prompt, dict) and "encoder_prompt" in prompt + + +def split_enc_dec_inputs( + inputs: ProcessorInputs, +) -> tuple[Optional[SingletonInputs], SingletonInputs]: + if "encoder" in inputs and "decoder" in inputs: + # NOTE: This passes pyright but not mypy + return ( + inputs["encoder"], # type: ignore[typeddict-item] + inputs["decoder"], # type: ignore[typeddict-item] + ) + + return None, inputs diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py new file mode 100644 index 0000000..669fb96 --- /dev/null +++ b/vllm/inputs/preprocess.py @@ -0,0 +1,783 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +from collections.abc import Mapping +from typing import Optional, Union, cast + +from typing_extensions import assert_never + +from vllm.config import ModelConfig +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, + MultiModalInputs) +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup + +from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs, + PromptType, SingletonInputs, SingletonPrompt, token_inputs) +from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt + +logger = init_logger(__name__) + + +class InputPreprocessor: + + def __init__( + self, + model_config: ModelConfig, + tokenizer: Optional[BaseTokenizerGroup], + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + ) -> None: + super().__init__() + + self.model_config = model_config + self.tokenizer = tokenizer + self.mm_registry = mm_registry + + def get_tokenizer_group(self) -> BaseTokenizerGroup: + if self.tokenizer is None: + raise ValueError("You cannot pass text prompts when " + "`skip_tokenizer_init` is True") + + return self.tokenizer + + def get_bos_token_id(self, + lora_request: Optional[LoRARequest] = None + ) -> Optional[int]: + if self.tokenizer is None: + logger.warning("Using None for BOS token id because tokenizer " + "is not initialized") + return None + + return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id + + def get_eos_token_id(self, + lora_request: Optional[LoRARequest] = None + ) -> Optional[int]: + if self.tokenizer is None: + logger.warning("Using None for EOS token id because tokenizer " + "is not initialized") + return None + + return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id + + def get_decoder_start_token_id(self) -> Optional[int]: + ''' + Obtain the decoder start token id employed by an encoder/decoder + model. Returns None for non-encoder/decoder models or if the + model config is unavailable. + ''' + + if not self.model_config.is_encoder_decoder: + logger.warning_once( + "Using None for decoder start token id because " + "this is not an encoder/decoder model.") + return None + + if (self.model_config is None or self.model_config.hf_config is None): + logger.warning_once( + "Using None for decoder start token id because " + "model config is not available.") + return None + + dec_start_token_id = getattr(self.model_config.hf_config, + 'decoder_start_token_id', None) + if dec_start_token_id is None: + logger.warning_once( + "Falling back on for decoder start token " + "id because decoder start token id is not " + "available.") + dec_start_token_id = self.get_bos_token_id() + + return dec_start_token_id + + def _get_default_enc_dec_decoder_prompt(self) -> list[int]: + ''' + Specifically for encoder/decoder models: + generate a default decoder prompt for when + the user specifies only the encoder prompt. + + Encoder/decoder models utilize the decoder + prompt in different ways; as new models are + added, it is intended that this function + will be extended to produce differing + default decoder prompts, depending on the + model variety. + + Absent a special case, the default behavior + of this method is to mirror the behavior of + the HuggingFace (HF) GenerationMixin for a None + decoder prompt, which is to employ a logit processor + setting to force the first decoded token to be . + Here, this behavior is approximated by having the + "default" decoder prompt be . + + However, it is possible that in the future + other models may have different or more + complex logic for the default decoder prompt. + This motivates having a special helper method + for default decoder prompts. + + Returns: + + * prompt_token_ids + ''' + + bos_token_id = self.get_bos_token_id() + assert bos_token_id is not None + return [bos_token_id] + + def _prepare_decoder_input_ids_for_generation( + self, + decoder_input_ids: Optional[list[int]], + ) -> list[int]: + """ + Prepares `decoder_input_ids` for generation with encoder-decoder models. + + Based on + + https://github.com/huggingface/transformers/blob/ + 4037a2b5b1278736e566aec12e169100275545ea/ + src/transformers/generation/utils.py + + specifically GenerationMixin._prepare_decoder_input_ids_for_generation() + + Arguments: + + * decoder_input_ids: input token ids to preprocess + + Returns: + + * Processed token list + """ + + decoder_start_token_id = self.get_decoder_start_token_id() + assert decoder_start_token_id is not None + + if decoder_input_ids is None: + # no decoder prompt input -> + # use decoder_start_token_id as decoder_input_ids + decoder_input_ids = self._get_default_enc_dec_decoder_prompt() + + if (len(decoder_input_ids) == 0 + or decoder_input_ids[0] != decoder_start_token_id): + decoder_input_ids = [decoder_start_token_id] + decoder_input_ids + + return decoder_input_ids + + def _apply_prompt_adapter( + self, + prompt_token_ids: list[int], + prompt_adapter_request: Optional[PromptAdapterRequest], + ) -> list[int]: + if prompt_adapter_request: + prompt_token_ids = ( + [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + + prompt_token_ids) + + return prompt_token_ids + + def _tokenize_prompt( + self, + prompt: str, + lora_request: Optional[LoRARequest], + ) -> list[int]: + """ + Apply the model's tokenizer to a text prompt, returning the + corresponding token IDs. + """ + tokenizer = self.get_tokenizer_group() + add_special_tokens = None + if self.model_config.hf_config.model_type == "whisper": + # For Whisper, special tokens should be provided by the user based + # on the task and language of their request. Also needed to avoid + # appending an EOS token to the prompt which disrupts generation. + add_special_tokens = False + + if (self.model_config.encoder_config is not None + and self.model_config.encoder_config.get( + "do_lower_case", False)): + prompt = prompt.lower() + + return tokenizer.encode(prompt=prompt, + lora_request=lora_request, + add_special_tokens=add_special_tokens) + + async def _tokenize_prompt_async( + self, + prompt: str, + lora_request: Optional[LoRARequest], + ) -> list[int]: + """Async version of :meth:`_tokenize_prompt`.""" + tokenizer = self.get_tokenizer_group() + add_special_tokens = None + if self.model_config.hf_config.model_type == "whisper": + # For Whisper, special tokens should be provided by the user based + # on the task and language of their request. Also needed to avoid + # appending an EOS token to the prompt which disrupts generation. + add_special_tokens = False + return await tokenizer.encode_async( + prompt=prompt, + lora_request=lora_request, + add_special_tokens=add_special_tokens) + + def _can_process_multimodal(self) -> bool: + model_config = self.model_config + + if not model_config.is_multimodal_model: + raise ValueError("Your model does not support multi-modal inputs") + + # Interim measure so we can handle models that have yet to be + # updated to use the new multi-modal processor + can_process_multimodal = self.mm_registry.has_processor(model_config) + if not can_process_multimodal: + from vllm.model_executor.models.registry import _VLLM_MODELS + if not any(arch in _VLLM_MODELS + for arch in model_config.architectures): + logger.warning_once( + "Your model uses the legacy input pipeline, which will be " + "removed in an upcoming release. " + "Please upgrade to the new multi-modal processing pipeline " + "(https://docs.vllm.ai/en/latest/design/mm_processing.html)" + ) + + return can_process_multimodal + + def _process_multimodal( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + mm_processor_kwargs: Optional[Mapping[str, object]], + lora_request: Optional[LoRARequest], + return_mm_hashes: bool = False, + ) -> MultiModalInputs: + """ + Apply the model's multi-modal processor to a multi-modal prompt, + returning the corresponding token IDs and metadata. + """ + # At the moment on model (PrithviGeoSpatialMAE) requires to be + # initialized without a tokenizer while using also multi-modal + # input. + if not self.tokenizer: + tokenizer = object() # Dummy + else: + tokenizer_group = self.get_tokenizer_group() + tokenizer = tokenizer_group.get_lora_tokenizer(lora_request) + + mm_processor = self.mm_registry.create_processor(self.model_config, + tokenizer=tokenizer) + + if mm_processor_kwargs is None: + mm_processor_kwargs = {} + + return mm_processor.apply(prompt, mm_data, mm_processor_kwargs, + return_mm_hashes) + + async def _process_multimodal_async( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + mm_processor_kwargs: Optional[Mapping[str, object]], + lora_request: Optional[LoRARequest], + return_mm_hashes: bool = False, + ) -> MultiModalInputs: + """Async version of :meth:`_process_multimodal`.""" + # At the moment on model (PrithviGeoSpatialMAE) requires to be + # initialized without a tokenizer while using also multi-modal + # input. + if not self.tokenizer: + tokenizer = object() # Dummy + else: + tokenizer_group = self.get_tokenizer_group() + tokenizer = await tokenizer_group.get_lora_tokenizer_async( + lora_request) + + mm_processor = self.mm_registry.create_processor(self.model_config, + tokenizer=tokenizer) + if mm_processor_kwargs is None: + mm_processor_kwargs = {} + + return mm_processor.apply(prompt, mm_data, mm_processor_kwargs, + return_mm_hashes) + + def _prompt_to_llm_inputs( + self, + prompt: SingletonPrompt, + lora_request: Optional[LoRARequest] = None, + return_mm_hashes: bool = False, + ) -> SingletonInputs: + """ + Extract the singleton inputs from a prompt. + + Arguments: + + * prompt: single encoder or decoder input prompt + * lora_request: this is only valid for decoder prompts + * return_mm_hashes: whether to return multimodal hashes + + Returns: + + * :class:`SingletonInputs` instance + """ + parsed = parse_singleton_prompt(prompt) + + if parsed["type"] == "str": + prompt_text = parsed["content"] + prompt_token_ids = self._tokenize_prompt( + prompt_text, + lora_request=lora_request, + ) + + return token_inputs( + prompt=prompt_text, + prompt_token_ids=prompt_token_ids, + ) + + if parsed["type"] == "tokens": + tokens_content = parsed["content"] + + prompt_token_ids = tokens_content["prompt_token_ids"] + token_type_ids = tokens_content.get("token_type_ids") + multi_modal_data = tokens_content.get("multi_modal_data") + mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") + + if multi_modal_data is not None and self._can_process_multimodal(): + return self._process_multimodal( + prompt_token_ids, + multi_modal_data, + mm_processor_kwargs, + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + + return token_inputs( + prompt_token_ids=prompt_token_ids, + token_type_ids=token_type_ids, + multi_modal_data=multi_modal_data, + mm_processor_kwargs=mm_processor_kwargs, + ) + + if parsed["type"] == "text": + text_content = parsed["content"] + + prompt_text = text_content["prompt"] + multi_modal_data = text_content.get("multi_modal_data") + mm_processor_kwargs = text_content.get("mm_processor_kwargs") + + if multi_modal_data is not None and self._can_process_multimodal(): + return self._process_multimodal( + prompt_text, + multi_modal_data, + mm_processor_kwargs, + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + + prompt_token_ids = self._tokenize_prompt( + prompt_text, + lora_request=lora_request, + ) + + return token_inputs( + prompt=prompt_text, + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data, + mm_processor_kwargs=mm_processor_kwargs, + ) + + assert_never(parsed) + + async def _prompt_to_llm_inputs_async( + self, + prompt: SingletonPrompt, + lora_request: Optional[LoRARequest] = None, + return_mm_hashes: bool = False, + ) -> SingletonInputs: + """Async version of :meth:`_extract_prompt_components`.""" + parsed = parse_singleton_prompt(prompt) + + if parsed["type"] == "str": + prompt_text = parsed["content"] + prompt_token_ids = await self._tokenize_prompt_async( + prompt_text, + lora_request=lora_request, + ) + + return token_inputs( + prompt=prompt_text, + prompt_token_ids=prompt_token_ids, + ) + + if parsed["type"] == "tokens": + tokens_content = parsed["content"] + + prompt_token_ids = tokens_content["prompt_token_ids"] + multi_modal_data = tokens_content.get("multi_modal_data") + mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") + + if multi_modal_data is not None and self._can_process_multimodal(): + return await self._process_multimodal_async( + prompt_token_ids, + multi_modal_data, + mm_processor_kwargs, + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + + return token_inputs( + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data, + mm_processor_kwargs=mm_processor_kwargs, + ) + + if parsed["type"] == "text": + text_content = parsed["content"] + + prompt_text = text_content["prompt"] + multi_modal_data = text_content.get("multi_modal_data") + mm_processor_kwargs = text_content.get("mm_processor_kwargs") + + if multi_modal_data is not None and self._can_process_multimodal(): + return await self._process_multimodal_async( + prompt_text, + multi_modal_data, + mm_processor_kwargs, + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + + prompt_token_ids = await self._tokenize_prompt_async( + prompt_text, + lora_request=lora_request, + ) + + return token_inputs( + prompt=prompt_text, + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data, + mm_processor_kwargs=mm_processor_kwargs, + ) + + assert_never(parsed) + + def _build_enc_dec_llm_inputs( + self, + encoder_inputs: SingletonInputs, + decoder_inputs: Optional[SingletonInputs], + ) -> EncoderDecoderInputs: + if (encoder_inputs["type"] == "token" + or encoder_inputs["type"] == "multimodal"): + pass + else: + assert_never(encoder_inputs) # type: ignore[arg-type] + + if decoder_inputs is None: + if self.model_config.hf_config.model_type == "whisper": + # For Whisper models, the text prompt should go to the decoder. + # If no explicit encoder/decoder inputs, then copy the prompt + # from the encoder to the decoder. The encoder tokens are later + # overridden by the audio features. + dec_token_ids = encoder_inputs["prompt_token_ids"].copy() + else: + dec_token_ids = self._prepare_decoder_input_ids_for_generation( + None) + decoder_inputs = token_inputs(dec_token_ids) + elif (decoder_inputs["type"] == "token" + or decoder_inputs["type"] == "multimodal"): + dec_token_ids = self._prepare_decoder_input_ids_for_generation( + decoder_inputs["prompt_token_ids"]) + decoder_inputs["prompt_token_ids"] = dec_token_ids + + if "multi_modal_data" in decoder_inputs: + raise ValueError("Multi-modal decoder inputs of encoder-" + "decoder models are not supported yet") + else: + assert_never(encoder_inputs) # type: ignore[arg-type] + + return EncoderDecoderInputs( + encoder=encoder_inputs, + decoder=decoder_inputs, + ) + + def _separate_enc_dec_inputs_from_mm_processor_outputs( + self, + inputs: SingletonInputs, + decoder_inputs_to_override: Optional[SingletonInputs] = None, + ) -> tuple[SingletonInputs, SingletonInputs]: + """ + For encoder/decoder models only: + Separate Encoder/Decoder inputs from a MultiModalEncDecInputs + """ + encoder_inputs: SingletonInputs + decoder_inputs: SingletonInputs + if inputs["type"] == "multimodal": + # Multimodal data inputs + assert ("encoder_prompt" in inputs + and "encoder_prompt_token_ids" in inputs) + inputs = cast(MultiModalEncDecInputs, inputs) + encoder_inputs = token_inputs( + prompt=inputs["encoder_prompt"], + prompt_token_ids=inputs["encoder_prompt_token_ids"], + ) + if decoder_inputs_to_override is not None: + decoder_inputs = MultiModalInputs( + type="multimodal", + prompt=decoder_inputs_to_override.get("prompt", ""), + prompt_token_ids=decoder_inputs_to_override[ + "prompt_token_ids"], + mm_kwargs=inputs["mm_kwargs"], + mm_hashes=inputs["mm_hashes"], + mm_placeholders=inputs["mm_placeholders"], + ) + else: + decoder_inputs = MultiModalInputs( + type="multimodal", + prompt=inputs["prompt"], + prompt_token_ids=inputs["prompt_token_ids"], + mm_kwargs=inputs["mm_kwargs"], + mm_hashes=inputs["mm_hashes"], + mm_placeholders=inputs["mm_placeholders"], + ) + elif inputs["type"] == "token": + # Text-only inputs + encoder_inputs = token_inputs(prompt="", prompt_token_ids=[]) + decoder_inputs = decoder_inputs_to_override or inputs + else: + assert_never(inputs) # type: ignore[arg-type] + return encoder_inputs, decoder_inputs + + def _process_encoder_decoder_prompt( + self, + prompt: PromptType, + ) -> EncoderDecoderInputs: + """ + For encoder/decoder models only: + Process an input prompt into an :class:`EncoderDecoderInputs` instance. + + There are two types of input prompts: + singleton prompts which carry only the + encoder prompt, and explicit encoder/decoder + prompts which carry both the encoder and the + decoder prompts as member variables. + + This function handles the following scenarios: + * Singleton encoder prompt: extract encoder prompt + token ids & infer default decoder prompt token ids + * Explicit encoder/decoder prompt: extract encoder + and decoder prompt token ids + + Note that for Explicit encoder/decoder prompts, + each sub-prompt (encoder or decoder prompt) can + have any possible singleton type; thus this + method relies on helper functions to obtain + token ids for the sub-prompts. + + Arguments: + + * prompt: an input prompt + + Returns: + + * :class:`EncoderDecoderInputs` instance + """ + encoder_inputs: SingletonInputs + decoder_inputs: Optional[SingletonInputs] + + if is_explicit_encoder_decoder_prompt(prompt): + encoder_inputs = self._prompt_to_llm_inputs( + prompt["encoder_prompt"]) + if (decoder_input := prompt["decoder_prompt"]) is None: + decoder_inputs = None + else: + decoder_inputs = self._prompt_to_llm_inputs(decoder_input) + # For multimodal model, override decoder prompt from processor + # with explicit decoder prompt. + if self.model_config.is_multimodal_model and ( + self._can_process_multimodal()): + encoder_inputs, decoder_inputs = ( + self._separate_enc_dec_inputs_from_mm_processor_outputs( + encoder_inputs, decoder_inputs)) + else: + inputs = self._prompt_to_llm_inputs(prompt) + if self.model_config.is_multimodal_model and ( + self._can_process_multimodal()): + # Encoder-Decoder Multimodal model + encoder_inputs, decoder_inputs = ( + self._separate_enc_dec_inputs_from_mm_processor_outputs( + inputs)) + else: + encoder_inputs = inputs + + decoder_inputs = None + + return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) + + async def _process_encoder_decoder_prompt_async( + self, + prompt: PromptType, + ) -> EncoderDecoderInputs: + """Async version of :meth:`_process_encoder_decoder_prompt`.""" + encoder_inputs: SingletonInputs + decoder_inputs: Optional[SingletonInputs] + + if is_explicit_encoder_decoder_prompt(prompt): + encoder_task = self._prompt_to_llm_inputs_async( + prompt["encoder_prompt"]) + + if (decoder_input := prompt["decoder_prompt"]) is None: + encoder_inputs = await encoder_task + decoder_inputs = None + else: + decoder_task = self._prompt_to_llm_inputs_async(decoder_input) + + encoder_inputs, decoder_inputs = await asyncio.gather( + encoder_task, decoder_task) + + # For multimodal model, override decoder prompt from processor + # with explicit decoder prompt. + if self.model_config.is_multimodal_model and ( + self._can_process_multimodal()): + encoder_inputs, decoder_inputs = ( + self._separate_enc_dec_inputs_from_mm_processor_outputs( + encoder_inputs, decoder_inputs)) + else: + inputs = await self._prompt_to_llm_inputs_async(prompt) + if self.model_config.is_multimodal_model and ( + self._can_process_multimodal()): + # Encoder-Decoder Multimodal model + encoder_inputs, decoder_inputs = ( + self._separate_enc_dec_inputs_from_mm_processor_outputs( + inputs)) + else: + encoder_inputs = inputs + + decoder_inputs = None + + return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) + + def _build_decoder_only_llm_inputs( + self, + prompt_inputs: DecoderOnlyInputs, + prompt_adapter_request: Optional[PromptAdapterRequest], + ) -> DecoderOnlyInputs: + if (prompt_inputs["type"] == "token" + or prompt_inputs["type"] == "multimodal"): + prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter( + prompt_inputs["prompt_token_ids"], + prompt_adapter_request=prompt_adapter_request, + ) + else: + assert_never(prompt_inputs) # type: ignore[arg-type] + + return prompt_inputs + + def _process_decoder_only_prompt( + self, + prompt: SingletonPrompt, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + return_mm_hashes: bool = False, + ) -> DecoderOnlyInputs: + """ + For decoder-only models: + Process an input prompt into an :class:`DecoderOnlyInputs` instance. + + Arguments: + + * prompt: input prompt + * lora_request + * prompt_adapter_request + * return_mm_hashes + + Returns: + + * :class:`DecoderOnlyInputs` instance + """ + + prompt_comps = self._prompt_to_llm_inputs( + prompt, + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + + return self._build_decoder_only_llm_inputs( + prompt_comps, + prompt_adapter_request=prompt_adapter_request, + ) + + async def _process_decoder_only_prompt_async( + self, + prompt: SingletonPrompt, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + return_mm_hashes: bool = False, + ) -> DecoderOnlyInputs: + """Async version of :meth:`_process_decoder_only_prompt`.""" + prompt_comps = await self._prompt_to_llm_inputs_async( + prompt, + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + + return self._build_decoder_only_llm_inputs( + prompt_comps, + prompt_adapter_request=prompt_adapter_request, + ) + + def preprocess( + self, + prompt: PromptType, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + return_mm_hashes: bool = False, + ) -> ProcessorInputs: + """Preprocess the input prompt.""" + if self.model_config.is_encoder_decoder: + assert not return_mm_hashes, ( + "Multimodal hashes for encoder-decoder models should not be ", + "returned until they are supported on vLLM V1.") + # Encoder-decoder model requires special mapping of + # input prompts to encoder & decoder + return self._process_encoder_decoder_prompt(prompt) + + if is_explicit_encoder_decoder_prompt(prompt): + raise ValueError("Cannot pass encoder-decoder prompt " + "to decoder-only models") + + # Decoder-only operation + return self._process_decoder_only_prompt( + prompt, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + return_mm_hashes=return_mm_hashes, + ) + + async def preprocess_async( + self, + prompt: PromptType, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + return_mm_hashes: bool = False, + ) -> ProcessorInputs: + """Async version of :meth:`preprocess`.""" + if self.model_config.is_encoder_decoder: + assert not return_mm_hashes, ( + "Multimodal hashes for encoder-decoder models should not be ", + "returned until they are supported on vLLM V1.") + # Encoder-decoder model requires special mapping of + # input prompts to encoder & decoder + return await self._process_encoder_decoder_prompt_async(prompt) + + if is_explicit_encoder_decoder_prompt(prompt): + raise ValueError("Cannot pass encoder-decoder prompt " + "to decoder-only models") + + # Decoder-only operation + return await self._process_decoder_only_prompt_async( + prompt, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + return_mm_hashes=return_mm_hashes, + ) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py new file mode 100644 index 0000000..0579893 --- /dev/null +++ b/vllm/inputs/registry.py @@ -0,0 +1,487 @@ +# SPDX-License-Identifier: Apache-2.0 + +import functools +from collections import UserDict +from collections.abc import Mapping +from dataclasses import dataclass +from typing import (TYPE_CHECKING, Any, Callable, NamedTuple, Optional, + Protocol, Union) + +from torch import nn +from transformers import BatchFeature, PretrainedConfig, ProcessorMixin +from typing_extensions import TypeVar, assert_never + +from vllm.logger import init_logger +from vllm.transformers_utils.processor import cached_processor_from_config +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides, + resolve_mm_processor_kwargs) + +from .data import ProcessorInputs, SingletonInputs +from .parse import split_enc_dec_inputs + +if TYPE_CHECKING: + from vllm.config import ModelConfig + from vllm.multimodal import (MultiModalDataDict, MultiModalPlaceholderDict, + MultiModalRegistry) + from vllm.sequence import SequenceData + +logger = init_logger(__name__) + +_T = TypeVar("_T") +_C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig) +_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin) + + +@dataclass(frozen=True) +class InputContext: + """ + Contains information about the model which may be used to + modify the inputs. + """ + + model_config: "ModelConfig" + """The configuration of the model.""" + + def get_hf_config( + self, + typ: Union[type[_C], tuple[type[_C], ...]] = PretrainedConfig, + /, + ) -> _C: + """ + Get the HuggingFace configuration + (:class:`transformers.PretrainedConfig`) of the model, + additionally checking its type. + + Raises: + TypeError: If the configuration is not of the specified type. + """ + hf_config = self.model_config.hf_config + if not isinstance(hf_config, typ): + raise TypeError("Invalid type of HuggingFace config. " + f"Expected type: {typ}, but " + f"found type: {type(hf_config)}") + + return hf_config + + def get_hf_image_processor_config(self) -> dict[str, Any]: + """ + Get the HuggingFace image processor configuration of the model. + """ + return self.model_config.hf_image_processor_config + + def get_mm_config(self): + """ + Get the multimodal config of the model. + + Raises: + RuntimeError: If the model is not a multimodal model. + """ + mm_config = self.model_config.multimodal_config + if mm_config is None: + raise RuntimeError("Not a multimodal model") + + return mm_config + + def get_hf_processor( + self, + typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin, + /, + **kwargs: object, + ) -> _P: + """ + Get the HuggingFace processor + (:class:`transformers.ProcessorMixin`) of the model, + additionally checking its type. + + Raises: + TypeError: If the processor is not of the specified type. + """ + return cached_processor_from_config( + self.model_config, + processor_cls=typ, + **kwargs, + ) + + def init_processor( + self, + typ: type[_T], + /, + **kwargs: object, + ) -> _T: + """ + Initialize a HuggingFace-like processor class, merging the + keyword arguments with those in the model's configuration. + """ + base_kwargs = self.model_config.mm_processor_kwargs + if base_kwargs is None: + base_kwargs = {} + + merged_kwargs = {**base_kwargs, **kwargs} + + return typ(**merged_kwargs) + + +@dataclass(frozen=True) +class InputProcessingContext(InputContext): + tokenizer: AnyTokenizer + """The tokenizer used to tokenize the inputs.""" + + def get_hf_processor( + self, + typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin, + /, + **kwargs: object, + ) -> _P: + return super().get_hf_processor( + typ, + tokenizer=self.tokenizer, + **kwargs, + ) + + def call_hf_processor( + self, + hf_processor: ProcessorMixin, + data: Mapping[str, object], + kwargs: Mapping[str, object] = {}, + ) -> BatchFeature: + """ + Call :code:`hf_processor` on the prompt :code:`data` + (text, image, audio...) with configurable options :code:`kwargs`. + """ + assert callable(hf_processor) + + base_kwargs = self.model_config.mm_processor_kwargs + if base_kwargs is None: + base_kwargs = {} + + merged_kwargs = resolve_mm_processor_kwargs( + base_kwargs, + kwargs, + hf_processor, + requires_kw_only=False, + allow_var_kwargs=True, + ) + + try: + return hf_processor(**data, **merged_kwargs, return_tensors="pt") + except Exception as exc: + msg = (f"Failed to apply {type(hf_processor).__name__} " + f"on data={data} with kwargs={merged_kwargs}") + + raise RuntimeError(msg) from exc + + +N = TypeVar("N", bound=type[nn.Module]) + + +class DummyData(NamedTuple): + """Dummy data used for profiling.""" + + seq_data: "SequenceData" + multi_modal_data: Optional["MultiModalDataDict"] = None + multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None + + +class DummyDataFactory(Protocol): + + def __call__( + self, + ctx: InputContext, + seq_len: int, + mm_counts: Mapping[str, int], + **mm_processor_kwargs: Any, + ) -> DummyData: + """ + Create dummy data to be inputted into the model. + + Note: + :data:`InputProcessor` is not applied to the dummy data. + + The :code:`mm_processor_kwargs` are overrides provided at + initialization time to values in the config whose values + may affect the number of tokens per instance. + """ + ... + + +class _MultiModalCounts(UserDict[str, int]): + """ + Wraps `mm_counts` for a more informative error message + when attempting to access a plugin that does not exist. + """ + + def __getitem__(self, key: str) -> int: + try: + return super().__getitem__(key) + except KeyError as exc: + msg = (f"There is no multi-modal plugin with the key: {key}. " + f"Available keys: {set(self.keys())}") + raise KeyError(msg) from exc + + +InputProcessor = Callable[[InputContext, ProcessorInputs], ProcessorInputs] +"""Preprocess the inputs to the model.""" + + +class InputRegistry: + """ + A registry to dispatch data processing + according to the target model. + """ + + def __init__(self) -> None: + self._dummy_factories_by_model_type = \ + ClassRegistry[nn.Module, DummyDataFactory]() + self._dummy_encoder_factories_by_model_type = \ + ClassRegistry[nn.Module, DummyDataFactory]() + self._input_processors_by_model_type = \ + ClassRegistry[nn.Module, InputProcessor]() + + def _default_dummy_data_factory( + self, + ctx: InputContext, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> DummyData: + """ + The default dummy data factory represents the longest possible text + that can be inputted to the model. + + Note: + :data:`InputProcessor` is not applied to the dummy data. + """ + # Avoid circular import + from vllm.sequence import SequenceData + + return DummyData(SequenceData.from_prompt_token_counts((0, seq_len))) + + def register_dummy_data(self, factory: DummyDataFactory): + """ + Register a dummy data factory to a model class. + + During memory profiling, the provided function is invoked to create + dummy data to be inputted into the model. The resulting memory usage + should be an upper bound of what the model would use at inference time. + """ + + def wrapper(model_cls: N) -> N: + if self._dummy_factories_by_model_type.contains(model_cls, + strict=True): + logger.warning( + "Model class %s already has dummy data " + "registered to %s. It is overwritten by the new one.", + model_cls, self) + + self._dummy_factories_by_model_type[model_cls] = factory + + return model_cls + + return wrapper + + def _get_dummy_data_factory(self, model_cls: type[nn.Module]): + return self._dummy_factories_by_model_type \ + .get(model_cls, self._default_dummy_data_factory) + + def register_dummy_encoder_data(self, factory: DummyDataFactory): + """ + Register a dummy encoder data factory to a model class + + This is similar to :meth:`~register_dummy_data`, but for encoder input. + """ + + def wrapper(model_cls: N) -> N: + if self._dummy_encoder_factories_by_model_type.contains( + model_cls, strict=True): + logger.warning( + "Model class %s already has dummy encoder data " + "registered to %s. It is overwritten by the new one.", + model_cls, self) + + self._dummy_encoder_factories_by_model_type[model_cls] = factory + + return model_cls + + return wrapper + + def _get_dummy_encoder_data_factory(self, model_cls: type[nn.Module]): + return self._dummy_encoder_factories_by_model_type \ + .get(model_cls, self._default_dummy_data_factory) + + def dummy_data_for_profiling( + self, + model_config: "ModelConfig", + seq_len: int, + mm_registry: "MultiModalRegistry", + is_encoder_data: bool = False, + ) -> DummyData: + """ + Create dummy data for profiling the memory usage of a model. + + The model is identified by ``model_config``. + + Note: + This should be called after + :meth:`~MultiModalRegistry.init_mm_limits_per_prompt`. + """ + # Avoid circular import + from vllm.model_executor.model_loader import get_model_architecture + from vllm.multimodal import MultiModalKwargs + from vllm.multimodal.profiling import MultiModalProfiler + from vllm.sequence import SequenceData + + if mm_registry.has_processor(model_config): + processor = mm_registry.create_processor(model_config, + disable_cache=True) + profiler = MultiModalProfiler(processor) + + dummy_data_v1 = (profiler.get_encoder_dummy_data(seq_len) + if is_encoder_data else + profiler.get_decoder_dummy_data(seq_len)) + _seq_data = SequenceData.from_seqs( + dummy_data_v1.prompt_token_ids) # type: ignore[attr-defined] + + dummy_data = DummyData( + seq_data=_seq_data, + multi_modal_data=getattr(dummy_data_v1, "multi_modal_data", + None), + multi_modal_placeholders=getattr(dummy_data_v1, + "multi_modal_placeholders", + None), + ) + else: + model_cls, _ = get_model_architecture(model_config) + if is_encoder_data: + dummy_factory = self._get_dummy_encoder_data_factory(model_cls) + else: + dummy_factory = self._get_dummy_data_factory(model_cls) + mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) + mm_processor_kwargs = get_allowed_kwarg_only_overrides( + dummy_factory, + overrides=model_config.mm_processor_kwargs, + requires_kw_only=False, + allow_var_kwargs=True, + ) + + dummy_data = dummy_factory(InputContext(model_config), seq_len, + _MultiModalCounts(mm_counts), + **mm_processor_kwargs) + + # Having more tokens is over-conservative but otherwise fine + num_tokens = dummy_data.seq_data.prompt_token_ids + if len(num_tokens) < seq_len: + if is_encoder_data: + logger.warning_once( + f"Expected at least {seq_len} dummy encoder tokens for " + f"profiling, but found {len(num_tokens)} tokens instead.") + else: + raise AssertionError( + f"Expected at least {seq_len} dummy tokens for profiling, " + f"but found {len(num_tokens)} tokens instead.") + + if (dummy_data.multi_modal_data is not None and + not isinstance(dummy_data.multi_modal_data, MultiModalKwargs)): + for k, v in dummy_data.multi_modal_data.items(): + num_items = len(v) if isinstance(v, list) else 1 + num_expected = mm_counts[k] + assert num_items >= num_expected, ( + f"Expected at least {num_expected} dummy '{k}' instances " + f"for profiling, but found {num_items} instances instead.") + + return dummy_data + + def _default_input_processor( + self, + ctx: InputContext, + inputs: ProcessorInputs, + **kwargs: object, + ) -> ProcessorInputs: + """The default input processor is a no-op.""" + return inputs + + def register_input_processor(self, processor: InputProcessor): + """ + Register an input processor to a model class. + + The provided function is invoked on each input to the model. This + happens before + :meth:`~vllm.multimodal.registry.MultiModalRegistry.map_input`. + """ + + def wrapper(model_cls: N) -> N: + if self._input_processors_by_model_type.contains(model_cls, + strict=True): + logger.warning( + "Model class %s already has input processor " + "registered to %s. It is overwritten by the new one.", + model_cls, self) + + self._input_processors_by_model_type[model_cls] = processor + + return model_cls + + return wrapper + + def _get_model_input_processor(self, model_cls: type[nn.Module]): + return self._input_processors_by_model_type \ + .get(model_cls, self._default_input_processor) + + def _ensure_mm_kwargs( + self, + inputs: SingletonInputs, + mm_processor_kwargs: dict[str, Any], + ): + if inputs["type"] == "token": + # In case the input processor for that model fails to set it + if "mm_processor_kwargs" not in inputs: + inputs["mm_processor_kwargs"] = mm_processor_kwargs + elif inputs["type"] == "multimodal": + # Be more strict in V2 + assert "mm_kwargs" in inputs + else: + assert_never(inputs["type"]) # type: ignore[arg-type] + + def process_input(self, model_config: "ModelConfig", + inputs: ProcessorInputs) -> ProcessorInputs: + """ + Apply an input processor to an instance of model inputs. + + The model is identified by ``model_config``. + """ + # Avoid circular import + from vllm.model_executor.model_loader import get_model_architecture + + model_cls, _ = get_model_architecture(model_config) + processor = self._get_model_input_processor(model_cls) + + # Handle multimodal processor kwargs with priority: + # Inference kwargs -> Init kwargs -> {} + # If it's empty, it'll fall back to the default kwarg values + mm_processor_kwargs = resolve_mm_processor_kwargs( + model_config.mm_processor_kwargs, + inputs.get("mm_processor_kwargs", {}), # type: ignore + processor, + requires_kw_only=False, + allow_var_kwargs=True, + ) + + processed_inputs = processor( + InputContext(model_config), + inputs, + **mm_processor_kwargs, + ) + + encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) + if encoder_inputs is not None: + self._ensure_mm_kwargs(encoder_inputs, mm_processor_kwargs) + if decoder_inputs is not None: + self._ensure_mm_kwargs(decoder_inputs, mm_processor_kwargs) + + return processed_inputs + + def create_input_processor(self, model_config: "ModelConfig"): + """ + Create an input processor (see :meth:`_process_input`) for a + specific model. + """ + return functools.partial(self.process_input, model_config) diff --git a/vllm/jsontree.py b/vllm/jsontree.py new file mode 100644 index 0000000..91cd7cb --- /dev/null +++ b/vllm/jsontree.py @@ -0,0 +1,79 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Helper functions to work with nested JSON structures.""" +from collections.abc import Iterable +from functools import reduce +from typing import Callable, TypeVar, Union, overload + +_T = TypeVar("_T") +_U = TypeVar("_U") + +JSONTree = Union[dict[str, "JSONTree[_T]"], list["JSONTree[_T]"], + tuple["JSONTree[_T]", ...], _T] +"""A nested JSON structure where the leaves need not be JSON-serializable.""" + + +def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]: + """Iterate through each leaf in a nested JSON structure.""" + if isinstance(value, dict): + for v in value.values(): + yield from json_iter_leaves(v) + elif isinstance(value, (list, tuple)): + for v in value: + yield from json_iter_leaves(v) + else: + yield value + + +def json_map_leaves( + func: Callable[[_T], _U], + value: JSONTree[_T], +) -> JSONTree[_U]: + """Apply a function to each leaf in a nested JSON structure.""" + if isinstance(value, dict): + return {k: json_map_leaves(func, v) for k, v in value.items()} + elif isinstance(value, list): + return [json_map_leaves(func, v) for v in value] + elif isinstance(value, tuple): + return tuple(json_map_leaves(func, v) for v in value) + else: + return func(value) + + +@overload +def json_reduce_leaves( + func: Callable[[_T, _T], _T], + value: JSONTree[_T], + /, +) -> _T: + ... + + +@overload +def json_reduce_leaves( + func: Callable[[_U, _T], _U], + value: JSONTree[_T], + initial: _U, + /, +) -> _U: + ... + + +def json_reduce_leaves( + func: Callable[..., Union[_T, _U]], + value: JSONTree[_T], + initial: _U = ..., # type: ignore[assignment] + /, +) -> Union[_T, _U]: + """ + Apply a function of two arguments cumulatively to each leaf in a + nested JSON structure, from left to right, so as to reduce the + sequence to a single value. + """ + if initial is ...: + return reduce(func, json_iter_leaves(value)) # type: ignore[arg-type] + + return reduce( + func, # type: ignore[arg-type] + json_iter_leaves(value), + initial, + ) diff --git a/vllm/logger.py b/vllm/logger.py new file mode 100644 index 0000000..2b0b9da --- /dev/null +++ b/vllm/logger.py @@ -0,0 +1,210 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Logging configuration for vLLM.""" +import datetime +import json +import logging +import os +import sys +from functools import lru_cache, partial +from logging import Logger +from logging.config import dictConfig +from os import path +from types import MethodType +from typing import Any, Optional, cast + +import vllm.envs as envs + +VLLM_CONFIGURE_LOGGING = envs.VLLM_CONFIGURE_LOGGING +VLLM_LOGGING_CONFIG_PATH = envs.VLLM_LOGGING_CONFIG_PATH +VLLM_LOGGING_LEVEL = envs.VLLM_LOGGING_LEVEL +VLLM_LOGGING_PREFIX = envs.VLLM_LOGGING_PREFIX + +_FORMAT = (f"{VLLM_LOGGING_PREFIX}%(levelname)s %(asctime)s " + "[%(filename)s:%(lineno)d] %(message)s") +_DATE_FORMAT = "%m-%d %H:%M:%S" + +DEFAULT_LOGGING_CONFIG = { + "formatters": { + "vllm": { + "class": "vllm.logging_utils.NewLineFormatter", + "datefmt": _DATE_FORMAT, + "format": _FORMAT, + }, + }, + "handlers": { + "vllm": { + "class": "logging.StreamHandler", + "formatter": "vllm", + "level": VLLM_LOGGING_LEVEL, + "stream": "ext://sys.stdout", + }, + }, + "loggers": { + "vllm": { + "handlers": ["vllm"], + "level": "DEBUG", + "propagate": False, + }, + }, + "version": 1, + "disable_existing_loggers": False +} + + +@lru_cache +def _print_info_once(logger: Logger, msg: str) -> None: + # Set the stacklevel to 2 to print the original caller's line info + logger.info(msg, stacklevel=2) + + +@lru_cache +def _print_warning_once(logger: Logger, msg: str) -> None: + # Set the stacklevel to 2 to print the original caller's line info + logger.warning(msg, stacklevel=2) + + +class _VllmLogger(Logger): + """ + Note: + This class is just to provide type information. + We actually patch the methods directly on the :class:`logging.Logger` + instance to avoid conflicting with other libraries such as + `intel_extension_for_pytorch.utils._logger`. + """ + + def info_once(self, msg: str) -> None: + """ + As :meth:`info`, but subsequent calls with the same message + are silently dropped. + """ + _print_info_once(self, msg) + + def warning_once(self, msg: str) -> None: + """ + As :meth:`warning`, but subsequent calls with the same message + are silently dropped. + """ + _print_warning_once(self, msg) + + +def _configure_vllm_root_logger() -> None: + logging_config = dict[str, Any]() + + if not VLLM_CONFIGURE_LOGGING and VLLM_LOGGING_CONFIG_PATH: + raise RuntimeError( + "VLLM_CONFIGURE_LOGGING evaluated to false, but " + "VLLM_LOGGING_CONFIG_PATH was given. VLLM_LOGGING_CONFIG_PATH " + "implies VLLM_CONFIGURE_LOGGING. Please enable " + "VLLM_CONFIGURE_LOGGING or unset VLLM_LOGGING_CONFIG_PATH.") + + if VLLM_CONFIGURE_LOGGING: + logging_config = DEFAULT_LOGGING_CONFIG + + if VLLM_LOGGING_CONFIG_PATH: + if not path.exists(VLLM_LOGGING_CONFIG_PATH): + raise RuntimeError( + "Could not load logging config. File does not exist: %s", + VLLM_LOGGING_CONFIG_PATH) + with open(VLLM_LOGGING_CONFIG_PATH, encoding="utf-8") as file: + custom_config = json.loads(file.read()) + + if not isinstance(custom_config, dict): + raise ValueError("Invalid logging config. Expected dict, got %s.", + type(custom_config).__name__) + logging_config = custom_config + + for formatter in logging_config.get("formatters", {}).values(): + # This provides backwards compatibility after #10134. + if formatter.get("class") == "vllm.logging.NewLineFormatter": + formatter["class"] = "vllm.logging_utils.NewLineFormatter" + + if logging_config: + dictConfig(logging_config) + + +def init_logger(name: str) -> _VllmLogger: + """The main purpose of this function is to ensure that loggers are + retrieved in such a way that we can be sure the root vllm logger has + already been configured.""" + + logger = logging.getLogger(name) + + methods_to_patch = { + "info_once": _print_info_once, + "warning_once": _print_warning_once, + } + + for method_name, method in methods_to_patch.items(): + setattr(logger, method_name, MethodType(method, logger)) + + return cast(_VllmLogger, logger) + + +# The root logger is initialized when the module is imported. +# This is thread-safe as the module is only imported once, +# guaranteed by the Python GIL. +_configure_vllm_root_logger() + +logger = init_logger(__name__) + + +def _trace_calls(log_path, root_dir, frame, event, arg=None): + if event in ['call', 'return']: + # Extract the filename, line number, function name, and the code object + filename = frame.f_code.co_filename + lineno = frame.f_lineno + func_name = frame.f_code.co_name + if not filename.startswith(root_dir): + # only log the functions in the vllm root_dir + return + # Log every function call or return + try: + last_frame = frame.f_back + if last_frame is not None: + last_filename = last_frame.f_code.co_filename + last_lineno = last_frame.f_lineno + last_func_name = last_frame.f_code.co_name + else: + # initial frame + last_filename = "" + last_lineno = 0 + last_func_name = "" + with open(log_path, 'a') as f: + ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f") + if event == 'call': + f.write(f"{ts} Call to" + f" {func_name} in {filename}:{lineno}" + f" from {last_func_name} in {last_filename}:" + f"{last_lineno}\n") + else: + f.write(f"{ts} Return from" + f" {func_name} in {filename}:{lineno}" + f" to {last_func_name} in {last_filename}:" + f"{last_lineno}\n") + except NameError: + # modules are deleted during shutdown + pass + return partial(_trace_calls, log_path, root_dir) + + +def enable_trace_function_call(log_file_path: str, + root_dir: Optional[str] = None): + """ + Enable tracing of every function call in code under `root_dir`. + This is useful for debugging hangs or crashes. + `log_file_path` is the path to the log file. + `root_dir` is the root directory of the code to trace. If None, it is the + vllm root directory. + + Note that this call is thread-level, any threads calling this function + will have the trace enabled. Other threads will not be affected. + """ + logger.warning( + "VLLM_TRACE_FUNCTION is enabled. It will record every" + " function executed by Python. This will slow down the code. It " + "is suggested to be used for debugging hang or crashes only.") + logger.info("Trace frame log is saved to %s", log_file_path) + if root_dir is None: + # by default, this is the vllm root directory + root_dir = os.path.dirname(os.path.dirname(__file__)) + sys.settrace(partial(_trace_calls, log_file_path, root_dir)) diff --git a/vllm/logging_utils/__init__.py b/vllm/logging_utils/__init__.py new file mode 100644 index 0000000..7ab4632 --- /dev/null +++ b/vllm/logging_utils/__init__.py @@ -0,0 +1,7 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm.logging_utils.formatter import NewLineFormatter + +__all__ = [ + "NewLineFormatter", +] diff --git a/vllm/logging_utils/formatter.py b/vllm/logging_utils/formatter.py new file mode 100644 index 0000000..010b0a1 --- /dev/null +++ b/vllm/logging_utils/formatter.py @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: Apache-2.0 + +import logging + + +class NewLineFormatter(logging.Formatter): + """Adds logging prefix to newlines to align multi-line messages.""" + + def __init__(self, fmt, datefmt=None, style="%"): + logging.Formatter.__init__(self, fmt, datefmt, style) + + def format(self, record): + msg = logging.Formatter.format(self, record) + if record.message != "": + parts = msg.split(record.message) + msg = msg.replace("\n", "\r\n" + parts[0]) + return msg diff --git a/vllm/logits_process.py b/vllm/logits_process.py new file mode 100644 index 0000000..e3faf20 --- /dev/null +++ b/vllm/logits_process.py @@ -0,0 +1,121 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable, Union + +import torch + +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer + +LogitsProcessor = Union[Callable[[list[int], torch.Tensor], torch.Tensor], + Callable[[list[int], list[int], torch.Tensor], + torch.Tensor]] +"""LogitsProcessor is a function that takes a list +of previously generated tokens, the logits tensor +for the next token and, optionally, prompt tokens as a +first argument, and returns a modified tensor of logits +to sample from.""" + + +def get_bad_words_logits_processors( + bad_words: list[str], + tokenizer: AnyTokenizer) -> list[LogitsProcessor]: + bad_words_ids: list[list[int]] = list() + + for bad_word in bad_words: + # To prohibit words both at the beginning + # and in the middle of text + # (related to add_prefix_space tokenizer parameter) + for add_prefix_space in [False, True]: + prefix = " " if add_prefix_space else "" + prompt = prefix + bad_word.lstrip() + + if isinstance(tokenizer, MistralTokenizer): + # Mistral tokenizers should not add special tokens + prompt_token_ids = tokenizer.encode(text=prompt) + else: + prompt_token_ids = tokenizer.encode(text=prompt, + add_special_tokens=False) + + # If no space at the beginning + # or if prefix space produces a new word token + if (not add_prefix_space) or ( + add_prefix_space + and prompt_token_ids[0] != bad_words_ids[-1][0] + and len(prompt_token_ids) == len(bad_words_ids[-1])): + bad_words_ids.append(prompt_token_ids) + + return [NoBadWordsLogitsProcessor(bad_words_ids=bad_words_ids)] + + +class NoBadWordsLogitsProcessor: + _SMALLEST_LOGIT = float("-inf") + _NEUTRAL_LOGIT = 0.0 + + def __init__(self, bad_words_ids: list[list[int]]): + self.bad_words_ids = bad_words_ids + self.word_bias: torch.FloatTensor = None + + def __call__( + self, + past_tokens_ids: Union[list[int], tuple[int]], + logits: torch.FloatTensor, + ) -> torch.Tensor: + if self.word_bias is None: + self._init_word_bias(logits=logits) + + last_token_bias = torch.zeros_like(logits) + + for bad_word_ids in self.bad_words_ids: + if len(bad_word_ids) == 1: # 1-token words already processed + continue + + if len(bad_word_ids) > len(past_tokens_ids) + 1: + continue + + prefix_length = len(bad_word_ids) - 1 + last_token_id = bad_word_ids[-1] + actual_prefix = past_tokens_ids[-prefix_length:] + expected_prefix = bad_word_ids[:prefix_length] + + assert len(actual_prefix) == len(expected_prefix) + + is_match = tuple(actual_prefix) == tuple(expected_prefix) + last_token_bias[last_token_id] += (self._SMALLEST_LOGIT if is_match + else self._NEUTRAL_LOGIT) + + logits = logits + self.word_bias + last_token_bias + + return logits + + def _init_word_bias(self, logits: torch.FloatTensor) -> None: + # Code based on NoBadWordsLogitsProcessor and SequenceBiasLogitsProcessor # noqa: E501 + # from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py + + vocab_size = logits.shape[-1] + + self._check_token_ids_bounds(vocab_size=vocab_size) + + self.word_bias = torch.zeros((vocab_size, ), + dtype=torch.float, + device=logits.device) + + for bad_word_ids in self.bad_words_ids: + if len(bad_word_ids) == 1: + bad_word_id = bad_word_ids[-1] + self.word_bias[bad_word_id] = self._SMALLEST_LOGIT + + def _check_token_ids_bounds(self, vocab_size: int) -> None: + invalid_token_ids = [] + + for bad_word_ids in self.bad_words_ids: + for token_id in bad_word_ids: + if token_id < 0 or token_id >= vocab_size: + invalid_token_ids.append(token_id) + + if len(invalid_token_ids) > 0: + raise ValueError( + f"The model vocabulary size is {vocab_size}," + f" but the following tokens" + f" were specified as bad: {invalid_token_ids}." + f" All token id values should be integers satisfying:" + f" 0 <= token_id < {vocab_size}.") diff --git a/vllm/lora/__init__.py b/vllm/lora/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py new file mode 100644 index 0000000..41e1ec9 --- /dev/null +++ b/vllm/lora/fully_sharded_layers.py @@ -0,0 +1,335 @@ +# SPDX-License-Identifier: Apache-2.0 + +# pylint: disable=unused-argument +from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config import LoRAConfig +from vllm.distributed.communication_op import ( + tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) +from vllm.distributed.parallel_state import get_tensor_model_parallel_rank +from vllm.lora.layers import (ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, + MergedQKVParallelLinearWithLoRA, + QKVParallelLinearWithLoRA, + RowParallelLinearWithLoRA) + +if TYPE_CHECKING: + pass + + +def _fully_sharded_can_replace(can_replace): + """ + decorator which adds the condition of fully sharded loras + intended to wrap can_replace_layer() + """ + + def dec(*args, **kwargs): + return (can_replace(*args, **kwargs) + and kwargs["lora_config"].fully_sharded_loras) + + return dec + + +def _mcp_apply(x, bias, layer: ColumnParallelLinearWithLoRA): + """ + For `ColumnParallelLinearWithLoRA` or classes that inherit from + `ColumnParallelLinearWithLoRA`, they share the same `apply` logic. + """ + assert (layer.n_slices == len(layer.lora_a_stacked) == len( + layer.lora_b_stacked) == len(layer.output_slices)) + if layer.lora_bias_stacked is not None: + assert layer.n_slices == len(layer.lora_bias_stacked) + + output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape + + # Since communication is needed, the buffer is directly initialized as a + # tensor rather than a tuple of tensor. + buffers = torch.zeros( + (layer.n_slices, x.shape[0], layer.lora_a_stacked[0].shape[2]), + dtype=torch.float32, + device=x.device, + ) + + layer.punica_wrapper.add_shrink(buffers, x, layer.lora_a_stacked, 1.0) + buffers = tensor_model_parallel_all_gather(buffers) + layer.punica_wrapper.add_expand(output, + buffers, + layer.lora_b_stacked, + layer.lora_bias_stacked, + layer.output_slices, + offset_start=0, + add_input=True) + + output = output.view(*out_orig_shape) + # now have column partitioned and packed output + return output + + +# these layers are based on the tensor parallelism strategy given in +# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023, +# https://arxiv.org/abs/2311.03285. + + +class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): + """ + Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + # For all LoRA layers where the `base_layer` is `ColumnParallelLinear`, + # their `lora_a` and `lora_b` have different sharding patterns. After + # completing the `lora_a` GEMM , a gather operation is performed. + # Therefore, the sharding of `lora_a` only needs to correspond with the + # gather operation. + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.lora_a_stacked[0].shape[2] + start_idx = tp_rank * shard_size + lora_a = lora_a[:, start_idx:start_idx + shard_size] + return lora_a + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class MergedColumnParallelLinearWithShardedLoRA( + MergedColumnParallelLinearWithLoRA): + """ + Differs from MergedColumnParallelLinearWithLoRA by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a( + self, lora_a: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + #NOTE: lora_a contains 2 subloras, and each sublora could be None. + output_shard_size = self.lora_a_stacked[0].shape[2] + output_start_idx = self.tp_rank * output_shard_size + lora_a = [ + lora_a[0][:, output_start_idx:output_start_idx + + output_shard_size] if lora_a[0] is not None else None, + lora_a[1][:, output_start_idx:output_start_idx + + output_shard_size] if lora_a[1] is not None else None, + ] + return lora_a + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA): + """ + Differs from QKVParallelLinearWithLoRA by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.lora_a_stacked[0].shape[2] + start_idx = tp_rank * shard_size + lora_a = lora_a[:, start_idx:start_idx + shard_size] + return lora_a + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA): + """ + Differs from MergedQKVParallelLinearWithLoRA by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a( + self, lora_a: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + # NOTE: lora_a contains 3 subloras, and each sublora could be None. + shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)] + start_idx = [self.tp_rank * shard_size[i] for i in range(3)] + lora_a = [ + lora_a[0][:, start_idx[0]:start_idx[0] + + shard_size[0]] if lora_a[0] is not None else None, + lora_a[1][:, start_idx[1]:start_idx[1] + + shard_size[1]] if lora_a[1] is not None else None, + lora_a[2][:, start_idx[2]:start_idx[2] + + shard_size[2]] if lora_a[2] is not None else None, + ] + return lora_a + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): + """ + Differs from RowParallelLinearWithLoRA by slicing the + LoRA B's also. + + Based on S-LoRA, slicing happens along the output dim. + This yields a combined partial sum from the row parallel base + layer and column partitioned output from the LoRA. + """ + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + shard_size = self.lora_b_stacked[0].shape[2] + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + lora_b = lora_b[:, start_idx:end_idx] + return lora_b + + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + if bias is None: + return bias + self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], + self.lora_bias_stacked) + shard_size = self.lora_bias_stacked[0].shape[2] + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + bias = bias[start_idx:end_idx] + return bias + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, + output.shape[-1]), output.shape + buffer = torch.zeros( + (self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]), + dtype=torch.float32, + device=x.device, + ) + + self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0) + buffer = tensor_model_parallel_all_reduce(buffer) + + # following S-LoRA, allows the fusing of all_gather and all_reduce + # by adding the column partitioned lora output to a slice of output + # tensor, which is a partial sum due to row parallel. All that + # remains is a standard all_reduce. User should be aware though that + # the output is not the same as a normal row_parallel, it should be + # reduced before being used + # NOTE offset are based on the rank. + shard_size = self.lora_b_stacked[0].shape[2] + offset_start = self.tp_rank * shard_size + self.punica_wrapper.add_expand( + output, + buffer, + self.lora_b_stacked, + self.lora_bias_stacked, + self.output_slices, + offset_start=offset_start, + add_input=True, + ) + output = output.view(*out_orig_shape) + return output + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py new file mode 100644 index 0000000..7a9d523 --- /dev/null +++ b/vllm/lora/layers.py @@ -0,0 +1,1258 @@ +# SPDX-License-Identifier: Apache-2.0 + +# pylint: disable=unused-argument +import math +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PretrainedConfig + +from vllm.adapter_commons.layers import AdapterMapping +from vllm.config import LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce) +from vllm.distributed.utils import divide +# yapf: disable +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearBase, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +# yapf: enable +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.rotary_embedding import ( + LinearScalingRotaryEmbedding, RotaryEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.platforms import current_platform + +if TYPE_CHECKING: + from vllm.lora.punica_wrapper import PunicaWrapperBase + + +def _get_lora_device(base_layer: nn.Module) -> torch.device: + # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34 + """Returns the device for where to place the LoRA tensors.""" + # unquantizedLinear + if hasattr(base_layer, "weight"): + return base_layer.weight.device + # Compressed Tensor + elif hasattr(base_layer, "weight_packed"): + return base_layer.weight_packed.device + # GPTQ/AWQ + elif hasattr(base_layer, "qweight"): + return base_layer.qweight.device + # marlin + elif hasattr(base_layer, "B"): + return base_layer.B.device + # HQQ marlin + elif hasattr(base_layer, "W_q"): + return base_layer.W_q.device + else: + raise ValueError(f"Unsupported base layer: {base_layer}") + + +def _not_fully_sharded_can_replace(can_replace): + """ + decorator which adds the condition of not using fully sharded loras + intended to wrap can_replace_layer() + """ + + def dec(*args, **kwargs): + decorate = kwargs.pop("decorate") if "decorate" in kwargs else True + condition = (not kwargs["lora_config"].fully_sharded_loras + if decorate else True) + return can_replace(*args, **kwargs) and condition + + return dec + + +@dataclass +class LoRAMapping(AdapterMapping): + is_prefill: bool = False + + +class BaseLayerWithLoRA(nn.Module): + + def slice_lora_a( + self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]] + ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]: + """Slice lora a if splitting for tensor parallelism.""" + ... + + def slice_lora_b( + self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]] + ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]: + """Slice lora b if splitting with tensor parallelism.""" + ... + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + """Initializes lora matrices.""" + ... + + def reset_lora(self, index: int): + """Resets the lora weights at index back to 0.""" + ... + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, + ): + """Overwrites lora tensors at index.""" + ... + + def set_mapping( + self, + punica_wrapper, + ): + self.punica_wrapper: PunicaWrapperBase = punica_wrapper + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + """Returns True if the layer can be replaced by this LoRA layer.""" + raise NotImplementedError + + +class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): + + def __init__(self, base_layer: VocabParallelEmbedding) -> None: + super().__init__() + self.base_layer = base_layer + self.embeddings_slice: Optional[Tuple[int, int]] + self.embeddings_weights: Optional[torch.Tensor] + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + + if self.base_layer.num_added_embeddings_per_partition > 0: + # We can start adding lora weights + self.embeddings_weights = self.base_layer.weight.data[ + self.base_layer.num_org_embeddings_per_partition:self. + base_layer.num_org_embeddings_per_partition + + self.base_layer.num_added_embeddings_per_partition] + self.embeddings_slice = ( + self.base_layer.shard_indices.added_vocab_start_index - + self.base_layer.org_vocab_size, + self.base_layer.shard_indices.added_vocab_end_index - + self.base_layer.org_vocab_size) + self.base_layer.weight.data[ + self.base_layer.num_org_embeddings_per_partition:].fill_(0) + else: + self.embeddings_slice = None + self.embeddings_weights = None + + self.embeddings_tensors = torch.zeros( + ( + max_loras, + lora_config.lora_extra_vocab_size, + self.base_layer.embedding_dim, + ), + dtype=self.base_layer.weight.dtype, + device=self.base_layer.weight.device, + ) + self.lora_a_stacked = torch.zeros( + ( + max_loras, + self.base_layer.org_vocab_size + + lora_config.lora_extra_vocab_size, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + self.base_layer.embedding_dim, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_a_stacked_2d = self.lora_a_stacked.view( + self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1], + self.lora_a_stacked.shape[2], + ) + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + self.embeddings_tensors[index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, + ): + self.reset_lora(index) + self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_( + lora_a, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if embeddings_tensor is not None: + self.embeddings_tensors[ + index, + :embeddings_tensor.shape[0], + :embeddings_tensor.shape[1], + ].copy_(embeddings_tensor, non_blocking=True) + if self.embeddings_slice is not None: + # TODO(yard1): Optimize this copy, we don't need to copy + # everything, just the modified part + embeddings = self.embeddings_tensors.view( + self.embeddings_tensors.shape[0] * + self.embeddings_tensors.shape[1], + self.embeddings_tensors.shape[2], + )[self.embeddings_slice[0]:self.embeddings_slice[1]] + assert self.embeddings_weights is not None + self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, + 1, 0) + embeddings_indices = torch.narrow( + self.punica_wrapper._embeddings_indices, 1, 0, x.size(0)) + + indices = embeddings_indices[1] + full_lora_a_embeddings = F.embedding( + x + indices, + self.lora_a_stacked_2d, + ) + indices = embeddings_indices[0] + full_output = self.base_layer.forward(x + + (indices * added_tokens_mask)) + + full_output_org = full_output + if full_output.ndim == 3: + full_output = full_output.view( + full_output.shape[0] * full_output.shape[1], -1) + if full_lora_a_embeddings.ndim == 3: + full_lora_a_embeddings = full_lora_a_embeddings.view( + full_lora_a_embeddings.shape[0] * + full_lora_a_embeddings.shape[1], + -1, + ) + self.punica_wrapper.add_lora_embedding(full_output, + full_lora_a_embeddings, + self.lora_b_stacked, + add_input=True) + return full_output.view_as(full_output_org) + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is VocabParallelEmbedding + + @property + def weight(self): + return self.base_layer.weight + + +class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): + + def __init__(self, base_layer: LinearBase): + super().__init__() + self.base_layer = base_layer + self.input_size = self.base_layer.input_size + self.device = _get_lora_device(self.base_layer) + self.lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]] = None + + self.output_slices: Tuple[int, ...] + self.tp_size: int + self.output_size: int + self.n_slices: int + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + self.lora_config = lora_config + # + if isinstance(self.base_layer, ReplicatedLinear): + lora_a_out_size = lora_config.max_lora_rank + lora_b_out_size = self.output_size + + elif isinstance(self.base_layer, ColumnParallelLinear): + lora_a_out_size = (lora_config.max_lora_rank if + not lora_config.fully_sharded_loras else divide( + lora_config.max_lora_rank, self.tp_size)) + lora_b_out_size = self.output_size + + elif isinstance(self.base_layer, RowParallelLinear): + lora_a_out_size = lora_config.max_lora_rank + lora_b_out_size = (self.output_size if + not lora_config.fully_sharded_loras else divide( + self.output_size, self.tp_size)) + else: + raise NotImplementedError + + self.lora_a_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_a_out_size, + self.input_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(self.n_slices)) + self.lora_b_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_b_out_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(self.n_slices)) + if lora_config.bias_enabled: + lora_bias_out_size = lora_b_out_size + self.lora_bias_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_bias_out_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(self.n_slices)) + self.output_slices = (self.lora_b_stacked[0].shape[2], ) + + def reset_lora(self, index: int): + for s_index in range(self.n_slices): + self.lora_a_stacked[s_index][index] = 0 + self.lora_b_stacked[s_index][index] = 0 + if self.lora_config.bias_enabled: + # Make mypy happy + self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], + self.lora_bias_stacked) + self.lora_bias_stacked[s_index][index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + lora_bias: Optional[torch.Tensor] = None, + ): + # Except for QKVParallelLinearWithLoRA and + # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers + # store weights in a tuple of size 1. These two layers will + # override this function. + assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) == + self.n_slices == 1) + + self.reset_lora(index) + if self.tp_size > 1: + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) + if lora_bias is not None: + lora_bias = self.slice_bias(lora_bias) + + self.lora_a_stacked[0][index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[0][index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if lora_bias is not None: + + self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], + self.lora_bias_stacked) + assert len(self.lora_bias_stacked) + self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_( + lora_bias.T, non_blocking=True) + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + + # In transformers backend, x and output have extra batch dimension like + # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim), + # therefore we need to flatten the batch dimensions. + if x.ndim == 3 and output.ndim == 3: + output = output.flatten(0, 1) + x = x.flatten(0, 1) + + self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked, + self.lora_b_stacked, + self.lora_bias_stacked, 1.0, + self.output_slices) + return output + + @property + def weight(self) -> torch.Tensor: + + # unquantizedLinear + if hasattr(self.base_layer, "weight"): + return self.base_layer.weight + # Compressed Tensor + elif hasattr(self.base_layer, "weight_packed"): + return self.base_layer.weight_packed + # GPTQ/AWQ + elif hasattr(self.base_layer, "qweight"): + return self.base_layer.qweight + # marlin + elif hasattr(self.base_layer, "B"): + return self.base_layer.B + # HQQ marlin + elif hasattr(self.base_layer, "W_q"): + return self.base_layer.W_q + else: + raise ValueError(f"Unsupported base layer: {self.base_layer}") + + @property + def bias(self) -> Optional[torch.Tensor]: + if hasattr(self.base_layer, "bias"): + return self.base_layer.bias + else: + return None + + +class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): + + def __init__(self, base_layer: ReplicatedLinear) -> None: + super().__init__(base_layer, ) + # To ensure interface compatibility, set to 1 always. + self.tp_size = 1 + self.output_size = self.base_layer.output_size + self.n_slices = 1 + + def forward( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: + """Forward of ReplicatedLinearWithLoRA + + Args: + input_: Tensor whose last dimension is `input_size`. + + Returns: + - output + - bias + """ + bias = (self.base_layer.bias + if not self.base_layer.skip_bias_add else None) + + # Matrix multiply. + output = self.apply(input_, bias) + + output_bias = (self.base_layer.bias + if self.base_layer.skip_bias_add else None) + + if not self.base_layer.return_bias: + return output + + return output, output_bias + + # ReplicatedLinear should always be replaced, regardless of the fully + # sharded LoRAs setting, because it is, by definition, copied per GPU. + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is ReplicatedLinear + + +class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): + """ + LoRA on top of ColumnParallelLinear layer. + LoRA B is sliced for tensor parallelism. + There are two types for the `base_layer`: + 1. ColumnParallelLinear, e.g.`dense_h_to_4h` in `FalconForCausalLM`. + 2. MergedColumnParallelLinear, e.g.`gate_up_proj` in `Phi3ForCausalLM`. + """ + + def __init__(self, base_layer: ColumnParallelLinear) -> None: + super().__init__(base_layer) + # The base_layer type is ColumnParallelLinear or + # MergedColumnParallelLinear, their weight sharding logic is + # inconsistent when TP is greater than 1. + self.is_merged_col_linear = type( + base_layer) is MergedColumnParallelLinear + self.tp_size = get_tensor_model_parallel_world_size() + self.output_size = self.base_layer.output_size_per_partition + # There is only one LoRA layer + self.n_slices = 1 + + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + return lora_a + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + # Applicable to cases where the base_layer is + # MergedColumnParallelLinear. + if self.is_merged_col_linear: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.output_size // 2 + offset = lora_b.shape[-1] // 2 + + left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) * + shard_size] + right_weight = lora_b[:, offset + tp_rank * shard_size:offset + + (tp_rank + 1) * shard_size] + lora_b = torch.cat([left_weight, right_weight], dim=1) + # Applicable to cases where the base_layer is + # ColumnParallelLinear. + else: + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.output_size + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + lora_b = lora_b[:, start_idx:end_idx] + return lora_b + + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + # TODO: Fix the slicing logic of bias. + if bias is None: + return bias + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.output_size + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + bias = bias[start_idx:end_idx] + return bias + + def forward( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: + """Forward of ColumnParallelLinear + + Args: + input_: Tensor whose last dimension is `input_size`. + + Returns: + - output + - bias + """ + bias = (self.base_layer.bias + if not self.base_layer.skip_bias_add else None) + + # Matrix multiply. + output_parallel = self.apply(input_, bias) + if self.base_layer.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + + if not self.base_layer.return_bias: + return output + + output_bias = (self.base_layer.bias + if self.base_layer.skip_bias_add else None) + return output, output_bias + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is ColumnParallelLinear or ( + type(source_layer) is MergedColumnParallelLinear + and len(packed_modules_list) == 1) + + +class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): + """ColumnParallelLinear layer that is composed of 2 sublayers (slices) + packed together (eg. gate_proj + up_proj -> gate_up_proj). + + This means we have 2 LoRAs, each applied to one half of the layer. + + Both slices must have the same size. + """ + + def __init__( + self, base_layer: Union[MergedColumnParallelLinear, + QKVParallelLinear]) -> None: + super().__init__(base_layer) + # There are two LoRA layers + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + # the output_sizes in MergedColumnParallelLinear is not sharded by tp + # we need to divide it by the tp_size to get correct slices size + output_sizes = self.base_layer.output_sizes + self.output_slices = tuple( + divide(output_size, self.tp_size) for output_size in output_sizes) + self.n_slices = len(self.output_slices) + self.output_ids = (self.tp_rank, ) * self.n_slices + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + """ + The main reason for overriding this function is to enhance code + maintainability. + """ + self.lora_config = lora_config + + lora_a_output_size_per_partition = ( + lora_config.max_lora_rank if not lora_config.fully_sharded_loras + else divide(lora_config.max_lora_rank, self.tp_size)) + + self.lora_a_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_a_output_size_per_partition, + self.input_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(self.n_slices)) + self.lora_b_stacked = tuple( + torch.zeros( + max_loras, + 1, + output_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.device, + ) for output_size in self.output_slices) + if lora_config.bias_enabled: + self.lora_bias_stacked = tuple( + torch.zeros( + max_loras, + 1, + output_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) for output_size in self.output_slices) + + def slice_lora_a( + self, lora_a: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + return lora_a + + def slice_lora_b( + self, lora_b: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + for i, (shard_id, shard_size) in enumerate( + zip(self.output_ids, self.output_slices)): + if (lora_b_i := lora_b[i]) is not None: + lora_b[i] = lora_b_i[:, shard_size * shard_id:shard_size * + (shard_id + 1)] + return lora_b + + def slice_bias( + self, bias: List[Union[torch.Tensor, + None]]) -> List[Union[torch.Tensor, None]]: + for i, (shard_id, shard_size) in enumerate( + zip(self.output_ids, self.output_slices)): + if (bias_i := bias[i]) is not None: + bias[i] = bias_i[shard_size * shard_id:shard_size * + (shard_id + 1)] + return bias + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + lora_bias: Optional[torch.Tensor] = None, + ): + self.reset_lora(index) + + if self.tp_size > 1: + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) + if lora_bias is not None: + lora_bias = self.slice_bias(lora_bias) + + for i in range(self.n_slices): + if (lora_a_i := lora_a[i]) is not None: + self.lora_a_stacked[i][ + index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_( + lora_a_i.T, non_blocking=True) + if (lora_b_i := lora_b[i]) is not None: + self.lora_b_stacked[i][ + index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_( + lora_b_i.T, non_blocking=True) + + if lora_bias is not None: + self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], + self.lora_bias_stacked) + for i in range(self.n_slices): + if (lora_bias_i := lora_bias[i]) is not None: + self.lora_bias_stacked[i][index, + 0, :lora_bias_i.shape[0]].copy_( + lora_bias_i.T, + non_blocking=True) + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return (type(source_layer) is MergedColumnParallelLinear + and len(packed_modules_list) == 2) + + +class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): + """ + ColumnParallelLinear layer that is specifically designed for + qkv_proj. Certain models, such as chatglm3 and baichuan-7b, + only contains a single LoRA within their qkv_proj layer. + + During inference with Tensor Parallel, the weights of lora_b + must be accurately partitioned according to the respective ranks. + + Q slice may have different shape than K and V slices (which both have + the same shape). + """ + + def __init__(self, base_layer: QKVParallelLinear) -> None: + super().__init__(base_layer) + self.q_proj_total_size = (self.base_layer.total_num_heads * + self.base_layer.head_size) + self.q_proj_shard_size = (self.base_layer.num_heads * + self.base_layer.head_size) + self.kv_proj_shard_size = (self.base_layer.num_kv_heads * + self.base_layer.head_size) + self.kv_proj_total_size = (self.base_layer.total_num_kv_heads * + self.base_layer.head_size) + # There is only one LoRA layer + self.n_slices = 1 + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + tp_rank = get_tensor_model_parallel_rank() + self.q_shard_id = tp_rank + self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas + lora_b_q = lora_b[:, self.q_proj_shard_size * + self.q_shard_id:self.q_proj_shard_size * + (self.q_shard_id + 1)] + k_offset = self.q_proj_total_size + lora_b_k = lora_b[:, k_offset + + self.kv_proj_shard_size * self.kv_shard_id:k_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + v_offset = k_offset + self.kv_proj_total_size + lora_b_v = lora_b[:, v_offset + + self.kv_proj_shard_size * self.kv_shard_id:v_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1) + return lora_b + + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + bias_q = bias[self.q_proj_shard_size * + self.q_shard_id:self.q_proj_shard_size * + (self.q_shard_id + 1)] + k_offset = self.q_proj_total_size + bias_k = bias[k_offset + + self.kv_proj_shard_size * self.kv_shard_id:k_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + v_offset = k_offset + self.kv_proj_total_size + bias_v = bias[v_offset + + self.kv_proj_shard_size * self.kv_shard_id:v_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + bias = torch.cat([bias_q, bias_k, bias_v], dim=1) + return bias + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + return type(source_layer) is QKVParallelLinear and len( + packed_modules_list) == 1 + + +class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA): + """MergedColumnParallelLinear layer that is composed of 3 sublayers (slices) + packed together in qkv proj fashion + (q_proj + k_proj + v_proj -> qkv_proj). + + This means we have 3 LoRAs, each applied to one slice of the layer. + + Q slice may have different shape than K and V slices (which both have + the same shape). + """ + + def __init__(self, base_layer: QKVParallelLinear) -> None: + super().__init__(base_layer) + # There are three LoRA layer. + self.n_slices = len(self.base_layer.output_sizes) + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + self.q_proj_shard_size = (self.base_layer.num_heads * + self.base_layer.head_size) + self.kv_proj_shard_size = (self.base_layer.num_kv_heads * + self.base_layer.head_size) + self.q_shard_id = self.tp_rank + self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas + + self.output_slices = ( + self.q_proj_shard_size, + self.kv_proj_shard_size, + self.kv_proj_shard_size, + ) + self.output_ids = ( + self.q_shard_id, + self.kv_shard_id, + self.kv_shard_id, + ) + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + """ + The main reason for overloading this function is to handle inconsistent + weight dimensions in qkv lora. + """ + super().create_lora_weights(max_loras, lora_config, model_config) + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return (type(source_layer) is QKVParallelLinear + and len(packed_modules_list) == 3) + + +class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): + + def __init__(self, base_layer: RowParallelLinear) -> None: + super().__init__(base_layer) + + self.tp_size = get_tensor_model_parallel_world_size() + # reset input_size + self.input_size = self.base_layer.input_size_per_partition + self.output_size = self.base_layer.output_size + + self.tp_rank = get_tensor_model_parallel_rank() + # There is only one LoRA layer. + self.n_slices = 1 + + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + + shard_size = self.input_size + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + lora_a = lora_a[start_idx:end_idx, :] + return lora_a + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + return lora_b + + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + return bias + + def forward( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: + """Forward of RowParallelLinear + + Args: + input_: tensor whose last dimension is `input_size`. If + `input_is_parallel` is set, then the last dimension + is `input_size // tp_size`. + + Returns: + - output + - bias + """ + # Set up backprop all-reduce. + if self.base_layer.input_is_parallel: + input_parallel = input_ + else: + # TODO: simplify code below + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.base_layer.tp_size) + input_parallel = splitted_input[self.tp_rank].contiguous() + + # Matrix multiply. + output_parallel = self.apply(input_parallel) + if self.base_layer.reduce_results and self.base_layer.tp_size > 1: + output_ = tensor_model_parallel_all_reduce(output_parallel) + else: + output_ = output_parallel + + if not self.base_layer.skip_bias_add: + output = (output_ + self.base_layer.bias + if self.base_layer.bias is not None else output_) + output_bias = None + else: + output = output_ + output_bias = self.base_layer.bias + + if not self.base_layer.return_bias: + return output + + return output, output_bias + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is RowParallelLinear + + +class LogitsProcessorWithLoRA(BaseLayerWithLoRA): + """ + LoRA wrapper for LogitsProcessor, with extra logic to handle the + application of the LoRA adapter and added LoRA vocabulary. + + Args: + base_layer: LogitsProcessor layer + hidden_size: hidden size of the model + dtype: data type of the model + device: device of the model + sharded_to_full_mapping: index mapping from sharded vocab to full vocab + received from base_layer.get_sharded_to_full_mapping(). If None, + no reindexing will be done. + """ + + def __init__(self, base_layer: LogitsProcessor, hidden_size: int, + dtype: torch.dtype, device: torch.device, + sharded_to_full_mapping: Optional[List[int]]) -> None: + super().__init__() + self.base_layer = base_layer + self.hidden_size = hidden_size + self.dtype = dtype + self.device = device + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.sharded_to_full_mapping = sharded_to_full_mapping + + @property + def logits_as_input(self): + return self.base_layer.logits_as_input + + @property + def vocab_size(self): + return self.base_layer.vocab_size + + @property + def scale(self): + return self.base_layer.scale + + @property + def soft_cap(self): + return self.base_layer.soft_cap + + @property + def use_all_gather(self): + return self.base_layer.use_all_gather + + @property + def org_vocab_size(self): + return self.base_layer.org_vocab_size + + @property + def include_gpu_probs_tensor(self): + return self.base_layer.include_gpu_probs_tensor + + @property + def should_modify_greedy_probs_inplace(self): + return self.base_layer.should_modify_greedy_probs_inplace + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + # TODO: Verify if this condition can be further relaxed + if 32000 < self.base_layer.vocab_size > 257024: + raise ValueError("When using LoRA, vocab size must be " + "32000 >= vocab_size <= 257024") + self.lora_a_stacked = torch.zeros( + ( + max_loras, + 1, + lora_config.max_lora_rank, + self.hidden_size, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + # Pad for kernel compatibility + math.ceil(self.base_layer.vocab_size / + lora_config.lora_vocab_padding_size) * + lora_config.lora_vocab_padding_size, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.embeddings_tensors = torch.full( + (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size), + fill_value=float("-inf"), + dtype=self.dtype, + device=self.device, + ) + if self.sharded_to_full_mapping is not None: + self.sharded_to_full_mapping_gpu = torch.tensor( + self.sharded_to_full_mapping, + device=self.device, + dtype=torch.long) + else: + self.sharded_to_full_mapping_gpu = None + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + self.embeddings_tensors[index] = float("-inf") + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, + ): + self.reset_lora(index) + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if embeddings_tensor is not None: + self.embeddings_tensors[ + index, + :embeddings_tensor.shape[0], + :embeddings_tensor.shape[1], + ] = embeddings_tensor + + def _get_logits( + self, + hidden_states: torch.Tensor, + lm_head: VocabParallelEmbedding, + embedding_bias: Optional[torch.Tensor] = None, + ) -> Optional[torch.Tensor]: + # Get the logits for the next tokens. + logits = lm_head.quant_method.apply(lm_head, hidden_states) + if embedding_bias is not None: + logits += embedding_bias + + # Gather logits for TP + logits = self.base_layer._gather_logits(logits) + + if logits is None: + return None + + if self.sharded_to_full_mapping_gpu is not None: + # Reindex full logits tensor to ensure 1:1 mapping between + # index and token_id + # Example for: + # org_vocab_size = 4 + # added_vocab_size = 2 + # pad_to_size = 8 + # tp_size = 2 + + # indices: [0, 1, 2, 3, 4, 5, 6, 7] + # token_id: [0, 1, 4, -1, 2, 3, 5, -1] + + # Therefore, the mapping is expected to be: + # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex, + # we get: + # indices: [0, 1, 2, 3, 4, 5, 6, 7] + # token_id: [0, 1, 2, 3, 4, 5, -1, -1] + logits = logits[:, self.sharded_to_full_mapping_gpu] + + lora_logits = torch.empty( + self.embeddings_tensors.shape[0] + 1, + self.embeddings_tensors.shape[1], + hidden_states.shape[0], + dtype=self.embeddings_tensors.dtype, + device=self.embeddings_tensors.device, + ) + torch.matmul(self.embeddings_tensors, + hidden_states.T, + out=lora_logits[:-1]) + lora_logits[-1] = float("-inf") + lora_logits = lora_logits.mT + indices_padded = self.punica_wrapper.sampler_indices_padded + lora_logits = (lora_logits.reshape( + lora_logits.shape[0] * lora_logits.shape[1], + lora_logits.shape[2], + ).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"), + posinf=float("inf"), + neginf=float("-inf"))) + + # HPU needs special handling to prune out dummy samples. + if current_platform.is_hpu(): + lora_logits = lora_logits[:logits.shape[0], :] + + logits[:, + self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + + lora_logits.shape[1]] = lora_logits + + # LogitsProcessorWithLoRA always using bgmv + self.punica_wrapper.add_lora_logits(logits, hidden_states, + self.lora_a_stacked, + self.lora_b_stacked, 1.0) + + # Remove paddings in vocab (if any). + logits = logits[:, :self.base_layer.vocab_size] + return logits + + def forward(self, *args, **kwargs): + return type(self.base_layer).forward(self, *args, **kwargs) + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + # Special handling for the LogitsProcessor. + return False + + +class LinearScalingRotaryEmbeddingWithLoRA(BaseLayerWithLoRA): + """Implements RoPE-scaled embeddings with linear scaling for + multiple LoRA adapters with a specialized kernel. + + Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding + which can handle multi lora adapters in a specialied kernel. + """ + + def __init__(self, base_layer: RotaryEmbedding) -> None: + super().__init__() + self.base_layer = base_layer + + @property + def scaling_factors(self): + return self.base_layer.scaling_factors + + @property + def rotary_dim(self): + return self.base_layer.rotary_dim + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + scaling_factors = (list(lora_config.long_lora_scaling_factors) + if lora_config.long_lora_scaling_factors else []) + base_scaling_factor = (self.base_layer.scaling_factor if isinstance( + self.base_layer, LinearScalingRotaryEmbedding) else 1.0) + scaling_factors = sorted( + list(set([base_scaling_factor] + scaling_factors))) + self.base_layer = LinearScalingRotaryEmbedding( + self.base_layer.head_size, + self.base_layer.rotary_dim, + self.base_layer.max_position_embeddings, + self.base_layer.base, + self.base_layer.is_neox_style, + scaling_factors, + self.base_layer.dtype, + ) + + def reset_lora(self, index: int): + ... + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, + ): + ... + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + return self.base_layer( + positions, + query, + key, + offsets=self.punica_wrapper.long_lora_indices, + ) + + @property + def scaling_factor_to_offset(self) -> Dict[float, int]: + return self.base_layer.scaling_factor_to_offset + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + """Returns True if the layer can be replaced by this LoRA layer.""" + return (type(source_layer) is LinearScalingRotaryEmbedding + or type(source_layer) is RotaryEmbedding) + + def extra_repr(self) -> str: + return self.base_layer.extra_repr() diff --git a/vllm/lora/lora.py b/vllm/lora/lora.py new file mode 100644 index 0000000..00299bf --- /dev/null +++ b/vllm/lora/lora.py @@ -0,0 +1,198 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional +from typing import Sequence as GenericSequence + +import torch +import torch.types + +from vllm.lora.peft_helper import PEFTHelper +from vllm.utils import is_pin_memory_available + + +class LoRALayerWeights: + """LoRA weights for a layer composed of two low rank matrixes.""" + + def __init__( + self, + module_name: str, + rank: int, + lora_alpha: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + bias: Optional[torch.Tensor] = None, + embeddings_tensor: Optional[torch.Tensor] = None, + scaling: Optional[float] = None, + ) -> None: + self.module_name = module_name + self.rank = rank + self.lora_alpha = lora_alpha + self.lora_a = lora_a + self.lora_b = lora_b + self.bias = bias + self.embeddings_tensor = embeddings_tensor + + if scaling is None: + self.scaling = self.lora_alpha / self.rank + else: + self.scaling = scaling + + def optimize(self) -> "LoRALayerWeights": + """Optimize the LoRA by merging the scaling into lora_b.""" + if self.scaling == 1: + return self + self.lora_b *= self.scaling + self.scaling = 1 + return self + + @property + def input_dim(self) -> int: + return self.lora_a.shape[0] + + @property + def output_dim(self) -> int: + return self.lora_b.shape[1] + + @property + def is_packed(self) -> bool: + return False + + @property + def extra_vocab_size(self) -> int: + return self.embeddings_tensor.shape[ + 0] if self.embeddings_tensor is not None else 0 + + @classmethod + def from_config( + cls, + module_name: str, + peft_helper: PEFTHelper, + embeddings_tensor: Optional[torch.Tensor] = None, + ) -> "LoRALayerWeights": + return cls(module_name, peft_helper.r, peft_helper.lora_alpha, None, + None, None, embeddings_tensor, + peft_helper.vllm_lora_scaling_factor) + + @classmethod + def create_dummy_lora_weights( + cls, + module_name: str, + input_dim: int, + output_dim: int, + rank: int, + dtype: torch.dtype, + device: torch.types.Device, + embeddings_tensor_dim: Optional[int] = None, + bias_enabled: Optional[bool] = False) -> "LoRALayerWeights": + pin_memory = str(device) == "cpu" and is_pin_memory_available() + lora_a = torch.zeros([input_dim, rank], + dtype=dtype, + device=device, + pin_memory=pin_memory) + lora_b = torch.zeros([rank, output_dim], + dtype=dtype, + device=device, + pin_memory=pin_memory) + if bias_enabled: + bias = torch.zeros([output_dim], + dtype=dtype, + device=device, + pin_memory=pin_memory) + else: + bias = None + + embeddings_tensor = torch.rand( + 10, + embeddings_tensor_dim, + dtype=dtype, + device=device, + pin_memory=pin_memory) if embeddings_tensor_dim else None + return cls( + module_name, + rank=rank, + lora_alpha=1, + lora_a=lora_a, + lora_b=lora_b, + bias=bias, + embeddings_tensor=embeddings_tensor, + ) + + +class PackedLoRALayerWeights(LoRALayerWeights): + """LoRA used for packed layers (eg. qkv_proj).""" + + def __init__( + self, + module_name: str, + rank: int, + lora_alphas: List[Optional[int]], + lora_a: List[Optional[torch.Tensor]], + lora_b: List[Optional[torch.Tensor]], + bias: Optional[List[Optional[torch.Tensor]]] = None, + scaling: Optional[List[float]] = None, + ) -> None: + super().__init__( + module_name=module_name, + rank=rank, + lora_alpha=0, + lora_a=lora_a, + lora_b=lora_b, + bias=bias, + scaling=scaling, # type: ignore + embeddings_tensor=None, + ) + self.lora_alphas = lora_alphas + if scaling is None: + self.scaling = [ # type: ignore + lora_alpha / self.rank # type: ignore # noqa + for lora_alpha in self.lora_alphas + ] + + @classmethod + def pack( + cls, loras: GenericSequence[Optional["LoRALayerWeights"]] + ) -> "PackedLoRALayerWeights": + """Pack a list of LoRAs into a single LoRA. + + If LoRA is None, it signifies that the submodule does not have a LoRA. + """ + first_lora = next(lora for lora in loras if lora is not None) + for lora in loras: + if lora is None: + continue + lora.optimize() + rank = first_lora.rank + module_name = first_lora.module_name + obj = cls( + module_name, + rank, + [lora.lora_alpha if lora is not None else None for lora in loras], + [lora.lora_a if lora is not None else None for lora in loras], + [lora.lora_b if lora is not None else None for lora in loras], + [lora.bias if lora is not None else None for lora in loras], + scaling=[ + 1 if lora is not None else None # type: ignore + for lora in loras + ]) + return obj + + def optimize(self) -> "PackedLoRALayerWeights": + """Optimize the LoRA by merging the scaling into lora_b.""" + for i in range(len(self.lora_b)): + if self.scaling[i] == 1 or self.lora_b[i] is None: # type: ignore + continue + self.lora_b[i] *= self.scaling[i] # type: ignore + self.scaling[i] = 1 # type: ignore + return self + + @property + def input_dim(self) -> int: + raise NotImplementedError() + + @property + def output_dim(self) -> int: + raise NotImplementedError() + + @property + def is_packed(self) -> bool: + return True diff --git a/vllm/lora/models.py b/vllm/lora/models.py new file mode 100644 index 0000000..8164d91 --- /dev/null +++ b/vllm/lora/models.py @@ -0,0 +1,802 @@ +# SPDX-License-Identifier: Apache-2.0 + +import copy +import math +import os +import re +from dataclasses import dataclass, field +from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Type, + Union) + +import safetensors.torch +import torch +from torch import nn + +from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel, + AdapterModelManager) +from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter, + get_adapter, list_adapters, + remove_adapter, set_adapter_mapping) +from vllm.config import LoRAConfig +from vllm.logger import init_logger +from vllm.lora.layers import (BaseLayerWithLoRA, + LinearScalingRotaryEmbeddingWithLoRA, + LoRAMapping) +from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights +from vllm.lora.peft_helper import PEFTHelper +from vllm.lora.punica_wrapper import get_punica_wrapper +from vllm.lora.utils import (from_layer, from_layer_logits_processor, + get_supported_lora_modules, + is_regex_target_modules, + parse_fine_tuned_lora_name, replace_submodule) +from vllm.model_executor.models import SupportsLoRA, supports_multimodal +from vllm.model_executor.models.interfaces import is_pooling_model +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper +from vllm.utils import is_pin_memory_available + +logger = init_logger(__name__) + +_GLOBAL_LORA_ID = 0 + + +@dataclass +class LongContextLoRAContext: + """Context for lora adapters that support long context.""" + # The scaling factors to support long context lora fine tuned models. + scaling_factors: List[float] + # dimension to apply rotary embedding. + rot_dim: int + # offsets to the sin_cos_cache for each lora_id loaded. + # This value is dynamically modified. + offsets_by_lora_id: Dict[int, int] = field(default_factory=dict) + + +def get_lora_id(): + global _GLOBAL_LORA_ID + _GLOBAL_LORA_ID += 1 + return _GLOBAL_LORA_ID + + +class LoRAModel(AdapterModel): + """A LoRA fine-tuned model.""" + + def __init__( + self, + lora_model_id: int, + rank: int, + loras: Dict[str, LoRALayerWeights], + scaling_factor: Optional[float] = None, + ) -> None: + """ + Args: + lora_model_id: The integer id for the lora model. + rank: lora rank. + loras: module name -> weights for lora-replaced layers. + scaling_factor: Scaling factor to support long context lora model. + None if the lora is not tuned for long context support. + """ + self.id = lora_model_id + # Scaling factor for long context lora model. None if it is not + # fine tuned for the long context. + self.scaling_factor = scaling_factor + assert ( + lora_model_id + > 0), f"a valid lora id should be greater than 0, got {self.id}" + self.rank = rank + self.loras: Dict[str, LoRALayerWeights] = loras + + def clone(self, lora_model_id: int) -> "LoRAModel": + """Return a copy of the object with different ids. + + Will share the underlying tensors.""" + return self.__class__( + lora_model_id, + rank=self.rank, + loras=self.loras.copy(), + ) + + @property + def extra_vocab_size(self) -> int: + return max(lora.extra_vocab_size + for lora in self.loras.values()) if self.loras else 0 + + def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]: + """Get LoRA for a given module by name""" + return self.loras.get(module_name, None) + + def check_lora_name(self, lora_name: str) -> bool: + return lora_name in self.loras + + # (yard1): TODO see if we can derive target_embedding_padding automatically + @classmethod + def from_lora_tensors( + cls, + lora_model_id: int, + tensors: Dict[str, torch.Tensor], + peft_helper: PEFTHelper, + device: str = "cuda", + dtype: Optional[torch.dtype] = None, + embeddings: Optional[Dict[str, torch.Tensor]] = None, + target_embedding_padding: Optional[int] = None, + embedding_modules: Optional[Dict[str, str]] = None, + embedding_padding_modules: Optional[List[str]] = None, + weights_mapper: Optional[WeightsMapper] = None, + ) -> "LoRAModel": + """Create a LoRAModel from a dictionary of tensors.""" + pin_memory = str(device) == "cpu" and is_pin_memory_available() + loras: Dict[str, LoRALayerWeights] = {} + for tensor_name, tensor in tensors.items(): + module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name( + tensor_name, weights_mapper) + if module_name not in loras: + lora_embeddings_tensor = None + if embeddings: + assert embedding_modules is not None + embeddings_module = next( + (k for k in embedding_modules if k in module_name), + None) + if embeddings_module: + lora_embeddings_tensor = embeddings[ + embedding_modules[embeddings_module]].to( + device=device, dtype=dtype) + if pin_memory: + lora_embeddings_tensor = ( + lora_embeddings_tensor.pin_memory()) + loras[module_name] = LoRALayerWeights.from_config( + module_name, peft_helper, lora_embeddings_tensor) + + if is_bias: + loras[module_name].bias = tensor.to(device=device, + dtype=dtype).t() + bias = tensor.to(device=device, dtype=dtype).t() + if pin_memory: + bias = bias.pin_memory() + loras[module_name].bias = bias + elif is_lora_a: + loras[module_name].lora_a = tensor.to(device=device, + dtype=dtype).t() + if pin_memory: + loras[module_name].lora_a = loras[ + module_name].lora_a.pin_memory() + else: + loras[module_name].lora_b = tensor.to(device=device, + dtype=dtype).t() + assert embedding_padding_modules is not None + if any(name in module_name + for name in embedding_padding_modules + ) and target_embedding_padding is not None: + lora_b = loras[module_name].lora_b + assert target_embedding_padding >= lora_b.shape[1] + addition = target_embedding_padding - lora_b.shape[1] + loras[module_name].lora_b = torch.nn.functional.pad( + lora_b, (0, addition)) + if pin_memory: + loras[module_name].lora_b = loras[ + module_name].lora_b.pin_memory() + + for lora in loras.values(): + lora.optimize() + + return cls(lora_model_id, + peft_helper.r, + loras, + scaling_factor=peft_helper.vllm_long_context_scaling_factor) + + @classmethod + def from_local_checkpoint( + cls, + lora_dir: str, + expected_lora_modules: List[str], + peft_helper: PEFTHelper, + *, + lora_model_id: Optional[int] = None, + device: str = "cuda", + dtype: Optional[torch.dtype] = None, + target_embedding_padding: Optional[int] = None, + embedding_modules: Optional[Dict[str, str]] = None, + embedding_padding_modules: Optional[List[str]] = None, + weights_mapper: Optional[WeightsMapper] = None, + ) -> "LoRAModel": + """Create a LoRAModel from a local checkpoint. + + Args: + lora_dir: The local path that has lora data. + expected_lora_modules: Name of modules that are expected to be + replaced by lora. + peft_helper: Loaded lora configuration information. + lora_model_id: LoRA model id. If not given, automatically set by + a global counter. + device: Device where the lora model is loaded. + dtype: dtype of the lora model weights. + + Returns: + Loaded LoRA Model. + """ + lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors") + lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin") + new_embeddings_tensor_path = os.path.join( + lora_dir, "new_embeddings.safetensors") + new_embeddings_bin_file_path = os.path.join(lora_dir, + "new_embeddings.bin") + + unexpected_modules: List[Union[list[str], str]] + if os.path.isfile(lora_tensor_path): + tensors: Dict[str, torch.Tensor] = {} + # Find unexpected modules. + # Use safetensor key as a source of truth to find expected modules. + # in peft if you have target_modules A, B, C and C does not exist + # in the model it won’t error and model will be trained with A, B + # loraified. C won’t exist in the safetensor but it will exist in + # the target_modules of the adapter_config.json. + unexpected_modules = [] + with safetensors.safe_open(lora_tensor_path, + framework="pt") as f: # type: ignore + for lora_module in f.keys(): # noqa + module_name, _, _ = parse_fine_tuned_lora_name( + lora_module, weights_mapper) + part_name = module_name.split(".")[-1] + if part_name not in expected_lora_modules: + unexpected_modules.append(module_name) + if unexpected_modules: + raise ValueError( + f"While loading {lora_dir}, expected" + f" target modules in {expected_lora_modules}" + f" but received {unexpected_modules}." + f" Please verify that the loaded LoRA module is correct" + ) + # Load tensors if there are only expected modules. + for module in f.keys(): # noqa + tensors[module] = f.get_tensor(module) + elif os.path.isfile(lora_bin_file_path): + # When a bin file is provided, we rely on config to find unexpected + # modules. + unexpected_modules = [] + target_modules = peft_helper.target_modules + if not isinstance(target_modules, list): + target_modules = [target_modules] + for module in target_modules: + # Compatible with more modules, + # such as:layers.11.self_attn.k_proj + part_name = module.split(".")[-1] + if part_name not in expected_lora_modules: + unexpected_modules.append(module) + # loaded lora's target modules must be a subset of + # expected_lora_modules. It is not reliable. See + # https://github.com/vllm-project/vllm/pull/5909. But there's no + # other better mechanism. + if unexpected_modules and not is_regex_target_modules( + peft_helper.target_modules, expected_lora_modules): + raise ValueError( + f"While loading {lora_dir}, expected" + f" target modules in {expected_lora_modules}" + f" but received {unexpected_modules}." + f" Please verify that the loaded LoRA module is correct") + tensors = torch.load(lora_bin_file_path, + map_location=device, + weights_only=True) + else: + raise ValueError(f"{lora_dir} doesn't contain tensors") + + embeddings = None + if os.path.isfile(new_embeddings_tensor_path): + embeddings = safetensors.torch.load_file( + new_embeddings_tensor_path) + elif os.path.isfile(new_embeddings_bin_file_path): + embeddings = torch.load(new_embeddings_bin_file_path, + map_location=device, + weights_only=True) + + return cls.from_lora_tensors( + lora_model_id=get_lora_id() + if lora_model_id is None else lora_model_id, + tensors=tensors, + peft_helper=peft_helper, + device=device, + dtype=dtype, + embeddings=embeddings, + target_embedding_padding=target_embedding_padding, + embedding_modules=embedding_modules, + embedding_padding_modules=embedding_padding_modules, + weights_mapper=weights_mapper) + + +class LoRAModelManager(AdapterModelManager): + """A manager that manages multiple LoRA-fine-tuned models.""" + + def __init__( + self, + model: SupportsLoRA, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + device: torch.device, + ): + """Create a LoRAModelManager and adapter for a given model. + + Args: + model: the model to be adapted. + max_num_seqs: the maximum number of sequences model can run in a + single batch. + max_num_batched_tokens: the maximum number of tokens model can run + in a single batch. + vocab_size: the vocab size of the model. + lora_config: the LoRA configuration. + """ + self.lora_config = lora_config + self.device = device + self.max_num_seqs = max_num_seqs + assert self.capacity >= self.lora_slots + self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 + self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots + self.vocab_size = vocab_size + self.long_lora_context: Optional[LongContextLoRAContext] = None + self.punica_wrapper = get_punica_wrapper( + max_num_batched_tokens, + max_batches=self.max_num_seqs, + device=self.device, + max_loras=self.lora_config.max_loras) + # Scaling factor -> offset to the sin_cos_cache to it. + # Used for long context lora. + self.scaling_factor_to_offset: Dict[float, int] = {} + super().__init__(model) + + self.supported_lora_modules = get_supported_lora_modules(self.model) + assert self.supported_lora_modules, "No supported LoRA modules found in" + f"{self.model.__class__.__name__}." + if lora_config.long_lora_scaling_factors: + # We need to replace rotary emb layer to do batch computation + # for long lora. + self.supported_lora_modules.append("rotary_emb") + self.packed_modules_mapping = copy.deepcopy( + self.model.packed_modules_mapping) + # Used to indicate whether the model is a multimodal model + self.supports_mm: bool = ( + supports_multimodal(self.model) + # In case the model only supports LoRA for + # text modules (e.g. ChatGLM) + and hasattr(self.model, "get_mm_mapping")) + self.is_pooling_model = is_pooling_model(self.model) + self.packed_modules: Dict[str, List[str]] = {} + self.modules: Dict[str, BaseLayerWithLoRA] = {} + # Dict instead of a Set for compatibility with LRUCache. + self._last_mapping: Optional[LoRAMapping] = None + self._create_lora_modules() + self.model.lora_manager = self + self.adapter_type = 'LoRa' + + @property + def capacity(self) -> int: + return self.lora_config.max_cpu_loras + + @property + def lora_slots(self) -> int: + return self.lora_config.max_loras + + @property + def adapter_slots(self) -> int: + return self.lora_slots + + def activate_adapter( + self, + lora_id: int, + ) -> bool: + """Move LoRA into a GPU buffer to be used in the forward pass.""" + if lora_id in self._active_adapters: + return False + first_free_slot = next( + ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id) + if lora_id is None), None) + if first_free_slot is None: + raise ValueError("No free lora slots") + index, _ = first_free_slot + self._active_adapters[lora_id] = None + lora_model = self._registered_adapters[lora_id] + logger.debug("Activating LoRA. int id: %d, slot index: %d", + lora_model.id, index) + self.lora_index_to_id[index] = lora_model.id + for module_name, module in self.modules.items(): + module_lora = self._get_lora_layer_weights(lora_model, module_name) + if module_lora: + module_lora.optimize() + # Bias is not explicitly enabled with the flag enable_lora_bias. + bias = module_lora.bias + if ((torch.is_tensor(bias) or + (isinstance(bias, Sequence) and any(b is not None + for b in bias))) + and not self.lora_config.bias_enabled): + module_lora.bias = None + raise ValueError( + f"Adapter bias cannot be used for {module_name}" + " without --enable-lora-bias.") + module.set_lora(index, module_lora.lora_a, module_lora.lora_b, + module_lora.embeddings_tensor, + module_lora.bias) + else: + module.reset_lora(index) + return True + + def _deactivate_adapter(self, lora_id: int): + try: + index = self.lora_index_to_id.index(lora_id) + self.lora_index_to_id[index] = None + except ValueError: + pass + + def _set_long_lora_context(self, lora: LoRAModel): + if self.long_lora_context is None: + return + + if lora.scaling_factor is None: + return + + if (lora.scaling_factor not in self.scaling_factor_to_offset): + raise ValueError(f"Long LoRA scaling factor {lora.scaling_factor}" + " has not been initialized.") + + offsets = self.scaling_factor_to_offset.get(lora.scaling_factor) + if offsets: + self.long_lora_context.offsets_by_lora_id[lora.id] = offsets + + def _add_adapter(self, lora: LoRAModel): + self._create_merged_loras_inplace(lora) + self._registered_adapters[lora.id] = lora + self._set_long_lora_context(lora) + + def pin_adapter(self, lora_id: int) -> bool: + """Pin a LoRAModel in the manager cache.""" + raise NotImplementedError( + "Pinning is not supported in LoRAModelManager. " + "Use LRUCacheLoRAModelManager for pinning") # type: ignore + + def _set_adapter_mapping(self, mapping: LoRAMapping) -> None: + # update lora states + self.punica_wrapper.update_metadata( + mapping, + self.lora_index_to_id, + self.lora_slots + 1, + self.vocab_size, + self.lora_config.lora_extra_vocab_size, + self.long_lora_context, + ) + + def remove_all_adapters(self): + """Remove all LoRAModels from the manager.""" + self._registered_adapters.clear() + self.lora_index_to_id = [None] * self.lora_slots + self._active_adapters.clear() + + def _create_lora_modules(self): + for module_name, module in self.model.named_modules( + remove_duplicate=False): + if isinstance(module, PPMissingLayer): + continue + if not self._match_target_modules(module_name): + continue + # A temporary approach for multimodal models to support LoRA + # TODO: Remove this restriction + if self._filter_unsupported_mm_module(module_name): + logger.warning( + "Regarding multimodal models, vLLM currently only supports " + "adding LoRA to language model, %s will be ignored.", + module_name, + ) + continue + parts = module_name.split(".")[-1] + packed_moduled_lst = self.packed_modules_mapping.get(parts, []) + new_module = replace_submodule( + self.model, module_name, + from_layer(module, self.lora_slots, self.lora_config, + packed_moduled_lst, self.model.config)) + + # LinearScalingRotaryEmbeddingWithLoRA is used to handle + # long context lora. Register relevant metadata. + if isinstance(new_module, LinearScalingRotaryEmbeddingWithLoRA): + self.long_lora_context = LongContextLoRAContext( + new_module.scaling_factors, new_module.rotary_dim) + self.scaling_factor_to_offset = \ + new_module.scaling_factor_to_offset + # (yard1): TODO make this more robust + if "lm_head" in module_name: + logits_processor_module = self.model.get_submodule( + "logits_processor") + new_module = replace_submodule( + self.model, "logits_processor", + from_layer_logits_processor(logits_processor_module, + module, self.lora_slots, + self.lora_config, + self.model.config)) + + # In some models, especially multimodal ones, layers with the same + # name may have different types, such as nn.Linear and + # ReplicatedLinear. The nn.Linear layers cannot be replaced with + # LoRA layers, leading to assertion error. The following check + # aims to prevent this error + if self.supports_mm and not isinstance(new_module, + BaseLayerWithLoRA): + continue + self.register_module(module_name, new_module) + self._register_packed_modules(module_name) + # All lora layers share the same punica_wrapper based on reference. + new_module.set_mapping(self.punica_wrapper) + + def register_module(self, module_name: str, module: "BaseLayerWithLoRA"): + assert isinstance(module, BaseLayerWithLoRA) + self.modules[module_name] = module + + def create_dummy_lora( + self, + lora_id: int, + rank: int, + scaling_factor: Optional[float], + embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel: + """Create zero-initialized LoRAModel for warmup.""" + model = LoRAModel(lora_id, rank, {}, scaling_factor) + for module_name, module in self.model.named_modules(): + bias_enabled = self.lora_config.bias_enabled + if (not self._match_target_modules(module_name) + or not isinstance(module, BaseLayerWithLoRA) + or isinstance(module, LinearScalingRotaryEmbeddingWithLoRA) + or self._filter_unsupported_mm_module(module_name)): + continue + parts = module_name.split(".") + if module_name not in self.packed_modules: + assert embedding_modules is not None + if parts[-1] in embedding_modules: + input_dim = (module.base_layer.org_vocab_size + + self.lora_config.lora_extra_vocab_size if + hasattr(module.base_layer, "org_vocab_size") + else module.base_layer.weight.shape[1]) + output_dim = module.base_layer.embedding_dim if hasattr( + module.base_layer, + "embedding_dim") else module.base_layer.weight.shape[0] + embeddings_tensor_dim = (module.base_layer.embedding_dim if + hasattr(module.base_layer, + "embedding_dim") else + module.base_layer.weight.shape[1]) + lora = LoRALayerWeights.create_dummy_lora_weights( + module_name, + input_dim, + output_dim, + rank, + module.lora_a_stacked[0].dtype, + "cpu", + embeddings_tensor_dim=embeddings_tensor_dim, + bias_enabled=bias_enabled) + else: + lora = LoRALayerWeights.create_dummy_lora_weights( + module_name, + module.lora_a_stacked[0].shape[-1], + module.lora_b_stacked[0].shape[-2], + rank, + module.lora_a_stacked[0].dtype, + "cpu", + bias_enabled=bias_enabled, + ) + lora.optimize() + else: + parts = module_name.split(".") + replacements = self.packed_modules_mapping[parts[-1]] + subloras: List[Optional[LoRALayerWeights]] = [] + for i, r in enumerate(replacements): + lora = LoRALayerWeights.create_dummy_lora_weights( + module_name + "." + r, + module.lora_a_stacked[i].shape[-1], + module.lora_b_stacked[i].shape[-2], + rank, + module.lora_a_stacked[i].dtype, + "cpu", + bias_enabled=bias_enabled, + ) + lora.optimize() + subloras.append(lora) + lora = PackedLoRALayerWeights.pack(subloras) + model.loras[module_name] = lora + return model + + def _match_target_modules(self, module_name: str): + return any( + re.match( + r".*\.{target_module}$".format(target_module=target_module), + module_name) or target_module == module_name + for target_module in self.supported_lora_modules) + + def _filter_unsupported_mm_module(self, module_name: str) -> bool: + """ + Regarding multimodal models, vLLM currently only supports adding LoRA to + language model. LoRA for other modules, such as the vision tower, will + be filtered out. + """ + if self.supports_mm: + module_mapping: MultiModelKeys = self.model.get_mm_mapping() + prefix_lst = module_mapping.connector + module_mapping.tower_model + return any( + [module_name.startswith(prefix) for prefix in prefix_lst]) + return False + + def _register_packed_modules(self, module_full_name: str) -> None: + parts = module_full_name.split(".") + module_name = parts[-1] + replacements = self.packed_modules_mapping.get(module_name, []) + # When replacements is less than or equal to 1, it indicates that this + # module is not a packed module. + if len(replacements) <= 1: + return + prefix = ".".join(parts[:-1]) + self.packed_modules[module_full_name] = [ + prefix + "." + r if prefix else r for r in replacements + ] + + def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: + for module_name, new_module_names in self.packed_modules.items(): + replacement_loras: List[Optional[LoRALayerWeights]] = [] + replaced_module: Set[str] = set() + has_replacement = False + for r in new_module_names: + lora = self._get_lora_layer_weights(lora_model, r) + replacement_loras.append(lora) + if lora: + has_replacement = True + replaced_module.add(r) + if not has_replacement: + continue + for i in range(len(replacement_loras)): + if replacement_loras[i]: + continue + replacement_loras[i] = None + # HACK Temporary solution for the pool model. + if self.is_pooling_model and not lora_model.check_lora_name( + module_name): + replaced_module_name = module_name.replace("model.", "") + if lora_model.check_lora_name(module_name): + module_name = replaced_module_name + lora_model.loras[module_name] = PackedLoRALayerWeights.pack( + replacement_loras) + # Remove the modules that have been replaced. + for module in replaced_module: + lora_model.loras.pop(module, None) + + def _get_lora_layer_weights( + self, lora_model: LoRAModel, + module_name: str) -> Optional[LoRALayerWeights]: + org_module_name = module_name + if self.is_pooling_model and not lora_model.check_lora_name( + module_name): + # If it's a pool model, and the layer name is not found, + # remove the prefix 'model.' and search again. + module_name = module_name.replace("model.", "") + if lora_model.check_lora_name(module_name): + org_module_name = module_name + logger.info_once( + "For the pool model, successfully loaded the LoRA weights " + "after removing the prefix 'model.'.") + return lora_model.get_lora(org_module_name) + + def deactivate_adapter(self, adapter_id: int) -> bool: + return deactivate_adapter(adapter_id, self._active_adapters, + self._deactivate_adapter) + + def add_adapter(self, adapter: LoRAModel) -> bool: + logger.debug( + "Adding lora. Model id: %d, " + "int id: %d, " + "scaling factor: %s", adapter.id, adapter.id, + adapter.scaling_factor) + return add_adapter(adapter, self._registered_adapters, self.capacity, + self._add_adapter) + + def set_adapter_mapping(self, mapping: LoRAMapping) -> None: + self._last_mapping = set_adapter_mapping(mapping, self._last_mapping, + self._set_adapter_mapping) + + def remove_adapter(self, adapter_id: int) -> bool: + return remove_adapter(adapter_id, self._registered_adapters, + self.deactivate_adapter) + + def list_adapters(self) -> Dict[int, Any]: + return list_adapters(self._registered_adapters) + + def get_adapter(self, adapter_id: int) -> Optional[Any]: + return get_adapter(adapter_id, self._registered_adapters) + + +class LoRALRUCache(AdapterLRUCache[LoRAModel]): + + def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], + bool]): + super().__init__(capacity, deactivate_lora_fn) + + +class LRUCacheLoRAModelManager(LoRAModelManager): + """A model manager that manages multiple LoRAs with LRU cache.""" + + def __init__(self, model: nn.Module, max_num_seqs: int, + max_num_batched_tokens: int, vocab_size: int, + lora_config: LoRAConfig, device: torch.device): + super().__init__(model, max_num_seqs, max_num_batched_tokens, + vocab_size, lora_config, device) + self._registered_adapters: LoRALRUCache = LoRALRUCache( + self.capacity, self.deactivate_adapter) + self._active_adapters: LoRALRUCache = LoRALRUCache( + self.lora_slots, self._deactivate_adapter) + + def list_adapters(self) -> Dict[int, LoRAModel]: + """List all registered LoRAModels.""" + return dict(self._registered_adapters.cache) + + def add_adapter(self, lora: LoRAModel) -> bool: + """Add a LoRAModel to the manager.""" + logger.debug( + "Adding lora. Model id: %d, " + "int id: %d, " + "scaling factor: %s", lora.id, lora.id, lora.scaling_factor) + if lora.id not in self._registered_adapters: + self._add_adapter(lora) + was_added = True + else: + # We always touch to update the LRU cache order + self._registered_adapters.touch(lora.id) + was_added = False + return was_added + + def activate_adapter( + self, + lora_id: int, + ) -> bool: + if lora_id not in self._active_adapters and len( + self._active_adapters) >= self.lora_slots: + self._active_adapters.remove_oldest() + result = super().activate_adapter(lora_id) + # We always touch to update the LRU cache order + self._active_adapters.touch(lora_id) + return result + + def remove_oldest_adapter(self) -> bool: + if len(self._registered_adapters) > 0: + self._registered_adapters.remove_oldest() + return True + return False + + def pin_adapter(self, lora_id: int) -> bool: + """Pin a LoRAModel in the manager cache.""" + self._pin_lora_in_cpu_cache(lora_id) + self._pin_lora_in_gpu_cache(lora_id) + return True + + def _pin_lora_in_cpu_cache(self, lora_id: int): + try: + self._registered_adapters.pin(lora_id) + except ValueError as err: + raise ValueError("Pinning failed. " + f"LoRA {lora_id} is not registered.") from err + + def _pin_lora_in_gpu_cache(self, lora_id: int): + if lora_id not in self._active_adapters: + # move lora to gpu if not already active + self.activate_adapter(lora_id) + + self._active_adapters.pin(lora_id) + + +def create_lora_manager( + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + device: torch.device, + lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager, + **kwargs) -> LoRAModelManager: + """Create a LoRA adapter for a given model.""" + if not hasattr(model, "packed_modules_mapping"): + raise ValueError(f"Model {type(model)} is not supported for LoRA.") + lora_manager = lora_manager_cls( + model=model, + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + vocab_size=vocab_size, + lora_config=lora_config, + device=device, + **kwargs) + return lora_manager diff --git a/vllm/lora/ops/__init__.py b/vllm/lora/ops/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/lora/ops/torch_ops/__init__.py b/vllm/lora/ops/torch_ops/__init__.py new file mode 100644 index 0000000..85601d5 --- /dev/null +++ b/vllm/lora/ops/torch_ops/__init__.py @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm.lora.ops.torch_ops.lora_ops import bgmv_expand # noqa: F401 +from vllm.lora.ops.torch_ops.lora_ops import (bgmv_expand_slice, bgmv_shrink, + sgmv_expand, sgmv_expand_slice, + sgmv_shrink) + +__all__ = [ + "bgmv_expand", + "bgmv_expand_slice", + "bgmv_shrink", + "sgmv_expand", + "sgmv_expand_slice", + "sgmv_shrink", +] diff --git a/vllm/lora/ops/torch_ops/lora_ops.py b/vllm/lora/ops/torch_ops/lora_ops.py new file mode 100644 index 0000000..af79f98 --- /dev/null +++ b/vllm/lora/ops/torch_ops/lora_ops.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def sgmv_expand(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + add_inputs: bool = False): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, + seq_len_tensor) + + bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices, + add_inputs) + + +def bgmv_expand(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True): + selected_loras = lora_b_weights[lora_indices_tensor].to( + dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + inputs = inputs.to(dtype=output_tensor.dtype) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + limit = output_tensor.shape[0] + if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: + limit = 1 + + if add_inputs: + output_tensor[:, :outputs.shape[1]] += outputs[:limit, :] + else: + output_tensor[:, :outputs.shape[1]] = outputs[:limit, :] + + +def sgmv_shrink( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + scaling: float, +): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, + seq_len_tensor) + + bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices, + scaling) + + +def bgmv_shrink(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0): + selected_loras = lora_b_weights[lora_indices_tensor].to( + dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + inputs = inputs.to(dtype=output_tensor.dtype) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] + + +def sgmv_expand_slice(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + slice_offset: int, + slice_size: int, + add_inputs: bool = False): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, + seq_len_tensor) + + bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices, + slice_offset, slice_size, add_inputs) + + +def bgmv_expand_slice(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True): + selected_loras = lora_b_weights[lora_indices_tensor].to( + dtype=output_tensor.dtype) + inputs = inputs.to(dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + if add_inputs: + output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:] + else: + output_tensor[:, slice_offset:slice_offset + slice_size] = outputs[:] diff --git a/vllm/lora/ops/triton_ops/__init__.py b/vllm/lora/ops/triton_ops/__init__.py new file mode 100644 index 0000000..acae0d9 --- /dev/null +++ b/vllm/lora/ops/triton_ops/__init__.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm.lora.ops.triton_ops.lora_expand import lora_expand +from vllm.lora.ops.triton_ops.lora_kernel_metadata import LoRAKernelMeta +from vllm.lora.ops.triton_ops.lora_shrink import lora_shrink + +__all__ = [ + "lora_expand", + "lora_shrink", + "LoRAKernelMeta", +] diff --git a/vllm/lora/ops/triton_ops/kernel_utils.py b/vllm/lora/ops/triton_ops/kernel_utils.py new file mode 100644 index 0000000..5b8c193 --- /dev/null +++ b/vllm/lora/ops/triton_ops/kernel_utils.py @@ -0,0 +1,243 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Utilities for Punica kernel construction. +""" +import triton +import triton.language as tl + + +@triton.jit +def mm_k(a_ptr, b_ptr, ak_stride, bk_stride, offset_k, K: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr, CAST_TYPE: tl.constexpr, + b_dtype: tl.constexpr): + """ + Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of + B (k x n), iterate, through the K dimension to compute the partial/complete + matrix block product. + If SPLIT_K == 1, the output m x n product is complete. + If SPLIT_K > 1, the thread block computes partial outputs. The partial + outputs are then atomically summed in the caller code. + Args: + a_ptr: Array of pointers, identifying rows of A + b_ptr: Array of pointers, identifying columns of B + ak_stride: K dimension stride of the A matrix + bk_stride: K dimension stride of the B matrix + K: Length of the K dimension + BLOCK_M: M dimension of the output block m x n + BLOCK_N: N dimension of the output block m x n + BLOCK_K: K dimension atom + EVEN_K: True if the blocks of A and B can be loaded without any + masking. + SPLIT_K: Parameter signifying parallelism in the K dimension. + CAST_TYPE: if True, cast the values from the A matrix to the B + matrix dtype. + b_dtype: datatype of the B matrix + """ + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + tiled_a = tl.load(a_ptr) + tiled_b = tl.load(b_ptr) + else: + tiled_a = tl.load(a_ptr, + mask=offset_k[None, :] + < K - k * (BLOCK_K * SPLIT_K), + other=0) + tiled_b = tl.load(b_ptr, + mask=offset_k[:, None] + < K - k * (BLOCK_K * SPLIT_K), + other=0) + if CAST_TYPE: + tiled_a = tiled_a.to(b_dtype) + accumulator += tl.dot( + tiled_a, + tiled_b, + ) + a_ptr += BLOCK_K * SPLIT_K * ak_stride + b_ptr += BLOCK_K * SPLIT_K * bk_stride + return accumulator + + +@triton.jit +def do_expand_kernel( + pid_n, + lora_index, + slice_id, + input_ptr, + lora_ptr, + out_ptr, + N, + K, + M_LEN, + ram, # array identifying the rows of Input ptr to operate on + slice_start_loc, + # input ptr strides + input_d0_stride, + input_d1_stride, + input_d2_stride, + # lora ptr strides + ls_d0_ptr, + ls_d1_ptr, + ls_d2_ptr, + # out ptr strides + output_d0_stride, + output_d1_stride, + # constants + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + SAME_STRIDE: tl.constexpr, + SLICE_NUM: tl.constexpr, + EVEN_K: tl.constexpr, + CAST_TYPE: tl.constexpr, + ADD_INPUTS: tl.constexpr, +): + """ + Given an array of integers that identifies the rows of A, ram, + a lora index that identifies which LoRA to use from lora_ptr, lora_index, + a slice_id that identifies the input/output slice, + compute the matrix product and store in the appropriate output location. + Given that this is an expand kernel, we don't perform any split-K reduction + as the K dimension is assumed to be small. + """ + + # ls_d*_ptr can be either an integer or a pointer + if SAME_STRIDE: + # integer + cur_lora_d0_stride = ls_d0_ptr + cur_lora_d1_stride = ls_d1_ptr + cur_lora_d2_stride = ls_d2_ptr + else: + # pointer + cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id) + cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id) + cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id) + + # Identify the input_ptr and lora_ptr from slice_id. + if SLICE_NUM == 1: + cur_input_ptr = input_ptr + cur_lora_ptr = lora_ptr + else: + cur_input_ptr = input_ptr + slice_id * input_d0_stride + cur_lora_ptr = tl.load(lora_ptr + slice_id).to( + tl.pointer_type(out_ptr.dtype.element_ty)) + + # Identify the column indices of B to process. + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + + # Identify A and B block pointers + offset_k = tl.arange(0, BLOCK_K) + a_ptr = (cur_input_ptr + ram[:, None] * input_d1_stride + + offset_k[None, :] * input_d2_stride) + b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index + + offset_k[:, None] * cur_lora_d2_stride + + rbn[None, :] * cur_lora_d1_stride) + + # Compute the block matrix product. + SPLIT_K = 1 + accumulator = mm_k(a_ptr, b_ptr, input_d2_stride, cur_lora_d2_stride, + offset_k, K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K, + CAST_TYPE, cur_lora_ptr.dtype.element_ty) + + tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty) + if SLICE_NUM == 1: + cur_slice_start = slice_start_loc + else: + cur_slice_start = tl.load(slice_start_loc + slice_id) + + # Identify the C output pointers to store the results of the accumulator. + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start + offset_cm = tl.arange(0, BLOCK_M) + c_ptr = (out_ptr + ram[:, None] * output_d0_stride + + offset_cn[None, :] * output_d1_stride) + c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] + < (cur_slice_start + N)) + + if ADD_INPUTS: + tiled_out = tl.load(c_ptr, mask=c_mask) + tiled_c += tiled_out + tl.store(c_ptr, tiled_c, mask=c_mask) + + +@triton.jit +def do_shrink_kernel( + pid_n, + pid_sk, + slice_id, + lora_index, + input_ptr, + lora_ptr, + out_ptr, + N, + K, + M_LEN, + ram, + # input strides + input_d0_stride, + input_d1_stride, + # lora strides + lora_d0_stride, + lora_d1_stride, + lora_d2_stride, + # output strides + output_d0_stride, + output_d1_stride, + output_d2_stride, + scaling, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + SPLIT_K: tl.constexpr, + SLICE_NUM: tl.constexpr, +): + """ + Given an array of integers that identifies the rows of A, ram, + a lora index that identifies which LoRA to use from lora_ptr, lora_index, + a slice_id that identifies the input/output slice, compute the + matrix product and store in the appropriate output location. + """ + + # Identify the lora_ptr from slice_id. + if SLICE_NUM == 1: + # current lora ptr + cur_lora_ptr = lora_ptr + else: + # current lora ptr + cur_lora_ptr = tl.load(lora_ptr + slice_id).to( + tl.pointer_type(input_ptr.dtype.element_ty)) + + # Identify the column indices of B to process. + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + + # Identify A and B block pointers + offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K) + a_ptr = (input_ptr + ram[:, None] * input_d0_stride + + offset_k[None, :] * input_d1_stride) + b_ptr = (cur_lora_ptr + lora_d0_stride * lora_index + + rbn[None, :] * lora_d1_stride + + offset_k[:, None] * lora_d2_stride) + + # Compute partial/complete block matrix product. + accumulator = mm_k(a_ptr, b_ptr, input_d1_stride, lora_d2_stride, offset_k, + K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K, False, + cur_lora_ptr.dtype.element_ty) + + # Identify the C output pointers to store the results of the accumulator. + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + offset_cm = tl.arange(0, BLOCK_M) + cur_out_ptr = (out_ptr if SLICE_NUM == 1 else out_ptr + + slice_id * output_d0_stride) + c_ptr = cur_out_ptr + ram[:, None] * output_d1_stride + offset_cn[ + None, :] * output_d2_stride + c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] < N) + + accumulator *= scaling + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(c_ptr, accumulator, mask=c_mask) + else: + tl.atomic_add(c_ptr, accumulator, mask=c_mask) diff --git a/vllm/lora/ops/triton_ops/lora_expand.py b/vllm/lora/ops/triton_ops/lora_expand.py new file mode 100644 index 0000000..eacc6fb --- /dev/null +++ b/vllm/lora/ops/triton_ops/lora_expand.py @@ -0,0 +1,293 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from typing import List + +import torch +import triton +import triton.language as tl + +from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel +from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr +from vllm.utils import direct_register_custom_op + + +@triton.jit +def _lora_expand_kernel( + input_ptr, + lora_ptr, + out_ptr, + M, + N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + slice_start_loc, + input_d0_stride, + input_d1_stride, + input_d2_stride, # 1 + ls_d0_ptr, + ls_d1_ptr, + ls_d2_ptr, # 1 + output_d0_stride, + output_d1_stride, # 1 + output_hs_ptr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, + SLICE_NUM: tl.constexpr, + SAME_STRIDE: tl.constexpr): + + cta_n_num = tl.cdiv(N, BLOCK_N) + cta_m_num = tl.cdiv(M, BLOCK_M) + + pid_mn = tl.program_id(axis=0) + pid_m = pid_mn % cta_m_num + pid_n = (pid_mn // cta_m_num) % cta_n_num + + slice_id = tl.program_id(axis=1) + lora_idx = tl.program_id(axis=2) + + lora_id = tl.load(lora_ids + lora_idx) + if lora_id == -1: + # Early exit for the no-lora case. + return + + lora_m_size = tl.load(num_tokens_per_lora + lora_idx) + + cta_m_offset = pid_m * BLOCK_M + if cta_m_offset >= lora_m_size: + # Early exit CTA. + return + + # When the output dimensions of each slice are the same,cur_n=N, otherwise + # cur_n=tl.load(output_hs_ptr + slice_id), this situation exists in GQA's + # qkv linear. + curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id) + if pid_n * BLOCK_N >= curr_N: + # Early exit CTA. + return + + # num rows this CTA should process. + cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset) + + # Identify all rows that this CTA should process. + lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx) + cta_lora_seq_indices = (token_indices_sorted_by_lora_ids + + lora_m_indices_start + cta_m_offset) + + # Load all relevant row indices. + offset_m = tl.arange(0, BLOCK_M) % cta_m_len + ram = tl.load(cta_lora_seq_indices + offset_m) + + do_expand_kernel( + pid_n, + lora_id, + slice_id, + input_ptr, + lora_ptr, + out_ptr, + curr_N, + K, + cta_m_len, + ram, # array identifying the rows of Input ptr to operate on + slice_start_loc, + # input ptr strides + input_d0_stride, + input_d1_stride, + input_d2_stride, + # lora ptr strides + ls_d0_ptr, + ls_d1_ptr, + ls_d2_ptr, + # out ptr strides + output_d0_stride, + output_d1_stride, + # constants + BLOCK_M, + BLOCK_N, + BLOCK_K, + SAME_STRIDE, + SLICE_NUM, + EVEN_K, + CAST_TYPE, + ADD_INPUTS) + + +@torch.inference_mode() +def _lora_expand( + inputs: torch.Tensor, # shape [num_slices, num_tokens, lora_rank] + lora_b_weights: List[ + torch.Tensor], # shape [num_lora, hidden_size, lora_rank] + output_tensor: torch. + Tensor, # shape [num_tokens, hidden_size * num_slices] + token_lora_mapping: torch.Tensor, # shape [num_tokens] + token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens] + num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1] + lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] + lora_ids: torch.Tensor, # shape [max-loras + 1] + no_lora_flag_cpu: torch.Tensor, # shape [1] + offset_start: int = 0, + add_inputs: bool = False, +) -> None: + """ + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (List[torch.Tensor]): lora'b weight + output_tensor (torch.Tensor): output tensor + token_lora_mapping (torch.Tensor): A tensor mapping each input token + to the lora-id related to that token. A value of -1 indicates that + LoRA doesn't apply to that token. + token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from + the A matrix grouped by LoRA IDs. + num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number + of tokens that are to be processed by LoRA ID lora_ids[i] + lora_token_start_loc (torch.Tensor): A cumulative sum of + num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that + lora_token_start_loc[i], along with num_tokens_per_lora[i] + identifies the the region in token_indices_sorted_by_lora_ids that + LoRA lora_ids[i] should process. + lora_ids (torch.Tensor): LoRA ids to process. + no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates + if there are any requests that require LoRA. + offset_start (int, optional): Offset start for output_tensor. + Defaults to 0. + add_inputs (bool, optional): Whether to add the input tensor to the + output tensor. Defaults to False. + """ + + assert no_lora_flag_cpu.numel() == 1 + if no_lora_flag_cpu.item(): + # None of the inputs require LoRA. + return + + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + for weight in lora_b_weights: + assert weight.dtype in [torch.float16, torch.bfloat16] + + assert inputs.size(0) == len(lora_b_weights) + assert output_tensor.is_contiguous() + + # metadata sanity check. + M = inputs.size(1) + assert token_lora_mapping.size(0) == M + assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size( + 0) + assert lora_ids.size(0) == num_tokens_per_lora.size(0) + assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1 + + (slice_start_tensor, lora_ptr_tensor, lora_strides_d0_tensor, + lora_strides_d1_tensor, lora_strides_d2_tensor, hidden_sizes_tensor, + same_stride, MAX_N) = _get_lora_b_ptr(lora_b_weights, offset_start, + inputs.device) + + K = lora_b_weights[0].shape[-1] # K= rank + ADD_INPUTS = add_inputs + MAX_LORAS = lora_ids.size(0) + CAST_TYPE = False + NUM_SLICES = len(lora_b_weights) + + # Triton kernel configs. + BLOCK_M = 64 + BLOCK_N = 128 + BLOCK_K = 16 + NUM_WARPS = 4 + NUM_CTAS = 1 + NUM_STAGES = 2 + MAX_NREG = None + + EVEN_K = K % BLOCK_K == 0 # type: ignore + + if inputs.dtype == torch.float32 and lora_b_weights[0].dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + + # TODO (varun): This grid formulation maximizes parallelization at the + # cost of wasteful thread block launch when only a few input tokens require + # LoRA. This might not be the best in all cases. + grid = ( + triton.cdiv(M, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N), + NUM_SLICES, + # Each LoRA receives its own set of thread blocks for output + # computation. If some LoRA doesn't have any tokens to process, its + # thread blocks simply exit. + MAX_LORAS, + ) + + _lora_expand_kernel[grid]( + inputs, + lora_ptr_tensor, + output_tensor, + M, + MAX_N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + slice_start_tensor, + inputs.stride(0), + inputs.stride(1), + inputs.stride(2), + lora_strides_d0_tensor, + lora_strides_d1_tensor, + lora_strides_d2_tensor, + output_tensor.stride(0), + output_tensor.stride(1), + hidden_sizes_tensor, + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + ADD_INPUTS, + CAST_TYPE, + NUM_SLICES, + same_stride, + num_warps=NUM_WARPS, + num_ctas=NUM_CTAS, + num_stages=NUM_STAGES, + maxnreg=MAX_NREG, + ) + + return + + +def _lora_expand_fake( + inputs: torch.Tensor, + lora_b_weights: List[torch.Tensor], + output_tensor: torch.Tensor, + token_lora_mapping: torch.Tensor, + token_indices_sorted_by_lora_ids: torch.Tensor, + num_tokens_per_lora: torch.Tensor, + lora_token_start_loc: torch.Tensor, + lora_ids: torch.Tensor, + no_lora_flag_cpu: torch.Tensor, + offset_start: int = 0, + add_inputs: bool = False, +) -> None: + return + + +try: + direct_register_custom_op( + op_name="lora_expand", + op_func=_lora_expand, + mutates_args=["output_tensor"], + fake_impl=_lora_expand_fake, + ) + lora_expand = torch.ops.vllm.lora_expand + +except AttributeError: + lora_expand = _lora_expand diff --git a/vllm/lora/ops/triton_ops/lora_kernel_metadata.py b/vllm/lora/ops/triton_ops/lora_kernel_metadata.py new file mode 100644 index 0000000..1dcdfc8 --- /dev/null +++ b/vllm/lora/ops/triton_ops/lora_kernel_metadata.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +LoRA kernels metadata preparation utilities. +""" + +from dataclasses import dataclass +from typing import Tuple, Union + +import torch + + +@dataclass +class LoRAKernelMeta: + token_lora_mapping: torch.Tensor + token_indices_sorted_by_lora_ids: torch.Tensor + active_lora_ids: torch.Tensor + num_tokens_per_lora: torch.Tensor + lora_token_start_loc: torch.Tensor + + # The V1 architecture uses the traced torch.compile graphs to execute + # a forward pass. Things to note about this process, + # 1. The tracing infers all python scalar datatype objects into a constant + # value. + # 2. The tracing cannot handle dynamic control flow. (dynamic control flow + # is an experimental feature in pytorch) + # 3. The internals of torch.ops functions are not traced. + # We disguise the "no_lora" flag as a cpu tensor and leverage point number 3 + # to early exit from inside the lora_expand / lora_shrink torch operation. + no_lora_flag_cpu: torch.Tensor + + @staticmethod + def make(max_loras: int, max_num_tokens: int, + device: Union[torch.device, str]) -> "LoRAKernelMeta": + + token_lora_mapping = torch.empty(max_num_tokens, + dtype=torch.int32, + device=device) + + token_indices_sorted_by_lora_ids = torch.empty(max_num_tokens, + dtype=torch.int32, + device=device) + + # +1 because "no-lora" is also a possibility + # example: let max_loras be 3, active_lora_ids of [-1, 0, 2, 1] + # is a possibility. + active_lora_ids = torch.empty(max_loras + 1, + dtype=torch.int32, + device=device) + + # using running example, [3, 10, 5, 2] is a possibility. + num_tokens_per_lora = torch.zeros(max_loras + 1, + dtype=torch.int32, + device=device) + + # +2 for this because, the first index is always 0. + # using running example, lora_token_start_loc + # is [0, 3, 13, 18, 20]. + lora_token_start_loc = torch.zeros(max_loras + 2, + dtype=torch.int32, + device=device) + + no_lora_flag_cpu = torch.tensor([False], + dtype=torch.bool, + device='cpu') + + return LoRAKernelMeta( + token_lora_mapping=token_lora_mapping, + token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids, + active_lora_ids=active_lora_ids, + num_tokens_per_lora=num_tokens_per_lora, + lora_token_start_loc=lora_token_start_loc, + no_lora_flag_cpu=no_lora_flag_cpu) + + def _reset(self): + self.active_lora_ids.fill_(-1) + self.num_tokens_per_lora.fill_(0) + self.lora_token_start_loc.fill_(0) + self.no_lora_flag_cpu.fill_(False) + + def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None: + """ + Prepare kernel metadata tensors for the current forward pass. + + Args: + token_lora_tensor (torch.Tensor): Tensor containing lora indices + for each input token. + """ + + self._reset() + + # Check and record no-lora case. + no_lora = torch.all(token_lora_mapping == -1) + self.no_lora_flag_cpu[0] = no_lora + + if no_lora: + # Early exit. LoRA kernels will not be run. + return + + num_tokens = token_lora_mapping.size(0) + + # copy token lora mapping + self.token_lora_mapping[:num_tokens].copy_(token_lora_mapping, + non_blocking=True) + + # token_indices_sorted_by_lora_ids + _, token_indices_sorted_by_lora_ids = torch.sort(token_lora_mapping, + stable=True) + # start gpu transfer + self.token_indices_sorted_by_lora_ids[:num_tokens].copy_( + token_indices_sorted_by_lora_ids, non_blocking=True) + + # active_lora_ids, num_tokens_per_lora + lora_ids, num_tokens_per_lora = torch.unique(token_lora_mapping, + sorted=False, + return_counts=True) + self.active_lora_ids[:lora_ids.size(0)].copy_(lora_ids, + non_blocking=True) + self.num_tokens_per_lora[:num_tokens_per_lora.size(0)].copy_( + num_tokens_per_lora, non_blocking=True) + + # lora_token_start_loc + lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0) + self.lora_token_start_loc[1:1 + lora_token_start_loc.size(0)].copy_( + lora_token_start_loc, non_blocking=True) + + def meta_args( + self, token_nums: int + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + torch.Tensor, torch.Tensor]: + """ + This function returns the kernel metadata required for the current + forward pass execution of the kernel. The function returns all the + metadata required by the kernel, in order, as a tuple, so it can be + unpacked directly during the lora_shrink/lora_expand function call. + + Args: + token_nums (int): Number of input tokens in the current forward + pass. + """ + return ( + self.token_lora_mapping[:token_nums], + self.token_indices_sorted_by_lora_ids[:token_nums], + self.num_tokens_per_lora, + self.lora_token_start_loc, + self.active_lora_ids, + self.no_lora_flag_cpu, + ) diff --git a/vllm/lora/ops/triton_ops/lora_shrink.py b/vllm/lora/ops/triton_ops/lora_shrink.py new file mode 100644 index 0000000..8233193 --- /dev/null +++ b/vllm/lora/ops/triton_ops/lora_shrink.py @@ -0,0 +1,247 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from typing import List + +import torch +import triton +import triton.language as tl + +from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel +from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr +from vllm.utils import direct_register_custom_op + + +@triton.jit +def _lora_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K, + token_indices_sorted_by_lora_ids, num_tokens_per_lora, + lora_token_start_loc, lora_ids, scaling, + input_d0_stride, input_d1_stride, lora_d0_stride, + lora_d1_stride, lora_d2_stride, output_d0_stride, + output_d1_stride, output_d2_stride, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, EVEN_K: tl.constexpr, + SPLIT_K: tl.constexpr, SLICE_NUM: tl.constexpr): + + cta_n_num = tl.cdiv(N, BLOCK_N) + cta_m_num = tl.cdiv(M, BLOCK_M) + + pid_sk_m_n = tl.program_id(axis=0) + pid_sk = pid_sk_m_n % SPLIT_K + pid_m = (pid_sk_m_n // SPLIT_K) % cta_m_num + pid_n = pid_sk_m_n // (SPLIT_K * cta_m_num) % cta_n_num + + slice_id = tl.program_id(axis=1) + lora_idx = tl.program_id(axis=2) + + lora_id = tl.load(lora_ids + lora_idx) + if lora_id == -1: + # Early exit for the no-lora case. + return + + lora_m_size = tl.load(num_tokens_per_lora + lora_idx) + + cta_m_offset = pid_m * BLOCK_M + if cta_m_offset >= lora_m_size: + # Early exit CTA. + return + + # num rows this CTA should process. + cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset) + + # Identify all rows that this CTA should process. + lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx) + cta_lora_seq_indices = (token_indices_sorted_by_lora_ids + + lora_m_indices_start + cta_m_offset) + + # Load all relevant row indices. + offset_m = tl.arange(0, BLOCK_M) % cta_m_len + ram = tl.load(cta_lora_seq_indices + offset_m) + + do_shrink_kernel( + pid_n, + pid_sk, + slice_id, + lora_id, + input_ptr, + lora_ptr, + out_ptr, + N, + K, + cta_m_len, + ram, # array identifying the rows of Input ptr to operate on + # input strides + input_d0_stride, + input_d1_stride, + # lora strides + lora_d0_stride, + lora_d1_stride, + lora_d2_stride, + # output strides + output_d0_stride, + output_d1_stride, + output_d2_stride, + scaling, + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + SPLIT_K, + SLICE_NUM) + + +@torch.inference_mode() +def _lora_shrink( + inputs: torch.Tensor, # shape [num_tokens, hidden_size] + lora_a_weights: List[ + torch.Tensor], # shape [num_loras, lora_rank, hidden_size] + output_tensor: torch.Tensor, # shape [num_slices, num_tokens, lora_rank] + token_lora_mapping: torch.Tensor, # shape [num_tokens] + token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens] + num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1] + lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] + lora_ids: torch.Tensor, # shape [max-loras + 1] + no_lora_flag_cpu: torch.Tensor, # shape [1] + scaling: float, +) -> None: + """ + Args: + inputs (torch.Tensor): Input tensor + lora_a_weights (List[torch.Tensor]): LoRA weights + output_tensor (torch.Tensor): output tensor + token_lora_mapping (torch.Tensor): A tensor mapping each input token + to the lora-id related to that token. A value of -1 indicates that + LoRA doesn't apply to that token. + token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from + the A matrix grouped by LoRA IDs. + num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number + of tokens that are to be processed by LoRA ID lora_ids[i] + lora_token_start_loc (torch.Tensor): A cumulative sum of + num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that + lora_token_start_loc[i], along with num_tokens_per_lora[i] + identifies the region in token_indices_sorted_by_lora_ids that + LoRA lora_ids[i] should process. + lora_ids (torch.Tensor): LoRA ids to process. + no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates + if there are any requests that require LoRA. + scaling (float): Scaling factor. + """ + + assert no_lora_flag_cpu.numel() == 1 + if no_lora_flag_cpu.item(): + # None of the inputs require LoRA. + return + + assert inputs.dtype == lora_a_weights[0].dtype + assert inputs.dtype in [torch.float16, torch.bfloat16] + for weight in lora_a_weights: + assert weight.dtype in [torch.float16, torch.bfloat16] + + assert inputs.size(1) == lora_a_weights[0].size(-1) + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + + # metadata sanity check + M = inputs.size(0) + assert token_lora_mapping.size(0) == M + assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size( + 0) + assert lora_ids.size(0) == num_tokens_per_lora.size(0) + assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1 + + (lora_ptr_tensor, lora_strides_d0, lora_strides_d1, + lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, inputs.device) + N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank + NUM_SLICES = len(lora_a_weights) + MAX_LORAS = lora_ids.size(0) + + # Triton kernel configs + BLOCK_M = 32 + BLOCK_N = 16 + BLOCK_K = 256 if M < 128 else 32 + SPLIT_K = 64 if M < 128 else 8 + NUM_WARPS = 4 + NUM_CTAS = 1 + NUM_STAGES = 2 + MAX_NREG = None + + EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 # type: ignore + + # TODO (varun): This grid formulation maximizes parallelization at the + # cost of wasteful thread block launch when only few of the input tokens + # require LoRA. This might not be the best in all cases. + grid = ( + SPLIT_K * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), + NUM_SLICES, + # Each LoRA receives its own set of thread blocks for output + # computation. If some LoRA doesn't have any tokens to process, its + # thread blocks exit early. + MAX_LORAS, + ) + + _lora_shrink_kernel[grid]( + inputs, + lora_ptr_tensor, + output_tensor, + M, + N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + scaling, + inputs.stride(0), + inputs.stride(1), + lora_strides_d0, + lora_strides_d1, + lora_strides_d2, + output_tensor.stride(0), + output_tensor.stride(1), + output_tensor.stride(2), + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + SPLIT_K, + NUM_SLICES, + num_warps=NUM_WARPS, + num_ctas=NUM_CTAS, + num_stages=NUM_STAGES, + maxnreg=MAX_NREG, + ) + + return + + +def _lora_shrink_fake( + inputs: torch.Tensor, + lora_a_weights: List[torch.Tensor], + output_tensor: torch.Tensor, + token_lora_mapping: torch.Tensor, + token_indices_sorted_by_lora_ids: torch.Tensor, + num_tokens_per_lora: torch.Tensor, + lora_token_start_loc: torch.Tensor, + lora_ids: torch.Tensor, + no_lora_flag_cpu: torch.Tensor, + scaling: float, +) -> None: + return + + +try: + direct_register_custom_op( + op_name="lora_shrink", + op_func=_lora_shrink, + mutates_args=["output_tensor"], + fake_impl=_lora_shrink_fake, + ) + lora_shrink = torch.ops.vllm.lora_shrink + +except AttributeError: + lora_shrink = _lora_shrink diff --git a/vllm/lora/ops/triton_ops/utils.py b/vllm/lora/ops/triton_ops/utils.py new file mode 100644 index 0000000..f779bbc --- /dev/null +++ b/vllm/lora/ops/triton_ops/utils.py @@ -0,0 +1,121 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, List, Tuple + +import torch + +_LORA_A_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {} +_LORA_B_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {} + + +def _get_lora_a_ptr(lora_a_weights: List[torch.Tensor], device: torch.device): + """ + `_LORA_A_PTR_DICT` collects the required information during `profile_run`, + After this, it remains constant and subsequent usage is through LUT. + Refer to: + https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py + """ + key = tuple(lora_weight.data_ptr() for lora_weight in lora_a_weights) + + if values := _LORA_A_PTR_DICT.get(key): + return values + + lora_strides_d0 = [] + lora_strides_d1 = [] + lora_strides_d2 = [] + tensor_ptrs = [] + for lora_a_weight in lora_a_weights: + if lora_a_weight.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_a_weight.size(1) == 1 + lora_a_weight = lora_a_weight.squeeze(dim=1) + else: + assert lora_a_weight.ndim == 3 # shape:(lora_num,size,rank) + assert lora_a_weight.is_contiguous() + tensor_ptrs.append(lora_a_weight.data_ptr()) + lora_strides_d0.append(lora_a_weight.stride(0)) + lora_strides_d1.append(lora_a_weight.stride(1)) + lora_strides_d2.append(lora_a_weight.stride(2)) + if len(lora_a_weights) > 1: + lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device) + else: + lora_ptr_tensor = lora_a_weights[0] + + if (len(set(lora_strides_d0)) > 1 or len(set(lora_strides_d1)) > 1 + or len(set(lora_strides_d2)) > 1): + raise ValueError("All LoRA weights must have the same stride.") + + _LORA_A_PTR_DICT[key] = ( + lora_ptr_tensor, + lora_strides_d0[0], + lora_strides_d1[0], + lora_strides_d2[0], + ) + return _LORA_A_PTR_DICT.get(key) + + +def _get_lora_b_ptr(lora_weights: List[torch.Tensor], offset_start: int, + device: torch.device): + """ + `_LORA_B_PTR_DICT` collects the required information during `profile_run`, + After this, it remains constant and subsequent usage is through LUT. + Refer to: + https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py + + """ + + key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights) + if values := _LORA_B_PTR_DICT.get(key): + return values + slice_offset_lst = [] + tensor_ptrs = [] + lora_strides_d0 = [] + lora_strides_d1 = [] + lora_strides_d2 = [] + hidden_sizes = [] + slice_offset = offset_start + for lora_b_weight in lora_weights: + if lora_b_weight.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weight.size(1) == 1 + lora_b_weight = lora_b_weight.squeeze(dim=1) + else: + assert lora_b_weight.ndim == 3 # shape:(lora_num,size,rank) + assert lora_b_weight.is_contiguous() + tensor_ptrs.append(lora_b_weight.data_ptr()) + lora_strides_d0.append(lora_b_weight.stride(0)) + lora_strides_d1.append(lora_b_weight.stride(1)) + lora_strides_d2.append(lora_b_weight.stride(2)) + slice_offset_lst.append(slice_offset) + slice_offset += lora_b_weight.size(1) + hidden_sizes.append(lora_b_weight.size(1)) + + if len(lora_weights) > 1: + # note these are device tensors + lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device) + slice_start_tensor = torch.tensor(slice_offset_lst, device=device) + else: + slice_start_tensor = slice_offset_lst[0] + lora_ptr_tensor = lora_b_weight[0] + + # If each lora has the same stride, there's no need to use a + # tensor for storage. + if (len(set(lora_strides_d0)) == 1 and len(set(lora_strides_d1)) == 1 and + len(set(lora_strides_d2)) == 1) and len(set(hidden_sizes)) == 1: + lora_strides_d0_tensor = lora_strides_d0[0] + lora_strides_d1_tensor = lora_strides_d1[0] + lora_strides_d2_tensor = lora_strides_d2[0] + hidden_sizes_tensor = hidden_sizes[0] + same_stride = True + + else: + lora_strides_d0_tensor = torch.tensor(lora_strides_d0, device=device) + lora_strides_d1_tensor = torch.tensor(lora_strides_d1, device=device) + lora_strides_d2_tensor = torch.tensor(lora_strides_d2, device=device) + hidden_sizes_tensor = torch.tensor(hidden_sizes, device=device) + same_stride = False + # MAX_N is the maximum hidden size among all the lora_b weights + MAX_N = max(hidden_sizes) + _LORA_B_PTR_DICT[key] = (slice_start_tensor, lora_ptr_tensor, + lora_strides_d0_tensor, lora_strides_d1_tensor, + lora_strides_d2_tensor, hidden_sizes_tensor, + same_stride, MAX_N) + return _LORA_B_PTR_DICT.get(key) diff --git a/vllm/lora/peft_helper.py b/vllm/lora/peft_helper.py new file mode 100644 index 0000000..f694436 --- /dev/null +++ b/vllm/lora/peft_helper.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/config.py + +import json +import math +import os +from dataclasses import MISSING, dataclass, field, fields +from typing import List, Literal, Optional, Union + +from vllm.config import LoRAConfig +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@dataclass +class PEFTHelper: + """ + A helper class for PEFT configurations, specifically designed for LoRA. + This class handles configuration validation, compatibility checks for + various LoRA implementations. + """ + + # Required fields + r: int + lora_alpha: int + target_modules: Union[list[str], str] + + bias: Literal["none", "all", "lora_only"] = field(default="none") + modules_to_save: Optional[list[str]] = field(default=None) + # True to use Rank-Stabilized LoRA (rsLoRA, see: https://arxiv.org/abs/2312.03732) + use_rslora: bool = field(default=False) + # True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353) + use_dora: bool = field(default=False) + # long context lora field + context_length: int = field(default=0) + # Extra vllm field, start with 'vllm_' to avoid conflict + vllm_lora_scaling_factor: float = field(default=1.0) + vllm_max_position_embeddings: Optional[int] = field(default=False) + vllm_long_context_scaling_factor: Optional[float] = field(default=None) + + def _validate_features(self) -> List[str]: + """ + Check if there are any unsupported LoRA features. + """ + error_msg = [] + if self.modules_to_save: + error_msg.append("vLLM only supports modules_to_save being None.") + if self.use_dora: + error_msg.append("vLLM does not yet support DoRA.") + return error_msg + + def __post_init__(self): + if self.use_rslora: + logger.info_once("Loading LoRA weights trained with rsLoRA.") + self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r) + else: + self.vllm_lora_scaling_factor = self.lora_alpha / self.r + if self.context_length: + if self.vllm_max_position_embeddings is None: + self.vllm_max_position_embeddings = self.context_length + self.vllm_long_context_scaling_factor = float( + math.ceil(self.context_length / + self.vllm_max_position_embeddings)) + + @classmethod + def from_dict(cls, config_dict: dict) -> "PEFTHelper": + # Get all field information from the class + class_fields = {f.name: f for f in fields(cls)} + # Check for required fields + required_fields = { + name + for name, f in class_fields.items() + if f.default is MISSING and f.default_factory is MISSING + } + + # Identify any missing required fields + missing_fields = required_fields - set(config_dict.keys()) + if missing_fields: + raise ValueError( + f"Missing required configuration fields: {missing_fields}") + + # Filter out fields that aren't defined in the class + filtered_dict = { + k: v + for k, v in config_dict.items() if k in class_fields + } + return cls(**filtered_dict) + + @classmethod + def from_local_dir(cls, lora_path: str, + max_position_embeddings: Optional[int]) -> "PEFTHelper": + lora_config_path = os.path.join(lora_path, "adapter_config.json") + + with open(lora_config_path) as f: + config = json.load(f) + config["vllm_max_position_embeddings"] = max_position_embeddings + return cls.from_dict(config) + + def validate_legal(self, lora_config: LoRAConfig) -> None: + """ + Validates the LoRA configuration settings against application + constraints and requirements. + """ + error_msg = self._validate_features() + if self.r > lora_config.max_lora_rank: + error_msg.append( + f"LoRA rank {self.r} is greater than max_lora_rank" + f" {lora_config.max_lora_rank}.") + if self.bias != "none" and not lora_config.bias_enabled: + error_msg.append( + "Adapter bias cannot be used without bias_enabled.") + if error_msg: + raise ValueError(f"{' '.join(error_msg)}") diff --git a/vllm/lora/punica_wrapper/__init__.py b/vllm/lora/punica_wrapper/__init__.py new file mode 100644 index 0000000..915fc66 --- /dev/null +++ b/vllm/lora/punica_wrapper/__init__.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase +from vllm.lora.punica_wrapper.punica_selector import get_punica_wrapper + +__all__ = [ + "PunicaWrapperBase", + "get_punica_wrapper", +] diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py new file mode 100644 index 0000000..94fa3f2 --- /dev/null +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -0,0 +1,483 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import torch + +from .utils import compute_meta, convert_mapping + +if TYPE_CHECKING: + # avoid circuit import + from vllm.lora.layers import LoRAMapping + from vllm.lora.models import LongContextLoRAContext + + +class PunicaWrapperABC(ABC): + """ + PunicaWrapper ABC. + """ + + @abstractmethod + def update_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + **kwargs, + ) -> None: + """ + Update the lora-related metadata + """ + raise NotImplementedError + + @abstractmethod + def add_shrink( + self, + y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + scale: float, + **kwargs, + ) -> None: + """ + Performs GEMM for multiple slices of lora_a. + """ + + raise NotImplementedError + + @abstractmethod + def add_expand( + self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs, + ) -> None: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + """ + raise NotImplementedError + + @abstractmethod + def add_lora_embedding( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs, + ) -> None: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA, + and this layer only requires the expand operation. + """ + raise NotImplementedError + + @abstractmethod + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None, + **kwargs) -> None: + """ + Applicable to linear-related lora. + """ + + raise NotImplementedError + + @abstractmethod + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + """ + raise NotImplementedError + + +class PunicaWrapperBase(PunicaWrapperABC): + """ + PunicaWrapperBase is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the punica. + """ + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str], **kwargs): + self._token_lora_indices = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + self._sampler_indices = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + self._sampler_indices_padded = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + self._embeddings_indices = torch.empty(2, + max_num_batched_tokens, + dtype=torch.long, + device=device) + self._long_lora_indices = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + + # 5 is the number of indices tensors. + # base_indices, sampler_indices, sampler_indices_padded, + # embeddings_indices,long_lora_indices + self.indices_len: List[Optional[int]] = [None] * 5 + # these attributes are the information required for sgmv kernel + self._seq_start_locs = torch.empty(max_batches, + dtype=torch.long, + device=device) + self._seq_lengths = torch.empty(max_batches, + dtype=torch.long, + device=device) + self._lora_indices_per_batch = torch.empty(max_batches, + dtype=torch.long, + device=device) + self.device: torch.device = device + self.max_length: int = 0 + self.token_nums: int = 0 + self.batch_size: int = -1 + self.is_prefill = False + self.no_lora = False + + def _update_base_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + ): + ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + long_lora_offsets_tensor, + indices_len, + ) = convert_mapping( + mapping, + lora_index_to_id, + max_loras, + vocab_size, + extra_vocab_size, + self.device, + long_lora_context, + ) + self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) + self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) + self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( + sampler_indices_padded) + self._embeddings_indices[:embeddings_indices. + shape[0], :embeddings_indices.shape[1]].copy_( + embeddings_indices) + if long_lora_offsets_tensor is not None: + self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( + long_lora_offsets_tensor) + else: + self._long_lora_indices.zero_() + self.indices_len[:] = indices_len + + def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: + + (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, + batch_size, max_length, token_nums, + no_lora) = compute_meta(token_lora_tensor) + + self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_( + b_seq_start_tensor) + self._seq_lengths[:seq_length_tensor.shape[0]].copy_(seq_length_tensor) + self._lora_indices_per_batch[:lora_indices_tensor.shape[0]].copy_( + lora_indices_tensor) + self.batch_size = batch_size + self.max_length = max_length + self.token_nums = token_nums + self.no_lora = no_lora + + def _apply_bias( + self, + indices: torch.Tensor, + output: torch.Tensor, + output_slices: Tuple[int, ...], + lora_bias_stacked: Tuple[Optional[torch.Tensor], ...], + ): + """Applies bias to output + + Input shapes: + lora_bias_stacked: 3 element tuple of (num_loras, output_dim) + indices: (batch_size) + output: (batch_size, q_slice_size + 2*kv_slice_size) + output_slices: n-1 element tuple of (slice_size...), + where n is number of slices + """ + org_output = output + output = output.view(-1, output.shape[-1]) + indices = indices.view(-1) + + offset_left = 0 + for slice_idx, slice in enumerate(output_slices): + bias = lora_bias_stacked[slice_idx] + if bias is not None: + bias = bias.view(-1, bias.shape[-1]) + bias = bias[indices] + bias[indices == -1] = 0 + output[:, offset_left:offset_left + slice] += bias + offset_left += slice + + return output.view_as(org_output) + + @property + def prefill_metadata( + self + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]: + """ + This property provides a convenient way to access the necessary + metadata for prefill-related kernel computations. + 1. seq_start_locs: Tensor of sequence start positions. + 2. seq_lengths: Tensor of sequence lengths. + 3. lora_indices_per_batch: Tensor of lora indices, and an index of + -1 means no lora should be applied. + 4. batch_size: Batch size after clustering identical lora indices. + 5. max_length: The maximum sequence length in the batch. + 6. token_nums: The token numbers in the batch. + """ + return (self._seq_start_locs[:self.batch_size], + self._seq_lengths[:self.batch_size], + self._lora_indices_per_batch[:self.batch_size], + self.batch_size, self.max_length, self.token_nums) + + @property + def token_lora_indices(self) -> torch.Tensor: + """ + This property provides the lora indices corresponding to each token + in the batch. An index of -1 means no lora should be applied. + """ + token_lora_len = self.indices_len[0] + return self._token_lora_indices[:token_lora_len] + + @property + def sampler_indices(self) -> torch.Tensor: + """ + This property is used to access the lora indices specifically for + LogitsProcessorWithLoRA. + """ + sampler_indices_len = self.indices_len[1] + return self._sampler_indices[:sampler_indices_len] + + @property + def sampler_indices_padded(self) -> torch.Tensor: + """ + This property provides access to padded sampler indices. + """ + indices_padded_len = self.indices_len[2] + return self._sampler_indices_padded[:indices_padded_len] + + @property + def embeddings_indices(self) -> torch.Tensor: + """ + This property provides access to the indices used for lora embeddings, + specifically for VocabParallelEmbeddingWithLoRA. + """ + embeddings_indices_len = self.indices_len[3] + return self._embeddings_indices[:, :embeddings_indices_len] + + @property + def long_lora_indices(self) -> torch.Tensor: + """ + This property provides access to the indices used for long context + lora, specifically for LinearScalingRotaryEmbeddingWithLoRA. + """ + long_lora_len = self.indices_len[4] + return self._long_lora_indices[:long_lora_len] + + def update_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + **kwargs): + + self._update_base_metadata(mapping, lora_index_to_id, max_loras, + vocab_size, extra_vocab_size, + long_lora_context) + if mapping.is_prefill: + # Update metadata required for prefill-related operators. + self._update_prefill_metada(self.token_lora_indices) + self.is_prefill = True + else: + self.is_prefill = False + + @abstractmethod + def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], + scale: float, **kwargs) -> None: + """ + Performs GEMM for multiple slices of lora_a. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += (x @ lora_a_stacked[i]) * scale + + Args: + y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights + scale (float): Scaling factor for the operation + + """ + # TODO: implement it based on torch ops + raise NotImplementedError + + @abstractmethod + def add_expand(self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs) -> None: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + + Semantics: + offset = offset_start + for i in range(len(lora_b_stacked)): + slice = output_slices[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] + offset += slice + + Args: + y (torch.Tensor): Output tensor. + x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + bias's weight + output_slices (Tuple[int, ...]): Every slice's size + offset_start (int): The starting position of y, defaults to 0 + add_inputs (bool): Defaults to True. + + """ + # TODO: implement it based on torch ops + raise NotImplementedError + + @abstractmethod + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs) -> None: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA. + and this layer only requires the expand operation. + Semantics: + y += x @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_b_stacked (torch.Tensor): lora_b's weights. + add_inputs (bool): Default to True. + """ + # TODO: implement it based on torch ops + raise NotImplementedError + + @abstractmethod + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None, + **kwargs) -> None: + """ + Applicable to linear-related lora. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += ( + x[i].unsqueeze(0) + @ lora_a_stacked[indices[i], layer_idx, :, :] + @ lora_b_stacked[indices[i], layer_idx, :, :] + * scale + ).squeeze(0)+lora_bias_stacked[i] + + Args: + y (torch.Tensor): Output tensor. Will be changed in-place. + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. + scale (float): Scaling factor. + output_slices (Tuple[int, ...]): Every slice's size. + buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. + """ + # TODO: implement it based on torch ops + raise NotImplementedError + + @abstractmethod + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + + Semantics: + buffer = (x @ lora_a_stacked) * scale + y += buffer @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_a_stacked (torch.Tensor): lora_a's weights. + lora_b_stacked (torch.Tensor):lora_b's weights. + scale (float): Scaling factor. + buffer (Optional[torch.Tensor]):Default to None. + """ + # TODO: implement it based on torch ops + raise NotImplementedError diff --git a/vllm/lora/punica_wrapper/punica_cpu.py b/vllm/lora/punica_wrapper/punica_cpu.py new file mode 100644 index 0000000..29428f4 --- /dev/null +++ b/vllm/lora/punica_wrapper/punica_cpu.py @@ -0,0 +1,348 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable, Optional, Tuple, Union + +import torch + +from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, + bgmv_shrink, sgmv_expand, + sgmv_expand_slice, sgmv_shrink) + +from .punica_base import PunicaWrapperBase + + +# The platforms that are compatible with the PyTorch-native implementation can +# inherit this class +class PunicaWrapperCPU(PunicaWrapperBase): + """ + PunicaWrapperCPU is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the pytorch punica ops. + """ + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str], **kwargs): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, + device) + + def _shrink_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_shrink( + x, + w_t_all, + y, + *self.prefill_metadata, + scale, + ) + + def _shrink_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) + + def _expand_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_inputs: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand( + x, + w_t_all, + y, + *self.prefill_metadata, + add_inputs, + ) + + def _expand_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_inputs: bool, + ): + bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs) + + def _expand_slice_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: int, + y_slice_size: int, + add_inputs: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand_slice( + x, + w_t_all, + y, + *self.prefill_metadata, + y_offset, + y_slice_size, + add_inputs, + ) + + def _expand_slice_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: int, + y_slice_size: int, + add_inputs: bool, + ): + bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, + y_slice_size, add_inputs) + + def _apply_expand( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: int, + y_slice_size: int, + add_inputs: bool = True, + ): + """ + Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all` + computation, which is suitable for the + GEMM of lora'b. + """ + + expand_slice_fun: Callable = (self._expand_slice_prefill + if self.is_prefill else + self._expand_slice_decode) + expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs) + + def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, + w_t_all: torch.Tensor, scale: float): + """ + Perform the ` y+=x@w_t_all` computation, which is suitable for the + GEMM of lora'a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function + should be called. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + shrink_fun: Callable = (self._shrink_prefill + if self.is_prefill else self._shrink_decode) + shrink_fun(y, x, w_t_all, scale) + y = y.view_as(y_org) + + def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], + scale: float, **kwargs): + """ + Performs GEMM for multiple slices of lora_a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function + should be called. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += (x @ lora_a_stacked[i]) * scale + + Args: + y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights + scale (float): Scaling factor for the operation + """ + + x = x.view(-1, x.shape[-1]) + # TODO fuse these kernels + for slice_idx in range(len(lora_a_stacked)): + self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], + scale) + + def add_expand(self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs) -> None: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + + Semantics: + for i in range(len(lora_b_stacked)): + slice = output_slices[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] + offset += slice + + Args: + y (torch.Tensor): Output tensor. + x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + bias's weight + output_slices (Tuple[int, ...]): Every slice's size + add_inputs (bool): Defaults to True. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + offset_left = offset_start + if lora_bias_stacked is not None: + self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) + for slice_idx in range(len(lora_b_stacked)): + self._apply_expand( + y, + x[slice_idx], + lora_b_stacked[slice_idx], + offset_left, + output_slices[slice_idx], + add_inputs=add_inputs, + ) + offset_left += output_slices[slice_idx] + y = y.view_as(y_org) + + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs) -> None: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA. + + Semantics: + y += x @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_b_stacked (torch.Tensor): lora_b's weights. + add_inputs (bool): Default to True. + """ + + # Embedding layer only need expand op + expand_fun: Callable = (self._expand_prefill + if self.is_prefill else self._expand_decode) + expand_fun(y, x, lora_b_stacked, add_inputs) + + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None, + **kwargs) -> None: + """ + Applicable to linear-related lora. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += ( + x[i].unsqueeze(0) + @ lora_a_stacked[indices[i], layer_idx, :, :] + @ lora_b_stacked[indices[i], layer_idx, :, :] + * scale + ).squeeze(0)+lora_bias_stacked[i] + + Args: + y (torch.Tensor): Output tensor. Will be changed in-place. + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. + scale (float): Scaling factor. + output_slices (Tuple[int, ...]): Every slice's size. + buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. + """ + + assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) + if lora_bias_stacked is not None: + assert len(lora_bias_stacked) == len(output_slices) + y = self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) + + if buffer is None: + r = lora_b_stacked[0].size(-1) + # We set the buffer to be float32 by default, consistent with the + # triton op + buffer = tuple( + torch.zeros( + (x.size(0), r), dtype=torch.float32, device=x.device) + for _ in range(len(output_slices))) + self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) + self.add_expand(y, + buffer, + lora_b_stacked, + None, + output_slices, + add_inputs=True, + **kwargs) + + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + + Semantics: + buffer = (x @ lora_a_stacked) * scale + y += buffer @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_a_stacked (torch.Tensor): lora_a's weights. + lora_b_stacked (torch.Tensor):lora_b's weights. + scale (float): Scaling factor. + buffer (Optional[torch.Tensor]):Default to None. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + r = lora_b_stacked.size(-1) + if buffer is None: + # We set the buffer to be float32 by default, consistent with the + # triton op + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + # LogitsProcessorWithLoRA always using bgmv. + bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) + bgmv_expand(buffer, + lora_b_stacked, + y, + self.sampler_indices, + add_inputs=True) + y = y.view_as(y_org) diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py new file mode 100644 index 0000000..2afd3af --- /dev/null +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -0,0 +1,331 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final + +import torch + +import vllm.envs as envs +from vllm.lora.layers import LoRAMapping +from vllm.triton_utils import HAS_TRITON + + +from vllm.lora.ops.triton_ops import (LoRAKernelMeta, lora_expand, + lora_shrink) + +from .punica_base import PunicaWrapperBase + +if TYPE_CHECKING: + # avoid circuit import + from vllm.lora.models import LongContextLoRAContext +from vllm import _custom_ops as ops + +@final +class PunicaWrapperGPU(PunicaWrapperBase): + """ + PunicaWrapperGPU is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the punica triton kernel. + """ + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str], **kwargs): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, + device) + + self.max_loras = kwargs['max_loras'] + + self.token_mapping_meta = LoRAKernelMeta.make(self.max_loras, + max_num_batched_tokens, + device=device) + + # When cudagraph capture size is greater than max_num_seqs (max_batches, + # here), V0 captures the graph as if max_num_seqs is set to + # the capture size. + # V1 doesn't have this problem and always respects max_num_seqs. + max_num_prompts = (max_batches + if envs.VLLM_USE_V1 else max_num_batched_tokens) + self.prompt_mapping_meta = LoRAKernelMeta.make(self.max_loras, + max_num_prompts, + device=device) + + def update_metadata( + self, + mapping: LoRAMapping, + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + **kwargs): + + self.is_prefill = mapping.is_prefill + self._update_base_metadata(mapping, lora_index_to_id, max_loras, + vocab_size, extra_vocab_size, + long_lora_context) + + # Prepare cuda kernel metadata tensors + self.token_mapping_meta.prepare_tensors(self.token_lora_indices) + self.prompt_mapping_meta.prepare_tensors(self.sampler_indices) + + # def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, + # w_t_all: torch.Tensor, scale: float): + # """ + # Perform the ` y+=x@w_t_all` computation, which is suitable for the + # GEMM of lora'a. + # When `is_prefill is` true, it indicates that it is currently the + # prefill stage, and the `_shrink_prefill` function should be called. + # Otherwise, it is the decode stage, and the _shrink_decode function + # should be called. + # """ + # y_org = y + # y = y.view(-1, y.shape[-1]) + # if self.is_prefill: + # if self.no_lora: + # return y + # y = ops.sbgmv_shrink(x, w_t_all, y, *self.prefill_metadata, scale=scale) + # else: + # y = ops.sbgmv_shrink(x, w_t_all, y, lora_indices_tensor=self.token_lora_indices, scale=scale) + # y = y.view_as(y_org) + # return y + + def add_shrink(self, y: torch.Tensor, x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, + ...], scale: float, **kwargs): + """ + Performs GEMM for multiple slices of lora_a. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += (x @ lora_a_stacked[i]) * scale + + Args: + y (torch.Tensor): Output tensors + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights + scale (float): Scaling factor for the operation + """ + + x = x.view(-1, x.shape[-1]) + lora_shrink( + x, + lora_a_stacked, + y, + *self.token_mapping_meta.meta_args(x.size(0)), + scale, + ) + + def _apply_expand( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_inputs: bool = True, + ): + """ + Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all` + computation, which is suitable for the + GEMM of lora'b. + """ + if self.is_prefill: + if self.no_lora: + return y + ops.sbgmv_expand(x, w_t_all, y[:, y_offset:y_offset+y_slice_size], *self.prefill_metadata, add_input=add_inputs) + else: + ops.sbgmv_expand(x, w_t_all, y[:, y_offset:y_offset+y_slice_size], lora_indices_tensor=self.token_lora_indices, add_input=add_inputs) + def add_expand(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs) -> None: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + + Semantics: + for i in range(len(lora_b_stacked)): + slice = output_slices[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] + offset += slice + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensors + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + bias's weight + output_slices (Tuple[int, ...]): Every slice's size + add_inputs (bool): Defaults to True. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + if lora_bias_stacked is not None: + token_lora_indices = torch.narrow(self._token_lora_indices, 0, 0, + y.size(0)) + self._apply_bias(token_lora_indices, y, output_slices, + lora_bias_stacked) + + assert x.ndim == 3 + assert x.size(0) == len(output_slices) + num_tokens = x.size(1) # first dimension is the num slices + + lora_expand( + x, + lora_b_stacked, + y, + *self.token_mapping_meta.meta_args(num_tokens), + offset_start=offset_start, + add_inputs=True, + ) + + y = y.view_as(y_org) + + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs) -> None: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA. + + Semantics: + y += x @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_b_stacked (torch.Tensor): lora_b's weights. + add_inputs (bool): Default to True. + """ + + lora_expand( + x.unsqueeze(dim=0), + (lora_b_stacked, ), + y, + *self.token_mapping_meta.meta_args(x.size(0)), + offset_start=0, + add_inputs=add_inputs, + ) + + + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + """ + Applicable to linear-related lora. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += ( + x[i].unsqueeze(0) + @ lora_a_stacked[indices[i], layer_idx, :, :] + @ lora_b_stacked[indices[i], layer_idx, :, :] + * scale + ).squeeze(0)+lora_bias_stacked[i] + + Args: + y (torch.Tensor): Output tensor. Will be changed in-place. + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. + scale (float): Scaling factor. + output_slices (Tuple[int, ...]): Every slice's size. + buffer (Optional[torch.Tensor]): Defaults to None. + """ + + assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) + if lora_bias_stacked is not None: + assert len(lora_bias_stacked) == len(output_slices) + token_lora_indices = torch.narrow(self._token_lora_indices, 0, 0, + y.size(0)) + y = self._apply_bias(token_lora_indices, y, output_slices, + lora_bias_stacked) + + if buffer is None: + r = lora_b_stacked[0].size(-1) + # We set the buffer to be float32 by default, refer to: + # https://github.com/triton-lang/triton/issues/1387 + buffer = torch.zeros( # type: ignore + (len(output_slices), x.size(0), r), + dtype=torch.float32, + device=x.device, + ) + self.add_shrink( + buffer, # type: ignore + x, + lora_a_stacked, + scale, + **kwargs) + self.add_expand( + y, + buffer, # type: ignore + lora_b_stacked, + None, + output_slices, + add_inputs=True, + **kwargs) + + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + + Semantics: + buffer = (x @ lora_a_stacked) * scale + y += buffer @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_a_stacked (torch.Tensor): lora_a's weights. + lora_b_stacked (torch.Tensor): lora_b's weights. + scale (float): Scaling factor. + buffer (Optional[torch.Tensor]): Default to None. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + r = lora_b_stacked.size(-1) + if buffer is None: + # We set the buffer to be float32 by default, refer to: + # https://github.com/triton-lang/triton/issues/1387 + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + + lora_shrink(x, [lora_a_stacked], buffer.unsqueeze(dim=0), + *self.prompt_mapping_meta.meta_args(x.size(0)), scale) + + lora_expand(buffer.unsqueeze(dim=0), [lora_b_stacked], + y, + *self.prompt_mapping_meta.meta_args(buffer.size(0)), + add_inputs=True) + y = y.view_as(y_org) diff --git a/vllm/lora/punica_wrapper/punica_hpu.py b/vllm/lora/punica_wrapper/punica_hpu.py new file mode 100644 index 0000000..3661a72 --- /dev/null +++ b/vllm/lora/punica_wrapper/punica_hpu.py @@ -0,0 +1,144 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final + +import torch +from vllm_hpu_extension.ops import (dispatch_bgmv_embedding, + dispatch_bgmv_linear) + +from .punica_base import PunicaWrapperBase +from .utils import convert_mapping + +if TYPE_CHECKING: + # avoid circuit import + from vllm.lora.layers import LoRAMapping + from vllm.lora.models import LongContextLoRAContext + + +@final +class PunicaWrapperHPU(PunicaWrapperBase): + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str], **kwargs): + # Increasing max_num_batched_tokens by 3x to handle increase in + # tensor size due to padding. + PunicaWrapperBase.__init__(self, 3 * max_num_batched_tokens, + max_batches, device) + + def _update_base_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + ): + ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + long_lora_offsets_tensor, + indices_len, + ) = convert_mapping(mapping, lora_index_to_id, max_loras, vocab_size, + extra_vocab_size, self.device, None) + # Updating each element in `long_lora_offsets` with `lora_offset` slows + # down perf in HPU due to a series of `strided_insert` ops during lazy + # graph accumulation. Hence HPU appends `lora_offset` to a list and + # converts it to a tensor only after it is ready. + if long_lora_context: + index_mapping_indices: List[int] = list( + mapping.index_mapping).copy() + long_lora_offsets: List[int] = [] + for i in range(len(index_mapping_indices)): + lora_offset: int = long_lora_context.offsets_by_lora_id.get( + index_mapping_indices[i], 0) + long_lora_offsets.append(lora_offset) + long_lora_offsets_tensor = torch.tensor(long_lora_offsets, + device=self.device, + dtype=torch.long) + indices_len[-1] = long_lora_offsets_tensor.shape[-1] + + self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) + self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) + self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( + sampler_indices_padded) + self._embeddings_indices[:embeddings_indices. + shape[0], :embeddings_indices.shape[1]].copy_( + embeddings_indices) + if long_lora_offsets_tensor is not None: + self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( + long_lora_offsets_tensor) + else: + self._long_lora_indices.zero_() + self.indices_len[:] = indices_len + + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs) -> None: + dispatch_bgmv_embedding(y, x, lora_b_stacked, 0) + + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None, + **kwargs) -> None: + y_org = y + x = x.view(-1, x.shape[-1]) + y = y.view(-1, y.shape[-1]) + offset_left = 0 + + for slice_idx in range(len(output_slices)): + dispatch_bgmv_linear( + y[:, offset_left:offset_left + output_slices[slice_idx]], x, + lora_a_stacked[slice_idx], lora_b_stacked[slice_idx], 0, scale) + offset_left += output_slices[slice_idx] + y = y.view_as(y_org) + + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + dispatch_bgmv_linear(y, x, lora_a_stacked, lora_b_stacked, 0, scale) + y = y.view_as(y_org) + + def add_shrink( + self, + y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + scale: float, + **kwargs, + ) -> None: + raise NotImplementedError + + def add_expand( + self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs, + ) -> None: + raise NotImplementedError diff --git a/vllm/lora/punica_wrapper/punica_selector.py b/vllm/lora/punica_wrapper/punica_selector.py new file mode 100644 index 0000000..ad5d4b7 --- /dev/null +++ b/vllm/lora/punica_wrapper/punica_selector.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import resolve_obj_by_qualname + +from .punica_base import PunicaWrapperBase + +logger = init_logger(__name__) + + +def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase: + punica_wrapper_qualname = current_platform.get_punica_wrapper() + punica_wrapper_cls = resolve_obj_by_qualname(punica_wrapper_qualname) + punica_wrapper = punica_wrapper_cls(*args, **kwargs) + assert punica_wrapper is not None, \ + "the punica_wrapper_qualname(" + punica_wrapper_qualname + ") is wrong." + logger.info_once("Using " + punica_wrapper_qualname.rsplit(".", 1)[1] + + ".") + return punica_wrapper diff --git a/vllm/lora/punica_wrapper/utils.py b/vllm/lora/punica_wrapper/utils.py new file mode 100644 index 0000000..dbc2d27 --- /dev/null +++ b/vllm/lora/punica_wrapper/utils.py @@ -0,0 +1,161 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import torch + +if TYPE_CHECKING: + # avoid circuit import + from vllm.lora.layers import LoRAMapping + from vllm.lora.models import LongContextLoRAContext + + +def compute_meta( + token_lora_tensor: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]: + """ + Get the information required for the sgmv kernel. With the features: + 1. If consecutive requests in the batch use the same LoRA, this function + will combine them into a single request, improving sgmv kernel inference + performance. + 2. At the beginning of each prefill stage inference, recalculations are + needed based on the input, but only once. + """ + + lora_indices_tensor, seq_length_tensor = torch.unique_consecutive( + token_lora_tensor, return_counts=True) + cum_result = torch.cumsum(seq_length_tensor, dim=0) + b_seq_start_tensor = torch.zeros_like(seq_length_tensor) + b_seq_start_tensor[1:].copy_(cum_result[:-1]) + max_length = seq_length_tensor.max().item() + token_nums = seq_length_tensor.sum().item() + batch_size = lora_indices_tensor.size(0) + no_lora = False + # -1 means no lora should be applied. Use `no_lora` to determine whether + # the current step requires LoRA. If LoRA is not needed, the prefill stage + # does not need to launch the triton kernel, which can improve performance + if batch_size == 1 and lora_indices_tensor == -1: + no_lora = True + return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, + batch_size, max_length, token_nums, no_lora) + + +# TODO see if this can be vectorized +def convert_mapping( + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + device: torch.device, + long_lora_context: Optional["LongContextLoRAContext"] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor], List[int]]: + """Converts LoRAMapping to index tensors. + + Args: + mapping: LoRAMapping mapping rows in a batch to LoRA ids. + lora_index_to_id: List mapping LoRA ids to LoRA indices. + max_loras: Maximum number of LoRAs. + vocab_size: Model vocab size. + extra_vocab_size: Extra vocab size each LoRA can have. + long_lora_context: Passed if there are long context lora in a batch. + + Returns: + A tuple of tensors: + base_indices: Tensor of shape [batch_size] mapping batch rows to + LoRA indices. + sampler_indices: Tensor of shape [batch_size] mapping requests to + LoRA indices for sampler. For generation, this will be the + same as base_indicies. For prefill, this will map requests + to LoRA indices. + sampler_indices_padded: Tensor of shape [batch_size] mapping + requests to LoRA indices for sampler with padding. + Same as sampler_indicies, but -1 is replaced with + max_loras. + embeddings_indices: Tensor of shape [2, batch_size] mapping + requests to embedding indices. First row is for embeddings + added by the LoRAs, second row is for the LoRA.lora_a + embeddings. + long_lora_indices: Tensor of shape [batch_size] mapping + requests to RoPE offsets and rot dims for long LoRAs. + None if long context lora doesn't exist. + indices_len: List of lengths of the above tensors. It contains + (base_indices, sampler_indices, sampler_indices_padded, + embeddings_indices, long_lora_indices). + """ + index_mapping_indices: List[int] = list(mapping.index_mapping).copy() + embedding_indices = index_mapping_indices.copy() + lora_indices = index_mapping_indices.copy() + long_lora_offsets: Optional[torch.Tensor] = None + if long_lora_context: + long_lora_offsets = torch.zeros(len(index_mapping_indices), + device=device, + dtype=torch.long) + prompt_mapping: List[int] = [ + lora_index_to_id.index(x) if x > 0 else -1 + for x in mapping.prompt_mapping + ] + lora_idx = None + for i in range(len(index_mapping_indices)): + # TODO index can be slow. optimize + lora_idx = (lora_index_to_id.index(index_mapping_indices[i]) + if index_mapping_indices[i] > 0 else -1) + embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 + lora_indices[i] = lora_idx + if long_lora_context: + assert long_lora_offsets is not None + lora_offset: int = long_lora_context.offsets_by_lora_id.get( + index_mapping_indices[i], 0) + long_lora_offsets[i] = lora_offset + + indices_list: List[Union[List[int], torch.Tensor]] = [ + index_mapping_indices, + lora_indices, + embedding_indices, + ] + if long_lora_context: + assert long_lora_offsets is not None + indices_list.append(long_lora_offsets) + indices = torch.tensor(indices_list, dtype=torch.long, device=device) + prompt_mapping_tensor = torch.tensor(prompt_mapping, + dtype=torch.long, + device=device) + embeddings_indices = torch.stack([ + indices[2] * extra_vocab_size, + indices[2] * (vocab_size + extra_vocab_size), + ]) + embeddings_indices[embeddings_indices == -1] = max_loras - 1 + base_indices = indices[1] + sampler_indices = prompt_mapping_tensor + sampler_indices_padded = sampler_indices.clone() + sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 + sampler_indices_padded = torch.arange( + 0, len(sampler_indices_padded), device=device, dtype=torch.long) + ( + sampler_indices_padded * len(sampler_indices_padded)) + long_lora_indices = None + long_lora_indices_len: Optional[int] = None + if long_lora_context: + long_lora_indices = indices[3] + long_lora_indices_len = long_lora_indices.shape[-1] + # Contain length of indices tensors. Used to index into each tensor. + indices_len = [ + base_indices.shape[-1], + sampler_indices.shape[-1], + sampler_indices_padded.shape[-1], + embeddings_indices.shape[-1], + ] + if long_lora_indices_len is not None: + indices_len.append(long_lora_indices_len) + else: + # If long_lora doesn't exist,append None + indices_len.append(None) + + return ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + long_lora_indices, + indices_len, + ) diff --git a/vllm/lora/request.py b/vllm/lora/request.py new file mode 100644 index 0000000..badfaa4 --- /dev/null +++ b/vllm/lora/request.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 + +import warnings +from typing import Optional + +import msgspec + +from vllm.adapter_commons.request import AdapterRequest + + +class LoRARequest( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] + """ + Request for a LoRA adapter. + + Note that this class should be used internally. For online + serving, it is recommended to not allow users to use this class but + instead provide another layer of abstraction to prevent users from + accessing unauthorized LoRA adapters. + + lora_int_id must be globally unique for a given adapter. + This is currently not enforced in vLLM. + """ + __metaclass__ = AdapterRequest + + lora_name: str + lora_int_id: int + lora_path: str = "" + lora_local_path: Optional[str] = msgspec.field(default=None) + long_lora_max_len: Optional[int] = None + base_model_name: Optional[str] = msgspec.field(default=None) + + def __post_init__(self): + if self.lora_local_path: + warnings.warn( + "The 'lora_local_path' attribute is deprecated " + "and will be removed in a future version. " + "Please use 'lora_path' instead.", + DeprecationWarning, + stacklevel=2) + if not self.lora_path: + self.lora_path = self.lora_local_path or "" + + # Ensure lora_path is not empty + assert self.lora_path, "lora_path cannot be empty" + + @property + def adapter_id(self): + return self.lora_int_id + + @property + def name(self): + return self.lora_name + + @property + def path(self): + return self.lora_path + + @property + def local_path(self): + warnings.warn( + "The 'local_path' attribute is deprecated " + "and will be removed in a future version. " + "Please use 'path' instead.", + DeprecationWarning, + stacklevel=2) + return self.lora_path + + @local_path.setter + def local_path(self, value): + warnings.warn( + "The 'local_path' attribute is deprecated " + "and will be removed in a future version. " + "Please use 'path' instead.", + DeprecationWarning, + stacklevel=2) + self.lora_path = value + + def __eq__(self, value: object) -> bool: + """ + Overrides the equality method to compare LoRARequest + instances based on lora_name. This allows for identification + and comparison lora adapter across engines. + """ + return isinstance(value, + self.__class__) and self.lora_name == value.lora_name + + def __hash__(self) -> int: + """ + Overrides the hash method to hash LoRARequest instances + based on lora_name. This ensures that LoRARequest instances + can be used in hash-based collections such as sets and dictionaries, + identified by their names across engines. + """ + return hash(self.lora_name) diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py new file mode 100644 index 0000000..610cbf8 --- /dev/null +++ b/vllm/lora/utils.py @@ -0,0 +1,232 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +import re +from typing import List, Optional, Set, Tuple, Type, Union + +import huggingface_hub +from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, + HFValidationError, RepositoryNotFoundError) +from torch import nn +from transformers import PretrainedConfig + +from vllm.config import LoRAConfig +from vllm.logger import init_logger +from vllm.lora.fully_sharded_layers import ( + ColumnParallelLinearWithShardedLoRA, + MergedColumnParallelLinearWithShardedLoRA, + MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLoRA, + RowParallelLinearWithShardedLoRA) +# being imported for _all_lora_classes below +# yapf conflicts with isort for this block +# yapf: disable +from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, + LinearScalingRotaryEmbeddingWithLoRA, + LogitsProcessorWithLoRA, + MergedColumnParallelLinearWithLoRA, + MergedQKVParallelLinearWithLoRA, + QKVParallelLinearWithLoRA, + ReplicatedLinearWithLoRA, + RowParallelLinearWithLoRA, + VocabParallelEmbeddingWithLoRA) +from vllm.model_executor.layers.linear import LinearBase +# yapf: enable +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.models.utils import WeightsMapper + +logger = init_logger(__name__) + +_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = { + VocabParallelEmbeddingWithLoRA, + ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, + QKVParallelLinearWithLoRA, + MergedQKVParallelLinearWithLoRA, + RowParallelLinearWithLoRA, + ReplicatedLinearWithLoRA, + LogitsProcessorWithLoRA, + ColumnParallelLinearWithShardedLoRA, + QKVParallelLinearWithShardedLoRA, + MergedColumnParallelLinearWithShardedLoRA, + MergedQKVParallelLinearWithShardedLoRA, + RowParallelLinearWithShardedLoRA, + LinearScalingRotaryEmbeddingWithLoRA, +} + + +def from_layer(layer: nn.Module, + max_loras: int, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig] = None) -> nn.Module: + for lora_cls in _all_lora_classes: + # specifying kwargs so they can be easily accessed in decorator + if lora_cls.can_replace_layer(source_layer=layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config): + instance_layer = lora_cls(layer) + instance_layer.create_lora_weights(max_loras, lora_config, + model_config) + return instance_layer + return layer + + +def from_layer_logits_processor( + layer: LogitsProcessor, + lm_head: ParallelLMHead, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, +) -> LogitsProcessorWithLoRA: + ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim, + lm_head.weight.dtype, lm_head.weight.device, + lm_head.get_sharded_to_full_mapping()) + ret.create_lora_weights(max_loras, lora_config, model_config) + return ret + + +def replace_submodule(model: nn.Module, module_name: str, + new_module: nn.Module) -> nn.Module: + """Replace a submodule in a model with a new module.""" + parent = model.get_submodule(".".join(module_name.split(".")[:-1])) + target_name = module_name.split(".")[-1] + setattr(parent, target_name, new_module) + return new_module + + +def parse_fine_tuned_lora_name( + name: str, + weights_mapper: Optional[WeightsMapper] = None +) -> Tuple[str, bool, bool]: + """Parse the name of lora weights. + + args: + name: the name of the fine-tuned LoRA, e.g. + base_model.model.dense1.weight + weights_mapper: maps the name of weight, e.g. + `model.` -> `language_model.model.`, + return: + Tuple(module_name, is_lora_a): + module_name: the name of the module, e.g. model.dense1, + is_lora_a whether the tensor is lora_a or lora_b. + is_bias whether the tensor is lora bias. + """ + + # LoRA weight qualified name always starts with `base_model.model.`, + # so we remove the prefix `base_model.model.` to make the following + # mapping correctly. + if "base_model.model." in name: + name = name.replace("base_model.model.", "") + name = weights_mapper._map_name(name) if weights_mapper else name + # recover the prefix `base_model.model.` + name = "base_model.model." + name + + parts = name.split(".") + if parts[-1] == "weight" and (parts[-2] == "lora_A" + or parts[-2] == "lora_B"): + new_name = ".".join(parts[2:-2]) + return new_name, parts[-2] == "lora_A", False + + if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B": + new_name = ".".join(parts[2:-1]) + return new_name, parts[-1] == "lora_embedding_A", False + + if parts[-1] == "bias": + new_name = ".".join(parts[2:-2]) + return new_name, False, True + + raise ValueError(f"{name} is unsupported LoRA weight") + + +def is_regex_target_modules(load_modules: Union[str, List[str]], + expected_lora_modules: List[str]) -> bool: + """ + PEFT supports passing `target_modules` in the form of regular expressions, + such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to + determine whether the suffix in the regular expression is present in the + `expected_lora_modules`. + """ + + def is_valid_regex(pattern): + try: + re.compile(pattern) + return True + except re.error: + return False + + def is_subset(sub_list, full_list): + return set(sub_list).issubset(set(full_list)) + + # Similar to PEFT's processing logic, regex-related operations are only + # executed when the load_modules is a `str`. + if not isinstance(load_modules, str): + return False + + if is_valid_regex(load_modules): + match = re.search(r"\((.*?)\)\$?$", load_modules) + if match: + suffix = match.group(1).split("|") + return is_subset(suffix, expected_lora_modules) + return False + + +def get_supported_lora_modules(model: nn.Module) -> List[str]: + """ + In vLLM, all linear layers support LoRA. + """ + supported_lora_modules: Set[str] = set() + # step1: traverse the model to get all the linear subfixes. + for name, module in model.named_modules(): + if isinstance(module, (LinearBase, )): + supported_lora_modules.add(name.split(".")[-1]) + # step 2: get the embedding modules if the model's mbedding_modules + # is not empty. + if model.embedding_modules: + for name in model.embedding_modules: + supported_lora_modules.add(name) + return list(supported_lora_modules) + + +def get_adapter_absolute_path(lora_path: str) -> str: + """ + Resolves the given lora_path to an absolute local path. + + If the lora_path is identified as a Hugging Face model identifier, + it will download the model and return the local snapshot path. + Otherwise, it treats the lora_path as a local file path and + converts it to an absolute path. + + Parameters: + lora_path (str): The path to the lora model, which can be an absolute path, + a relative path, or a Hugging Face model identifier. + + Returns: + str: The resolved absolute local path to the lora model. + """ + + # Check if the path is an absolute path. Return it no matter exists or not. + if os.path.isabs(lora_path): + return lora_path + + # If the path starts with ~, expand the user home directory. + if lora_path.startswith('~'): + return os.path.expanduser(lora_path) + + # Check if the expanded relative path exists locally. + if os.path.exists(lora_path): + return os.path.abspath(lora_path) + + # If the path does not exist locally, assume it's a Hugging Face repo. + try: + local_snapshot_path = huggingface_hub.snapshot_download( + repo_id=lora_path) + except (HfHubHTTPError, RepositoryNotFoundError, EntryNotFoundError, + HFValidationError): + # Handle errors that may occur during the download + # Return original path instead instead of throwing error here + logger.exception("Error downloading the HuggingFace model") + return lora_path + + return local_snapshot_path diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py new file mode 100644 index 0000000..108beb3 --- /dev/null +++ b/vllm/lora/worker_manager.py @@ -0,0 +1,251 @@ +# SPDX-License-Identifier: Apache-2.0 + +from contextlib import contextmanager +from typing import Any, Dict, List, Literal, Optional, Set, Type, Union + +import torch + +from vllm.adapter_commons.utils import (add_adapter_worker, + apply_adapters_worker, + list_adapters_worker, + set_active_adapters_worker) +from vllm.adapter_commons.worker_manager import AbstractWorkerManager +from vllm.config import LoRAConfig +from vllm.logger import init_logger +from vllm.lora.models import (LoRAModel, LoRAModelManager, + LRUCacheLoRAModelManager, create_lora_manager) +from vllm.lora.peft_helper import PEFTHelper +from vllm.lora.request import LoRARequest +from vllm.lora.utils import get_adapter_absolute_path + +logger = init_logger(__name__) + + +class WorkerLoRAManager(AbstractWorkerManager): + """WorkerLoRAManager that manages LoRA models on the worker side. + + Every request, the requested LoRAs will be loaded (unless they are already + loaded), and every other LoRA will be unloaded.""" + + _manager_cls: Type[LoRAModelManager] = LoRAModelManager + + def __init__( + self, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + device: torch.device, + embedding_modules: Dict[str, str], + embedding_padding_modules: List[str], + lora_model_cls: Type[LoRAModel] = LoRAModel, + max_position_embeddings: Optional[int] = None, + ): + self._lora_model_cls = lora_model_cls + self.embedding_modules = embedding_modules + self.embedding_padding_modules = embedding_padding_modules + self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False + self.max_num_seqs = max_num_seqs + self.max_num_batched_tokens = max_num_batched_tokens + self.vocab_size = vocab_size + self.lora_config = lora_config + self.max_position_embeddings = max_position_embeddings + super().__init__(device) + # Lazily initialized by create_lora_manager. + self._adapter_manager: LoRAModelManager + + @contextmanager + def dummy_lora_cache(self): + """Use this context manager to reuse the dummy lora model + to avoid creating it repeatedly.""" + self._cached_dummy_lora = None + yield + self._cached_dummy_lora = False + + @property + def is_enabled(self) -> bool: + return True + + def create_lora_manager( + self, + model: torch.nn.Module, + ) -> Any: + lora_manager = create_lora_manager( + model, + max_num_seqs=self.max_num_seqs, + max_num_batched_tokens=self.max_num_batched_tokens, + vocab_size=self.vocab_size, + lora_config=self.lora_config, + device=self.device, + lora_manager_cls=self._manager_cls, + ) + self._adapter_manager = lora_manager + return lora_manager.model + + def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: + try: + supported_lora_modules = ( + self._adapter_manager.supported_lora_modules) + packed_modules_mapping = ( + self._adapter_manager.packed_modules_mapping) + expected_lora_modules: List[str] = [] + for module in supported_lora_modules: + if module in packed_modules_mapping: + expected_lora_modules.extend( + packed_modules_mapping[module]) + else: + expected_lora_modules.append(module) + + expected_lora_modules = list(set(expected_lora_modules)) + lora_path = get_adapter_absolute_path(lora_request.lora_path) + + peft_helper = PEFTHelper.from_local_dir( + lora_path, self.max_position_embeddings) + + # Validates the LoRA configuration against requirements before + # loading weights, throwing an exception if validation fails. + peft_helper.validate_legal(self.lora_config) + + # For some models like Qwen2VL, we need to use hf_to_vllm_mapper + # to ensure correct loading of lora weights. + model = self._adapter_manager.model + hf_to_vllm_mapper = None + if (hasattr(model, "hf_to_vllm_mapper") + and model.hf_to_vllm_mapper is not None): + hf_to_vllm_mapper = model.hf_to_vllm_mapper + + lora = self._lora_model_cls.from_local_checkpoint( + lora_path, + expected_lora_modules, + peft_helper=peft_helper, + lora_model_id=lora_request.lora_int_id, + device="cpu", + dtype=self.lora_config.lora_dtype, + target_embedding_padding=self.vocab_size + + self.lora_config.lora_extra_vocab_size, + embedding_modules=self.embedding_modules, + embedding_padding_modules=self.embedding_padding_modules, + weights_mapper=hf_to_vllm_mapper) + + except FileNotFoundError as e: + # FileNotFoundError should be raised if both + # - No adapter found to download from huggingface (or in + # offline mode) + # - No local adapter files found at `lora_request.lora_path` + # For NotFoundError + raise ValueError( + f"Loading lora {lora_request.lora_name} failed: No adapter " + f"found for {lora_request.lora_path}") from e + except Exception as e: + # For BadRequestError + raise e + + if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size: + raise ValueError(f"LoRA added vocab size {lora.extra_vocab_size} " + f"is greater than lora_extra_vocab_size " + f"{self.lora_config.lora_extra_vocab_size}.") + return lora + + def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: + if lora_request.lora_int_id in self.list_adapters(): + return False + if isinstance(self._cached_dummy_lora, LoRAModel): + dummy_lora = self._cached_dummy_lora.clone( + lora_request.lora_int_id) + else: + dummy_lora = self._adapter_manager.create_dummy_lora( + lora_request.lora_int_id, rank, 1, self.embedding_modules) + if self._cached_dummy_lora is None: + self._cached_dummy_lora = dummy_lora + return self._adapter_manager.add_adapter(dummy_lora) + + def pin_adapter(self, adapter_id: int) -> bool: + return self._adapter_manager.pin_adapter(adapter_id) + + def set_active_adapters(self, requests: Set[Any], + mapping: Optional[Any]) -> None: + set_active_adapters_worker(requests, mapping, self._apply_adapters, + self._adapter_manager.set_adapter_mapping) + + def _apply_adapters(self, adapter_requests: Set[Any]) -> None: + apply_adapters_worker(adapter_requests, self.list_adapters, + self._adapter_manager.adapter_slots, + self.remove_adapter, self.add_adapter) + + def add_adapter(self, adapter_request: Any) -> bool: + return add_adapter_worker(adapter_request, self.list_adapters, + self._load_adapter, + self._adapter_manager.add_adapter, + self._adapter_manager.activate_adapter) + + def remove_adapter(self, adapter_id: int) -> bool: + return self._adapter_manager.remove_adapter(adapter_id) + + def remove_all_adapters(self): + self._adapter_manager.remove_all_adapters() + + def list_adapters(self) -> Set[int]: + return list_adapters_worker(self._adapter_manager.list_adapters) + + +class LRUCacheWorkerLoRAManager(WorkerLoRAManager): + """WorkerLoRAManager that manages LoRA models on the worker side. + + Uses an LRU Cache. Every request, the requested LoRAs will be loaded + (unless they are already loaded) and least recently used LoRAs will + be unloaded if the cache is above capacity.""" + + _manager_cls: Type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager + + def create_lora_manager( + self, + model: torch.nn.Module, + ) -> Any: + lora_manager = create_lora_manager( + model, + lora_manager_cls=self._manager_cls, + max_num_seqs=self.max_num_seqs, + vocab_size=self.vocab_size, + lora_config=self.lora_config, + device=self.device, + max_num_batched_tokens=self.max_num_batched_tokens, + ) + self._adapter_manager = lora_manager + return lora_manager.model + + def _apply_adapters(self, lora_requests: Set[LoRARequest]) -> None: + loras_map = { + lora_request.lora_int_id: lora_request + for lora_request in lora_requests if lora_request + } + if len(loras_map) > self._adapter_manager.lora_slots: + raise RuntimeError( + f"Number of requested LoRAs ({len(loras_map)}) is greater " + "than the number of GPU LoRA slots " + f"({self._adapter_manager.lora_slots}).") + for lora in loras_map.values(): + self.add_adapter(lora) + + def add_adapter(self, lora_request: LoRARequest) -> bool: + if lora_request.lora_int_id not in self.list_adapters(): + # Load the new adapter first to ensure it is actually valid, before + # evicting any existing adapters. + # This may cause the # of loaded lora adapters to very temporarily + # exceed `--max-cpu-loras`. + lora = self._load_adapter(lora_request) + + # Loading succeeded, now check if we will exceed cache capacity and + # evict if the oldest adapter if so + if len(self._adapter_manager) + 1 > self._adapter_manager.capacity: + assert isinstance(self._adapter_manager, + LRUCacheLoRAModelManager) + self._adapter_manager.remove_oldest_adapter() + # Then add the new adapter to the cache + loaded = self._adapter_manager.add_adapter(lora) + else: + # If the lora is already loaded, just touch it to + # update its position in the caches + loaded = self._adapter_manager.get_adapter( + lora_request.lora_int_id) is not None + self._adapter_manager.activate_adapter(lora_request.lora_int_id) + return loaded diff --git a/vllm/model_executor/__init__.py b/vllm/model_executor/__init__.py new file mode 100644 index 0000000..7636152 --- /dev/null +++ b/vllm/model_executor/__init__.py @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm.model_executor.parameter import (BasevLLMParameter, + PackedvLLMParameter) +from vllm.model_executor.sampling_metadata import (SamplingMetadata, + SamplingMetadataCache) +from vllm.model_executor.utils import set_random_seed + +__all__ = [ + "SamplingMetadata", + "SamplingMetadataCache", + "set_random_seed", + "BasevLLMParameter", + "PackedvLLMParameter", +] diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py new file mode 100644 index 0000000..dfd052f --- /dev/null +++ b/vllm/model_executor/custom_op.py @@ -0,0 +1,153 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, Type + +import torch.nn as nn + +from vllm.config import get_current_vllm_config +from vllm.logger import init_logger +from vllm.platforms import current_platform + +logger = init_logger(__name__) + + +class CustomOp(nn.Module): + """ + Base class for custom ops. + Dispatches the forward method to the appropriate backend. + """ + + def __init__(self): + super().__init__() + self._forward_method = self.dispatch_forward() + + def forward(self, *args, **kwargs): + return self._forward_method(*args, **kwargs) + + def forward_native(self, *args, **kwargs): + """PyTorch-native implementation of the forward method. + This method is optional. If implemented, it can be used with compilers + such as torch.compile or PyTorch XLA. Also, it can be used for testing + purposes. + """ + raise NotImplementedError + + def forward_cuda(self, *args, **kwargs): + raise NotImplementedError + + def forward_hip(self, *args, **kwargs): + # By default, we assume that HIP ops are compatible with CUDA ops. + return self.forward_cuda(*args, **kwargs) + + def forward_xpu(self, *args, **kwargs): + # By default, we assume that XPU ops are compatible with the + # PyTorch-native implementation. + return self.forward_native(*args, **kwargs) + + def forward_cpu(self, *args, **kwargs): + # By default, we assume that CPU ops are compatible with CUDA ops. + return self.forward_cuda(*args, **kwargs) + + def forward_tpu(self, *args, **kwargs): + # By default, we assume that TPU ops are compatible with the + # PyTorch-native implementation. + # NOTE(woosuk): This is a placeholder for future extensions. + return self.forward_native(*args, **kwargs) + + def forward_hpu(self, *args, **kwargs): + # By default, we assume that Gaudi ops are compatible with the + # PyTorch-native implementation. + return self.forward_native(*args, **kwargs) + + def forward_neuron(self, *args, **kwargs): + # By default, we assume that Neuron ops are compatible with the + # PyTorch-native implementation. + return self.forward_native(*args, **kwargs) + + def forward_oot(self, *args, **kwargs): + # By default, we assume that OOT ops are compatible with the + # PyTorch-native implementation. + return self.forward_native(*args, **kwargs) + + def dispatch_forward(self): + # NOTE(woosuk): Here we assume that vLLM was built for only one + # specific backend. Currently, we do not support dynamic dispatching. + compilation_config = get_current_vllm_config().compilation_config + enabled = self.enabled() + if enabled: + compilation_config.enabled_custom_ops.update([self.__class__.name]) + else: + compilation_config.disabled_custom_ops.update( + [self.__class__.name]) + + if not enabled: + return self.forward_native + + if current_platform.is_rocm(): + return self.forward_hip + elif current_platform.is_cpu(): + return self.forward_cpu + elif current_platform.is_hpu(): + return self.forward_hpu + elif current_platform.is_tpu(): + return self.forward_tpu + elif current_platform.is_xpu(): + return self.forward_xpu + elif current_platform.is_neuron(): + return self.forward_neuron + elif current_platform.is_out_of_tree(): + return self.forward_oot + else: + return self.forward_cuda + + @classmethod + def enabled(cls) -> bool: + # if no name, then it was not registered + compilation_config = get_current_vllm_config().compilation_config + custom_ops = compilation_config.custom_ops + if not hasattr(cls, "name"): + logger.warning_once( + f"Custom op {cls.__name__} was not registered, " + f"which means it won't appear in the op registry. " + f"It will be enabled/disabled based on the global settings.") + return CustomOp.default_on() + + enabled = f"+{cls.name}" in custom_ops + disabled = f"-{cls.name}" in custom_ops + assert not (enabled + and disabled), f"Cannot enable and disable {cls.name}" + + return (CustomOp.default_on() or enabled) and not disabled + + @staticmethod + def default_on() -> bool: + """ + On by default if level < CompilationLevel.PIECEWISE + Specifying 'all' or 'none' in custom_op takes precedence. + """ + from vllm.config import CompilationLevel + compilation_config = get_current_vllm_config().compilation_config + custom_ops = compilation_config.custom_ops + count_none = custom_ops.count("none") + count_all = custom_ops.count("all") + return compilation_config.level < CompilationLevel.PIECEWISE and \ + not count_none > 0 or count_all > 0 + + # Dictionary of all custom ops (classes, indexed by registered name). + # To check if an op with a name is enabled, call .enabled() on the class. + # Examples: + # - MyOp.enabled() + # - op_registry["my_op"].enabled() + op_registry: Dict[str, Type['CustomOp']] = {} + + # Decorator to register custom ops. + @classmethod + def register(cls, name: str): + + def decorator(op_cls): + assert name not in cls.op_registry, f"Duplicate op name: {name}" + op_cls.name = name + cls.op_registry[name] = op_cls + return op_cls + + return decorator diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py new file mode 100644 index 0000000..d4fd11f --- /dev/null +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -0,0 +1,179 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from vllm.logger import init_logger +from vllm.model_executor.guided_decoding.utils import ( + convert_lark_to_gbnf, grammar_is_likely_lark, + has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features) +from vllm.reasoning import ReasoningParserManager + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + + from vllm.config import ModelConfig + from vllm.logits_process import LogitsProcessor + from vllm.sampling_params import GuidedDecodingParams + +logger = init_logger(__name__) + + +def maybe_backend_fallback( + guided_params: GuidedDecodingParams) -> GuidedDecodingParams: + + def fallback_or_error(guided_params: GuidedDecodingParams, message: str, + fallback: str) -> None: + """Change the backend to the specified fallback with a warning log, + or raise a ValueError if the `no-fallback` option is specified.""" + if guided_params.no_fallback(): + raise ValueError(message) + + logger.warning("%s Falling back to use %s instead.", message, fallback) + guided_params.backend = fallback + + # lm-format-enforce doesn't support grammar, fallback to xgrammar + if guided_params.backend_name == "lm-format-enforcer": + if guided_params.grammar is not None: + fallback_or_error( + guided_params, + "lm-format-enforcer does not support grammar guided decoding.", + "xgrammar") + + # lm-format-enforcer doesn't support some JSON schema features + elif (guided_params.json is not None + and has_lmf_unsupported_json_features(guided_params.json)): + fallback_or_error( + guided_params, + "lm-format-enforcer does not support advanced JSON schema " + "features like patterns or numeric ranges.", "outlines") + + if guided_params.backend_name == "xgrammar": + from vllm.model_executor.guided_decoding.xgrammar_decoding import ( + xgr_installed) + + # xgrammar doesn't support regex, fallback to outlines + if guided_params.regex is not None: + fallback_or_error( + guided_params, + "xgrammar does not support regex guided decoding.", "outlines") + # xgrammar doesn't support some JSON schema features + elif (guided_params.json is not None + and has_xgrammar_unsupported_json_features(guided_params.json)): + fallback_or_error( + guided_params, + "xgrammar does not support advanced JSON schema features like " + "enums, patterns or numeric ranges.", "outlines") + + # xgrammar only supports GBNF grammars, so we must convert Lark. + # We must check if the grammar is likely Lark and if that + # grammar is convertible to GBNF + elif (guided_params.grammar is not None + and grammar_is_likely_lark(guided_params.grammar)): + try: + convert_lark_to_gbnf(guided_params.grammar) + except Exception: + fallback_or_error( + guided_params, + "xgrammar does not support Lark grammars and the " + "grammar failed to convert to GBNF.", "outlines") + + # If the xgrammar module cannot be imported successfully, + # we should still allow users to use guided decoding with a fallback. + elif not xgr_installed: + fallback_or_error( + guided_params, + "xgrammar module cannot be imported successfully.", "outlines") + + if (guided_params.backend_name == "outlines" + and guided_params.json_object is not None): + # outlines doesn't support json_object, fallback to guidance + fallback_or_error(guided_params, + "outlines does not support json_object.", "guidance") + + return guided_params + + +async def get_guided_decoding_logits_processor( + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizer, + model_config: ModelConfig, + reasoning_backend: str | None = None) -> LogitsProcessor | None: + + reasoner = None + if reasoning_backend is not None: + reasoner_class = ReasoningParserManager.get_reasoning_parser( + reasoning_backend) + reasoner = reasoner_class(tokenizer) + + guided_params = maybe_backend_fallback(guided_params) + + # CFG grammar not supported by LMFE, so we use outlines instead + if guided_params.backend_name == 'outlines': + # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 + from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa + get_outlines_guided_decoding_logits_processor) + return await get_outlines_guided_decoding_logits_processor( + guided_params, tokenizer, reasoner) + if guided_params.backend == 'lm-format-enforcer': + from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa + get_local_lm_format_enforcer_guided_decoding_logits_processor) + return get_local_lm_format_enforcer_guided_decoding_logits_processor( + guided_params, tokenizer) + if guided_params.backend_name == 'xgrammar': + from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa + get_local_xgrammar_guided_decoding_logits_processor) + return get_local_xgrammar_guided_decoding_logits_processor( + guided_params, tokenizer, model_config, reasoner) + if guided_params.backend_name == 'guidance': + from vllm.model_executor.guided_decoding.guidance_decoding import ( + get_local_guidance_guided_decoding_logits_processor) + return get_local_guidance_guided_decoding_logits_processor( + guided_params, tokenizer) + raise ValueError( + f"Unknown guided decoding backend '{guided_params.backend}'. " + "Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar', 'guidance'" + ) + + +def get_local_guided_decoding_logits_processor( + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizer, + model_config: ModelConfig, + reasoning_backend: str | None = None) -> LogitsProcessor | None: + guided_params = maybe_backend_fallback(guided_params) + + reasoner = None + if reasoning_backend is not None: + reasoner_class = ReasoningParserManager.get_reasoning_parser( + reasoning_backend) + reasoner = reasoner_class(tokenizer) + + # CFG grammar not supported by LMFE, so we use outlines instead + if guided_params.backend_name == 'outlines': + # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 + from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa + get_local_outlines_guided_decoding_logits_processor) + return get_local_outlines_guided_decoding_logits_processor( + guided_params, tokenizer, reasoner) + if guided_params.backend_name == 'lm-format-enforcer': + from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa + get_local_lm_format_enforcer_guided_decoding_logits_processor) + return get_local_lm_format_enforcer_guided_decoding_logits_processor( + guided_params, tokenizer) + if guided_params.backend_name == 'xgrammar': + from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa + get_local_xgrammar_guided_decoding_logits_processor) + return get_local_xgrammar_guided_decoding_logits_processor( + guided_params, tokenizer, model_config, reasoner) + if guided_params.backend_name == 'guidance': + from vllm.model_executor.guided_decoding.guidance_decoding import ( + get_local_guidance_guided_decoding_logits_processor) + return get_local_guidance_guided_decoding_logits_processor( + guided_params, tokenizer) + + raise ValueError( + f"Unknown guided decoding backend '{guided_params.backend}'. " + "Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar', 'guidance'" + ) diff --git a/vllm/model_executor/guided_decoding/guidance_decoding.py b/vllm/model_executor/guided_decoding/guidance_decoding.py new file mode 100644 index 0000000..f19ebcb --- /dev/null +++ b/vllm/model_executor/guided_decoding/guidance_decoding.py @@ -0,0 +1,52 @@ +# SPDX-License-Identifier: Apache-2.0 +from re import escape as regex_escape + +import llguidance +from transformers import PreTrainedTokenizerBase + +from vllm.model_executor.guided_decoding.guidance_logits_processors import ( + GuidanceLogitsProcessor) +from vllm.sampling_params import GuidedDecodingParams + + +def get_local_guidance_guided_decoding_logits_processor( + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizerBase) -> GuidanceLogitsProcessor: + """ + Given an OpenAI-compatible request, check for guided decoding parameters + and get the necessary logits processor for the given guide. + """ + + grm = "" + any_whitespace = 'disable-any-whitespace' not in \ + guided_params.backend_options() + if guided_params.json: + grm = llguidance.LLMatcher.grammar_from_json_schema( + guided_params.json, + overrides={"whitespace_pattern": guided_params.whitespace_pattern}, + defaults={ + "whitespace_flexible": any_whitespace, + }) + elif guided_params.json_object: + grm = llguidance.LLMatcher.grammar_from_json_schema( + '{"type": "object"}', + overrides={"whitespace_pattern": guided_params.whitespace_pattern}, + defaults={ + "whitespace_flexible": any_whitespace, + }) + elif guided_params.regex: + grm = llguidance.grammar_from("regex", guided_params.regex) + elif guided_params.choice: + # choice just uses regex + choices = (regex_escape(str(choice)) + for choice in guided_params.choice) + choices_regex = "(" + "|".join(choices) + ")" + grm = llguidance.grammar_from("regex", choices_regex) + elif guided_params.grammar: + # this supports Lark and GBNF + grm = llguidance.grammar_from("grammar", guided_params.grammar) + + if grm: + return GuidanceLogitsProcessor(grm, tokenizer) + + raise ValueError("Unknown guided decoding mode") diff --git a/vllm/model_executor/guided_decoding/guidance_logits_processors.py b/vllm/model_executor/guided_decoding/guidance_logits_processors.py new file mode 100644 index 0000000..26fcafe --- /dev/null +++ b/vllm/model_executor/guided_decoding/guidance_logits_processors.py @@ -0,0 +1,85 @@ +# SPDX-License-Identifier: Apache-2.0 +import os +from typing import Any, List + +import llguidance +import llguidance.hf +import llguidance.torch +import torch +from transformers import PreTrainedTokenizerBase + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class GuidanceLogitsProcessor: + """Base Guidance Logits Processor""" + + cached_tokenizers: dict[str, Any] = {} + + def __init__( + self, + grammar: str, + tokenizer: PreTrainedTokenizerBase, + ) -> None: + """Base Guidance Logits Processor + + Args: + grammar (str) + grammar to guide the generation + tokenizer (PreTrainedTokenizerBase) + model's tokenizer + """ + self.grammar = grammar + self.tokenizer = tokenizer + self.tokenizer_name = tokenizer.name_or_path + self.new_sampling = False + self.initialized = False + + def _initialize(self): + if self.initialized: + return + + ll_tokenizer = self.cached_tokenizers.get(self.tokenizer.name_or_path, + None) + if ll_tokenizer is None: + ll_tokenizer = llguidance.hf.from_tokenizer(self.tokenizer, None) + self.cached_tokenizers[self.tokenizer.name_or_path] = ll_tokenizer + + self.ll_tokenizer = ll_tokenizer + self.ll_matcher = llguidance.LLMatcher( + self.ll_tokenizer, + self.grammar, + log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")), + ) + + # create reusable bitmask + self.bitmask = llguidance.torch.allocate_token_bitmask( + 1, self.ll_tokenizer.vocab_size) + + self.initialized = True + + def __call__( + self, + input_ids: List[int], + scores: torch.Tensor, + ) -> torch.Tensor: + # we initialize the guidance model here + # to avoid pickling ll_tokenizer and ll_interpreter + self._initialize() + + if self.new_sampling and len(input_ids) > 0: + self.ll_matcher.consume_token(input_ids[-1]) + err = self.ll_matcher.get_error() + if err: + logger.warning("Error in LLMatcher: %s", err) + + llguidance.torch.fill_next_token_bitmask(self.ll_matcher, self.bitmask, + 0) + llguidance.torch.apply_token_bitmask_inplace( + scores, self.bitmask.to(scores.device)) + + self.new_sampling = True + + return scores diff --git a/vllm/model_executor/guided_decoding/guided_fields.py b/vllm/model_executor/guided_decoding/guided_fields.py new file mode 100644 index 0000000..db4ce26 --- /dev/null +++ b/vllm/model_executor/guided_decoding/guided_fields.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import Dict, List, Optional, TypedDict, Union + +from pydantic import BaseModel + + +# These classes are deprecated, see SamplingParams +class LLMGuidedOptions(TypedDict, total=False): + guided_json: Union[Dict, BaseModel, str] + guided_regex: str + guided_choice: List[str] + guided_grammar: str + guided_decoding_backend: str + guided_whitespace_pattern: str + guided_json_object: bool + + +@dataclass +class GuidedDecodingRequest: + """One of the fields will be used to retrieve the logit processor.""" + guided_json: Optional[Union[Dict, BaseModel, str]] = None + guided_regex: Optional[str] = None + guided_choice: Optional[List[str]] = None + guided_grammar: Optional[str] = None + guided_decoding_backend: Optional[str] = None + guided_whitespace_pattern: Optional[str] = None + guided_json_object: Optional[bool] = None + + def __post_init__(self): + """Validate that some fields are mutually exclusive.""" + guide_count = sum([ + self.guided_json is not None, self.guided_regex is not None, + self.guided_choice is not None, self.guided_grammar is not None, + self.guided_json_object is not None + ]) + if guide_count > 1: + raise ValueError( + "You can only use one kind of guided decoding but multiple are " + f"specified: {self.__dict__}") diff --git a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py new file mode 100644 index 0000000..7eaf9e3 --- /dev/null +++ b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 + +from functools import lru_cache +from json import loads as json_loads +from typing import Optional, Union + +from lmformatenforcer import (CharacterLevelParser, JsonSchemaParser, + RegexParser, StringParser, + TokenEnforcerTokenizerData, UnionParser) +from lmformatenforcer.integrations.vllm import ( + build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data) +from transformers import PreTrainedTokenizerBase + +from vllm.logits_process import LogitsProcessor +from vllm.sampling_params import GuidedDecodingParams + + +def get_local_lm_format_enforcer_guided_decoding_logits_processor( + guided_params: GuidedDecodingParams, + tokenizer) -> Optional[LogitsProcessor]: + """ + Given an OpenAI-compatible request, check for guided decoding parameters + and get the necessary logits processor for the given guide. + We cache logit processors by (guide, tokenizer), and on cache hit + we make a shallow copy to reuse the same underlying FSM. + """ + + tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data( + tokenizer) + character_level_parser: CharacterLevelParser + if guided_params.json: + schema_dict = _normalize_json_schema_object(guided_params.json) + character_level_parser = JsonSchemaParser(schema_dict) + elif guided_params.choice: + character_level_parser = UnionParser( + [StringParser(choice) for choice in guided_params.choice]) + elif guided_params.regex: + character_level_parser = RegexParser(guided_params.regex) + elif guided_params.grammar: + # CFG grammar not supported by LMFE + raise ValueError("Cannot construct a guided decoding logits processor" + " using the grammar option with the" + " lm_format_enforcer backend.") + elif guided_params.json_object: + # None means any json object + character_level_parser = JsonSchemaParser(None) + else: + return None + + logits_processor = build_vllm_logits_processor(tokenizer_data, + character_level_parser) + return logits_processor + + +def _normalize_json_schema_object(schema: Union[str, dict]) -> dict: + if isinstance(schema, str): + return json_loads(schema) + if isinstance(schema, dict): + return schema + raise AssertionError(f"Unsupported schema type {schema}") + + +@lru_cache +def _cached_build_vllm_token_enforcer_tokenizer_data( + tokenizer: PreTrainedTokenizerBase) -> TokenEnforcerTokenizerData: + return build_vllm_token_enforcer_tokenizer_data(tokenizer) diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py new file mode 100644 index 0000000..564f927 --- /dev/null +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -0,0 +1,154 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import concurrent.futures +import os +from enum import Enum +from json import dumps as json_dumps +from re import escape as regex_escape +from typing import Optional, Tuple, Union + +from transformers import PreTrainedTokenizerBase + +from vllm.model_executor.guided_decoding.outlines_logits_processors import ( + CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) +from vllm.reasoning import ReasoningParser +from vllm.sampling_params import GuidedDecodingParams + + +class GuidedDecodingMode(Enum): + JSON = "json" + REGEX = "regex" + CHOICE = "choice" + GRAMMAR = "grammar" + + +# https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark +# the main difference is that we changed the start: value to +# start: object | array, so we are denying scalar values as the root of the +# JSON. Starting with scalars as the root seems to cause llama to generate +# without stop. +JSON_GRAMMAR = r""" +?start: object | array + +?value: object +| array +| UNESCAPED_STRING +| SIGNED_NUMBER -> number +| "true" -> true +| "false" -> false +| "null" -> null + +array : "[" [value ("," value)*] "]" +object : "{" [pair ("," pair)*] "}" +pair : UNESCAPED_STRING ":" value + +%import common.UNESCAPED_STRING +%import common.SIGNED_NUMBER +%import common.WS + +%ignore WS +""" + +global_thread_pool = None # used for generating logits processor fsm + +# It's not yet clear that using more provides a benefit, and it could +# potentially starve other processes on the machine. We'll cap this for now and +# adjust later if testing proves it to help overcome a bottleneck. +_MAX_THREADPOOL_WORKERS = 16 + + +async def get_outlines_guided_decoding_logits_processor( + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizerBase, + reasoner: Optional[ReasoningParser], +) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, + None]: + """ + Given an OpenAI-compatible request, check for guided decoding parameters + and get the necessary logits processor for the given guide. + We cache logit processors by (guide, tokenizer), and on cache hit + we make a shallow copy to reuse the same underlying FSM. + """ + global global_thread_pool + guide, mode = _get_guide_and_mode(guided_params) + if not guide or not mode: + return None + + if global_thread_pool is None: + max_workers = os.cpu_count() or 2 + if max_workers > _MAX_THREADPOOL_WORKERS: + max_workers = _MAX_THREADPOOL_WORKERS + global_thread_pool = concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers) + loop = asyncio.get_running_loop() + + return await loop.run_in_executor(global_thread_pool, + _get_logits_processor, guide, tokenizer, + mode, guided_params.whitespace_pattern, + reasoner) + + +def get_local_outlines_guided_decoding_logits_processor( + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizerBase, + reasoner: Optional[ReasoningParser], +) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, + None]: + """ + Given an OpenAI-compatible request, check for guided decoding parameters + and get the necessary logits processor for the given guide. + We cache logit processors by (guide, tokenizer), and on cache hit + we make a shallow copy to reuse the same underlying FSM. + """ + guide, mode = _get_guide_and_mode(guided_params) + if not guide or not mode: + return None + + return _get_logits_processor(guide, tokenizer, mode, + guided_params.whitespace_pattern, reasoner) + + +def _get_guide_and_mode( + guided_params: GuidedDecodingParams +) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]: + if guided_params.json: + if isinstance(guided_params.json, dict): + # turn dict into hashable string + json = json_dumps(guided_params.json) + else: + json = guided_params.json + return json, GuidedDecodingMode.JSON + elif guided_params.regex: + return guided_params.regex, GuidedDecodingMode.REGEX + elif guided_params.choice: + # choice just uses regex + choices = [ + regex_escape(str(choice)) for choice in guided_params.choice + ] + choices_regex = "(" + "|".join(choices) + ")" + return choices_regex, GuidedDecodingMode.CHOICE + elif guided_params.grammar: + return guided_params.grammar, GuidedDecodingMode.GRAMMAR + elif guided_params.json_object: + return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR + else: + return None, None + + +def _get_logits_processor( + guide: str, + tokenizer: PreTrainedTokenizerBase, + mode: GuidedDecodingMode, + whitespace_pattern: Union[str, None], + reasoner: Optional[ReasoningParser], +) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]: + if mode == GuidedDecodingMode.JSON: + return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern, + reasoner) + elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE: + return RegexLogitsProcessor(guide, tokenizer, reasoner) + elif mode == GuidedDecodingMode.GRAMMAR: + return CFGLogitsProcessor(guide, tokenizer, reasoner) + else: + raise ValueError(f"Unknown guided decoding mode {mode}") diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py new file mode 100644 index 0000000..31af459 --- /dev/null +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -0,0 +1,271 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2024- the Outlines developers +# This file is adapted from +# https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import json +from collections import defaultdict +from functools import lru_cache +from typing import Callable, DefaultDict, Dict, List, Optional, Union + +import numpy as np +import torch +from outlines import grammars +from outlines.caching import cache, disable_cache +from outlines.fsm.guide import (CFGGuide, CFGState, Generate, Guide, + RegexGuide, Write) +from outlines.fsm.parsing import PartialLark +from outlines_core.fsm.json_schema import build_regex_from_schema +from pydantic import BaseModel +from transformers import PreTrainedTokenizerBase + +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.reasoning import ReasoningParser + +logger = init_logger(__name__) + +if envs.VLLM_V0_USE_OUTLINES_CACHE: + logger.warning("Enabling outlines cache. This is an unbounded on-disk " + "cache. It may consume a lot of disk space and should " + "not be used with untrusted clients.") +else: + disable_cache() + + +class BaseLogitsProcessor: + + def __init__(self, guide: Guide, reasoner: Optional[ReasoningParser]): + self._guide: Guide = guide + self._reasoner: Optional[ReasoningParser] = reasoner + # CFGState is used for the FSM state for CFGGuide + self._fsm_state: DefaultDict[int, Union[int, + CFGState]] = defaultdict(int) + + def __call__(self, input_ids: List[int], + scores: torch.Tensor) -> torch.Tensor: + """Use the FSM to bias the logits before sampling the next token.""" + + # Skip the structured logits processing if reasoning is not finished. + # reasoner is not None only when `--enable-reasoning` is set. + if self._reasoner is not None: + if not self._reasoner.is_reasoning_end(input_ids): + return scores + else: + # Remove the reasoning tokens from the input_ids + # We need this because our implementation relies on the + # hash of the input_ids to store the FSM state. + input_ids = self._reasoner.extract_content_ids(input_ids) + + seq_id = hash(tuple(input_ids)) + + if len(input_ids) > 0: + last_token = input_ids[-1] + last_seq_id = hash(tuple(input_ids[:-1])) + self._fsm_state[seq_id] = self._guide.get_next_state( + state=self._fsm_state[last_seq_id], token_id=last_token) + else: + # Note: this is a hack. + # Lark pickling does not work properly (silent failure), + # which breaks the RPC (which uses python pickleing). + # We need to find a better solution. + # On the first time this is called, we simply re-create + # the Lark object. + if isinstance(self._guide, CFGGuide): + self._guide.parser = PartialLark( + self._guide.cfg_string, + parser="lalr", + import_paths=[grammars.GRAMMAR_PATH], + ) + self._fsm_state[seq_id] = CFGState( + parser_state=self._guide.parser.parse(""), prev_token=None) + + instruction = self._guide.get_next_instruction( + state=self._fsm_state[seq_id]) + + if type(instruction) == Generate: # noqa: E721 + allowed_tokens = instruction.tokens + elif type(instruction) == Write: # noqa: E721 + # TODO: support fast forward tokens + allowed_tokens = [instruction.tokens[0]] + else: + raise TypeError( + f"Unsupported instruction type {type(instruction)}") + + mask = torch.full((scores.shape[-1], ), + -torch.inf, + device=scores.device) + # The tokenizer may support more token ids than the model can generate, + # eg. Llama 3.2 Vision models have an `<|image|>` token with id 128256 + # but scores.shape == torch.Size([128256]) + # Using NumPy is faster for filtering token ids + allowed_tokens = np.array(allowed_tokens, dtype=np.int64) + allowed_tokens = torch.tensor(allowed_tokens, device=scores.device) + allowed_tokens = allowed_tokens.masked_select( + allowed_tokens < scores.shape[-1]) + mask.index_fill_(0, allowed_tokens, 0) + if current_platform.is_hpu(): + # Workaround for HPU bug where add_() raise RuntimeError: + # synNodeCreateWithId failed for node: strided_insert + # with synStatus 1 [Invalid argument], hopefully it will + # be fixed in the future releases of the HPU runtime. + scores = scores.add(mask) + else: + scores.add_(mask) + return scores + + +class RegexLogitsProcessor(BaseLogitsProcessor): + + @classmethod + @cache() + def _get_guide(cls, regex_string: str, + tokenizer: PreTrainedTokenizerBase) -> Guide: + tokenizer = _adapt_tokenizer(tokenizer) + return RegexGuide.from_regex(regex_string, tokenizer) + + def __init__( + self, + regex_string: str, + tokenizer: PreTrainedTokenizerBase, + reasoner: Optional[ReasoningParser], + ): + """Compile the FSM that drives the regex-structured generation. + + Parameters + ---------- + regex_string + A string that represents a regular expression + tokenizer + The model's tokenizer + + """ + super().__init__( + RegexLogitsProcessor._get_guide(regex_string, tokenizer), reasoner) + + +class JSONLogitsProcessor(RegexLogitsProcessor): + + def __init__(self, schema: Union[str, Dict, BaseModel], + tokenizer: PreTrainedTokenizerBase, + whitespace_pattern: Union[str, None], + reasoner: Optional[ReasoningParser]): + """Compile the FSM that drives the JSON-guided generation. + + Parameters + ---------- + schema + A JSON schema that encodes the structure we want the model to + generate + tokenizer + The model's tokenizer + whitespace_pattern + Pattern to use for JSON syntactic whitespace (doesn't impact + string literals) + Example: allow only a single space or newline with + `whitespace_pattern=r"[\n ]?"` + """ + if isinstance(schema, type(BaseModel)): + schema_str = json.dumps(schema.model_json_schema()) + elif isinstance(schema, Dict): + schema_str = json.dumps(schema) + elif isinstance(schema, str): + schema_str = schema + else: + raise ValueError( + f"Cannot parse schema {schema}. The schema must be either " + f"a Pydantic object, a dictionary or a string that contains " + f"the JSON Schema specification") + regex_string = build_regex_from_schema(schema_str, whitespace_pattern) + super().__init__(regex_string, tokenizer, reasoner) + + +class CFGLogitsProcessor(BaseLogitsProcessor): + + @classmethod + @cache() + def _get_guide(cls, cfg: str, tokenizer: PreTrainedTokenizerBase) -> Guide: + tokenizer = _adapt_tokenizer(tokenizer) + return CFGGuide(cfg, tokenizer) + + def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase, + reasoner: Optional[ReasoningParser]): + """Compile the FSM that drives the context free grammar generation. + + Parameters + ---------- + cfg + A string that represents a context-free grammar + tokenizer + The model's tokenizer + + """ + super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer), + reasoner) + self._guide = self._guide.copy() + + +@lru_cache(maxsize=32) +def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase): + """Adapt vLLM's tokenizer to use to compile the FSM. + + The API of Outlines tokenizers is slightly different to that of + `transformers`. The decoder of outlines, returns a list whereas + the decode of vLLM returns an str. To sync the vLLM decoder with + outlines internal api, the decoder should be adapted. In addition + we need to handle the missing spaces to Llama's tokenizer to be + able to compile FSMs for this model. + + """ + if getattr(tokenizer, "_outlines_adapted", False): + return tokenizer + + tokenizer = copy.deepcopy(tokenizer) + + tokenizer.vocabulary = tokenizer.get_vocab() + tokenizer.special_tokens = set(tokenizer.all_special_tokens) + + def convert_token_to_string(token: str) -> str: + from transformers.file_utils import SPIECE_UNDERLINE + + string = tokenizer.convert_tokens_to_string([token]) + + # A hack to handle missing spaces to HF's Llama tokenizers + if (type(token) is str and token.startswith(SPIECE_UNDERLINE) + or token == "<0x20>"): + return " " + string + + return string + + def change_decoder( + decoder: Callable[[List[int]], + str]) -> Callable[[List[int]], List[str]]: + """Sync vLLM's decoder with the outlines by returning list.""" + + def new_decoder(inp_tokens: List[int]) -> List[str]: + if (isinstance(inp_tokens, list) and len(inp_tokens) == 1 + and isinstance(inp_tokens[0], list)): + inp_tokens = inp_tokens[0] + return [decoder(inp_tokens)] + + return new_decoder + + tokenizer.convert_token_to_string = convert_token_to_string + tokenizer.decode = change_decoder(tokenizer.decode) + setattr(tokenizer, "_outlines_adapted", True) # noqa: B010 + + return tokenizer diff --git a/vllm/model_executor/guided_decoding/reasoner/__init__.py b/vllm/model_executor/guided_decoding/reasoner/__init__.py new file mode 100644 index 0000000..ab6e47c --- /dev/null +++ b/vllm/model_executor/guided_decoding/reasoner/__init__.py @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from transformers import PreTrainedTokenizer + +from vllm.logger import init_logger +from vllm.model_executor.guided_decoding.reasoner.deepseek_reasoner import ( # noqa: E501 + DeepSeekReasoner) +from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner + +logger = init_logger(__name__) + + +def get_reasoner(tokenizer: PreTrainedTokenizer, + reasoning_backend: str | None) -> Reasoner | None: + if reasoning_backend is None: + # No reasoning backend specified + return None + elif reasoning_backend == "deepseek_r1": + return DeepSeekReasoner.from_tokenizer(tokenizer) + elif reasoning_backend == "granite": + logger.warning( + "Granite reasoner not yet implemented for structured outputs") + return None + else: + # Raise a warning for unknown reasoning backend and return None + # We cannot raise an error here because some reasoning models + # may not have a corresponding Reasoner class. + logger.warning("Unknown reasoning backend %s for structured outputs ", + reasoning_backend) + return None + + +__all__ = ["Reasoner", "get_reasoner"] diff --git a/vllm/model_executor/guided_decoding/utils.py b/vllm/model_executor/guided_decoding/utils.py new file mode 100644 index 0000000..1098177 --- /dev/null +++ b/vllm/model_executor/guided_decoding/utils.py @@ -0,0 +1,253 @@ +# SPDX-License-Identifier: Apache-2.0 + +import re + + +def has_xgrammar_unsupported_json_features(schema: dict) -> bool: + """Check if JSON schema contains features unsupported by xgrammar.""" + + def check_object(obj: dict) -> bool: + if not isinstance(obj, dict): + return False + + # Check for pattern restrictions + if "pattern" in obj: + return True + + # Check for enum restrictions + if "enum" in obj: + return True + + # Check for numeric ranges + if obj.get("type") in ("integer", "number") and any( + key in obj for key in [ + "minimum", "maximum", "exclusiveMinimum", + "exclusiveMaximum", "multipleOf" + ]): + return True + + # Check for array unsupported keywords + if obj.get("type") == "array" and any(key in obj for key in [ + "uniqueItems", "contains", "minContains", "maxContains", + "minItems", "maxItems" + ]): + return True + + # Unsupported keywords for strings + if obj.get("type") == "string" and any( + key in obj for key in ["minLength", "maxLength", "format"]): + return True + + # Unsupported keywords for objects + if obj.get("type") == "object" and any(key in obj for key in [ + "minProperties", "maxProperties", "propertyNames", + "patternProperties" + ]): + return True + + # Recursively check all nested objects and arrays + for value in obj.values(): + if isinstance(value, dict): + if check_object(value): + return True + elif isinstance(value, list): + for item in value: + if isinstance(item, dict) and check_object(item): + return True + + return False + + return check_object(schema) + + +def has_lmf_unsupported_json_features(schema: dict) -> bool: + """ + Check if JSON schema contains features unsupported + by lm_format_enforcer. + + Known issues: + - Regex patterns: + "grade": { + "type": "string", + "pattern": "^[A-D]$" # Regex pattern + }, + """ + + def check_object(obj: dict) -> bool: + if not isinstance(obj, dict): + return False + + # Check for pattern restrictions + if "pattern" in obj: + return True + + # Recursively check all nested objects and arrays + for value in obj.values(): + if isinstance(value, dict): + if check_object(value): + return True + elif isinstance(value, list): + for item in value: + if isinstance(item, dict) and check_object(item): + return True + + return False + + return check_object(schema) + + +def grammar_is_likely_lark(grammar_str: str) -> bool: + """ + Check if grammar appears to use Lark syntax. + + Args: + grammar_str: Input grammar string + + Returns: + bool: True if grammar appears to be in Lark format, False otherwise + + Examples: + >>> grammar_is_likely_lark("rule: 'abc'") + True + >>> grammar_is_likely_lark("rule ::= 'abc'") + False + """ + if not grammar_str or not isinstance(grammar_str, str): + return False + + for line in grammar_str.split('\n'): + # Remove both comment styles + line = re.sub(r'(#|//).*$', '', line).strip() + if not line: + continue + + # Look for GBNF rule definition + if '::=' in line: + return False + + return True + + +def convert_lark_to_gbnf(grammar_str: str) -> str: + """ + Convert a Lark grammar string to GBNF format. + + GBNF reference: + https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md + Lark grammar reference: + https://lark-parser.readthedocs.io/en/latest/grammar.html + + Args: + grammar_str: Input grammar in Lark format + + Returns: + str: Converted grammar in GBNF format + + Examples: + >>> print(convert_lark_to_gbnf("rule: 'hello'")) + root ::= rule + rule ::= "hello" + """ + if not isinstance(grammar_str, str): + raise ValueError(f"Grammar must be a string, got {type(grammar_str)}") + if not grammar_str.strip(): + raise ValueError("Grammar string cannot be empty") + + defined_rules = set() + referenced_rules = set() + output_lines = [] + + def clean_line(line: str) -> str: + """Remove comments and whitespace from line.""" + return re.sub(r'(#|//).*$', '', line).strip() + + def check_quotes(text: str, rule_name: str, line_num: int) -> None: + """Validate quote matching in text.""" + if text.count("'") % 2 != 0 or text.count('"') % 2 != 0: + raise ValueError( + f"Mismatched quotes in {rule_name} on line {line_num}") + + def extract_references(text: str) -> set: + """Extract rule references from text.""" + # Remove quoted strings and special characters + text = re.sub(r'"[^"]*"', '', text) + text = re.sub(r'[+*?()|\[\]{}]', ' ', text) + return set(re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', text)) + + # First pass: Find root rule and validate rule definitions + lines = [clean_line(line) for line in grammar_str.split('\n')] + first_rule = None + + for line_num, line in enumerate(lines, 1): + if not line or line.startswith('|'): + continue + + if ':' in line: + try: + name = line.split(':', 1)[0].strip().strip('?') + defined_rules.add(name) + if first_rule is None: + first_rule = name + if name == 'start': + first_rule = 'start' + except IndexError as e: + raise ValueError(f"Invalid rule format on line {line_num}. " + "Expected 'rule_name: definition'") from e + + if not defined_rules: + raise ValueError("No valid rules found in grammar") + + # Add root rule + output_lines.append(f"root ::= {first_rule}") + + # Second pass: Process rule definitions and alternatives + current_rule = None + current_definition = [] + + for line_num, line in enumerate(lines, 1): + if not line: + continue + + try: + if ':' in line and not line.startswith('|'): + # Save previous rule if exists + if current_rule: + output_lines.append( + f"{current_rule} ::= {' | '.join(current_definition)}") + + # Process new rule + name, definition = line.split(':', 1) + current_rule = name.strip().strip('?') + + check_quotes(definition, f"rule '{current_rule}'", line_num) + definition = re.sub(r"'([^']*)'", r'"\1"', definition) + referenced_rules.update(extract_references(definition)) + current_definition = [definition.strip()] + + elif line.startswith('|'): + if not current_rule: + raise ValueError(f"Alternative '|' on line {line_num} " + "without a preceding rule definition") + + alt_def = line[1:].strip() + check_quotes(alt_def, f"alternative for rule '{current_rule}'", + line_num) + alt_def = re.sub(r"'([^']*)'", r'"\1"', alt_def) + referenced_rules.update(extract_references(alt_def)) + current_definition.append(alt_def) + + except ValueError as e: + raise ValueError(f"Error on line {line_num}: {str(e)}") from e + + # Add final rule if exists + if current_rule: + output_lines.append( + f"{current_rule} ::= {' | '.join(current_definition)}") + + # Validate all rules are defined + undefined_rules = referenced_rules - defined_rules - {'root'} + if undefined_rules: + raise ValueError("Referenced rules are not defined: " + f"{', '.join(sorted(undefined_rules))}") + + return '\n'.join(output_lines) diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py new file mode 100644 index 0000000..b44301f --- /dev/null +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -0,0 +1,409 @@ +# SPDX-License-Identifier: Apache-2.0 + +# noqa: UP007 +from __future__ import annotations + +import json +import re +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, List + +import torch + +from vllm.logger import init_logger + +try: + import xgrammar as xgr + xgr_installed = True +except ImportError: + xgr_installed = False + pass + +from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf, + grammar_is_likely_lark) +from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + + from vllm.config import ModelConfig + from vllm.reasoning import ReasoningParser + from vllm.sampling_params import GuidedDecodingParams + +logger = init_logger(__name__) + + +def get_local_xgrammar_guided_decoding_logits_processor( + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizer, + model_config: ModelConfig, + reasoner: ReasoningParser | None, + max_threads: int = 8): + config = GrammarConfig.from_guided_params(guided_params=guided_params, + model_config=model_config, + tokenizer=tokenizer, + max_threads=max_threads) + return XGrammarLogitsProcessor(config, reasoner) + + +@dataclass(frozen=True) +class TokenizerData: + """Immutable container for cached tokenizer data.""" + metadata: str + encoded_vocab: list[str] = field(default_factory=list) + + +class TokenizerDataCache: + """Cache manager for tokenizer data to avoid repeated processing.""" + _cache: dict[int, TokenizerData] = {} + + @classmethod + def get_tokenizer_data( + cls, + tokenizer: PreTrainedTokenizer, + /, + *, + tokenizer_hash: int, + vocab_size: int, + ) -> TokenizerData: + + if tokenizer_hash not in cls._cache: + tokenizer_info = xgr.TokenizerInfo.from_huggingface( + tokenizer, + # NOTE: We will need to use lm_head's vocab_size + # to determine correct special_token_ids for this tokenizer. + # See https://github.com/mlc-ai/xgrammar/commit/70c959fb6d9cea75aae33c414763cd0602022d92 # noqa: E501 + vocab_size=vocab_size, + ) + metadata = json.loads(tokenizer_info.dump_metadata()) + + # Vendored from xgrammar logic to get encoded_vocab + # https://github.com/mlc-ai/xgrammar/blob/989222175c2a30fb7987d8bcce35bec1bf6817f2/python/xgrammar/tokenizer_info.py#L127 # noqa: E501 + try: + vocab_dict = tokenizer.get_vocab() + except AttributeError as e: + raise ValueError( + f"Cannot get the vocabulary of the tokenizer " + f"{type(tokenizer)}. The tokenizer should have a " + "get_vocab method.") from e + + # maintain tokenizer's indexing + encoded_vocab = [""] * tokenizer_info.vocab_size + for token, idx in vocab_dict.items(): + if idx < tokenizer_info.vocab_size: + encoded_vocab[idx] = token + + if isinstance(tokenizer, MistralTokenizer): + # REF: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501 + metadata.update({ + "vocab_type": xgr.VocabType.BYTE_FALLBACK, + "add_prefix_space": True + }) + + cls._cache[tokenizer_hash] = TokenizerData( + encoded_vocab=encoded_vocab, + metadata=json.dumps(metadata), + ) + + return cls._cache[tokenizer_hash] + + +class GrammarCompilerCache: + """ + Cache for GrammarCompiler instances based on tokenizer. + + This cache reduces the overhead of creating new compiler instances when + using the same tokenizer configuration. + """ + _cache: dict[str, xgr.GrammarCompiler] = {} + + @classmethod + def get_compiler(cls, config: GrammarConfig) -> xgr.GrammarCompiler: + cache_key = str(config.tokenizer_hash) + + if cache_key not in cls._cache: + config_data = config.tokenizer_data + + # In TokenizerDataCache.get_tokenizer_data, a serializable + # tokenizer_data is created and cached. This data is used to build + # a tokenizer_info and create an xgrammar compiler. + tokenizer_info = xgr.TokenizerInfo.from_vocab_and_metadata( + encoded_vocab=config_data.encoded_vocab, + metadata=config_data.metadata, + ) + cls._cache[cache_key] = xgr.GrammarCompiler( + tokenizer_info, max_threads=config.max_threads) + + return cls._cache[cache_key] + + +@dataclass +class GrammarConfig: + """Serializable configuration for grammar compilation""" + tokenizer_hash: int + tokenizer_data: TokenizerData + json_str: str | None = None + grammar_str: str | None = None + json_object: bool | None = None + any_whitespace: bool = True + max_threads: int = 8 + + @classmethod + def from_guided_params(cls, + guided_params: GuidedDecodingParams, + model_config: ModelConfig, + tokenizer: PreTrainedTokenizer, + max_threads: int = 8) -> GrammarConfig: + + tokenizer_hash = hash(tokenizer) + tokenizer_data = TokenizerDataCache.get_tokenizer_data( + tokenizer, + tokenizer_hash=tokenizer_hash, + vocab_size=model_config.hf_text_config.vocab_size, + ) + + if guided_params.json: + if not isinstance(guided_params.json, str): + json_str = json.dumps(guided_params.json) + else: + json_str = guided_params.json + + any_whitespace = 'disable-any-whitespace' not in \ + guided_params.backend_options() + + # Check and log if model with xgrammar and whitespace have history + # of runaway generation of whitespaces. + # References: + # https://github.com/vllm-project/vllm/pull/12744 + # https://github.com/mlc-ai/xgrammar/issues/212 + model_with_warn = None + + if 'Mistral' in model_config.model: + model_with_warn = 'Mistral' + elif 'Qwen' in model_config.model: + model_with_warn = 'Qwen' + + if model_with_warn is not None and any_whitespace: + msg = (f"{model_with_warn} " + f"model detected, consider set " + f"`guided_backend=xgrammar:disable-any-whitespace` " + f"to prevent runaway generation of whitespaces.") + logger.info_once(msg) + # Validate the schema and raise ValueError here if it is invalid. + # This is to avoid exceptions in model execution, which will crash + # the engine worker process. + try: + xgr.Grammar.from_json_schema(json_str, + any_whitespace=any_whitespace) + except RuntimeError as err: + raise ValueError(str(err)) from err + + return cls(json_str=json_str, + tokenizer_hash=tokenizer_hash, + max_threads=max_threads, + tokenizer_data=tokenizer_data, + any_whitespace=any_whitespace) + elif guided_params.grammar: + # XGrammar only supports GBNF grammars, so we must convert Lark + if grammar_is_likely_lark(guided_params.grammar): + try: + grammar_str = convert_lark_to_gbnf(guided_params.grammar) + except ValueError as e: + raise ValueError( + "Failed to convert the grammar from Lark to GBNF. " + "Please either use GBNF grammar directly or specify" + " --guided-decoding-backend=outlines.\n" + f"Conversion error: {str(e)}") from e + else: + grammar_str = guided_params.grammar + + # Validate the grammar and raise ValueError here if it is invalid. + # This is to avoid exceptions in model execution, which will crash + # the engine worker process. + try: + xgr.Grammar.from_ebnf(grammar_str) + except RuntimeError as err: + raise ValueError(str(err)) from err + + return cls(grammar_str=grammar_str, + tokenizer_hash=tokenizer_hash, + max_threads=max_threads, + tokenizer_data=tokenizer_data) + elif guided_params.json_object: + return cls( + json_object=True, + tokenizer_hash=tokenizer_hash, + max_threads=max_threads, + tokenizer_data=tokenizer_data, + ) + elif guided_params.choice: + choice_str = GrammarConfig.choice_as_grammar(guided_params.choice) + try: + xgr.Grammar.from_ebnf(choice_str) + except RuntimeError as err: + raise ValueError(str(err)) from err + + return cls( + grammar_str=choice_str, + tokenizer_hash=tokenizer_hash, + max_threads=max_threads, + tokenizer_data=tokenizer_data, + ) + else: + raise ValueError( + "Currently only support JSON and EBNF grammar mode for xgrammar" + ) + + @staticmethod + def escape_ebnf_string(s: str) -> str: + """Escape special characters in a EBNF string.""" + # Escape double quotes and backslashes + return re.sub(r'(["\\])', r'\\\1', s) + + @staticmethod + def choice_as_grammar(choice: List[str] | None) -> str: + if choice is None: + raise ValueError("Choice is not set") + escaped_choices = (GrammarConfig.escape_ebnf_string(c) for c in choice) + grammar = ('root ::= ' + ' | '.join(f'"{c}"' for c in escaped_choices)) + return grammar + + @staticmethod + def tokenizer_info(tokenizer_data: TokenizerData) -> xgr.TokenizerInfo: + return xgr.TokenizerInfo.from_vocab_and_metadata( + encoded_vocab=tokenizer_data.encoded_vocab, + metadata=tokenizer_data.metadata, + ) + + +@dataclass +class XGrammarLogitsProcessor: + """Wrapper class to support pickle protocol""" + config: GrammarConfig + reasoner: ReasoningParser | None = None + + ctx: xgr.CompiledGrammar | None = None + tokenizer_info: xgr.TokenizerInfo = None # type: ignore[assignment] + token_bitmask: torch.Tensor = None # type: ignore[assignment] + matchers: list[xgr.GrammarMatcher] = field(default_factory=list) + batch_size: int = field(default=1) + prefilled: bool = field(default=False) + + def __post_init__(self): + self.tokenizer_info = self.config.tokenizer_info( + self.config.tokenizer_data) + + def __getstate__(self) -> dict[str, Any]: + return {'config': self.config, 'reasoner': self.reasoner} + + def __setstate__(self, state: dict[str, Any]): + self.config = state['config'] + self.reasoner = state['reasoner'] + + self.tokenizer_info = GrammarConfig.tokenizer_info( + self.config.tokenizer_data) + self.ctx = None + self.matchers = [] + self.batch_size = 1 + self.token_bitmask = None # type: ignore[assignment] + self.prefilled = False + + def _ensure_ctx(self): + """Lazily initialize the processor in the worker process""" + if self.ctx is None: + compiler = GrammarCompilerCache.get_compiler(self.config) + if self.config.json_str is not None: + any_whitespace = self.config.any_whitespace + self.ctx = compiler\ + .compile_json_schema(self.config.json_str, + any_whitespace=any_whitespace) + elif self.config.grammar_str is not None: + self.ctx = compiler.compile_grammar(self.config.grammar_str) + elif self.config.json_object: + any_whitespace = self.config.any_whitespace + self.ctx = compiler\ + .compile_json_schema('{"type": "object"}', + any_whitespace=any_whitespace) + else: + raise ValueError( + "Invalid configuration for xgrammar logits processor") + + def __call__(self, input_ids: list[int], + scores: torch.Tensor) -> torch.Tensor: + + # Skip the structured logits processing if reasoning is not finished. + # reasoner is not None only when `--enable-reasoning` is set. + if self.reasoner is not None and \ + not self.reasoner.is_reasoning_end( + input_ids): + return scores + + if self.ctx is None: + self._ensure_ctx() + + if len(self.matchers) == 0: + self.matchers = [ + xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size) + ] + self.token_bitmask = xgr.allocate_token_bitmask( + self.batch_size, self.tokenizer_info.vocab_size) + + if not self.prefilled: + # Have not sampled a token yet + self.prefilled = True + else: + for i, matcher in enumerate(self.matchers): + if not matcher.is_terminated(): + sampled_token = input_ids[-1] + assert self.matchers[i].accept_token(sampled_token) + + for i, matcher in enumerate(self.matchers): + if not matcher.is_terminated(): + # @ubospica: ideally, fill_next_token_bitmask should be + # parallelized with model decoding + # See https://github.com/vllm-project/vllm/pull/10785/files#r1864278303 + matcher.fill_next_token_bitmask(self.token_bitmask, i) + + # token_bitmask is a CPU tensor for use with accept_token and + # fill_next_token_bitmask so we move it to the device of scores + device_type = scores.device.type + dtype = scores.dtype + if device_type != "cuda": + # xgrammar on cpu only supports float32 scores + # see: https://github.com/mlc-ai/xgrammar/blob/c1b64920cad24f44f235778c1c00bb52d57da01a/python/xgrammar/kernels/apply_token_bitmask_inplace_cpu.py#L22 + scores = scores.to("cpu").float().unsqueeze(0) + + # Note: In this method, if the tensors have different dimensions + # on CPU device fails, but on GPU it runs without error. Hence the + # unsqueeze above for scores, to match the token bitmask shape + xgr.apply_token_bitmask_inplace( + scores, self.token_bitmask.to(scores.device, non_blocking=True)) + if device_type != "cuda": + scores = scores.to(dtype).to(device_type).squeeze() + + return scores + + def clone(self) -> XGrammarLogitsProcessor: + """Create a new instance with shared compiled grammar + but separate state""" + new_processor = XGrammarLogitsProcessor(self.config, self.reasoner) + + # Share the compiled grammar context (immutable after compilation) + new_processor.ctx = self.ctx + + # Create fresh matchers for the new sequence + if self.ctx is not None: + new_processor.matchers = [ + xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size) + ] + + # Create a new token bitmask with the same size + if hasattr(self, 'token_bitmask') and self.token_bitmask is not None: + new_processor.token_bitmask = self.token_bitmask + + # Copy simple attributes + new_processor.batch_size = self.batch_size + # Reset prefilled state for new sequence + new_processor.prefilled = False + + return new_processor diff --git a/vllm/model_executor/layers/__init__.py b/vllm/model_executor/layers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py new file mode 100644 index 0000000..c4c17a3 --- /dev/null +++ b/vllm/model_executor/layers/activation.py @@ -0,0 +1,375 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Custom activation functions.""" +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.utils import LazyDict + + +@CustomOp.register("fatrelu_and_mul") +class FatreluAndMul(CustomOp): + """An activation function for FATReLU. + + The function computes x -> FATReLU(x[:d]) * x[d:] where + d = x.shape[-1] // 2. + This is used in openbmb/MiniCPM-S-1B-sft. + + Shapes: + x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d) + return: (num_tokens, d) or (batch_size, seq_len, d) + """ + + def __init__(self, threshold: float = 0.): + super().__init__() + self.threshold = threshold + if current_platform.is_cuda_alike(): + from vllm import _custom_ops as ops + self.op = ops.fatrelu_and_mul + elif current_platform.is_cpu(): + self._forward_method = self.forward_native + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + x1 = x[..., :d] + x2 = x[..., d:] + x1 = F.threshold(x1, self.threshold, 0.0) + return x1 * x2 + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + output_shape = (x.shape[:-1] + (d, )) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + self.op(out, x, self.threshold) + return out + + +@CustomOp.register("silu_and_mul") +class SiluAndMul(CustomOp): + """An activation function for SwiGLU. + + The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2. + + Shapes: + x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d) + return: (num_tokens, d) or (batch_size, seq_len, d) + """ + + def __init__(self): + super().__init__() + if current_platform.is_cuda_alike() or current_platform.is_cpu(): + from vllm import _custom_ops as ops + self.op = ops.silu_and_mul + elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops + self.op = ipex_ops.silu_and_mul + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + d = x.shape[-1] // 2 + return F.silu(x[..., :d]) * x[..., d:] + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + output_shape = (x.shape[:-1] + (d, )) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + self.op(out, x) + return out + + def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + output_shape = (x.shape[:-1] + (d, )) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + self.op(out, x) + return out + + def forward_neuron(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + x_reshaped = x.view(-1, x.shape[-1]) + s = x_reshaped[:, :d] * F.sigmoid(x_reshaped[:, :d]) + result = s * x_reshaped[:, d:] + return result.view(*x.shape[:-1], d) + + +@CustomOp.register("mul_and_silu") +class MulAndSilu(CustomOp): + """An activation function for SwiGLU. + + The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2. + + Shapes: + x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d) + return: (num_tokens, d) or (batch_size, seq_len, d) + """ + + def __init__(self): + super().__init__() + if current_platform.is_cuda_alike(): + from vllm import _custom_ops as ops + self.op = ops.mul_and_silu + elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops + self.op = ipex_ops.silu_and_mul + elif current_platform.is_cpu(): + self._forward_method = self.forward_native + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + d = x.shape[-1] // 2 + return x[..., :d] * F.silu(x[..., d:]) + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + output_shape = (x.shape[:-1] + (d, )) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + self.op(out, x) + return out + + # TODO implement forward_xpu for MulAndSilu + # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + + +@CustomOp.register("gelu_and_mul") +class GeluAndMul(CustomOp): + """An activation function for GeGLU. + + The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2. + + Shapes: + x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d) + return: (batch_size, seq_len, d) or (num_tokens, d) + """ + + def __init__(self, approximate: str = "none"): + super().__init__() + self.approximate = approximate + if approximate not in ("none", "tanh"): + raise ValueError(f"Unknown approximate mode: {approximate}") + if current_platform.is_cuda_alike() or current_platform.is_cpu(): + if approximate == "none": + from vllm import _custom_ops as ops + self.op = ops.gelu_and_mul + elif approximate == "tanh": + from vllm import _custom_ops as ops + self.op = ops.gelu_tanh_and_mul + elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops + if approximate == "none": + self.op = ipex_ops.gelu_and_mul + else: + self.op = ipex_ops.gelu_tanh_and_mul + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + d = x.shape[-1] // 2 + return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:] + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + output_shape = (x.shape[:-1] + (d, )) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + self.op(out, x) + return out + + def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + output_shape = (x.shape[:-1] + (d, )) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + self.op(out, x) + return out + + def extra_repr(self) -> str: + return f'approximate={repr(self.approximate)}' + + +@CustomOp.register("gelu_new") +class NewGELU(CustomOp): + + def __init__(self): + super().__init__() + if current_platform.is_cuda_alike() or current_platform.is_cpu(): + from vllm import _custom_ops as ops + self.op = ops.gelu_new + elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops + self.op = ipex_ops.gelu_new + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + c = math.sqrt(2.0 / math.pi) + return 0.5 * x * (1.0 + torch.tanh(c * + (x + 0.044715 * torch.pow(x, 3.0)))) + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + self.op(out, x) + return out + + def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + return self.op(x) + + +@CustomOp.register("gelu_fast") +class FastGELU(CustomOp): + + def __init__(self): + super().__init__() + if current_platform.is_cuda_alike() or current_platform.is_cpu(): + from vllm import _custom_ops as ops + self.op = ops.gelu_fast + elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops + self.op = ipex_ops.gelu_fast + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * + (1.0 + 0.044715 * x * x))) + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + self.op(out, x) + return out + + def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + return self.op(x) + + +@CustomOp.register("quick_gelu") +class QuickGELU(CustomOp): + # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90 + def __init__(self): + super().__init__() + if current_platform.is_cuda_alike() or current_platform.is_cpu(): + from vllm import _custom_ops as ops + self.op = ops.gelu_quick + elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops + self.op = ipex_ops.gelu_quick + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + return x * torch.sigmoid(1.702 * x) + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + self.op(out, x) + return out + + def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + self.op(out, x) + return out + + # TODO implement forward_xpu for QuickGELU + # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + + +@CustomOp.register("relu2") +class ReLUSquaredActivation(CustomOp): + """ + Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2 + """ + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + return torch.square(F.relu(x)) + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + return self.forward_native(x) + + +class ScaledActivation(nn.Module): + """An activation function with post-scale parameters. + + This is used for some quantization methods like AWQ. + """ + + def __init__( + self, + act_module: nn.Module, + intermediate_size: int, + input_is_parallel: bool = True, + params_dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.act = act_module + self.input_is_parallel = input_is_parallel + if input_is_parallel: + tp_size = get_tensor_model_parallel_world_size() + intermediate_size_per_partition = divide(intermediate_size, + tp_size) + else: + intermediate_size_per_partition = intermediate_size + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.scales = nn.Parameter( + torch.empty(intermediate_size_per_partition, dtype=params_dtype)) + set_weight_attrs(self.scales, {"weight_loader": self.weight_loader}) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.act(x) / self.scales + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): + param_data = param.data + if self.input_is_parallel: + tp_rank = get_tensor_model_parallel_rank() + shard_size = param_data.shape[0] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(0, start_idx, shard_size) + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +_ACTIVATION_REGISTRY = LazyDict({ + "gelu": + lambda: nn.GELU(), + "gelu_fast": + lambda: FastGELU(), + "gelu_new": + lambda: NewGELU(), + "gelu_pytorch_tanh": + lambda: nn.GELU(approximate="tanh"), + "relu": + lambda: nn.ReLU(), + "relu2": + lambda: ReLUSquaredActivation(), + "silu": + lambda: nn.SiLU(), + "quick_gelu": + lambda: QuickGELU(), +}) + + +def get_act_fn(act_fn_name: str) -> nn.Module: + """Get an activation function by name.""" + act_fn_name = act_fn_name.lower() + if act_fn_name not in _ACTIVATION_REGISTRY: + raise ValueError( + f"Activation function {act_fn_name!r} is not supported.") + + return _ACTIVATION_REGISTRY[act_fn_name] + + +_ACTIVATION_AND_MUL_REGISTRY = LazyDict({ + "gelu": lambda: GeluAndMul(), + "silu": lambda: SiluAndMul(), +}) + + +def get_act_and_mul_fn(act_fn_name: str) -> nn.Module: + """Get an activation-and-mul (i.e. SiluAndMul) function by name.""" + act_fn_name = act_fn_name.lower() + if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY: + raise ValueError( + f"Activation function {act_fn_name!r} is not supported.") + + return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name] diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py new file mode 100644 index 0000000..434845f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: Apache-2.0 + +from contextlib import contextmanager +from typing import Any, Dict, Optional + +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.triton_utils import HAS_TRITON + +_config: Optional[Dict[str, Any]] = None + + +@contextmanager +def override_config(config): + global _config + old_config = _config + _config = config + yield + _config = old_config + + +def get_config() -> Optional[Dict[str, Any]]: + return _config + + +__all__ = [ + "FusedMoE", + "FusedMoEMethodBase", + "FusedMoeWeightScaleSupported", + "override_config", + "get_config", +] + +# if HAS_TRITON: +# import to register the custom ops +import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa +import vllm.model_executor.layers.fused_moe.fused_moe # noqa +from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + cutlass_moe_fp8) +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_experts, fused_moe, fused_topk, get_config_file_name, + grouped_topk) + +__all__ += [ + "fused_moe", + "fused_topk", + "fused_experts", + "get_config_file_name", + "grouped_topk", + "cutlass_moe_fp8", +] diff --git a/vllm/model_executor/layers/fused_moe/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json new file mode 100644 index 0000000..56c1a4e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..d3677be --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json new file mode 100644 index 0000000..265768f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json @@ -0,0 +1,218 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "5120": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "9216": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "13312": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "17408": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "25600": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "33792": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "41984": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "50176": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "58368": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..d3be23d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,218 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "5120": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "9216": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "13312": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "17408": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "25600": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "33792": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "41984": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "50176": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "58368": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json new file mode 100644 index 0000000..589f5d3 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json @@ -0,0 +1,218 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "5120": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "9216": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "13312": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "17408": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "25600": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "33792": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "41984": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "50176": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "58368": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json new file mode 100644 index 0000000..2c78bfa --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json @@ -0,0 +1,218 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "5120": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "9216": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "13312": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "17408": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "25600": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "33792": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "41984": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "50176": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "58368": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000..4da841e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,218 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "5120": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "9216": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "13312": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "17408": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "25600": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "33792": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "41984": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "50176": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "58368": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json new file mode 100644 index 0000000..2003567 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json @@ -0,0 +1,218 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "5120": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "9216": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "13312": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "17408": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "25600": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "33792": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "41984": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "50176": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "58368": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..e076615 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,218 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "5120": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "9216": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "13312": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "17408": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "25600": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "33792": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "41984": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "50176": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "58368": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json new file mode 100644 index 0000000..ee89655 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json @@ -0,0 +1,218 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "5120": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "9216": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "13312": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "17408": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "25600": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "33792": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "41984": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "50176": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "58368": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..05aed8b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,218 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "5120": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "9216": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "13312": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "17408": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "25600": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "33792": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "41984": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "50176": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "58368": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000..f10e394 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json new file mode 100644 index 0000000..9262a74 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..d251f9b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000..0ecf814 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json new file mode 100644 index 0000000..51ad5b2 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..ee51191 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json new file mode 100644 index 0000000..68793c7 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json @@ -0,0 +1,218 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "5120": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "9216": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "13312": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "17408": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "25600": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "33792": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "41984": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "50176": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "58368": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..6129107 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,218 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "5120": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "9216": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "13312": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "17408": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "25600": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "33792": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "41984": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "50176": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "58368": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..039a10e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000..3793fca --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json new file mode 100644 index 0000000..51d03d8 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json new file mode 100644 index 0000000..26f9abd --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "1536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000..cd0cdbe --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,130 @@ +{ + "3328": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "768": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "1792": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2560": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "2816": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3584": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "3840": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1280": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "2304": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json new file mode 100644 index 0000000..64be6e6 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..0a6a6a7 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,218 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "5120": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "9216": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "13312": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "17408": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "25600": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "33792": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "41984": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "50176": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "58368": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000..ba9041d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,130 @@ +{ + "3840": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "1792": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "3584": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "2816": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1280": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "768": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "3328": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "2560": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "2304": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json new file mode 100644 index 0000000..7a7508a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..dbf9a2d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json new file mode 100644 index 0000000..bbb2386 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "1536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000..5705545 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,130 @@ +{ + "2048": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "1792": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "3328": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2560": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "768": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "2816": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2304": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2 + }, + "1280": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3840": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3584": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json new file mode 100644 index 0000000..0611620 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=AMD_Instinct_MI325X,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=AMD_Instinct_MI325X,block_shape=[128,128].json new file mode 100644 index 0000000..43c249d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=AMD_Instinct_MI325X,block_shape=[128,128].json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..43c249d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..4dd00d1 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json new file mode 100644 index 0000000..48f9697 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..a8c0571 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json new file mode 100644 index 0000000..f1244c6 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..2e692a1 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..a2ee05d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..63e1187 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..e676960 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..e676960 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..fc573cd --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..c6d7e96 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..6fcf408 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..c6eabea --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..e676960 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json new file mode 100644 index 0000000..21f6022 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=60,N=1408,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=60,N=1408,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000..d09508b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=60,N=1408,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=60,N=176,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=60,N=176,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000..746463a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=60,N=176,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=60,N=352,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=60,N=352,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000..bbdb9ad --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=60,N=352,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=60,N=704,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=60,N=704,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000..43584b1 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=60,N=704,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..8cc6c64 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json new file mode 100644 index 0000000..39a9912 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000..05b5463 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000..d4c9ddd --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 0000000..c17a4ec --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H200.json new file mode 100644 index 0000000..170ae7f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000..1d9d352 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 0000000..9ad5b31 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=2560,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=2560,device_name=NVIDIA_H200.json new file mode 100644 index 0000000..2883dfd --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=2560,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000..8abfd84 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000..2fc18a5 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 0000000..be8d4a7 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H200.json new file mode 100644 index 0000000..71fdd88 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..b2799ed --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json new file mode 100644 index 0000000..c02de2f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json new file mode 100644 index 0000000..3e0bc75 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000..9f7ed67 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000..b8d3be2 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 0000000..21b7255 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H200.json new file mode 100644 index 0000000..eaf32f6 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json new file mode 100644 index 0000000..b6f1d01 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000..4bf7753 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json new file mode 100644 index 0000000..f245285 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json new file mode 100644 index 0000000..3918c93 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000..3f3ccda --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,138 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 0000000..841044a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H200.json new file mode 100644 index 0000000..59be497 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json new file mode 100644 index 0000000..0e5fd1e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000..d6ad635 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json new file mode 100644 index 0000000..16e0a91 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325X.json new file mode 100644 index 0000000..d766fc0 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json new file mode 100644 index 0000000..8323f51 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000..1b46cb5 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json new file mode 100644 index 0000000..6d5b1ae --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json new file mode 100644 index 0000000..ffc1b23 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json new file mode 100644 index 0000000..f4c0f84 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..5c8185c --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000..97c9f44 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 0000000..e4110a5 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H200.json new file mode 100644 index 0000000..0883ef4 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json new file mode 100644 index 0000000..81bb765 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000..811c77a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json new file mode 100644 index 0000000..2758e48 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325X.json new file mode 100644 index 0000000..fc31215 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..0bb423b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000..5557187 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000..26bcbf2 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 0000000..1a0aa33 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200.json new file mode 100644 index 0000000..9952be6 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json new file mode 100644 index 0000000..379ca10 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000..5a3f415 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json new file mode 100644 index 0000000..6cb80f4 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json new file mode 100644 index 0000000..de9d0ab --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json new file mode 100644 index 0000000..b41f9d4 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..edf2a38 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json new file mode 100644 index 0000000..32bbadb --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000..673bae2 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000..b2100ce --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 0000000..e6f753c --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H200.json new file mode 100644 index 0000000..53f3394 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_L40S.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_L40S.json new file mode 100644 index 0000000..d720deb --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_L40S.json @@ -0,0 +1,173 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_ctas": 1, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 2, + "num_warps": 4, + "num_ctas": 1, + "num_stages": 7 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 128, + "num_warps": 2, + "num_ctas": 1, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_ctas": 1, + "num_stages": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_ctas": 1, + "num_stages": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 2, + "num_warps": 4, + "num_ctas": 1, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 2, + "num_warps": 4, + "num_ctas": 1, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 2, + "num_warps": 4, + "num_ctas": 1, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_ctas": 1, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_ctas": 1, + "num_stages": 2 + }, + "192": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_ctas": 1, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 16, + "num_ctas": 1, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 128, + "num_warps": 2, + "num_ctas": 1, + "num_stages": 8 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_ctas": 1, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 16, + "num_ctas": 1, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 16, + "num_ctas": 1, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_ctas": 1, + "num_stages": 2 + }, + "6144": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_ctas": 1, + "num_stages": 2 + }, + "8192": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 16, + "num_ctas": 1, + "num_stages": 2 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json new file mode 100644 index 0000000..48bb5f2 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000..a64d06c --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json new file mode 100644 index 0000000..2c49f35 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X.json new file mode 100644 index 0000000..c7db6c0 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..dbc6247 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000..cc614e6 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000..32c0c9d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 0000000..4dd475c --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H200.json new file mode 100644 index 0000000..2ed15f3 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json new file mode 100644 index 0000000..bd2c6fb --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000..8d7b780 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json new file mode 100644 index 0000000..7a07bbf --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json new file mode 100644 index 0000000..3a3268c --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..f578c8d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000..918f683 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000..e341a67 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 0000000..eb81726 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H200.json new file mode 100644 index 0000000..0c7062a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json new file mode 100644 index 0000000..cd4fb8f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000..cf66868 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json new file mode 100644 index 0000000..c27ca0a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X.json new file mode 100644 index 0000000..da477b1 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000..34b916e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 0000000..96cbc11 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/README b/vllm/model_executor/layers/fused_moe/configs/README new file mode 100644 index 0000000..787bd06 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/README @@ -0,0 +1,13 @@ +This directory contains tuned configurations for different settings of the fused_moe kernel. +For different settings of +- E (number of experts) +- N (intermediate size) +- device_name (torch.cuda.get_device_name()) +the JSON file contains a mapping from M (batch size) to the chosen configuration. + +The example configurations provided are for the Mixtral model for TP2 on H100 +and TP4 on A100. Mixtral has intermediate size N = 14336, i.e. for TP2 we have +N = 7168 and for TP4 we have N = 3584. + +Please feel free to tune the configurations using scripts in `benchmarks/kernels/benchmark_moe.py` +Some of the configurations files are copied from the SGLang repository. Thank you! diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py new file mode 100644 index 0000000..d6a27aa --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -0,0 +1,153 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Fused MoE kernel.""" +from typing import Optional + +import torch + +from vllm import _custom_ops as ops + + +#TODO make the grouped gemm kernel consistent with scaled gemm kernel +def cutlass_moe_fp8( + a: torch.Tensor, + w1_q: torch.Tensor, + w2_q: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ab_strides1: torch.Tensor, + c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides2: torch.Tensor, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.half, + apply_router_weight_on_input: bool = False, +) -> torch.Tensor: + """ + This function computes a a8w8-quantized Mixture of Experts (MoE) layer + using two sets of quantized weights, w1_q and w2_q, and top-k gating + mechanism. The matrix multiplications are implemented with CUTLASS + grouped gemm. + + Parameters: + - a (torch.Tensor): The input tensor to the MoE layer. + Shape: [M, K] + - w1_q (torch.Tensor): The first set of fp8-quantized expert weights. + Shape: [num_experts, K, 2N] (the weights are passed transposed) + - w2_q (torch.Tensor): The second set of fp8-quantized expert weights. + Shape: [num_experts, N, K] (the weights are passed transposed) + - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q. + Shape: [num_experts] or [num_experts, 2N] + - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. + Shape: [num_experts] or [num_experts, K] + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk_weights (torch.Tensor): The weights of each token->expert mapping. + - ab_strides1 (torch.Tensor): The input and weights strides of the first + grouped gemm. + - c_strides1 (torch.Tensor): The output strides of the first grouped gemm. + - ab_strides2 (torch.Tensor): The input and weights strides of the second + grouped gemm. + - c_strides2 (torch.Tensor): The output strides of the second grouped gemm. + - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. + Shape: scalar or [M] + - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to + quantize the intermediate result between the gemms. + Shape: scalar or [M] + - out_dtype (torch.Tensor): The output tensor type. + + Returns: + - torch.Tensor: The fp16 output tensor after applying the MoE layer. + """ + + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert w1_q.dtype == torch.float8_e4m3fn + assert w2_q.dtype == torch.float8_e4m3fn + assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1" + assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2" + assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch" + assert a1_scale is None or a1_scale.dim( + ) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[0] == a.shape[ + 0], "Input scale shape mismatch" + assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[ + 1] == w1_q.shape[2], "W1 scale shape mismatch" + assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[ + 1] == w2_q.shape[2], "W2 scale shape mismatch" + assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch" + assert w1_q.shape[0] == w1_scale.shape[ + 0], "w1 scales expert number mismatch" + assert w1_q.shape[0] == w2_scale.shape[ + 0], "w2 scales expert number mismatch" + assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501 + assert ab_strides1.shape[0] == w1_q.shape[ + 0], "AB Strides 1 expert number mismatch" + assert c_strides1.shape[0] == w1_q.shape[ + 0], "C Strides 1 expert number mismatch" + assert ab_strides2.shape[0] == w2_q.shape[ + 0], "AB Strides 2 expert number mismatch" + assert c_strides2.shape[0] == w2_q.shape[ + 0], "C Strides 2 expert number mismatch" + assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype" + + num_experts = w1_q.size(0) + m = a.size(0) + k = w1_q.size(1) + n = w2_q.size(1) + + topk = topk_ids.size(1) + + per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) + if apply_router_weight_on_input: + assert topk == 1, \ + "apply_router_weight_on_input is only implemented for topk=1" + # TODO: this only works for topK=1, will need to update for topK>1 + a = a * topk_weights.to(out_dtype) + + a_q, a1_scale = ops.scaled_fp8_quant( + a, a1_scale, use_per_token_if_dynamic=per_act_token) + device = a_q.device + + expert_offsets = torch.empty((num_experts + 1), + dtype=torch.int32, + device=device) + problem_sizes1 = torch.empty((num_experts, 3), + dtype=torch.int32, + device=device) + problem_sizes2 = torch.empty((num_experts, 3), + dtype=torch.int32, + device=device) + + a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + + ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, + problem_sizes2, a_map, c_map, num_experts, n, + k) + + rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype) + rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale + + c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype) + c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype) + + ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale, + expert_offsets[:-1], problem_sizes1, ab_strides1, + ab_strides1, c_strides1) + + intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype) + torch.ops._C.silu_and_mul(intermediate, c1) + + intemediate_q, a2_scale = ops.scaled_fp8_quant( + intermediate, a2_scale, use_per_token_if_dynamic=per_act_token) + + ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale, + expert_offsets[:-1], problem_sizes2, ab_strides2, + ab_strides2, c_strides2) + # Gather tokens + c2 = c2[c_map].view(m, topk, k) + if not apply_router_weight_on_input: + c2 = c2 * topk_weights.view(m, topk, 1).to(out_dtype) + return c2.sum(dim=1) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py new file mode 100644 index 0000000..353c8cc --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -0,0 +1,294 @@ +# SPDX-License-Identifier: Apache-2.0 +import importlib.util +from typing import Optional, Tuple + +import torch + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( + moe_align_block_size) +from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm, + _fp8_quantize, + _resize_cache) +from vllm.utils import round_up + +logger = init_logger(__name__) + +has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None + + +def _valid_deep_gemm(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + expert_map: Optional[torch.Tensor] = None) -> bool: + """ + Check if the given problem size is supported by the DeepGemm grouped + gemm kernel. All of M, N, K and the quantization block_shape must be + aligned by `dg.get_m_alignment_for_contiguous_layout()`. + """ + if not has_deep_gemm: + return False + + # Lazy import to avoid CUDA initialization problems. + import deep_gemm as dg + + # Expert maps not supported yet. + if expert_map is not None: + return False + + align = dg.get_m_alignment_for_contiguous_layout() + M = hidden_states.shape[0] + _, K, N = w2.shape + + # For now, disable DeepGemm for small N until better permute/unpermute + # ops are available. + if N <= 512: + return False + + if align > M or N % align != 0 or K % align != 0: + return False + + return (hidden_states.is_contiguous() and w1.is_contiguous() + and w2.is_contiguous()) + + +def _moe_permute( + curr_hidden_states: torch.Tensor, + a1q_scale: Optional[torch.Tensor], + curr_topk_ids: torch.Tensor, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + block_m: int, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, + Optional[torch.Tensor]]: + """ + Determine the sorted_token_ids, expert_ids for the given problem size. + Permute the hidden states and scales according to `sorted_token_ids`. + """ + top_k_num = curr_topk_ids.shape[1] + + tokens_in_chunk, _ = curr_hidden_states.shape + + sorted_token_ids, expert_ids, num_tokens_post_padded = ( + moe_align_block_size(curr_topk_ids, + block_m, + global_num_experts, + expert_map, + pad_sorted_ids=True)) + + inv_perm: Optional[torch.Tensor] = None + + num_tokens = top_k_num * tokens_in_chunk + sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) + expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0) + inv_perm = torch.argsort(sorted_token_ids)[:num_tokens] + + # Permute according to sorted token ids. + curr_hidden_states = _fp8_perm(curr_hidden_states, + sorted_token_ids // top_k_num) + + if a1q_scale is not None: + a1q_scale = a1q_scale[sorted_token_ids // top_k_num] + + return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, + inv_perm) + + +def _moe_unpermute_and_reduce( + out: torch.Tensor, + curr_hidden: torch.Tensor, + inv_perm: Optional[torch.Tensor], + topk_weight: torch.Tensor, +) -> None: + """ + Unpermute the final result and apply topk_weights, then perform the final + reduction on the hidden states. + """ + M, topk = topk_weight.shape + K = curr_hidden.shape[1] + curr_hidden = curr_hidden[inv_perm, ...] + curr_hidden = curr_hidden.view(-1, topk, K) + curr_hidden.mul_(topk_weight.view(M, -1, 1)) + ops.moe_sum(curr_hidden, out) + + +def deep_gemm_moe_fp8( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + This function computes a a8w8-quantized Mixture of Experts (MoE) layer + using two sets of quantized weights, w1_q and w2_q, and top-k gating + mechanism. The matrix multiplications are implemented with DeepGemm + grouped gemm. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + Shape: [M, K] + - w1 (torch.Tensor): The first set of fp8 quantized expert weights. + Shape: [num_experts, K, 2N] (the weights are passed transposed) + - w2 (torch.Tensor): The second set of fp8 quantized expert weights. + Shape: [num_experts, N, K] (the weights are passed transposed) + - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q. + Shape: [num_experts] or [num_experts, 2N] + - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. + Shape: [num_experts] or [num_experts, K] + - topk_weights (torch.Tensor): The weights of each token->expert mapping. + - topk_ids (torch.Tensor): The token->expert mapping for topk_weights. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - activation (str): The activation function to apply after the first + MoE layer. + - global_num_experts (int): The total number of experts in the global + expert space. + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert + parallel shard. + - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. + Shape: scalar or [M] + - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to + quantize the intermediate result between the gemms. + Shape: scalar or [M] + + Returns: + - torch.Tensor: The bfloat16 output tensor after applying the MoE layer. + """ + # Lazy import to avoid CUDA initialization problems. + import deep_gemm as dg + + assert expert_map is None, "Expert maps not supported yet" + + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.stride(-1) == 1, "Stride of last dimension must be 1" + assert w2.stride(-1) == 1, "Stride of last dimension must be 1" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + assert w1.dtype == torch.float8_e4m3fn + assert w2.dtype == torch.float8_e4m3fn + assert w1.shape[0] == w2.shape[0], "Expert number mismatch" + assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch" + assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch" + assert a1_scale is None or a1_scale.dim( + ) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[ + 0] == hidden_states.shape[0], "Input scale shape mismatch" + assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501 + + num_tokens, _ = hidden_states.shape + E, N, _ = w1.shape + K = w2.shape[1] + if global_num_experts == -1: + global_num_experts = E + + # We execute the fused_moe kernel in chunks to circumvent this issue: + # https://github.com/vllm-project/vllm/issues/5938 + CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE + + assert _valid_deep_gemm(hidden_states, w1, w2, expert_map) + + if inplace: + out_hidden_states = hidden_states + else: + out_hidden_states = torch.empty_like(hidden_states) + + block_m = dg.get_m_alignment_for_contiguous_layout() + block_shape = [block_m, block_m] + + assert w1_scale is not None + assert w2_scale is not None + + # We attempt to transpose and align offline in Fp8MoEMethod, in which + # case these calls will be nops. Otherwise, they'll be performed every + # time the layer is executed. + w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous() + w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() + + M_sum = topk_ids.numel() + global_num_experts * (block_m - 1) + M_sum = round_up(M_sum, block_m) + + num_chunks = (num_tokens // CHUNK_SIZE) + 1 + + # We can reuse the memory between cache1 and cache3 because by the time + # we need cache3, we're done with cache1 + workspace13 = torch.empty(M_sum * max(N, K), + device=hidden_states.device, + dtype=hidden_states.dtype) + + workspace1 = workspace13[:M_sum * N].view(M_sum, N) + workspace2 = torch.empty((M_sum, N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype) + workspace3 = workspace13[:M_sum * K].view(M_sum, K) + + for chunk in range(num_chunks): + begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, + min((chunk + 1) * CHUNK_SIZE, + num_tokens)) + curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] + tokens_in_chunk, _ = curr_hidden_states.shape + + if tokens_in_chunk == 0: + break + + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] + curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] + + a1q_scale: Optional[torch.Tensor] = None + + qcurr_hidden_states, a1q_scale = _fp8_quantize(curr_hidden_states, + a1_scale, block_shape) + + (qcurr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, + inv_perm) = _moe_permute(qcurr_hidden_states, a1q_scale, + curr_topk_ids, global_num_experts, + expert_map, block_m) + + # Adjust the intermediate cache size and config for the last chunk. + # Note that in most cases we only have one chunk so the cache size + # and config are already set correctly and do not need to be adjusted. + if tokens_in_chunk < CHUNK_SIZE and chunk > 0: + curr_M = sorted_token_ids.numel() + workspace1 = _resize_cache(workspace1, (curr_M, N)) + workspace2 = _resize_cache(workspace2, (curr_M, N // 2)) + workspace3 = _resize_cache(workspace3, (curr_M, K)) + + dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (qcurr_hidden_states, a1q_scale), (w1, w1_scale), workspace1, + expert_ids) + + if activation == "silu": + torch.ops._C.silu_and_mul(workspace2, workspace1.view(-1, N)) + elif activation == "gelu": + torch.ops._C.gelu_and_mul(workspace2, workspace1.view(-1, N)) + else: + raise ValueError(f"Unsupported FusedMoe activation: {activation}") + + a2q_scale: Optional[torch.Tensor] = None + + qworkspace2, a2q_scale = _fp8_quantize(workspace2, a2_scale, + block_shape) + + dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (qworkspace2, a2q_scale), (w2, w2_scale), workspace3, expert_ids) + + _moe_unpermute_and_reduce( + out_hidden_states[begin_chunk_idx:end_chunk_idx], + workspace3.view(*workspace3.shape), inv_perm, curr_topk_weights) + + return out_hidden_states diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py new file mode 100644 index 0000000..ee158d7 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -0,0 +1,361 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Fused MoE utilities for GPTQ.""" +import functools +from typing import Optional + +import torch + +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_topk, moe_align_block_size, try_get_optimal_moe_config) +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types +from vllm.utils import direct_register_custom_op + + +def get_scalar_type(num_bits: int, has_zp: bool): + if has_zp: + assert num_bits == 4 + return scalar_types.uint4 + else: + return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128 + + +def single_marlin_moe( + hidden_states: torch.Tensor, + w: torch.Tensor, + scales: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + g_idx: Optional[torch.Tensor] = None, + sort_indices: Optional[torch.Tensor] = None, + w_zeros: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, +) -> torch.Tensor: + """ + This function computes the multiplication of hidden_states with expert + weights used in Marlin MoE, using weights w and top-k gating mechanism. + Its purpose is testing and debugging the fused MoE kernel. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the Marlin Mul. + - w (torch.Tensor): The set of expert weights. + - scales (torch.Tensor): The quantization scales. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - g_idx (Optional[torch.Tensor]): Optional act_order indices. + - sort_indices (Optional[torch.Tensor]): Optional act_order input + permutation. + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w. + - num_bits (bool): The number of bits in expert weights quantization. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + assert hidden_states.shape[1] == w.shape[1] * 16, "Hidden size mismatch" + assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w.is_contiguous(), "Expert weights must be contiguous" + assert hidden_states.dtype == torch.float16 + assert num_bits in [4, 8] + + M, K = hidden_states.shape + E = w.shape[0] + N = w.shape[2] // (num_bits // 2) + + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) + + # This might not be an optimal config for a single MMM + get_config_func = functools.partial(try_get_optimal_moe_config, + w.shape, + w.shape, + topk_ids.shape[1], + None, + is_marlin=True) + config = get_config_func(M) + + block_size_m = config['BLOCK_SIZE_M'] + + sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) + + max_workspace_size = (N // 64) * 16 + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + device=hidden_states.device, + requires_grad=False) + + has_zero_point = w_zeros is not None + if w_zeros is None: + w_zeros = torch.empty((0, 0), + dtype=hidden_states.dtype, + device=hidden_states.device, + requires_grad=False) + + if g_idx is None: + g_idx = torch.empty((0, 0), + dtype=torch.int32, + device=hidden_states.device, + requires_grad=False) + + if sort_indices is None: + sort_indices = torch.empty((0), + dtype=torch.int32, + device=hidden_states.device, + requires_grad=False) + + scalar_type = get_scalar_type(num_bits, has_zero_point) + + intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( + hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, + w_zeros, g_idx, sort_indices, workspace, scalar_type.id, M, N, K, + is_k_full, E, topk, block_size_m, True, False) + + return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) + + +def single_marlin_moe_fake( + hidden_states: torch.Tensor, + w: torch.Tensor, + scales: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + g_idx: Optional[torch.Tensor] = None, + sort_indices: Optional[torch.Tensor] = None, + w_zeros: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, +) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +direct_register_custom_op( + op_name="single_marlin_moe", + op_func=single_marlin_moe, + mutates_args=[], + fake_impl=single_marlin_moe_fake, +) + + +def fused_marlin_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + g_idx1: Optional[torch.Tensor] = None, + g_idx2: Optional[torch.Tensor] = None, + sort_indices1: Optional[torch.Tensor] = None, + sort_indices2: Optional[torch.Tensor] = None, + w1_zeros: Optional[torch.Tensor] = None, + w2_zeros: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, +) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - w1_scale (torch.Tensor): Scale to be used for w1. + - w2_scale (torch.Tensor): Scale to be used for w2. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - g_idx1 (Optional[torch.Tensor]): The first set of act_order indices. + - g_idx2 (Optional[torch.Tensor]): The second set of act_order indices. + - sort_indices1 (Optional[torch.Tensor]): The first act_order input + permutation. + - sort_indices2 (Optional[torch.Tensor]): The second act_order input + permutation. + - topk_weights (torch.Tensor): Top-k weights. + - topk_ids (torch.Tensor): Indices of topk-k elements. + - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1. + - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2. + - num_bits (bool): The number of bits in expert weights quantization. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert hidden_states.shape[0] == gating_output.shape[ + 0], "Number of tokens mismatch" + assert hidden_states.shape[ + 1] == w1.shape[1] * 16, "Hidden size mismatch w1" + assert hidden_states.shape[1] == w2.shape[2] // ( + num_bits // 2), "Hidden size mismatch w2" + assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype == torch.float16 + assert num_bits in [4, 8] + + has_no_act_order = (g_idx1 is None and g_idx2 is None + and sort_indices1 is None and sort_indices2 is None) + has_all_act_order = (g_idx1 is not None and g_idx2 is not None + and sort_indices1 is not None + and sort_indices2 is not None) + assert has_no_act_order or has_all_act_order, ( + "g_idx and sorted_indices " + "must be all not None or must be all None") + + has_no_zp = w1_zeros is None and w2_zeros is None + has_all_zp = w1_zeros is not None and w2_zeros is not None + assert has_no_zp or has_all_zp, ("zero points must be both not None or " + "must be both None") + + M, K = hidden_states.shape + E = w1.shape[0] + N = w2.shape[1] * 16 + topk = topk_ids.shape[1] + + get_config_func = functools.partial( + try_get_optimal_moe_config, + w1.shape, + w2.shape, + topk_ids.shape[1], + None, + is_marlin=True, + ) + config = get_config_func(M) + + block_size_m = config["BLOCK_SIZE_M"] + + sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) + + max_workspace_size = (max(2 * N, K) // 64) * 16 + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + device=current_platform.device_type, + requires_grad=False) + + if has_no_zp: + w1_zeros = torch.empty((0, 0), + dtype=hidden_states.dtype, + device=hidden_states.device, + requires_grad=False) + w2_zeros = torch.empty((0, 0), + dtype=hidden_states.dtype, + device=hidden_states.device, + requires_grad=False) + + if has_no_act_order: + g_idx1 = torch.empty((0, 0), + dtype=torch.int32, + device=hidden_states.device, + requires_grad=False) + g_idx2 = torch.empty((0, 0), + dtype=torch.int32, + device=hidden_states.device, + requires_grad=False) + sort_indices1 = torch.empty((0), + dtype=torch.int32, + device=hidden_states.device, + requires_grad=False) + sort_indices2 = torch.empty((0, 0), + dtype=torch.int32, + device=hidden_states.device, + requires_grad=False) + + scalar_type1 = get_scalar_type(num_bits, has_all_zp) + scalar_type2 = get_scalar_type(num_bits, has_all_zp) + + intermediate_cache2 = torch.empty( + (M * topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe( + hidden_states, + w1, + sorted_token_ids, + topk_weights, + topk_ids, + w1_scale, + w1_zeros, + g_idx1, + sort_indices1, + workspace, + scalar_type1.id, + M, + 2 * N, + K, + is_k_full, + E, + topk, + block_size_m, + True, + False, + ) + + torch.ops._C.silu_and_mul(intermediate_cache2, + intermediate_cache1.view(-1, 2 * N)) + + intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe( + intermediate_cache2, + w2, + sorted_token_ids, + topk_weights, + topk_ids, + w2_scale, + w2_zeros, + g_idx2, + sort_indices2, + workspace, + scalar_type2.id, + M, + K, + N, + is_k_full, + E, + topk, + block_size_m, + False, + True, + ) + + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), + dim=1) + + +def fused_marlin_moe_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + g_idx1: Optional[torch.Tensor] = None, + g_idx2: Optional[torch.Tensor] = None, + sort_indices1: Optional[torch.Tensor] = None, + sort_indices2: Optional[torch.Tensor] = None, + w1_zeros: Optional[torch.Tensor] = None, + w2_zeros: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, +) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +direct_register_custom_op( + op_name="fused_marlin_moe", + op_func=fused_marlin_moe, + mutates_args=[], + fake_impl=fused_marlin_moe_fake, +) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py new file mode 100644 index 0000000..80984b0 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -0,0 +1,1444 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Fused MoE kernel.""" +import functools +import json +import os +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +import triton +import triton.language as tl + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( + _valid_deep_gemm, deep_gemm_moe_fp8) +from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( + moe_align_block_size) +from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op +import vllm._custom_ops as ops + +from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled, + rocm_aiter_fused_experts, + rocm_aiter_topk_softmax) + +logger = init_logger(__name__) + + +@triton.jit +def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token, + token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N, + compute_type): + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ + None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +@triton.jit +def fused_moe_kernel_gptq_awq( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + b_scale_ptr, + b_zp_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N: tl.constexpr, + K: tl.constexpr, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bse, + stride_bsk, + stride_bsn, + stride_bze, + stride_bzk, + stride_bzn, + block_k_diviable: tl.constexpr, + group_size: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + has_zp: tl.constexpr, + use_int4_w4a16: tl.constexpr, + use_int8_w8a16: tl.constexpr): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to( + tl.int64) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + if off_experts == -1: + # ----------------------------------------------------------- + # Write back zeros to the output when the expert is not + # in the current expert parallel rank. + write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, + offs_token, token_mask, BLOCK_SIZE_M, + BLOCK_SIZE_N, compute_type) + return + + offs_bn = (pid_n * BLOCK_SIZE_N + + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + + offs_k[None, :] * stride_ak) + + if use_int4_w4a16: + b_ptrs = b_ptr + off_experts * stride_be + \ + (offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * \ + stride_bn + b_shifter = (offs_k[:, None] % 2) * 4 + elif use_int8_w8a16: + b_ptrs = b_ptr + off_experts * stride_be + \ + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn + + if not has_zp and use_int4_w4a16: + b_zp_num = 8 + if not has_zp and use_int8_w8a16: + b_zp_num = 128 + elif has_zp and use_int4_w4a16: + b_zp_shifter = (offs_bn[None, :] % 2) * 4 + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + + if not block_k_diviable: + k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K + k_other = 0.0 + else: + k_mask = None + k_other = None + + a = tl.load(a_ptrs, + mask=token_mask[:, None] & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0) + b = tl.load(b_ptrs) + if use_int4_w4a16: + b = (b >> b_shifter) & 0xF + + b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + \ + offs_bn[None, :] * stride_bsn + \ + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * \ + stride_bsk + b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) + b_scale = b_scale.to(tl.float32) + + if has_zp and use_int4_w4a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \ + (offs_bn[None, :] // 2) * stride_bzn + \ + offs_k_true * stride_bzk + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = ((b_zp >> b_zp_shifter) & 0xF) + b_zp = b_zp.to(tl.float32) + elif has_zp and use_int8_w8a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \ + offs_bn[None, :] * stride_bzn + \ + offs_k_true * stride_bzk + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = b_zp.to(tl.float32) + + # We accumulate along the K dimension. + if has_zp: + b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type) + else: + b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type) + accumulator = tl.dot(a, b, acc=accumulator) + + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + if use_int4_w4a16: + b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk + else: + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, + mask=token_mask, + other=0) + accumulator = accumulator * moe_weight[:, None] + + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ + None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +@triton.jit +def fused_moe_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + a_scale_ptr, + b_scale_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N, + K, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + use_fp8_w8a8: tl.constexpr, + use_int8_w8a16: tl.constexpr): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to( + tl.int64) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + if off_experts == -1: + # ----------------------------------------------------------- + # Write back zeros to the output when the expert is not + # in the current expert parallel rank. + write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, + offs_token, token_mask, BLOCK_SIZE_M, + BLOCK_SIZE_N, compute_type) + return + + offs_bn = (pid_n * BLOCK_SIZE_N + + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + + offs_k[None, :] * stride_ak) + + b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn) + if use_int8_w8a16: + b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[ + None, :] * stride_bsn + b_scale = tl.load(b_scale_ptrs) + + if use_fp8_w8a8: + if group_k > 0 and group_n > 0: + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + offs_bsn = offs_bn // group_n + b_scale_ptrs = (b_scale_ptr + off_experts * stride_bse + + offs_bsn * stride_bsn) + else: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + a = tl.load(a_ptrs, + mask=token_mask[:, None] & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0) + b = tl.load(b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0) + # We accumulate along the K dimension. + if use_int8_w8a16: + accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) + elif use_fp8_w8a8: + if group_k > 0 and group_n > 0: + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask, + mask=token_mask, + other=0.0) + b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) + + accumulator += tl.dot(a, b) * a_scale[:, + None] * b_scale[None, :] + else: + accumulator = tl.dot(a, b, acc=accumulator) + else: + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, + mask=token_mask, + other=0) + accumulator = accumulator * moe_weight[:, None] + if use_int8_w8a16: + accumulator = (accumulator * b_scale).to(compute_type) + elif use_fp8_w8a8: + if group_k > 0 and group_n > 0: + accumulator = accumulator.to(compute_type) + else: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) + else: + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ + None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +def invoke_fused_moe_kernel(A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + B_zp: Optional[torch.Tensor], + topk_weights: Optional[torch.Tensor], + topk_ids: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, + top_k: int, + config: Dict[str, Any], + compute_type: tl.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + block_shape: Optional[List[int]] = None) -> None: + assert topk_weights is not None or not mul_routed_weight + assert topk_weights is None or topk_weights.stride(1) == 1 + assert sorted_token_ids.stride(0) == 1 + ops.invoke_fused_moe_kernel(A,B,C,A_scale,B_scale,topk_weights,topk_ids,sorted_token_ids,expert_ids,num_tokens_post_padded,mul_routed_weight,top_k,config,compute_type,use_fp8_w8a8,use_int8_w8a16,block_shape) + return + + if use_fp8_w8a8: + assert B_scale is not None + assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0]) + == B_scale.shape[-2]) + assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1]) + == B_scale.shape[-1]) + + elif use_int8_w8a16 or use_int4_w4a16: + assert B_scale is not None + assert block_shape is None or block_shape[0] == 0 + else: + assert A_scale is None + assert B_scale is None + + M = A.shape[0] + num_tokens = M * top_k + + EM = sorted_token_ids.shape[0] + if A.shape[0] < config["BLOCK_SIZE_M"]: + # optimize for small batch_size. + # We assume that top_ids of each token is unique, so + # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, + # and we can skip some invalid blocks. + EM = min(sorted_token_ids.shape[0], + A.shape[0] * top_k * config['BLOCK_SIZE_M']) + grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv( + B.shape[1], META['BLOCK_SIZE_N']), ) + + if (use_int8_w8a16 or use_int4_w4a16) and \ + block_shape is not None and block_shape[1] > 0: + assert B_scale is not None and B_scale.ndim == 3 + assert B_zp is None or B_zp.ndim == 3 + + use_moe_wna16_cuda = should_moe_wna16_use_cuda( + num_valid_tokens=num_tokens, + group_size=block_shape[1], + num_experts=B.shape[0], + bit=4 if use_int4_w4a16 else 8) + config = config.copy() + config.update( + get_moe_wna16_block_config(config=config, + use_moe_wna16_cuda=use_moe_wna16_cuda, + num_valid_tokens=num_tokens, + size_k=A.shape[1], + size_n=B.shape[1], + num_experts=B.shape[1], + group_size=block_shape[1], + real_top_k=top_k, + block_size_m=config["BLOCK_SIZE_M"])) + + if use_moe_wna16_cuda: + bit = 4 if use_int4_w4a16 else 8 + ops.moe_wna16_gemm(A, C, B, B_scale, B_zp, + topk_weights if mul_routed_weight else None, + sorted_token_ids, expert_ids, + num_tokens_post_padded, top_k, + config["BLOCK_SIZE_M"], config["BLOCK_SIZE_N"], + config["BLOCK_SIZE_K"], bit) + return + + fused_moe_kernel_gptq_awq[grid]( + A, + B, + C, + B_scale, + B_zp, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + A.shape[1], + EM, + num_tokens, + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + B_scale.stride(0), + B_scale.stride(2), + B_scale.stride(1), + B_zp.stride(0) if B_zp is not None else 0, + B_zp.stride(2) if B_zp is not None else 0, + B_zp.stride(1) if B_zp is not None else 0, + block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0, + group_size=block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + has_zp=B_zp is not None, + use_int4_w4a16=use_int4_w4a16, + use_int8_w8a16=use_int8_w8a16, + **config, + ) + else: + config = config.copy() + BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K") + if block_shape is not None: + BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], + block_shape[1])) + fused_moe_kernel[grid]( + A, + B, + C, + A_scale, + B_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + B.shape[2], + EM, + num_tokens, + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + A_scale.stride(0) + if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(1) + if A_scale is not None and A_scale.ndim == 2 else 0, + B_scale.stride(0) + if B_scale is not None and B_scale.ndim >= 2 else 0, + B_scale.stride(2) + if B_scale is not None and B_scale.ndim == 3 else 0, + B_scale.stride(1) + if B_scale is not None and B_scale.ndim >= 2 else 0, + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + BLOCK_SIZE_K=BLOCK_SIZE_K, + **config, + ) + + +# Adapted from: https://github.com/sgl-project/sglang/pull/2628 +def get_config_file_name(E: int, + N: int, + dtype: Optional[str], + block_shape: Optional[List[int]] = None) -> str: + device_name = current_platform.get_device_name().replace(" ", "_") + dtype_selector = "" if not dtype else f",dtype={dtype}" + block_shape_selector = ("" if not block_shape or not all(block_shape) else + f",block_shape={block_shape}").replace(" ", "") + return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501 + + +# Adapted from: https://github.com/sgl-project/sglang/pull/2628 +@functools.lru_cache +def get_moe_configs( + E: int, + N: int, + dtype: Optional[str], + block_n: Optional[int] = None, + block_k: Optional[int] = None, +) -> Optional[Dict[int, Any]]: + """ + Return optimized configurations for the fused MoE kernel. + + The return value will be a dictionary that maps an irregular grid of + batch sizes to configurations of the fused_moe kernel. To evaluate the + kernel on a given batch size bs, the closest batch size in the grid should + be picked and the associated configuration chosen to invoke the kernel. + """ + + # First look up if an optimized configuration is available in the configs + # directory + block_shape = [block_n, block_k] if block_n and block_k else None + json_file_name = get_config_file_name(E, N, dtype, block_shape) + + config_file_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) + if os.path.exists(config_file_path): + with open(config_file_path) as f: + logger.info("Using configuration from %s for MoE layer.", + config_file_path) + # If a configuration has been found, return it + return {int(key): val for key, val in json.load(f).items()} + + # If no optimized configuration is available, we will use the default + # configuration + logger.warning( + ("Using default MoE config. Performance might be sub-optimal! " + "Config file not found at %s"), config_file_path) + return None + + +def get_moe_wna16_block_config(config: Dict[str, + int], use_moe_wna16_cuda: bool, + num_valid_tokens: int, size_k: int, size_n: int, + num_experts: int, group_size: int, + real_top_k: int, block_size_m: int): + if "BLOCK_SIZE_N" in config and "BLOCK_SIZE_K" in config: + # optimal block config is set + return {} + if not use_moe_wna16_cuda: + # triton moe wna16 kernel + if num_valid_tokens // real_top_k == 1: + # if bs=1, use a smaller BLOCK_SIZE_N + return {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64} + else: + return {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32} + else: + # cuda moe wna16 kernel + # set default block_size 128, and increase them when num_blocks + # is too large. + block_size_n = 128 + block_size_k = 128 + if block_size_k <= group_size: + block_size_k = group_size + + num_n_blocks = size_k // block_size_k + num_k_blocks = size_n // block_size_k + num_m_blocks = (num_valid_tokens + block_size_m - 1) / block_size_m + \ + num_experts + if num_valid_tokens // real_top_k <= block_size_m: + num_m_blocks = min(num_m_blocks, num_valid_tokens) + num_blocks = num_m_blocks * num_n_blocks * num_k_blocks + + if size_k % 256 == 0 and num_blocks >= 256 and \ + block_size_k < 256: + block_size_k = 256 + num_blocks = num_blocks // (256 // block_size_k) + + if num_m_blocks <= 16 and size_k % (block_size_k * 2) == 0 and \ + size_k % (block_size_k * 2) == 0 and block_size_k <= 512 and \ + num_blocks >= 512: + block_size_k = block_size_k * 2 + num_blocks = num_blocks // 2 + + if num_blocks > 1024: + block_size_n = 256 + num_n_blocks = num_n_blocks // 2 + num_blocks = num_blocks // 2 + + if size_n <= 1024 and num_blocks >= 1024: + # The kernel performance got much better with BLOCK_SIZE_N=1024 + # when num_blocks is large, event when N is small. + # Not sure why, maybe it force the CUDA SM process only one block + # at the same time. + block_size_n = 1024 + + return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k} + + +def should_moe_wna16_use_cuda(num_valid_tokens: int, group_size: int, + num_experts: int, bit: int): + return bit == 4 and group_size in [32, 64, 128] and \ + num_valid_tokens / num_experts <= 6 + + +def get_default_config( + M: int, + E: int, + N: int, + K: int, + topk: int, + dtype: Optional[str], + is_marlin: bool, + block_shape: Optional[List[int]] = None, +) -> Dict[str, int]: + if dtype == "fp8_w8a8" and block_shape is not None: + # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] + # BLOCK_SIZE_K must be divisible by block_shape[1] + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": block_shape[0], + "BLOCK_SIZE_K": block_shape[1], + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3, + } + elif dtype in ["int4_w4a16", "int8_w8a16"] and block_shape is not None: + # moe wna16 kernels + # only set BLOCK_SIZE_M + # BLOCK_SIZE_N and BLOCK_SIZE_K would be set later + bit = 4 if dtype == "int4_w4a16" else 8 + use_moe_wna16_cuda = should_moe_wna16_use_cuda(M * topk, + block_shape[1], E, bit) + if use_moe_wna16_cuda: + config = {"BLOCK_SIZE_M": min(16, M)} + elif M <= 20: + config = {"BLOCK_SIZE_M": 16, "GROUP_SIZE_M": 1} + elif M <= 40: + config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1} + else: + config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1} + else: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + } + numel = M * topk + if numel <= 64: + config['BLOCK_SIZE_M'] = 32 + elif numel <= 1024: + config['BLOCK_SIZE_M'] = 64 + else: + config['BLOCK_SIZE_M'] = 256 + return config + + +def try_get_optimal_moe_config( + w1_shape: Tuple[int, ...], + w2_shape: Tuple[int, ...], + top_k: int, + dtype: Optional[str], + M: int, + is_marlin: bool = False, + block_shape: Optional[List[int]] = None, +): + from vllm.model_executor.layers.fused_moe import get_config + override_config = get_config() + if override_config: + config = override_config + else: + # First try to load optimal config from the file + E, _, N = w2_shape + # block_n = block_shape[0] if block_shape else 0 + # block_k = block_shape[1] if block_shape else 0 + # configs = get_moe_configs(E, N, dtype, block_n, block_k) + + configs = None + + if configs: + # If an optimal configuration map has been found, look up the + # optimal config + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] + else: + # Else use the default config + config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, + is_marlin, block_shape) + return config + + +def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool) -> tuple[torch.Tensor, ...]: + ops.topk_softmax( + topk_weights, + topk_indices, + token_expert_indices, + gating_output, + ) + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights, topk_indices + + +def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]: + if is_rocm_aiter_moe_enabled(): + return rocm_aiter_topk_softmax + return vllm_topk_softmax + + +def fused_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + + M, _ = hidden_states.shape + + topk_weights = torch.empty(M, + topk, + dtype=torch.float32, + device=hidden_states.device) + topk_ids = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) + token_expert_indicies = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) + + gating_output_float = gating_output.float() # TODO(woosuk): Optimize this. + + topk_func = dispatch_topk_func() + topk_weights, topk_ids = topk_func(topk_weights, topk_ids, + token_expert_indicies, + gating_output_float, renormalize) + + del token_expert_indicies # Not used. Will be used in the future. + return topk_weights, topk_ids + + +# This is used by the Deepseek-V2 and Deepseek-V3 model +# @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) +def grouped_topk_( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + num_token = scores.shape[0] + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + group_scores = (scores.view(num_token, num_expert_group, + -1).topk(2, dim=-1)[0].sum(dim=-1)) + else: + group_scores = scores.view(num_token, num_expert_group, + -1).max(dim=-1).values # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, + sorted=False)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = group_mask.unsqueeze(-1).expand( + num_token, num_expert_group, + scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), + float("-inf")) # [n, e] + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, + k=topk, + dim=-1, + sorted=False) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) +from ixformer.inference.functions import moe_grouped_topk as grouped_topk + + +def get_config_dtype_str( + dtype: torch.dtype, + use_int4_w4a16: Optional[bool] = False, + use_int8_w8a16: Optional[bool] = False, + use_fp8_w8a8: Optional[bool] = False) -> Optional[str]: + if use_fp8_w8a8: + return "fp8_w8a8" + elif use_int8_w8a16: + return "int8_w8a16" + elif use_int4_w4a16: + return "int4_w4a16" + elif dtype == torch.float: + # avoiding cases where kernel fails when float32 MoE + # use fp16/bfloat16 configs + return "float32" + return None + + +def inplace_fused_experts(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None) -> None: + fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, + activation, apply_router_weight_on_input, use_fp8_w8a8, + use_int8_w8a16, use_int4_w4a16, global_num_experts, + expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, + a2_scale, block_shape) + + +def inplace_fused_experts_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None) -> None: + pass + + +direct_register_custom_op( + op_name="inplace_fused_experts", + op_func=inplace_fused_experts, + mutates_args=["hidden_states"], + fake_impl=inplace_fused_experts_fake, +) + + +def outplace_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None) -> torch.Tensor: + return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, + False, activation, apply_router_weight_on_input, + use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, + global_num_experts, expert_map, w1_scale, + w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, + block_shape) + + +def outplace_fused_experts_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +direct_register_custom_op( + op_name="outplace_fused_experts", + op_func=outplace_fused_experts, + mutates_args=[], + fake_impl=outplace_fused_experts_fake, +) + + +def torch_vllm_inplace_fused_experts(**kwargs) -> torch.Tensor: + inplace_fused_experts(**kwargs) + hidden_states = kwargs['hidden_states'] + return hidden_states + + +def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor: + return outplace_fused_experts(**kwargs) + + +def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]: + if is_rocm_aiter_moe_enabled(): + return rocm_aiter_fused_experts + if inplace: + return torch_vllm_inplace_fused_experts + return torch_vllm_outplace_fused_experts + + +def fused_experts(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, + allow_deep_gemm: bool = False) -> torch.Tensor: + if (allow_deep_gemm and use_fp8_w8a8 + and _valid_deep_gemm(hidden_states, w1, w2, expert_map)): + assert apply_router_weight_on_input is False + return deep_gemm_moe_fp8( + hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=inplace, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) + else: + return dispatch_fused_experts_func(inplace)( + hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_shape) + + +def fused_experts_impl(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None): + # Check constraints. + if use_int4_w4a16: + assert hidden_states.shape[1] // 2 == w1.shape[ + 2], "Hidden size mismatch" + else: + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.stride(-1) == 1, "Stride of last dimension must be 1" + assert w2.stride(-1) == 1, "Stride of last dimension must be 1" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + + num_tokens, _ = hidden_states.shape + E, N, _ = w1.shape + K = w2.shape[1] + if global_num_experts == -1: + global_num_experts = E + top_k_num = topk_ids.shape[1] + # We execute the fused_moe kernel in chunks to circumvent this issue: + # https://github.com/vllm-project/vllm/issues/5938 + CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE + M = min(num_tokens, CHUNK_SIZE) + config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + dtype=hidden_states.dtype) + + get_config_func = functools.partial( + try_get_optimal_moe_config, + w1.shape, + w2.shape, + top_k_num, + config_dtype, + block_shape=block_shape, + ) + + config = get_config_func(M) + + # We can reuse the memory between these because by the time we need + # cache3, we're done with cache1 + cache13 = torch.empty(M * top_k_num * max(N, K), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache1 = cache13[:M * top_k_num * N].view(M, top_k_num, N) + intermediate_cache3 = cache13[:M * top_k_num * K].view(M, top_k_num, K) + + # This needs separate memory since it's used concurrently with cache1 + intermediate_cache2 = torch.empty((M * top_k_num, N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype) + + if hidden_states.dtype == torch.bfloat16: + compute_type = tl.bfloat16 + elif hidden_states.dtype == torch.float16: + compute_type = tl.float16 + elif hidden_states.dtype == torch.float32: + compute_type = tl.float32 + else: + raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") + + if inplace: + out_hidden_states = hidden_states + else: + out_hidden_states = torch.empty_like(hidden_states) + + for chunk in range((num_tokens // CHUNK_SIZE) + 1): + begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, + min((chunk + 1) * CHUNK_SIZE, + num_tokens)) + curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] + tokens_in_chunk, _ = curr_hidden_states.shape + + if tokens_in_chunk == 0: + break + + if tokens_in_chunk < CHUNK_SIZE and chunk > 0: + # Adjust the intermediate cache size and config for the last + # chunk. Note that in most cases we only have one chunk + # so the cache size and config are already set correctly and + # do not need to be adjusted. + intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] + intermediate_cache2 = intermediate_cache2[:tokens_in_chunk * + topk_ids.shape[1]] + intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] + config = get_config_func(tokens_in_chunk) + + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] + curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] + + a1q_scale: Optional[torch.Tensor] = None + + if use_fp8_w8a8: + qcurr_hidden_states, a1q_scale = _fp8_quantize( + curr_hidden_states, a1_scale, block_shape) + else: + qcurr_hidden_states = curr_hidden_states + a1q_scale = a1_scale + + sorted_token_ids, expert_ids, num_tokens_post_padded = ( + moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], + global_num_experts, expert_map)) + + invoke_fused_moe_kernel(qcurr_hidden_states, + w1, + intermediate_cache1, + a1q_scale, + w1_scale, + w1_zp, + curr_topk_weights, + curr_topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + apply_router_weight_on_input, + top_k_num, + config, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + block_shape=block_shape) + + if activation == "silu": + ops.silu_and_mul(intermediate_cache2, + intermediate_cache1.view(-1, N)) + elif activation == "gelu": + ops.gelu_and_mul(intermediate_cache2, + intermediate_cache1.view(-1, N)) + else: + raise ValueError(f"Unsupported FusedMoe activation: {activation}") + + a2q_scale: Optional[torch.Tensor] = None + + if use_fp8_w8a8: + qintermediate_cache2, a2q_scale = _fp8_quantize( + intermediate_cache2, a2_scale, block_shape) + else: + qintermediate_cache2 = intermediate_cache2 + a2q_scale = a2_scale + + invoke_fused_moe_kernel(qintermediate_cache2, + w2, + intermediate_cache3, + a2q_scale, + w2_scale, + w2_zp, + curr_topk_weights, + curr_topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + not apply_router_weight_on_input, + 1, + config, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + block_shape=block_shape) + # ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), + # out_hidden_states[begin_chunk_idx:end_chunk_idx]) + torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), + dim=1, + out=out_hidden_states[begin_chunk_idx:end_chunk_idx]) + return out_hidden_states + + +def fused_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + inplace: bool = False, + activation: str = "silu", + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - activation (str): The activation function to apply after the first + MoE layer. + - num_expert_group: Optional[int]: additional parameter for grouped_topk + - topk_group: Optional[int]: additional parameter for grouped_topk + - use_grouped_topk: If True, use grouped_topk instead of fused_topk + note: Deepseekv2 model uses grouped_topk + - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. + - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. + - global_num_experts (int): The total number of experts in the global + expert space. + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert + parallel shard. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for + w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for + w2. + - a1_scale (Optional[torch.Tensor]): Optional scale to be used for + a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for + a2. + - block_shape: (Optional[List[int]]): Optional block size for block-wise + quantization. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + + if use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + topk_weights, topk_ids = grouped_topk(hidden_states, gating_output, + topk, renormalize, + num_expert_group, topk_group) + elif custom_routing_function is None: + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states, gating_output, topk, renormalize) + + return fused_experts(hidden_states, + w1, + w2, + topk_weights, + topk_ids, + inplace=inplace, + activation=activation, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_shape) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py new file mode 100644 index 0000000..89fc63b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -0,0 +1,967 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import abstractmethod +from enum import Enum +from typing import Callable, List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch.nn.parameter import UninitializedParameter + +import vllm.envs as envs +from vllm.config import get_current_vllm_config +from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.platforms.interface import CpuArchEnum +from vllm.utils import direct_register_custom_op + +if current_platform.is_cuda_alike(): + from .fused_moe import fused_experts +else: + fused_experts = None # type: ignore +if current_platform.is_tpu(): + # the iterative moe implementation is used until the moe_pallas is fixed + from .moe_torch_iterative import fused_moe as fused_moe_pallas +else: + fused_moe_pallas = None # type: ignore +logger = init_logger(__name__) + + +class FusedMoeWeightScaleSupported(Enum): + TENSOR = "tensor" + CHANNEL = "channel" + GROUP = "group" + BLOCK = "block" + + +class FusedMoEMethodBase(QuantizeMethodBase): + + @abstractmethod + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + raise NotImplementedError + + @abstractmethod + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + raise NotImplementedError + + +@CustomOp.register("unquantized_fused_moe") +class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): + """MoE method without quantization.""" + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + # Fused gate_up_proj (column parallel) + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + # down_proj (row parallel) + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor: + # Pad the weight tensor. This is an optimization on ROCm platform, which + # can benefit from tensors located far enough from one another in memory + if (envs.VLLM_ROCM_MOE_PADDING and current_platform.is_rocm() + and weight.stride(-1) == 1 + and (weight.stride(-2) * weight.element_size()) % 512 == 0): + num_pad = 256 // weight.element_size() + weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] + torch.cuda.empty_cache() + return weight + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + super().process_weights_after_loading(layer) + + layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight( + layer.w13_weight.data), + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight( + layer.w2_weight.data), + requires_grad=False) + # Lazy import to avoid importing triton. + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + is_rocm_aiter_moe_enabled, shuffle_weights) + if is_rocm_aiter_moe_enabled(): + # reshaping weights is required for aiter moe kernel. + shuffled_w13, shuffled_w2 = shuffle_weights( + layer.w13_weight.data, layer.w2_weight.data) + + layer.w13_weight = torch.nn.Parameter(shuffled_w13, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(shuffled_w2, + requires_grad=False) + + if current_platform.is_cpu(): + if current_platform.get_cpu_architecture() == CpuArchEnum.X86: + import intel_extension_for_pytorch as ipex + layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( + layer.w13_weight, + layer.w2_weight, + use_prepack=envs.VLLM_CPU_MOE_PREPACK, + ) + else: + raise NotImplementedError("CPU MOE only supports x86 arch.") + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + extra_residual: torch.Tensor = None, + routed_scaling_factor: float = 1.0, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + out = self.forward( + x=x, + layer=layer, + router_logits=router_logits, + top_k=top_k, + renormalize=renormalize, + use_grouped_topk=use_grouped_topk, + topk_group=topk_group, + num_expert_group=num_expert_group, + global_num_experts=global_num_experts, + expert_map=expert_map, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input) * routed_scaling_factor + if extra_residual is not None: + out += extra_residual + return out + + def forward_cuda( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map) + + def forward_cpu( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + **kwargs, + ): + assert activation == "silu", f"{activation} is not supported." + assert apply_router_weight_on_input is False + return layer.ipex_fusion( + x, + use_grouped_topk, + top_k, + router_logits, + renormalize, + topk_group, + num_expert_group, + custom_routing_function, + scoring_func, + e_score_correction_bias, + ) + + def forward_hpu( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + assert not use_grouped_topk + assert num_expert_group is None + assert topk_group is None + assert custom_routing_function is None + assert layer is not None + assert apply_router_weight_on_input is False + if scoring_func != "softmax": + raise NotImplementedError( + "Only softmax scoring function is supported for HPU.") + if e_score_correction_bias is not None: + raise NotImplementedError( + "Expert score correction bias is not supported for HPU.") + return layer.hpu_fused_moe(x, layer.w13_weight, layer.w2_weight, + router_logits, top_k) + + def forward_tpu( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + assert not use_grouped_topk + assert num_expert_group is None + assert topk_group is None + assert custom_routing_function is None + assert apply_router_weight_on_input is False + if scoring_func != "softmax": + raise NotImplementedError( + "Only softmax scoring function is supported for TPU.") + if e_score_correction_bias is not None: + raise NotImplementedError( + "Expert score correction bias is not supported for TPU.") + assert activation == "silu", f"{activation} is not supported for TPU." + return fused_moe_pallas(hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk=top_k, + gating_output=router_logits, + global_num_experts=global_num_experts, + expert_map=expert_map, + renormalize=renormalize) + + forward_native = forward_tpu if current_platform.is_tpu() else forward_cuda + + +def determine_expert_map( + ep_size: int, ep_rank: int, + global_num_experts: int) -> Tuple[int, Optional[torch.Tensor]]: + """ + Calculates how many experts should be assigned to each rank for EP and + creates a mapping from global to local expert index. Experts are + distributed evenly across ranks. Any remaining are assigned to the + last rank. + + Args: + ep_size (int): The size of the expert parallel group + global_num_experts (int): The total number of experts in the model. + + Returns: + Tuple[int, Optional[torch.Tensor]]: A tuple containing: + - local_num_experts (int): The number of experts assigned + to the current rank. + - expert_map (Optional[torch.Tensor]): A tensor of shape + (global_num_experts,) mapping from global to local index. + Contains -1 for experts not assigned to the current rank. + Returns None if ep_size is 1. + """ + assert ep_size > 0 + if ep_size == 1: + return (global_num_experts, None) + + local_num_experts = global_num_experts // ep_size + + # Create a tensor of size num_experts filled with -1 + expert_map = torch.full((global_num_experts, ), -1, dtype=torch.int32) + # Create a expert map for the local experts + if ep_rank < (ep_size - 1): + # Each non-last rank gets local_num_experts experts. + expert_map[ep_rank * local_num_experts: + (ep_rank + 1) * local_num_experts] = \ + torch.arange(0, local_num_experts, dtype=torch.int32) + else: + # All remaining experts are assigned to the last rank. + local_num_experts = (global_num_experts - ep_rank * local_num_experts) + + expert_map[-local_num_experts:] = \ + torch.arange(0, local_num_experts, dtype=torch.int32) + return (local_num_experts, expert_map) + + +class FusedMoE(torch.nn.Module): + """FusedMoE layer for MoE models. + + This layer contains both MergedColumnParallel weights (gate_up_proj / + w13) and RowParallelLinear weights (down_proj/ w2). + + Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We + copy that naming convention here and handle any remapping in the + load_weights function in each model implementation. + + Args: + num_experts: Number of experts in the model + top_k: Number of experts selected for each token + hidden_size: Input hidden state size of the transformer + intermediate_size: Intermediate size of the experts + params_dtype: Data type for the parameters. + reduce_results: Whether to all all_reduce on the output of the layer + renomalize: Whether to renormalize the logits in the fused_moe kernel + quant_config: Quantization configure. + """ + + def __init__( + self, + num_experts: int, # Global number of experts + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = False, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + ep_size: Optional[int] = None, + dp_size: Optional[int] = None, + prefix: str = "", + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ): + super().__init__() + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + + # Note: here we guard against accessing the TP and DP groups when + # uninitialized (this happens when testing) + self.tp_size = (tp_size if tp_size is not None else + get_tensor_model_parallel_world_size()) + tp_rank = 0 if self.tp_size == 1 else get_tensor_model_parallel_rank() + self.dp_size = (dp_size + if dp_size is not None else get_dp_group().world_size) + self.dp_rank = (0 + if self.dp_size == 1 else get_dp_group().rank_in_group) + self.global_num_experts = num_experts + + # Use expert parallelism instead of tensor parallelism? + vllm_config = get_current_vllm_config() + use_ep = (vllm_config.parallel_config.enable_expert_parallel + and self.tp_size > 1) + + # For smuggling this layer into the fused moe custom op + # self.use_direct_call = self.dp_size == 1 + self.use_direct_call = True + if not self.use_direct_call: + compilation_config = vllm_config.compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError("Duplicate layer name: {}".format(prefix)) + compilation_config.static_forward_context[prefix] = self + self.layer_name = prefix + + if use_ep: + # Set TP size to 1 to adjust for EP and adjust EP size and rank + # for DP attention. + self.ep_rank = tp_rank + self.tp_size * self.dp_rank + self.tp_rank = 0 + self.ep_size = self.tp_size * self.dp_size + self.tp_size = 1 + + self.local_num_experts, self.expert_map = determine_expert_map( + ep_size=self.ep_size, + ep_rank=self.ep_rank, + global_num_experts=self.global_num_experts) + else: + # Adjust TP size for DP attention + self.tp_rank = tp_rank + self.tp_size * self.dp_rank + self.ep_rank = 0 + self.tp_size = self.tp_size * self.dp_size + self.ep_size = 1 + self.local_num_experts = self.global_num_experts + self.expert_map = None + self.top_k = top_k + self.global_num_experts = num_experts + + self.hidden_size = hidden_size + self.num_experts = num_experts + assert intermediate_size % self.tp_size == 0 + self.intermediate_size_per_partition = intermediate_size // self.tp_size + self.reduce_results = reduce_results + self.renormalize = renormalize + self.use_grouped_topk = use_grouped_topk + if self.use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + self.num_expert_group = num_expert_group + self.topk_group = topk_group + self.custom_routing_function = custom_routing_function + self.scoring_func = scoring_func + self.e_score_correction_bias = e_score_correction_bias + self.activation = activation + + if self.scoring_func != "softmax" and not self.use_grouped_topk: + raise ValueError("Only softmax scoring function is supported for " + "non-grouped topk.") + if current_platform.is_hpu(): + from vllm_hpu_extension.ops import DynamicFusedMOE + self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts) + + # Note: get_quant_method will look at the layer's local_num_experts + # for heuristic purposes, so it must be initialized first. + if quant_config is None: + self.quant_method: Optional[QuantizeMethodBase] = ( + UnquantizedFusedMoEMethod()) + else: + self.quant_method = quant_config.get_quant_method(self, prefix) + assert self.quant_method is not None + + self.apply_router_weight_on_input = apply_router_weight_on_input + moe_quant_params = { + "num_experts": self.local_num_experts, + "hidden_size": hidden_size, + "intermediate_size_per_partition": + self.intermediate_size_per_partition, + "params_dtype": params_dtype, + "weight_loader": self.weight_loader, + } + # need full intermediate size pre-sharding for WNA16 act order + if (self.quant_method.__class__.__name__ + in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")): + moe_quant_params["intermediate_size_full"] = intermediate_size + + self.quant_method.create_weights(layer=self, **moe_quant_params) + + def _load_per_tensor_weight_scale(self, shard_id: str, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + expert_id: int): + param_data = param.data + # for per tensor weight quantization + if shard_id in ("w1", "w3"): + # We have to keep the weight scales of w1 and w3 because + # we need to re-quantize w1/w3 weights after weight loading. + idx = 0 if shard_id == "w1" else 1 + param_data[expert_id][idx] = loaded_weight + # If we are in the row parallel case (down_proj) + elif shard_id == "w2": + param_data[expert_id] = loaded_weight + + def _load_model_weight_or_group_weight_scale(self, + shard_dim: int, + expert_data: torch.Tensor, + shard_id: str, + loaded_weight: torch.Tensor, + tp_rank: int, + load_full_w2: bool = False): + """ + Load grouped weight scales for group quantization or model weights + :param shard_dim: dimension to shard + :param expert_data: parameter for a particular expert + :param shard_id: either w1, w2, or w3 + :param loaded_weight: checkpoint weight to load into the param + :param tp_rank: tensor parallel rank + :param load_full_w2: whether or not the w2 loaded should be sharded. + """ + if shard_id == "w2": + # In the case where we have actorder/g_idx, we do not partition the + # w2 scales, as indicated by `load_full` argument, for all tp cases + self._load_w2(shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank, + load_full=load_full_w2) + elif shard_id in ("w1", "w3"): + self._load_w13(shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + + def _load_per_channel_weight_scale(self, expert_data: torch.Tensor, + shard_dim: int, shard_id: str, + loaded_weight: torch.Tensor, + tp_rank: int): + # for per channel weight quantization + if shard_id == "w2": + expert_data.copy_(loaded_weight) + elif shard_id in ("w1", "w3"): + self._load_w13(shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + + def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, + shard_id: str, loaded_weight: torch.Tensor, tp_rank: int): + + # Index the loaded weight for tp sharding. + # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim + shard_size = expert_data.shape[shard_dim] // 2 + loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, + shard_size) + # Narrow parameter and load. + # w1, gate_proj: Load into first logical weight of w13. + if shard_id == "w1": + expert_data = expert_data.narrow(shard_dim, 0, shard_size) + # w3, up_proj: Load into second logical weight of w13. + else: + assert shard_id == "w3" + expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) + expert_data.copy_(loaded_weight) + + def _load_w2(self, + expert_data: torch.Tensor, + shard_dim: int, + loaded_weight: torch.Tensor, + tp_rank: int, + load_full: bool = False): + + # Index the loaded weight for tp sharding. + # down_proj: "RowParallel" so tp sharding on input_dim + # Narrow parameter and load. + shard_size = expert_data.shape[shard_dim] + if not load_full: + loaded_weight = loaded_weight.narrow(shard_dim, + shard_size * tp_rank, + shard_size) + # w2, down_proj: Load into only logical weight of w2. + expert_data.copy_(loaded_weight) + + def _load_single_value(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, expert_id: int): + param_data = param.data + + # Input scales can be loaded directly and should be equal. + param_data[expert_id] = loaded_weight + + def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor, + shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int): + + if shard_id == "w2": + self._load_w2(shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + else: + assert shard_id in ("w1", "w3") + expert_data.copy_(loaded_weight) + + def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int: + if self.expert_map is None: + return expert_id + return self.expert_map[expert_id].item() + + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, weight_name: str, + shard_id: str, expert_id: int) -> None: + + expert_id = self._map_global_expert_id_to_local_expert_id(expert_id) + if expert_id == -1: + return + + # compressed-tensors checkpoints with packed weights are stored flipped + # TODO (mgoin): check self.quant_method.quant_config.quant_format + # against known CompressionFormat enum values that have this quality + loaded_weight = loaded_weight.t().contiguous() if ( + self.quant_method.__class__.__name__ + == "CompressedTensorsWNA16MoEMethod") else loaded_weight + + if shard_id not in ("w1", "w2", "w3"): + raise ValueError(f"shard_id must be ['w1','w2','w3'] but " + f"got {shard_id}.") + + WEIGHT_SCALE_SUPPORTED = [ + e.value for e in FusedMoeWeightScaleSupported + ] + # Fetch the dim to shard the parameter/loaded weight + # based on the shard id. This will be whatever + # dimension intermediate_size_per_partition is used. + SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0} + + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type: + param.weight_type = loaded_weight.item() + param.data.copy_(loaded_weight) + return + + # is_transposed: if the dim to shard the weight + # should be flipped. Required by GPTQ, compressed-tensors + # should be whatever dimension intermediate_size_per_partition is + is_transposed = getattr(param, "is_transposed", False) + shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id] + if is_transposed: + shard_dim = int(not shard_dim) + shard_dim_force = getattr(param, "shard_dim", None) + shard_dim = shard_dim_force if shard_dim_force is not None else shard_dim + + full_load = len(loaded_weight.shape) == 3 + if full_load: + shard_dim += 1 + + # Materialize GGUF UninitializedParameter + if is_gguf_weight and isinstance(param, UninitializedParameter): + final_shape = list(loaded_weight.shape) + if shard_id in ["w1", "w3"]: + final_shape[1] *= 2 + final_shape[shard_dim] = final_shape[shard_dim] // self.tp_size + param.materialize(final_shape, dtype=loaded_weight.dtype) + + expert_data = param.data if full_load else param.data[expert_id] + # Case input scale: input_scale loading is only supported for fp8 + if "input_scale" in weight_name: + # this is needed for compressed-tensors only + loaded_weight = loaded_weight.to(param.data.device) + + if param.data[expert_id] != 1 and (param.data[expert_id] - + loaded_weight).abs() > 1e-5: + raise ValueError( + "input_scales of w1 and w3 of a layer " + f"must be equal. But got {param.data[expert_id]} " + f"vs. {loaded_weight}") + + self._load_single_value(param=param, + loaded_weight=loaded_weight, + expert_id=expert_id) + return + + # Case g_idx + if "g_idx" in weight_name: + self._load_g_idx(shard_dim=0, + shard_id=shard_id, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=self.tp_rank) + return + + # Case weight scales, zero_points and offset + if ("scale" in weight_name or "zero" in weight_name + or "offset" in weight_name): + # load the weight scales and zp based on the quantization scheme + # supported weight scales/zp can be found in + # FusedMoeWeightScaleSupported + # TODO @dsikka: once hardened, refactor to use vLLM Parameters + # specific to each case + quant_method = getattr(param, "quant_method", None) + if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value: + self._load_per_channel_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=self.tp_rank) + elif quant_method in [ + FusedMoeWeightScaleSupported.GROUP.value, + FusedMoeWeightScaleSupported.BLOCK.value, + ]: + self._load_model_weight_or_group_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=self.tp_rank, + load_full_w2=getattr(param, "load_full_w2", False)) + elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value: + self._load_per_tensor_weight_scale(shard_id=shard_id, + param=param, + loaded_weight=loaded_weight, + expert_id=expert_id) + else: + raise ValueError( + f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}") + return + + # Case weight_shape + if "weight_shape" in weight_name: + # only required by compressed-tensors + self._load_single_value(param=param, + loaded_weight=loaded_weight, + expert_id=expert_id) + return + + # Case model weights + if "weight" in weight_name: + self._load_model_weight_or_group_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=self.tp_rank) + return + + @staticmethod + def select_experts(hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + topk_ids: Optional[torch.Tensor] = None,): + from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk + from ixformer.inference.functions import moe_grouped_topk as grouped_topk + + # DeekSeekv2 uses grouped_top_k + if use_grouped_topk: + assert topk_group is not None + assert num_expert_group is not None + topk_weights, topk_ids = grouped_topk( + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + topk_ids=topk_ids, + ) + elif custom_routing_function is None: + topk_weights, topk_ids = fused_topk(hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize) + + return topk_weights, topk_ids + + def naive_multicast(self, x: torch.Tensor, + cu_tokens_across_dp_cpu: torch.Tensor): + assert (len(x.shape) == 2) + buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)), + device=x.device, + dtype=x.dtype) + + start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ + self.dp_rank - 1] + end = cu_tokens_across_dp_cpu[self.dp_rank] + buffer[start:end, :].copy_(x) + for idx in range(get_dp_group().world_size): + start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1] + end = cu_tokens_across_dp_cpu[idx] + get_dp_group().broadcast(buffer[start:end, :], idx) + + return buffer + + def forward(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor, + extra_residual: torch.Tensor = None, + routed_scaling_factor: float = 1.0): + if self.use_direct_call: + return self.forward_impl(hidden_states, router_logits, extra_residual, routed_scaling_factor) + else: + return torch.ops.vllm.moe_forward(hidden_states, router_logits, + self.layer_name) + + def forward_impl(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor, + extra_residual: torch.Tensor = None, + routed_scaling_factor: float = 1.0): + assert self.quant_method is not None + + if self.dp_size > 1: + cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_dp_cpu + + hidden_states = self.naive_multicast(hidden_states, + cu_tokens_across_dp_cpu) + router_logits = self.naive_multicast(router_logits, + cu_tokens_across_dp_cpu) + + # Matrix multiply. + final_hidden_states = self.quant_method.apply( + layer=self, + x=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + global_num_experts=self.global_num_experts, + expert_map=self.expert_map, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + custom_routing_function=self.custom_routing_function, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias, + activation=self.activation, + apply_router_weight_on_input=self.apply_router_weight_on_input, + extra_residual=extra_residual, + routed_scaling_factor=routed_scaling_factor, + ) + + if self.dp_size > 1: + start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ + self.dp_rank - 1] + end = cu_tokens_across_dp_cpu[self.dp_rank] + + all_hidden_states = get_dp_group().all_reduce(final_hidden_states) + final_hidden_states = all_hidden_states[start:end, :] + + if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): + # Default set to False. (May have to add shared expert outputs.) + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states + + @classmethod + def make_expert_params_mapping( + cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, + ckpt_up_proj_name: str, + num_experts: int) -> List[Tuple[str, str, int, str]]: + + return [ + # (param_name, weight_name, expert_id, shard_id) + ("experts.w13_" if weight_name + in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_", + f"experts.{expert_id}.{weight_name}.", expert_id, shard_id) + for expert_id in range(num_experts) for shard_id, weight_name in [ + ("w1", ckpt_gate_proj_name), + ("w2", ckpt_down_proj_name), + ("w3", ckpt_up_proj_name), + ] + ] + + def extra_repr(self) -> str: + + s = ( + f"global_num_experts={self.global_num_experts}, " + f"local_num_experts={self.local_num_experts}, " + f"top_k={self.top_k}, " + f"intermediate_size_per_partition={self.intermediate_size_per_partition}, " # noqa: E501 + f"tp_size={self.tp_size},\n" + f"ep_size={self.ep_size}, " + f"reduce_results={self.reduce_results}, " + f"renormalize={self.renormalize}, " + f"use_grouped_topk={self.use_grouped_topk}") + + if self.use_grouped_topk: + s += f", num_expert_group={self.num_expert_group}, topk_group={self.topk_group}" # noqa: E501 + + s += f", scoring_func='{self.scoring_func}', activation='{self.activation}'" # noqa: E501 + + return s + + +def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, + layer_name: str) -> torch.Tensor: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + assert self.quant_method is not None + + return self.forward_impl(hidden_states, router_logits) + + +def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, + layer_name: str) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +direct_register_custom_op( + op_name="moe_forward", + op_func=moe_forward, + mutates_args=[], + fake_impl=moe_forward_fake, + dispatch_key=current_platform.dispatch_key, +) diff --git a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py new file mode 100644 index 0000000..07d51ac --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py @@ -0,0 +1,243 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.utils import round_up + + +def ceil_div(a, b): + return (a + b - 1) // b + + +@triton.jit +def moe_align_block_size_stage1( + topk_ids_ptr, + tokens_cnts_ptr, + num_experts: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + + start_idx = pid * tokens_per_thread + + off_c = (pid + 1) * num_experts + + for i in range(tokens_per_thread): + if start_idx + i < numel: + idx = tl.load(topk_ids_ptr + start_idx + i) + token_cnt = tl.load(tokens_cnts_ptr + off_c + idx) + tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1) + + +@triton.jit +def moe_align_block_size_stage2( + tokens_cnts_ptr, + num_experts: tl.constexpr, +): + pid = tl.program_id(0) + + last_cnt = 0 + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid) + last_cnt = last_cnt + token_cnt + tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt) + + +@triton.jit +def moe_align_block_size_stage3( + total_tokens_post_pad_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, +): + last_cumsum = 0 + off_cnt = num_experts * num_experts + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1) + last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size + tl.store(cumsum_ptr + i, last_cumsum) + tl.store(total_tokens_post_pad_ptr, last_cumsum) + + +@triton.jit +def moe_align_block_size_stage4( + topk_ids_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + start_idx = tl.load(cumsum_ptr + pid) + end_idx = tl.load(cumsum_ptr + pid + 1) + + for i in range(start_idx, end_idx, block_size): + tl.store(expert_ids_ptr + i // block_size, pid) + + start_idx = pid * tokens_per_thread + off_t = pid * num_experts + + for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, + numel)): + expert_id = tl.load(topk_ids_ptr + i) + token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id) + rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id) + tl.store(sorted_token_ids_ptr + rank_post_pad, i) + tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1) + + +# Triton implementation based on: +# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0 +def moe_align_block_size_triton( + topk_ids: torch.Tensor, + num_experts: int, + block_size: int, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, +) -> None: + numel = topk_ids.numel() + grid = (num_experts, ) + tokens_cnts = torch.zeros((num_experts + 1, num_experts), + dtype=torch.int32, + device=topk_ids.device) + cumsum = torch.zeros((num_experts + 1, ), + dtype=torch.int32, + device=topk_ids.device) + tokens_per_thread = ceil_div(numel, num_experts) + + moe_align_block_size_stage1[grid]( + topk_ids, + tokens_cnts, + num_experts, + numel, + tokens_per_thread, + ) + moe_align_block_size_stage2[grid]( + tokens_cnts, + num_experts, + ) + moe_align_block_size_stage3[(1, )]( + num_tokens_post_pad, + tokens_cnts, + cumsum, + num_experts, + block_size, + ) + moe_align_block_size_stage4[grid]( + topk_ids, + sorted_token_ids, + expert_ids, + tokens_cnts, + cumsum, + num_experts, + block_size, + numel, + tokens_per_thread, + ) + + +def moe_align_block_size( + topk_ids: torch.Tensor, + block_size: int, + num_experts: int, + expert_map: Optional[torch.Tensor] = None, + pad_sorted_ids: bool = False +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Aligns the token distribution across experts to be compatible with block + size for matrix multiplication. + + Parameters: + - topk_ids: A tensor of shape [total_tokens, top_k] representing the + top-k expert indices for each token. + - block_size: The block size used in block matrix multiplication. + - num_experts: The total number of experts. + - expert_map: A tensor of shape [num_experts] that maps the expert index + from the global space to the local index space of the current + expert parallel shard. If the expert is not in the current expert + parallel shard, the mapping is set to -1. + - pad_sorted_ids: A flag indicating whether the sorted_token_ids length + should be padded to a multiple of block_size, + + Returns: + - sorted_token_ids: A tensor containing the sorted token indices according + to their allocated expert. + - expert_ids: A tensor indicating the assigned expert index for each block. + - num_tokens_post_padded: The total number of tokens after padding, + ensuring divisibility by block_size. + + This function pads the number of tokens that each expert needs to process + so that it is divisible by block_size. + Padding ensures that during block matrix multiplication, the dimensions + align correctly. + + Example: + Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], + block_size = 4, and num_experts = 4: + - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, + with each expert needing to process 3 tokens. + - As block_size is 4, we pad 1 token for each expert. + - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3]. + - Then append padding tokens [12, 12, 12, 12] for each block. + - After sorting by expert index, we obtain token_ids + [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12]. + Tokens 12 are non-existent (padding) and are ignored in + the subsequent matrix multiplication. + - The padding ensures that the total number of tokens is now divisible + by block_size for proper block matrix operations. + """ + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + if pad_sorted_ids: + max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) + sorted_ids = torch.empty((max_num_tokens_padded, ), + dtype=torch.int32, + device=topk_ids.device) + sorted_ids.fill_(topk_ids.numel()) + max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) + # Expert ids must be zeroed out to prevent index out of bounds error while + # mapping global expert ids to local expert ids in expert parallelism. + expert_ids = torch.zeros((max_num_m_blocks, ), + dtype=torch.int32, + device=topk_ids.device) + num_tokens_post_pad = torch.empty((1), + dtype=torch.int32, + device=topk_ids.device) + if num_experts >= 224: + if envs.VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON or num_experts != 256: + moe_align_block_size_triton( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + else: + # Currently requires num_experts=256 + ops.sgl_moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + else: + ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, + expert_ids, num_tokens_post_pad) + if expert_map is not None: + expert_ids = expert_map[expert_ids] + + return sorted_ids, expert_ids, num_tokens_post_pad diff --git a/vllm/model_executor/layers/fused_moe/moe_pallas.py b/vllm/model_executor/layers/fused_moe/moe_pallas.py new file mode 100644 index 0000000..0365afa --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/moe_pallas.py @@ -0,0 +1,64 @@ +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.nn.functional as F +from torch_xla.experimental.custom_kernel import _histogram + + +def fused_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +) -> torch.Tensor: + """ + Args: + hidden_states: [*, hidden_size] + w1: [num_experts, intermediate_size * 2, hidden_size] + w2: [num_experts, hidden_size, intermediate_size] + gating_output: [*, num_experts] + """ + orig_shape = hidden_states.shape + hidden_size = hidden_states.shape[-1] + num_tokens = hidden_states.shape[:-1].numel() + num_experts = w1.shape[0] + intermediate_size = w2.shape[-1] + device = hidden_states.device + dtype = hidden_states.dtype + assert (num_tokens * topk) % 16 == 0, ( + "The Pallas GMM kernel requires num_tokens * topk to be a multiple of " + f"16 but got {num_tokens * topk}") + + hidden_states = hidden_states.view(num_tokens, hidden_size) + gating_output = gating_output.view(num_tokens, num_experts) + topk_weights = gating_output.softmax(dim=-1, dtype=torch.float) + topk_weights, topk_indices = topk_weights.topk(topk, dim=-1) + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + topk_weights = topk_weights.to(dtype) + + topk_indices = topk_indices.flatten() + topk_argsort_indices = topk_indices.argsort() + topk_argsort_revert_indices = topk_argsort_indices.argsort() + token_indices = torch.arange(num_tokens, + device=device).repeat_interleave(topk) + token_indices = token_indices[topk_argsort_indices] + group_sizes = _histogram(topk_indices.to(torch.int32), 0, num_experts - 1) + + # NOTE(woosuk): The GMM Pallas kernel requires a different weight layout + # from HF Transformers. + w1 = w1.transpose(1, 2) + w2 = w2.transpose(1, 2) + + x = hidden_states[token_indices] + x = torch.ops.xla.gmm(x, w1, group_sizes) + x = F.silu(x[..., :intermediate_size]) * x[..., intermediate_size:] + x = torch.ops.xla.gmm(x, w2, group_sizes) + x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size) + + x = x * topk_weights.unsqueeze_(dim=-1) + x = x.sum(dim=-2) + x = x.reshape(orig_shape) + return x diff --git a/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py b/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py new file mode 100644 index 0000000..da27633 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.nn.functional as F + + +def fused_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + global_num_experts: int, + expert_map: torch.Tensor = None, + renormalize: bool = False, +) -> torch.Tensor: + """ + Args: + hidden_states: [*, hidden_size] + w1: [num_experts, intermediate_size * 2, hidden_size] + w2: [num_experts, hidden_size, intermediate_size] + gating_output: [*, num_experts] + expert_map: [num_experts] + """ + orig_shape = hidden_states.shape + hidden_size = hidden_states.shape[-1] + num_tokens = hidden_states.shape[:-1].numel() + num_experts = w1.shape[0] + intermediate_size = w2.shape[-1] + dtype = hidden_states.dtype + + hidden_states = hidden_states.view(num_tokens, hidden_size) + gating_output = gating_output.view(num_tokens, global_num_experts) + topk_weights = gating_output.softmax(dim=-1, dtype=torch.float) + topk_weights, selected_experts = topk_weights.topk(topk, dim=-1) + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + topk_weights = topk_weights.to(dtype) + + if expert_map is not None: + selected_experts = expert_map[selected_experts] + + final_hidden_states = None + for expert_idx in range(num_experts): + expert_w1 = w1[expert_idx] + expert_w2 = w2[expert_idx] + expert_mask = (selected_experts == expert_idx) + expert_weights = (topk_weights * expert_mask).sum(dim=-1, keepdim=True) + x = F.linear(hidden_states, expert_w1) + gate = F.silu(x[:, :intermediate_size]) + x = x[:, intermediate_size:] * gate + x = F.linear(x, expert_w2) + current_hidden_states = x * expert_weights + if final_hidden_states is None: + final_hidden_states = current_hidden_states + else: + final_hidden_states = final_hidden_states + current_hidden_states + + return final_hidden_states.view(orig_shape) # type: ignore diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py new file mode 100644 index 0000000..ac158a7 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -0,0 +1,158 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import List, Optional + +import torch + +import vllm.envs as envs +from vllm.platforms import current_platform + + +def is_rocm_aiter_moe_enabled() -> bool: + return current_platform.is_rocm() \ + and envs.VLLM_ROCM_USE_AITER_MOE \ + and envs.VLLM_ROCM_USE_AITER \ + + +def is_rocm_aiter_block_scaled_moe_enabled() -> bool: + return is_rocm_aiter_moe_enabled() and \ + envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE + + +def rocm_aiter_fused_experts( + *, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, + expert_mask: Optional[torch.Tensor] = None, + **kwagrs # Ignore additional keyword arguments +) -> torch.Tensor: + + import aiter as rocm_aiter + import aiter.fused_moe_bf16_asm as rocm_aiter_asm_fmoe + + from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8) + + if envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE and use_fp8_w8a8: + assert w1_scale is not None + assert w2_scale is not None + + local_E = E = w1.shape[0] + if expert_mask is not None: + E = expert_mask.numel() + + topk = topk_ids.shape[1] + model_dim = w1.shape[-1] + dtype = hidden_states.dtype + # The default block sizes are 128 in AITER. + if block_shape is None: + block_shape = [128, 128] + + scale_blk_k = block_shape[1] + + ( + sorted_token_ids, + sorted_weight_buf, + sorted_expert_ids, + num_valid_ids, + out_asm, + ) = rocm_aiter_asm_fmoe.moe_sorting_ck(topk_ids, + topk_weights, + E, + model_dim, + dtype, + expert_mask=expert_mask) + + a1, a1_scale = per_token_group_quant_fp8(hidden_states, scale_blk_k) + rocm_aiter.fmoe_fp8_blockscale_g1u1( + out_asm, + a1, + w1, + w2, + sorted_token_ids, + sorted_weight_buf, + sorted_expert_ids, + num_valid_ids, + topk, + w1_scale.view(local_E, -1), + w2_scale.view(local_E, -1), + a1_scale.t().contiguous(), + block_shape[0], + block_shape[1], + None, + ) + return out_asm + + elif use_fp8_w8a8: + return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weight=topk_weights, + topk_ids=topk_ids, + fc1_scale=w1_scale, + fc2_scale=w2_scale, + fc1_smooth_scale=None, + fc2_smooth_scale=None, + a16=False) + + return rocm_aiter.ck_moe(hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids) + + +def rocm_aiter_topk_softmax(topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool) -> tuple[torch.Tensor, ...]: + import aiter as rocm_aiter + rocm_aiter.topk_softmax(topk_weights, topk_indices, token_expert_indices, + gating_output, renormalize) + + return topk_weights, topk_indices + + +def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]: + """ + Applies shuffle_weight function from AITER to each + input tensor and returns them. + + Args: + *tensors: Variable number of torch.Tensor objects. + + Returns: + A tuple of shuffled tensors. + """ + from aiter.ops.shuffle import shuffle_weight + + return tuple(shuffle_weight(tensor) for tensor in tensors) + + +def expand_weights(*tensors: torch.Tensor, + expansion_dims: list[int]) -> tuple[torch.Tensor, ...]: + """ + Expands the dimensions of input tensors. + + Args: + *tensors: A variable number of torch.Tensor objects. + expansion_dims: A list of expansion dimensions + corresponding to each tensor. + + Returns: + A tuple of tensors with expanded dimensions. + """ + + assert len(tensors) == len(expansion_dims), \ + "Number of tensors must match the number of expansion dimensions." + + return tuple( + tensor.unsqueeze(-1).unsqueeze(-1).expand((-1, dim, -1)) + for tensor, dim in zip(tensors, expansion_dims)) diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py new file mode 100644 index 0000000..db31422 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -0,0 +1,48 @@ +# SPDX-License-Identifier: Apache-2.0 +from math import prod +from typing import List, Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8) +from vllm.utils import cdiv + + +def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor: + """ + Shrink the given tensor and apply the given view to it. This is + used to resize the intermediate fused_moe caches. + """ + assert prod(v) <= x.numel() + return x.flatten()[:prod(v)].view(*v) + + +def _fp8_quantize( + A: torch.Tensor, + A_scale: Optional[torch.Tensor], + block_shape: Optional[List[int]], +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform fp8 quantization on the inputs. If a block_shape + is provided, the output will be blocked. + """ + if block_shape is None: + A, A_scale = ops.scaled_fp8_quant(A, A_scale) + else: + assert len(block_shape) == 2 + _, block_k = block_shape[0], block_shape[1] + A, A_scale = per_token_group_quant_fp8(A, block_k) + assert cdiv(A.shape[-1], block_k) == A_scale.shape[-1] + return A, A_scale + + +def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: + """ + A permutation routine that works on fp8 types. + """ + if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8: + return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype) + else: + return m[idx, ...] diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py new file mode 100644 index 0000000..e49e3d5 --- /dev/null +++ b/vllm/model_executor/layers/layernorm.py @@ -0,0 +1,279 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Custom normalization layers.""" +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +import vllm.envs as envs +from vllm.model_executor.custom_op import CustomOp +from vllm.platforms import current_platform + + +def is_rocm_aiter_rmsnorm_enabled() -> bool: + return current_platform.is_rocm() \ + and envs.VLLM_ROCM_USE_AITER_RMSNORM \ + and envs.VLLM_ROCM_USE_AITER + + +def rms_norm(x: torch.Tensor, weight: torch.Tensor, + variance_epsilon: float) -> torch.Tensor: + from vllm import _custom_ops as ops + out = torch.empty_like(x) + ops.rms_norm( + out, + x, + weight, + variance_epsilon, + ) + return out + + +def fused_add_rms_norm( + x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, + variance_epsilon: float, + residual_alpha: float = 1.0) -> Tuple[torch.Tensor, torch.Tensor]: + from vllm import _custom_ops as ops + ops.fused_add_rms_norm( + x, + residual, + weight, + variance_epsilon, + residual_alpha, + ) + return x, residual + + +def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor, + variance_epsilon: float) -> torch.Tensor: + + import aiter as rocm_aiter + return rocm_aiter.rms_norm(x, weight, variance_epsilon) + + +def rocm_aiter_fused_add_rms_norm( + x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, + variance_epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]: + + import aiter as rocm_aiter + + # Assuming the correct signature for rmsnorm2d_fwd_with_add + rocm_aiter.rmsnorm2d_fwd_with_add( + x, # output + x, # input + residual, # residual input + residual, # residual output + weight, + variance_epsilon, + ) + return x, residual + + +def dispatch_cuda_rmsnorm_func(add_residual: bool): + if add_residual: + if is_rocm_aiter_rmsnorm_enabled(): + return rocm_aiter_fused_add_rms_norm + return fused_add_rms_norm + + if is_rocm_aiter_rmsnorm_enabled(): + return rocm_aiter_rms_norm + return rms_norm + + +@CustomOp.register("rms_norm") +class RMSNorm(CustomOp): + """Root mean square normalization. + + Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight. + Refer to https://arxiv.org/abs/1910.07467 + """ + + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + var_hidden_size: Optional[int] = None, + has_weight: bool = True, + dtype: Optional[torch.dtype] = None, + ) -> None: + super().__init__() + + self.hidden_size = hidden_size + self.variance_epsilon = eps + self.variance_size_override = (None if var_hidden_size == hidden_size + else var_hidden_size) + self.has_weight = has_weight + if dtype is not None: + self.weight = torch.ones(hidden_size, dtype=dtype) + else: + self.weight = torch.ones(hidden_size) + if self.has_weight: + self.weight = nn.Parameter(self.weight) + + def forward_native( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """PyTorch-native implementation equivalent to forward().""" + orig_dtype = x.dtype + x = x.to(torch.float32) + if residual is not None: + x = x + residual.to(torch.float32) + residual = x.to(orig_dtype) + + hidden_size = x.shape[-1] + if hidden_size != self.hidden_size: + raise ValueError("Expected hidden_size to be " + f"{self.hidden_size}, but found: {hidden_size}") + + if self.variance_size_override is None: + x_var = x + else: + if hidden_size < self.variance_size_override: + raise ValueError( + "Expected hidden_size to be at least " + f"{self.variance_size_override}, but found: {hidden_size}") + + x_var = x[:, :, :self.variance_size_override] + + variance = x_var.pow(2).mean(dim=-1, keepdim=True) + + x = x * torch.rsqrt(variance + self.variance_epsilon) + x = x.to(orig_dtype) + if self.has_weight: + x = x * self.weight + if residual is None: + return x + else: + return x, residual + + def forward_cuda( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + residual_alpha: float = 1.0, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if self.variance_size_override is not None: + return self.forward_native(x, residual) + + add_residual = residual is not None + norm_func = dispatch_cuda_rmsnorm_func(add_residual) + + if add_residual: + return norm_func(x, residual, self.weight.data, + self.variance_epsilon, residual_alpha) + else: + return norm_func(x, self.weight.data, self.variance_epsilon) + + def forward_hpu( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + from vllm_hpu_extension.ops import HPUFusedRMSNorm + if HPUFusedRMSNorm is None: + return self.forward_native(x, residual) + if residual is not None: + orig_shape = x.shape + residual += x.view(residual.shape) + # Note: HPUFusedRMSNorm requires 3D tensors as inputs + x = HPUFusedRMSNorm.apply(residual, self.weight, + self.variance_epsilon) + return x.view(orig_shape), residual + + x = HPUFusedRMSNorm.apply(x, self.weight, self.variance_epsilon) + return x + + def forward_xpu( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if self.variance_size_override is not None: + return self.forward_native(x, residual) + + from vllm._ipex_ops import ipex_ops as ops + + if residual is not None: + ops.fused_add_rms_norm( + x, + residual, + self.weight.data, + self.variance_epsilon, + ) + return x, residual + return ops.rms_norm( + x, + self.weight.data, + self.variance_epsilon, + ) + + def extra_repr(self) -> str: + s = f"hidden_size={self.weight.data.size(0)}" + s += f", eps={self.variance_epsilon}" + return s + + +@CustomOp.register("gemma_rms_norm") +class GemmaRMSNorm(CustomOp): + """RMS normalization for Gemma. + + Two differences from the above RMSNorm: + 1. x * (1 + w) instead of x * w. + 2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w. + """ + + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + ) -> None: + super().__init__() + self.weight = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + @staticmethod + def forward_static( + weight: torch.Tensor, + variance_epsilon: float, + x: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """PyTorch-native implementation equivalent to forward().""" + orig_dtype = x.dtype + if residual is not None: + x = x + residual + residual = x + + x = x.float() + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + variance_epsilon) + # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + x = x * (1.0 + weight.float()) + x = x.to(orig_dtype) + return x if residual is None else (x, residual) + + def forward_native( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """PyTorch-native implementation equivalent to forward().""" + return self.forward_static(self.weight.data, self.variance_epsilon, x, + residual) + + def forward_cuda( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if torch.compiler.is_compiling(): + return self.forward_native(x, residual) + + if not getattr(self, "_is_compiled", False): + self.forward_static = torch.compile( # type: ignore + self.forward_static) + self._is_compiled = True + return self.forward_native(x, residual) diff --git a/vllm/model_executor/layers/lightning_attn.py b/vllm/model_executor/layers/lightning_attn.py new file mode 100644 index 0000000..de36077 --- /dev/null +++ b/vllm/model_executor/layers/lightning_attn.py @@ -0,0 +1,651 @@ +# SPDX-License-Identifier: Apache-2.0 +import torch +import triton +import triton.language as tl +from einops import rearrange + + +@triton.jit +def _fwd_diag_kernel(Q, K, V, Out, S, b: tl.constexpr, h: tl.constexpr, n, + d: tl.constexpr, e: tl.constexpr, BLOCK: tl.constexpr, + NUM_BLOCK, CBLOCK: tl.constexpr): + # This kernel computes the diagonal blocks of the attention matrix + # Each diagonal block represents attention + # where queries attend to keys in the same block + off = tl.program_id(0) + off_bh = off // NUM_BLOCK # batch-head index + off_block = off % NUM_BLOCK # block index within the sequence + off_cblock = tl.program_id(1) # sub-block index within a block + + off_h = off_bh % h # head index + + # Calculate base offsets for the current batch and head + qk_offset = off_bh * n * d + v_offset = off_bh * n * e + o_offset = off_bh * n * e + + # Calculate offsets for the current block + block_offset = off_block * BLOCK + qk_block_offset = block_offset * d + v_block_offset = block_offset * e + o_block_offset = block_offset * e + + # Calculate offsets for the current sub-block + cblock_offset = off_cblock * CBLOCK + q_cblock_offset = cblock_offset * d + o_cblock_offset = cblock_offset * e + + # Calculate pointers to the query, key, value, and output tensors + Q_block_ptr = (Q + qk_offset + qk_block_offset + q_cblock_offset + + tl.arange(0, CBLOCK)[:, None] * d + + tl.arange(0, d)[None, :]) + K_trans_block_ptr = (K + qk_offset + qk_block_offset + + tl.arange(0, CBLOCK)[None, :] * d + + tl.arange(0, d)[:, None]) + V_block_ptr = (V + v_offset + v_block_offset + + tl.arange(0, CBLOCK)[:, None] * e + + tl.arange(0, e)[None, :]) + O_block_ptr = (Out + o_offset + o_block_offset + o_cblock_offset + + tl.arange(0, CBLOCK)[:, None] * e + + tl.arange(0, e)[None, :]) + + # Load the decay rate for the current head + S_block_ptr = S + off_h + s = tl.load(S_block_ptr) + + i = off_cblock + q_index = tl.arange(0, CBLOCK) + i * CBLOCK + + # Load query values + q = tl.load(Q_block_ptr, + mask=block_offset + q_index[:, None] < n, + other=0.0).to(tl.float32) + + # Initialize output accumulator + qkv = tl.zeros([CBLOCK, e], dtype=tl.float32) + + # Process all sub-blocks up to and + # including the current one (causal attention) + for j in range(i + 1): + kv_index = tl.arange(0, CBLOCK) + j * CBLOCK + diff = q_index[:, None] - kv_index[None, :] + s_index = s * diff + # Apply causal mask: only attend to positions before the current one + s_index = tl.where(diff >= 0, -s_index, float("-inf")) + decay = tl.exp(s_index) + + # Load key and value + k_trans = tl.load( + K_trans_block_ptr, + mask=block_offset + kv_index[None, :] < n, + other=0.0, + ).to(tl.float32) + v = tl.load( + V_block_ptr, + mask=block_offset + kv_index[:, None] < n, + other=0.0, + ).to(tl.float32) + + # Compute attention scores and apply decay + qk = tl.dot(q, k_trans) * decay + + # Compute weighted values and accumulate + qkv += tl.dot(qk, v) + + # Move to the next sub-block + K_trans_block_ptr += CBLOCK * d + V_block_ptr += CBLOCK * e + + # Store the result + tl.store( + O_block_ptr, + qkv.to(O_block_ptr.dtype.element_ty), + mask=block_offset + q_index[:, None] < n, + ) + + +@triton.jit +def _fwd_kv_parallel( + K, + V, + K_decay, + KV, + b: tl.constexpr, + h: tl.constexpr, + n, + d: tl.constexpr, + e: tl.constexpr, + BLOCK: tl.constexpr, + NUM_BLOCK, + D_FBLOCK: tl.constexpr, + E_FBLOCK: tl.constexpr, + NUM_FBLOCK: tl.constexpr, + CBLOCK: tl.constexpr, + NUM_CBLOCK: tl.constexpr, +): + # This kernel computes the key-value outer + # products for each block in parallel + off_bh = tl.program_id(0) # batch-head index + off_block = tl.program_id(1) # block index + + off_h = off_bh % h # head index + + block_offset = off_block * BLOCK + + # Calculate offsets for the current block + k_block_offset = block_offset * d + v_block_offset = block_offset * e + kv_block_offset = off_block * d * e + + # Calculate base offsets for the current batch and head + k_offset = off_bh * n * d + v_offset = off_bh * n * e + kv_offset = off_bh * NUM_BLOCK * d * e + + # Calculate pointers to the key, value, and key-value tensors + K_trans_block_ptr = (K + k_offset + k_block_offset + + tl.arange(0, CBLOCK)[None, :] * d + + tl.arange(0, D_FBLOCK)[:, None]) + V_block_ptr = (V + v_offset + v_block_offset + + tl.arange(0, CBLOCK)[:, None] * e + + tl.arange(0, E_FBLOCK)[None, :]) + KV_block_ptr = (KV + kv_offset + kv_block_offset + + tl.arange(0, D_FBLOCK)[:, None] * e + + tl.arange(0, E_FBLOCK)[None, :]) + + # Load the decay factors for the current head and block + k_decay_ptr = (K_decay + off_h * BLOCK + tl.arange(0, CBLOCK)[None, :]) + + kv_index = tl.arange(0, CBLOCK) + + # Initialize the key-value outer product accumulator + kv = tl.zeros([D_FBLOCK, E_FBLOCK], dtype=tl.float32) + + # Handle the last block which might be smaller than BLOCK + if off_block == NUM_BLOCK - 1: + split_n = n - (NUM_BLOCK - 1) * BLOCK + else: + split_n = BLOCK + left_shift = tl.cdiv(split_n, CBLOCK) * CBLOCK - split_n + num_blocks = min(tl.cdiv(split_n, CBLOCK), NUM_CBLOCK) + k_decay_ptr += (NUM_CBLOCK - num_blocks) * CBLOCK + + # Process all sub-blocks in the current block + for j in range(num_blocks): + left_bound = (1 - j) * left_shift + # Load key and value, handling boundary conditions + k_trans = tl.load(K_trans_block_ptr - left_shift * d, + mask=kv_index[None, :] >= left_bound, + other=0.0) + v = tl.load(V_block_ptr - left_shift * e, + mask=kv_index[:, None] >= left_bound, + other=0.0) + + # Load decay factor and compute weighted key-value outer product + k_decay = tl.load(k_decay_ptr) + kv += tl.dot(k_trans * k_decay, v) + + # Move to the next sub-block + K_trans_block_ptr += CBLOCK * d + V_block_ptr += CBLOCK * e + k_decay_ptr += CBLOCK + + # Store the result + tl.store(KV_block_ptr, kv.to(KV_block_ptr.dtype.element_ty)) + + +@triton.jit +def _fwd_kv_reduce(S, KV, KV_HISTORY, b: tl.constexpr, h: tl.constexpr, n, + d: tl.constexpr, e: tl.constexpr, BLOCK: tl.constexpr, + NUM_BLOCK, D_FBLOCK: tl.constexpr, E_FBLOCK: tl.constexpr): + # This kernel reduces the key-value outer products + # across blocks and updates the KV history + off_bh = tl.program_id(0) # batch-head index + off_h = off_bh % h # head index + + kv_offset = off_bh * NUM_BLOCK * d * e + + # Calculate pointer to the key-value tensor + KV_block_ptr = (KV + kv_offset + tl.arange(0, D_FBLOCK)[:, None] * e + + tl.arange(0, E_FBLOCK)[None, :]) + + # Load the decay rate for the current head + s_ptrs = S + off_h + s = tl.load(s_ptrs) + + # Calculate pointer to the key-value history tensor + kv_history_offset = off_bh * d * e + KV_HISTORY_block_ptr = (KV_HISTORY + kv_history_offset + + tl.arange(0, D_FBLOCK)[:, None] * e + + tl.arange(0, E_FBLOCK)[None, :]) + + # Load the previous key-value history + kv_pre = tl.load(KV_HISTORY_block_ptr).to(tl.float32) + + # Process all blocks in reverse order to compute the prefix sum + for i in range(NUM_BLOCK): + block_size = min(n - i * BLOCK, BLOCK) + # Compute decay factor for the current block + block_decay = tl.exp(-s.to(tl.float32) * block_size) + + # Load the current key-value outer product + kv_cur = tl.load(KV_block_ptr).to(tl.float32) + # Store the previous key-value history to the current block + tl.store(KV_block_ptr, kv_pre.to(KV_block_ptr.dtype.element_ty)) + + # Update the key-value history with the current block + kv_pre = block_decay * kv_pre + kv_cur + KV_block_ptr += d * e + + # Store the updated key-value history + tl.store(KV_HISTORY_block_ptr, kv_pre) + + +@triton.jit +def _fwd_none_diag_kernel( + Q, + Out, + S, + KV, + b: tl.constexpr, + h: tl.constexpr, + n, + d: tl.constexpr, + e: tl.constexpr, + BLOCK: tl.constexpr, + NUM_BLOCK, + E_FBLOCK: tl.constexpr, + CBLOCK: tl.constexpr, + NUM_CBLOCK: tl.constexpr, +): + # This kernel computes the non-diagonal blocks of the attention matrix + # Each non-diagonal block represents attention + # where queries attend to keys in different blocks + off_bh = tl.program_id(0) # batch-head index + off_h = off_bh % h # head index + + off_nc = tl.program_id(1) + off_n = off_nc // NUM_CBLOCK # block index + off_c = off_nc % NUM_CBLOCK # sub-block index + off_e = tl.program_id(2) # output feature block index + + n_offset = off_n * BLOCK + c_offset = off_c * CBLOCK + e_offset = off_e * E_FBLOCK + block_offset = n_offset + c_offset + + # Calculate offsets for the current batch, head, and block + q_offset = off_bh * n * d + (n_offset + c_offset) * d + o_offset = off_bh * n * e + (n_offset + c_offset) * e + e_offset + kv_offset = off_bh * NUM_BLOCK * d * e + off_n * d * e + e_offset + + # Calculate pointers to the query, output, and key-value tensors + Q_block_ptr = (Q + q_offset + tl.arange(0, CBLOCK)[:, None] * d + + tl.arange(0, d)[None, :]) + O_block_ptr = (Out + o_offset + tl.arange(0, CBLOCK)[:, None] * e + + tl.arange(0, E_FBLOCK)[None, :]) + KV_block_ptr = (KV + kv_offset + tl.arange(0, d)[:, None] * e + + tl.arange(0, E_FBLOCK)[None, :]) + + # Load the decay rate for the current head + S_block_ptr = S + off_h + s = tl.load(S_block_ptr) + + c_array = tl.arange(0, CBLOCK) + + # Load the key-value outer product for the current block + kv = tl.load(KV_block_ptr).to(tl.float32) + q_index = block_offset + tl.arange(0, CBLOCK) + + # Load query values + q = tl.load(Q_block_ptr, mask=q_index[:, None] < n, + other=0.).to(tl.float32) + + # Compute decay factors for the current sub-block + q_decay = tl.exp(-s.to(tl.float32) * (off_c * CBLOCK + c_array[:, None])) + + # Compute non-diagonal attention output + qkv_none_diag = tl.dot(q, kv) * q_decay + + # Load diagonal attention output (computed by _fwd_diag_kernel) + qkv_diag = tl.load(O_block_ptr, mask=q_index[:, None] < n, + other=0.).to(tl.float32) + + # Combine diagonal and non-diagonal attention outputs + qkv = qkv_diag + qkv_none_diag + + # Store the result + tl.store(O_block_ptr, + qkv.to(O_block_ptr.dtype.element_ty), + mask=q_index[:, None] < n) + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, s, kv_history): + # Forward pass of the lightning attention algorithm + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + s = s.contiguous() + + # Check CUDA compute capability + capability = torch.cuda.get_device_capability() + if capability[0] < 8: + raise RuntimeError("Flash attention currently only supported", + "for compute capability >= 80") + + # Get input dimensions + b, h, n, d = q.shape + e = v.shape[-1] + + # Initialize output tensor + o = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device) + + # Set block sizes + BLOCK = 256 + NUM_BLOCK = triton.cdiv(n, BLOCK) + + CBLOCK = 32 + NUM_CBLOCK = BLOCK // CBLOCK + assert BLOCK % CBLOCK == 0, "BLOCK must be a multiple of CBLOCK" + + # Compute decay factors for keys + array = torch.arange(0, BLOCK, device=q.device) + 1 + k_decay = torch.exp(-s * (BLOCK - array.reshape(1, -1))) + + # Step 1: Compute diagonal blocks of attention + grid = (b * h * NUM_BLOCK, NUM_CBLOCK) + _fwd_diag_kernel[grid](q, + k, + v, + o, + s, + b, + h, + n, + d, + e, + BLOCK=BLOCK, + NUM_BLOCK=NUM_BLOCK, + CBLOCK=CBLOCK) + + # Set feature block sizes + NUM_FBLOCK = 1 + D_FBLOCK = d // NUM_FBLOCK + assert d % NUM_FBLOCK == 0 + E_FBLOCK = e // NUM_FBLOCK + assert e % NUM_FBLOCK == 0 + + CBLOCK = 64 + NUM_CBLOCK = BLOCK // CBLOCK + assert BLOCK % CBLOCK == 0, "BLOCK must be a multiple of CBLOCK" + + # Step 2: Compute key-value outer products for each block in parallel + kv = torch.empty((b, h, NUM_BLOCK, d, e), + dtype=torch.float32, + device=q.device) + grid = (b * h, NUM_BLOCK) + _fwd_kv_parallel[grid]( + k, + v, + k_decay, + kv, + b, + h, + n, + d, + e, + BLOCK=BLOCK, + NUM_BLOCK=NUM_BLOCK, + D_FBLOCK=D_FBLOCK, + E_FBLOCK=E_FBLOCK, + NUM_FBLOCK=NUM_FBLOCK, + CBLOCK=CBLOCK, + NUM_CBLOCK=NUM_CBLOCK, + ) + + # Step 3: Reduce key-value outer products + # across blocks and update KV history + grid = (b * h, NUM_FBLOCK) + _fwd_kv_reduce[grid](s, + kv, + kv_history, + b, + h, + n, + d, + e, + BLOCK=BLOCK, + NUM_BLOCK=NUM_BLOCK, + D_FBLOCK=D_FBLOCK, + E_FBLOCK=E_FBLOCK) + + # Step 4: Compute non-diagonal blocks of attention + grid = (b * h, NUM_BLOCK * NUM_CBLOCK) + _fwd_none_diag_kernel[grid]( + q, + o, + s, + kv, + b, + h, + n, + d, + e, + BLOCK=BLOCK, + NUM_BLOCK=NUM_BLOCK, + E_FBLOCK=E_FBLOCK, + CBLOCK=CBLOCK, + NUM_CBLOCK=NUM_CBLOCK, + ) + + # Save tensors for backward pass + ctx.save_for_backward(q, k, v, s, kv) + ctx.BLOCK = BLOCK + + return o, torch.cat([kv, kv_history.unsqueeze(2)], dim=2) + + +# Apply the lightning attention function +lightning_attention_ = _attention.apply + + +def lightning_attention(q, k, v, ed, block_size=256, kv_history=None): + """ + Apply lightning attention algorithm + to compute attention efficiently. + + Args: + q: Query tensor of shape [batch, heads, seq_len, dim] + k: Key tensor of shape [batch, heads, seq_len, dim] + v: Value tensor of shape [batch, heads, seq_len, dim_v] + ed: Decay rate tensor of shape [heads] + block_size: Size of blocks for block-sparse attention + kv_history: Optional key-value history from previous computations + + Returns: + output: Attention output + kv: Updated key-value history + """ + d = q.shape[-1] + e = v.shape[-1] + + if ed.dim() == 1: + ed = ed.view(1, -1, 1, 1) + + # Split the computation into chunks for better parallelism + m = 128 if d >= 128 else 64 + assert d % m == 0, f"Dimension d ({d}) must be divisible by m ({m})" + arr = [m * i for i in range(d // m + 1)] + if arr[-1] != d: + arr.append(d) + n = len(arr) + output = 0 + + # Initialize or clone key-value history + if kv_history is None: + kv_history = torch.zeros((q.shape[0], q.shape[1], d, e), + dtype=torch.float32, + device=q.device) + else: + kv_history = kv_history.clone().contiguous() + + # Process each chunk and accumulate results + for i in range(n - 1): + s = arr[i] + e = arr[i + 1] + q1 = q[..., s:e] + k1 = k[..., s:e] + o, kv = lightning_attention_(q1, k1, v, ed, kv_history) + output = output + o + return output, kv + + +@triton.jit +def _linear_attn_decode_kernel( + q_ptr, + k_ptr, + v_ptr, + kv_cache_ptr, + slope_rate, + slot_idx, + output_ptr, + D: tl.constexpr, + qkv_b_stride, + qkv_h_stride, + cache_b_stride, + cache_h_stride, + cache_d0_stride, + cache_d1_stride, + BLOCK_SIZE: tl.constexpr, +): + """ + Kernel for linear attention decoding with KV cache. + + This kernel computes attention for a single token using the KV cache. + """ + pid_b = tl.program_id(0) # batch index + pid_h = tl.program_id(1) # head index + pid_d = tl.program_id(2) # dimension block index + + # Load slot index for the current batch + slot_id = tl.load(slot_idx + pid_b) + + # Skip if slot_id is -1 (padding) + if slot_id == -1: + return + + batch_id = pid_b + head_id = pid_h + + # Load decay rate for the current head + ratio = tl.load(slope_rate + pid_h) + + # Calculate offsets for dimensions + qk_d_offsets = tl.arange(0, D) + v_d_offsets = tl.arange(0, BLOCK_SIZE) + pid_d * BLOCK_SIZE + cache_d_offsets = qk_d_offsets[:, None] * cache_d0_stride + v_d_offsets[ + None, :] * cache_d1_stride + + # Calculate offsets for the current batch and head + q_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride + k_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride + v_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride + + cache_offset = slot_id * cache_b_stride + head_id * cache_h_stride + + # Create masks for loading tensors + qk_mask = qk_d_offsets < D + v_mask = v_d_offsets < D + + # Load query, key, and value tensors + q = tl.load(q_ptr + q_offset + qk_d_offsets, mask=qk_mask, other=0.0) + k = tl.load(k_ptr + k_offset + qk_d_offsets, mask=qk_mask, other=0.0) + v = tl.load(v_ptr + v_offset + v_d_offsets, mask=v_mask, other=0.0) + + # Compute key-value outer product + kv_outer = k[:, None] * v[None, :] + kv_mask = qk_mask[:, None] & v_mask[None, :] + + # Apply decay to previous KV cache + ratio = tl.exp(-ratio) + kv_ptr = kv_cache_ptr + cache_offset + cache_d_offsets + kv_cache_old = tl.load(kv_ptr, mask=kv_mask, other=0.0) + kv_outer = kv_outer + ratio * kv_cache_old + + # Compute attention output + output = q[:, None].to(tl.float32) * kv_outer + output = tl.sum(output, axis=0) + + # Update KV cache and store output + tl.store(kv_ptr, kv_outer, mask=kv_mask) + tl.store(output_ptr + q_offset + v_d_offsets, output, mask=v_mask) + + +def linear_decode_forward_triton( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kv_caches: torch.Tensor, + slope_rate: torch.Tensor, + slot_idx: torch.Tensor, + BLOCK_SIZE: int = 32, +) -> torch.Tensor: + """ + Perform linear attention decoding using Triton kernels. + + Args: + q: Query tensor of shape [B, H, 1, D] + k: Key tensor of shape [B, H, 1, D] + v: Value tensor of shape [B, H, 1, D] + kv_caches: Key-value cache tensor + slope_rate: Decay rate tensor + slot_idx: Slot indices for batches + BLOCK_SIZE: Size of blocks for processing + + Returns: + output: Attention output tensor + """ + B, H, _, D = q.shape + assert k.shape == (B, H, 1, D) + assert v.shape == (B, H, 1, D) + + # Initialize output tensor + output = torch.empty_like(q) + + # Set grid dimensions for the kernel + grid = (B, H, D // BLOCK_SIZE) + + # Calculate strides for tensors + qkv_b_stride = q.stride(0) + qkv_h_stride = q.stride(1) + + cache_b_stride = kv_caches.stride(0) + cache_h_stride = kv_caches.stride(1) + cache_d0_stride = kv_caches.stride(2) + cache_d1_stride = kv_caches.stride(3) + + # Launch the kernel + _linear_attn_decode_kernel[grid]( + q, + k, + v, + kv_caches, + slope_rate, + slot_idx, + output, + D, + qkv_b_stride, + qkv_h_stride, + cache_b_stride, + cache_h_stride, + cache_d0_stride, + cache_d1_stride, + BLOCK_SIZE=BLOCK_SIZE, + ) + + # Reshape output and return + output = rearrange(output, "b h n d -> b n (h d)") + return output.squeeze(1).contiguous() diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py new file mode 100644 index 0000000..762db17 --- /dev/null +++ b/vllm/model_executor/layers/linear.py @@ -0,0 +1,1479 @@ +# SPDX-License-Identifier: Apache-2.0 + +import itertools +from abc import abstractmethod +from typing import Any, Literal, Optional, Union + +import torch +import torch.nn as nn +import ixformer.inference.functions as F +from torch.nn.parameter import Parameter, UninitializedParameter + +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce) +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +# yapf: disable +from vllm.model_executor.parameter import (BasevLLMParameter, + BlockQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + PerTensorScaleParameter, + RowvLLMParameter) +# yapf: enable +from vllm.model_executor.utils import set_weight_attrs + +logger = init_logger(__name__) + +WEIGHT_LOADER_V2_SUPPORTED = [ + "CompressedTensorsLinearMethod", + "AWQMarlinLinearMethod", + "AWQLinearMethod", + "GPTQMarlinLinearMethod", + "Fp8LinearMethod", + "MarlinLinearMethod", + "QQQLinearMethod", + "GPTQMarlin24LinearMethod", + "TPUInt8LinearMethod", + "GPTQLinearMethod", + "FBGEMMFp8LinearMethod", + "ModelOptFp8LinearMethod", + "IPEXAWQLinearMethod", + "IPEXGPTQLinearMethod", + "HQQMarlinMethod", + "QuarkLinearMethod", + "ModelOptNvFp4LinearMethod", +] + + +def adjust_marlin_shard(param, shard_size, shard_offset): + marlin_tile_size = getattr(param, "marlin_tile_size", None) + if marlin_tile_size is None: + return shard_size, shard_offset + + return shard_size * marlin_tile_size, shard_offset * marlin_tile_size + + +def adjust_bitsandbytes_4bit_shard(param: Parameter, + shard_offsets: dict[str, tuple[int, int]], + loaded_shard_id: str) -> tuple[int, int]: + """Adjust the quantization offsets and sizes for BitsAndBytes sharding.""" + + total, _ = shard_offsets["total"] + orig_offset, orig_size = shard_offsets[loaded_shard_id] + + quantized_total = param.data.shape[0] + quantized_offset = orig_offset * quantized_total // total + quantized_size = orig_size * quantized_total // total + + return quantized_size, quantized_offset + + +def adjust_scalar_to_fused_array(param, loaded_weight, shard_id): + """For fused modules (QKV and MLP) we have an array of length + N that holds 1 scale for each "logical" matrix. So the param + is an array of length N. The loaded_weight corresponds to + one of the shards on disk. Here, we slice the param based on + the shard_id for loading. + """ + qkv_idxs = {"q": 0, "k": 1, "v": 2} + + if isinstance(shard_id, str): + shard_id = qkv_idxs[shard_id] + elif not isinstance(shard_id, int): + raise ValueError(f"Unknown Shard Id {shard_id}") + + # AutoFP8 scales do not have a shape + # compressed-tensors scales do have a shape + if len(loaded_weight.shape) != 0: + assert loaded_weight.shape[0] == 1 + loaded_weight = loaded_weight[0] + + return param[shard_id], loaded_weight + + +# TODO(Isotr0py): We might need a more flexible structure to handle +# bitsandbytes shard offsets. +def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]): + """ + Separate the BitsAndBytes 4-bit shard. + + For example, given bnb weight attributes as below: + { + 'bnb_shard_offsets': array([0, 4, 8, 16]), + 'bnb_quant_state': {0: ..., 1: ..., 2: ...}, + } + + The function will return: + { + 'bnb_shard_offsets': array([0, 4]), + 'bnb_quant_state': {0: ...}, + } + and + { + 'bnb_shard_offsets': array([0, 4, 12]), + 'bnb_quant_state': {0: ..., 1: ...}, + } + """ + shard_offsets = bnb_weight_attrs["bnb_shard_offsets"] + offset_l = shard_offsets[:2] + offset_r = shard_offsets[1:] - shard_offsets[1] + quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]} + quant_state_r = { + i - 1: bnb_weight_attrs["bnb_quant_state"][i] + for i in range(1, + len(shard_offsets) - 1) + } + left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l) + right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r) + return left, right + + +class LinearMethodBase(QuantizeMethodBase): + """Base class for different (maybe quantized) linear methods.""" + + @abstractmethod + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + """Create weights for a linear layer. + The weights will be set as attributes of the layer. + + Args: + layer: The layer that is using the LinearMethodBase factory. + input_size_per_partition: Size of the weight input dim on rank X. + output_partition_sizes: Sizes of the output dim of each logical + weight on rank X. E.g., output_partition_sizes for QKVLinear + is a list contains the width of Wq, Wk, Wv on rank X. + input_size: Size of the input dim of the weight across all ranks. + output_size: Size of the output dim of the weight across all ranks. + params_dtype: Datatype of the parameters. + """ + raise NotImplementedError + + @abstractmethod + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + """Apply the weights in layer to the input tensor. + Expects create_weights to have been called before on the layer.""" + raise NotImplementedError + + +class UnquantizedLinearMethod(LinearMethodBase): + """Linear method without quantization.""" + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + weight = Parameter(torch.empty(sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype), + requires_grad=False) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + return F.linear(x, layer.weight, bias) + + +class LinearBase(torch.nn.Module): + """Base linear layer. + + Args: + input_size: input dimension of the linear layer. + output_size: output dimension of the linear layer. + bias: If true, add bias. + skip_bias_add: If true, skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + return_bias: If true, return bias together with outputs in forward pass. + """ + + def __init__( + self, + input_size: int, + output_size: int, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + ): + super().__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.skip_bias_add = skip_bias_add + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + if quant_config is None: + self.quant_method: Optional[ + QuantizeMethodBase] = UnquantizedLinearMethod() + else: + self.quant_method = quant_config.get_quant_method(self, + prefix=prefix) + self.return_bias = return_bias + self.output_padding_size = 0 + + def forward( + self, x: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + raise NotImplementedError + + +class ReplicatedLinear(LinearBase): + """Replicated linear layer. + + Args: + input_size: input dimension of the linear layer. + output_size: output dimension of the linear layer. + bias: If true, add bias. + skip_bias_add: If true, skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + ): + super().__init__(input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix=prefix, + return_bias=return_bias) + + # All the linear layer supports quant method. + assert self.quant_method is not None + self.quant_method.create_weights(self, + self.input_size, [self.output_size], + self.input_size, + self.output_size, + self.params_dtype, + weight_loader=self.weight_loader) + + if bias: + self.bias = Parameter( + torch.empty(self.output_size, dtype=self.params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + # If the weight on disk does not have a shape, give it one + # (such scales for AutoFp8). + # Special case for GGUF + + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type: + param.weight_type = loaded_weight.item() + + # Materialize GGUF UninitializedParameter + if is_gguf_weight and isinstance(param, UninitializedParameter): + param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype) + + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param.size() == loaded_weight.size(), ( + f"Tried to load weights of size {loaded_weight.size()}" + f"to a parameter of size {param.size()}") + param.data.copy_(loaded_weight) + + def forward( + self, x: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + bias = self.bias if not self.skip_bias_add else None + assert self.quant_method is not None + output = self.quant_method.apply(self, x, bias) + output_bias = self.bias if self.skip_bias_add else None + if not self.return_bias: + return output + return output, output_bias + + def extra_repr(self) -> str: + s = f"in_features={self.input_size}" + s += f", output_features={self.output_size}" + s += f", bias={self.bias is not None}" + return s + + +class ColumnParallelLinear(LinearBase): + """Linear layer with column parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + + Args: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias. + gather_output: If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is Y_i = XA_i + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + output_sizes: list of output sizes packed into one output, like for QKV + the list would be size 3. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + output_sizes: Optional[list[int]] = None, + prefix: str = "", + *, + return_bias: bool = True, + ): + # Divide the weight matrix along the last dimension. + self.tp_size = get_tensor_model_parallel_world_size() + self.input_size_per_partition = input_size + self.output_size_per_partition = divide(output_size, self.tp_size) + self.output_partition_sizes = [self.output_size_per_partition] + # If QKV or MergedColumn, use output size of each partition. + if hasattr(self, "output_sizes"): + self.output_partition_sizes = [ + divide(output_size, self.tp_size) + for output_size in self.output_sizes + ] + + super().__init__(input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + return_bias=return_bias) + + self.gather_output = gather_output + + if output_sizes is None: + output_sizes = [output_size] + + assert self.quant_method is not None + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=( + self.weight_loader_v2 if self.quant_method.__class__.__name__ + in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) + if bias: + self.bias = Parameter( + torch.empty(self.output_size_per_partition, + dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + output_dim = getattr(param, "output_dim", None) + + is_sharded_weight = getattr(param, "is_sharded_weight", False) + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) + # bitsandbytes loads the weights of the specific portion + # no need to narrow + is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit + + # Special case for GGUF + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type: + param.weight_type = loaded_weight.item() + + # Materialize GGUF UninitializedParameter + if is_gguf_weight and isinstance(param, UninitializedParameter): + final_shape = list(loaded_weight.shape) + if output_dim is not None: + tp_size = get_tensor_model_parallel_world_size() + assert final_shape[output_dim] % tp_size == 0 + final_shape[output_dim] = final_shape[output_dim] // tp_size + param.materialize(final_shape, dtype=loaded_weight.dtype) + + param_data = param.data + if output_dim is not None and not is_sharded_weight: + shard_size = param_data.shape[output_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor): + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + assert loaded_weight.numel() == 1 + loaded_weight = loaded_weight.reshape(1) + param.load_column_parallel_weight(loaded_weight=loaded_weight) + + def forward( + self, input_ + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + bias = self.bias if not self.skip_bias_add else None + + # Matrix multiply. + assert self.quant_method is not None + output_parallel = self.quant_method.apply(self, input_, bias) + if self.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + if not self.return_bias: + return output + return output, output_bias + + def extra_repr(self) -> str: + s = f"in_features={self.input_size}" + s += f", output_features={self.output_size_per_partition}" + s += f", bias={self.bias is not None}" + s += f", tp_size={get_tensor_model_parallel_world_size()}" + s += f", gather_output={self.gather_output}" + return s + + +class MergedColumnParallelLinear(ColumnParallelLinear): + """Packed linear layers with column parallelism. + + Similar to ColumnParallelLinear, but the weight matrix is concatenated + along the output dimension. When the weight matrix is loaded, the + different partitions are sharded separately. + + Args: + input_size: input dimension of the linear layer. + output_sizes: list of output dimensions of the linear layer. + bias: If true, add bias. + gather_output: If true, call all-gather on output and make the output + available to all GPUs, otherwise, every GPU will have + its own output. + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + """ + + def __init__( + self, + input_size: int, + output_sizes: list[int], + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + ): + self.output_sizes = output_sizes + tp_size = get_tensor_model_parallel_world_size() + assert all(output_size % tp_size == 0 for output_size in output_sizes) + super().__init__(input_size=input_size, + output_size=sum(output_sizes), + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix, + return_bias=return_bias) + + def weight_loader(self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[int] = None): + + # Special case for GGUF + # initialize GGUF param after we know the quantize type + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type: + if loaded_shard_id is not None: + param.data[loaded_shard_id].copy_(loaded_weight) + param.shard_weight_type[loaded_shard_id] = loaded_weight.item() + else: + param.shard_weight_type = { + i: loaded_weight.item() + for i, _ in enumerate(self.output_sizes) + } + return + + if is_gguf_weight: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + output_dim = getattr(param, "output_dim", None) + shard_size = loaded_weight.size(output_dim) // tp_size + start_idx = tp_rank * shard_size + + if loaded_shard_id is not None: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + param.shard_id.append(loaded_shard_id) + param.shard_id_map[loaded_shard_id] = len(param.data_container) + param.data_container.append(loaded_weight) + if len(param.data_container) == 2: + self.qweight = param.materialize_nested() + return + + param_data = param.data + output_dim = getattr(param, "output_dim", None) + # Special case for AQLM codebooks. + is_metadata = getattr(param, "is_metadata", False) + # Special case for per-tensor scale to load scalar into fused array. + needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) + + if loaded_shard_id is None: + # Loaded weight is already fused on disk (mlp). + # (e.g., Phi-3's gate_up_proj). + if output_dim is None: + if needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, 0) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + return + current_shard_offset = 0 + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", + False) + shard_offsets: list[tuple[int, int, int]] = [] + for i, output_size in enumerate(self.output_sizes): + shard_offsets.append((i, current_shard_offset, output_size)) + current_shard_offset += output_size + packed_dim = getattr(param, "packed_dim", None) + for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantization. + # If quantized, we need to adjust the offset and size to account + # for the packing. + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset) + + if use_bitsandbytes_4bit: + index = list(itertools.accumulate([0] + self.output_sizes)) + orig_offsets = { + str(i): (index[i], size) + for i, size in enumerate(self.output_sizes) + } + orig_offsets["total"] = (self.output_size, 0) + shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( + param, orig_offsets, str(shard_id)) + + loaded_weight_shard = loaded_weight.narrow( + output_dim, shard_offset, shard_size) + self.weight_loader(param, loaded_weight_shard, shard_id) + return + + assert loaded_shard_id < len(self.output_sizes) + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + if output_dim is not None: + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size + shard_size = self.output_sizes[loaded_shard_id] // tp_size + # Special case for quantization. + # If quantized, we need to adjust the offset and size to account + # for the packing. + packed_dim = getattr(param, "packed_dim", None) + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset) + + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", + False) + is_sharded_weight = getattr(param, "is_sharded_weight", False) + # bitsandbytes loads the weights of the specific portion + # no need to narrow + is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit + + if use_bitsandbytes_4bit: + shard_size = loaded_weight.shape[output_dim] + shard_offset = loaded_weight.shape[output_dim] * \ + loaded_shard_id + + param_data = param_data.narrow(output_dim, shard_offset, + shard_size) + start_idx = tp_rank * shard_size + if not is_sharded_weight: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + # Special case for AQLM codebooks. + elif is_metadata: + # metadata indicates fixed size concatenated along dim 0 + shard_size = loaded_weight.shape[0] + shard_offset = loaded_shard_id * shard_size + param_data = param_data.narrow(0, shard_offset, shard_size) + + # Special case for per-tensor scales in fused case. + elif needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, loaded_shard_id) + + else: + ignore_warning = getattr(param, "ignore_warning", False) + if not ignore_warning: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "MergedColumnParallelLinear, assume the weight is " + "the same for all partitions.") + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter, + loaded_weight: torch.Tensor): + """ + Handle special case for models where MLP layers are already + fused on disk. In this case, we have no shard id. This function + determmines the shard id by splitting these layers and then calls + the weight loader using the shard id. + + An example of a model with these fused layers: + https://huggingface.co/microsoft/Phi-3-mini-4k-instruct + """ + + current_shard_offset = 0 + shard_offsets: list[tuple[int, int, int]] = [] + for i, output_size in enumerate(self.output_sizes): + shard_offsets.append((i, current_shard_offset, output_size)) + current_shard_offset += output_size + + for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantization. + # If quantized, we need to adjust the offset and size to account + # for the packing. + if isinstance(param, (PackedColumnParameter, PackedvLLMParameter + )) and param.packed_dim == param.output_dim: + shard_size, shard_offset = \ + param.adjust_shard_indexes_for_packing( + shard_size=shard_size, shard_offset=shard_offset) + + loaded_weight_shard = loaded_weight.narrow(param.output_dim, + shard_offset, + shard_size) + self.weight_loader_v2(param, loaded_weight_shard, shard_id) + + def weight_loader_v2(self, + param: BasevLLMParameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[int] = None): + if loaded_shard_id is None: + if isinstance(param, PerTensorScaleParameter): + param.load_merged_column_weight(loaded_weight=loaded_weight, + shard_id=0) + return + elif type(param) in (RowvLLMParameter, BasevLLMParameter): + param.load_merged_column_weight(loaded_weight=loaded_weight) + return + # TODO: @dsikka - move to parameter.py + self._load_fused_module_from_checkpoint(param, loaded_weight) + return + + assert loaded_shard_id < len(self.output_sizes) + + tp_size = get_tensor_model_parallel_world_size() + + if isinstance(param, BlockQuantScaleParameter): + from vllm.model_executor.layers.quantization.fp8 import ( + Fp8LinearMethod, Fp8MoEMethod) + assert self.quant_method is not None + assert isinstance(self.quant_method, + (Fp8LinearMethod, Fp8MoEMethod)) + weight_block_size = self.quant_method.quant_config.weight_block_size + assert weight_block_size is not None + block_n, _ = weight_block_size[0], weight_block_size[1] + shard_offset = ( + (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // + block_n) // tp_size + shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) // + block_n // tp_size) + else: + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size + shard_size = self.output_sizes[loaded_shard_id] // tp_size + + param.load_merged_column_weight(loaded_weight=loaded_weight, + shard_id=loaded_shard_id, + shard_offset=shard_offset, + shard_size=shard_size) + + +class QKVParallelLinear(ColumnParallelLinear): + """Linear layers for the attention's QKV transformation. + + Linear layers for the linear transformation of the query, key, and value + vectors in the attention layer. The weight matrix is concatenated along + the output dimension. The layer is parallelized along the head dimension. + When the number of key/value heads is smaller than the number of query + heads (e.g., multi-query/grouped-query attention), the key/value head may + be replicated while the query heads are partitioned. + + Args: + hidden_size: input hidden state size of the transformer. + head_size: size of each attention head. + total_num_heads: total number of attention query heads. + total_num_kv_heads: total number of attention key/value heads. If + None, assume total_num_kv_heads = total_num_heads. + bias: If true, add bias. + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + """ + + def __init__( + self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: Optional[int] = None, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + ): + self.hidden_size = hidden_size + self.head_size = head_size + self.total_num_heads = total_num_heads + if total_num_kv_heads is None: + total_num_kv_heads = total_num_heads + self.total_num_kv_heads = total_num_kv_heads + # Divide the weight matrix along the last dimension. + tp_size = get_tensor_model_parallel_world_size() + self.num_heads = divide(self.total_num_heads, tp_size) + if tp_size >= self.total_num_kv_heads: + self.num_kv_heads = 1 + self.num_kv_head_replicas = divide(tp_size, + self.total_num_kv_heads) + else: + self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) + self.num_kv_head_replicas = 1 + input_size = self.hidden_size + output_size = (self.num_heads + + 2 * self.num_kv_heads) * tp_size * self.head_size + self.output_sizes = [ + self.num_heads * self.head_size * tp_size, # q_proj + self.num_kv_heads * self.head_size * tp_size, # k_proj + self.num_kv_heads * self.head_size * tp_size, # v_proj + ] + + super().__init__(input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=False, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix, + return_bias=return_bias) + + def _get_shard_offset_mapping(self, loaded_shard_id: str): + shard_offset_mapping = { + "q": 0, + "k": self.num_heads * self.head_size, + "v": (self.num_heads + self.num_kv_heads) * self.head_size, + "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size + } + return shard_offset_mapping.get(loaded_shard_id) + + def _get_shard_size_mapping(self, loaded_shard_id: str): + shard_size_mapping = { + "q": self.num_heads * self.head_size, + "k": self.num_kv_heads * self.head_size, + "v": self.num_kv_heads * self.head_size, + } + return shard_size_mapping.get(loaded_shard_id) + + def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter, + loaded_weight: torch.Tensor): + """ + Handle special case for models where QKV layers are already + fused on disk. In this case, we have no shard id. This function + determmines the shard id by splitting these layers and then calls + the weight loader using the shard id. + + An example of a model with these fused layers: + https://huggingface.co/microsoft/Phi-3-mini-4k-instruct + """ + shard_offsets = [ + # (shard_id, shard_offset, shard_size) + ("q", 0, self.total_num_heads * self.head_size), + ("k", self.total_num_heads * self.head_size, + self.total_num_kv_heads * self.head_size), + ("v", + (self.total_num_heads + self.total_num_kv_heads) * self.head_size, + self.total_num_kv_heads * self.head_size), + ] + + for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantization. + # If quantized, we need to adjust the offset and size to account + # for the packing. + if isinstance(param, (PackedColumnParameter, PackedvLLMParameter + )) and param.packed_dim == param.output_dim: + shard_size, shard_offset = \ + param.adjust_shard_indexes_for_packing( + shard_size=shard_size, shard_offset=shard_offset) + + loaded_weight_shard = loaded_weight.narrow(param.output_dim, + shard_offset, + shard_size) + self.weight_loader_v2(param, loaded_weight_shard, shard_id) + + def weight_loader_v2(self, + param: BasevLLMParameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None): + if loaded_shard_id is None: # special case for certain models + if isinstance(param, PerTensorScaleParameter): + param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0) + return + elif type(param) in (RowvLLMParameter, BasevLLMParameter): + param.load_qkv_weight(loaded_weight=loaded_weight) + return + # TODO: @dsikka - move to parameter.py + self._load_fused_module_from_checkpoint(param, loaded_weight) + return + + assert loaded_shard_id in ["q", "k", "v"] + + shard_offset = self._get_shard_offset_mapping(loaded_shard_id) + shard_size = self._get_shard_size_mapping(loaded_shard_id) + + param.load_qkv_weight(loaded_weight=loaded_weight, + num_heads=self.num_kv_head_replicas, + shard_id=loaded_shard_id, + shard_offset=shard_offset, + shard_size=shard_size) + + def weight_loader(self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None): + + # Special case for GGUF + # initialize GGUF param after we know the quantize type + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type: + idx_map = {"q": 0, "k": 1, "v": 2} + if loaded_shard_id is not None: + param.data[idx_map[loaded_shard_id]].copy_(loaded_weight) + param.shard_weight_type[loaded_shard_id] = loaded_weight.item() + else: + param.shard_weight_type = { + k: loaded_weight.item() + for k in idx_map + } + return + + if is_gguf_weight: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + output_dim = getattr(param, "output_dim", None) + shard_size = loaded_weight.size(output_dim) // tp_size + start_idx = tp_rank * shard_size + + if loaded_shard_id is not None: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + param.shard_id.append(loaded_shard_id) + param.shard_id_map[loaded_shard_id] = len(param.data_container) + param.data_container.append(loaded_weight) + if len(param.data_container) == 3: + self.qweight = param.materialize_nested() + return + + param_data = param.data + output_dim = getattr(param, "output_dim", None) + # Special case for AQLM codebooks. + is_metadata = getattr(param, "is_metadata", False) + + # Special case for per-tensor scales in fused case. + needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) + + if loaded_shard_id is None: + # Loaded weight is already fused on disk (qkv). + # (e.g., Phi-3's qkv_proj). + if output_dim is None: + if needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, 0) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + return + shard_offsets = [ + # (shard_id, shard_offset, shard_size) + ("q", 0, self.total_num_heads * self.head_size), + ("k", self.total_num_heads * self.head_size, + self.total_num_kv_heads * self.head_size), + ("v", (self.total_num_heads + self.total_num_kv_heads) * + self.head_size, self.total_num_kv_heads * self.head_size), + ] + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", + False) + + packed_dim = getattr(param, "packed_dim", None) + for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantized Weights. + # If quantized, we need to adjust the offset and size to account + # for the packing. + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset) + + if use_bitsandbytes_4bit: + orig_qkv_offsets = { + "q": (0, self.total_num_heads * self.head_size), + "k": (self.total_num_heads * self.head_size, + self.total_num_kv_heads * self.head_size), + "v": + ((self.total_num_heads + self.total_num_kv_heads) * + self.head_size, + self.total_num_kv_heads * self.head_size), + "total": + ((self.total_num_heads + 2 * self.total_num_kv_heads) * + self.head_size, 0) + } + + shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( + param, orig_qkv_offsets, shard_id) + + loaded_weight_shard = loaded_weight.narrow( + output_dim, shard_offset, shard_size) + self.weight_loader(param, loaded_weight_shard, shard_id) + return + + tp_rank = get_tensor_model_parallel_rank() + assert loaded_shard_id in ["q", "k", "v"] + + # If output dim is defined, use the default loading process. + if output_dim is not None: + if loaded_shard_id == "q": + shard_offset = 0 + shard_size = self.num_heads * self.head_size + elif loaded_shard_id == "k": + shard_offset = self.num_heads * self.head_size + shard_size = self.num_kv_heads * self.head_size + elif loaded_shard_id == "v": + shard_offset = (self.num_heads + + self.num_kv_heads) * self.head_size + shard_size = self.num_kv_heads * self.head_size + # Special case for Quantized Weights. + # If quantized, we need to adjust the offset and size to account + # for the packing. + packed_dim = getattr(param, "packed_dim", None) + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset) + + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", + False) + is_sharded_weight = getattr(param, "is_sharded_weight", False) + # bitsandbytes loads the weights of the specific portion + # no need to narrow + is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit + + if use_bitsandbytes_4bit: + orig_qkv_offsets = { + "q": (0, self.num_heads * self.head_size), + "k": (self.num_heads * self.head_size, + self.num_kv_heads * self.head_size), + "v": + ((self.num_heads + self.num_kv_heads) * self.head_size, + self.num_kv_heads * self.head_size), + "total": + ((self.num_heads + 2 * self.num_kv_heads) * self.head_size, + 0) + } + shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( + param, orig_qkv_offsets, loaded_shard_id) + + param_data = param_data.narrow(output_dim, shard_offset, + shard_size) + if loaded_shard_id == "q": + shard_id = tp_rank + else: + shard_id = tp_rank // self.num_kv_head_replicas + start_idx = shard_id * shard_size + + if not is_sharded_weight: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + + # Special case for for AQLM codebooks. + elif is_metadata: + # metadata indicates fixed size concatenated along dim 0 + shard_size = loaded_weight.shape[0] + shard_index = ["q", "k", "v"].index(loaded_shard_id) + param_data = param_data.narrow(0, shard_index * shard_size, + shard_size) + # Special case for per-tensor scales in fused case. + elif needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, loaded_shard_id) + else: + ignore_warning = getattr(param, "ignore_warning", False) + if not ignore_warning: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "QKVParallelLinear, assume the weight is the same " + "for all partitions.") + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +class RowParallelLinear(LinearBase): + """Linear layer with row parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its first dimension and X along its second dimension as: + - - + | A_1 | + | . | + A = | . | X = [X_1, ..., X_p] + | . | + | A_p | + - - + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias. Note that bias is not parallelized. + input_is_parallel: If true, we assume that the input is already + split across the GPUs and we do not split + again. + skip_bias_add: This was added to enable performance optimization where + bias can be fused with other element-wise operations. + We skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + ): + # Divide the weight matrix along the first dimension. + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + self.input_size_per_partition = divide(input_size, self.tp_size) + self.output_size_per_partition = output_size + self.output_partition_sizes = [output_size] + + super().__init__(input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + return_bias=return_bias) + + self.input_is_parallel = input_is_parallel + self.reduce_results = reduce_results + + assert self.quant_method is not None + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=( + self.weight_loader_v2 if self.quant_method.__class__.__name__ + in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) + if not reduce_results and (bias and not skip_bias_add): + raise ValueError("When not reduce the results, adding bias to the " + "results can lead to incorrect results") + + if bias: + self.bias = Parameter( + torch.empty(self.output_size, dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + input_dim = getattr(param, "input_dim", None) + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) + is_sharded_weight = getattr(param, "is_sharded_weight", False) + # bitsandbytes loads the weights of the specific portion + # no need to narrow + is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit + + # Special case for GGUF + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type: + param.weight_type = loaded_weight.item() + + # Materialize GGUF UninitializedParameter + if is_gguf_weight and isinstance(param, UninitializedParameter): + weight_shape = list(loaded_weight.shape) + if input_dim: + weight_shape[input_dim] = weight_shape[input_dim] // tp_size + param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype) + + param_data = param.data + if input_dim is not None and not is_sharded_weight: + shard_size = param_data.shape[input_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(input_dim, start_idx, + shard_size) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def weight_loader_v2(self, param: BasevLLMParameter, + loaded_weight: torch.Tensor): + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + assert loaded_weight.numel() == 1 + loaded_weight = loaded_weight.reshape(1) + + param.load_row_parallel_weight(loaded_weight=loaded_weight) + + def forward( + self, input_ + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + if self.input_is_parallel: + input_parallel = input_ + else: + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[tp_rank].contiguous() + + # Matrix multiply. + assert self.quant_method is not None + # Only fuse bias add into GEMM for rank 0 (this ensures that + # bias will not get added more than once in TP>1 case) + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + output_parallel = self.quant_method.apply(self, + input_parallel, + bias=bias_) + if self.reduce_results and self.tp_size > 1: + output = tensor_model_parallel_all_reduce(output_parallel) + else: + output = output_parallel + + output_bias = self.bias if self.skip_bias_add else None + + if not self.return_bias: + return output + return output, output_bias + + def extra_repr(self) -> str: + s = f"input_features={self.input_size_per_partition}" + s += f", output_features={self.output_size}" + s += f", bias={self.bias is not None}" + s += f", tp_size={self.tp_size}" + s += f", reduce_results={self.reduce_results}" + return s + + +class QKVCrossParallelLinear(LinearBase): + """Linear layers for efficient cross-attention's QKV transformation. + + Args: + hidden_size: input hidden state size of the transformer. + head_size: size of each attention head. + total_num_heads: total number of attention query heads. + total_num_kv_heads: total number of attention key/value heads. If + None, assume total_num_kv_heads = total_num_heads. + bias: If true, add bias. + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + """ + + def __init__(self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: Optional[int] = None, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + # input_size and output_size are not used, just for alignment + input_size = hidden_size + output_size = (total_num_heads + (total_num_kv_heads or 0)) * head_size + super().__init__(input_size=input_size, + output_size=output_size, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix) + + self.quant_config = quant_config + + # Empty placeholders for loading as a single module. + placeholder_size = 0 + assert self.quant_method is not None + self.quant_method.create_weights(self, + placeholder_size, [placeholder_size], + placeholder_size, + placeholder_size, + self.params_dtype, + weight_loader=self.weight_loader) + + # Use a dictionary to avoid submodules parameters auto-registration: + # drop-in replacement for a `QKVParallelLinear` module. + self.proj = dict() + self.proj["q_proj_decoder"] = ColumnParallelLinear( + input_size=hidden_size, + output_size=total_num_heads * head_size, + bias=bias, + quant_config=quant_config, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + prefix=f"{prefix}.q_proj_decoder") + + self.proj["kv_proj_encoder"] = QKVParallelLinear( + hidden_size=hidden_size, + head_size=head_size, + total_num_heads=0, + total_num_kv_heads=total_num_kv_heads, + bias=bias, + quant_config=quant_config, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + prefix=f"{prefix}.kv_proj_encoder") + + # `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1. + self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size + + if bias: + self.bias = torch.nn.Parameter() + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.bias = None + + @property + def q_proj_decoder(self) -> ColumnParallelLinear: + layer = self.proj["q_proj_decoder"] + for name, param in self.named_parameters(): + target_param = getattr(layer, name) + self.sync_weight_attrs(param, target_param, mode="q_proj_decoder") + return layer + + @property + def kv_proj_encoder(self) -> QKVParallelLinear: + layer = self.proj["kv_proj_encoder"] + for name, param in self.named_parameters(): + target_param = getattr(layer, name) + self.sync_weight_attrs(param, target_param, mode="kv_proj_encoder") + return layer + + def sync_weight_attrs( + self, + src_param: nn.Parameter, + tgt_param: nn.Parameter, + mode: Literal["q_proj_decoder", "kv_proj_encoder"], + ): + missing_attrs_dict = { + k: getattr(src_param, k) + for k in (set(src_param.__dict__.keys()) - + set(tgt_param.__dict__.keys())) + } + # TODO(Isotr0py): handle bitsandbytes 8bit + use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit", + False) + if (missing_attrs_dict and use_bitsandbytes_4bit): + q_proj_attrs, kv_proj_attrs = left_shift_bitsandbytes_4bit_shard( + missing_attrs_dict) + if mode == "q_proj_decoder": + set_weight_attrs(tgt_param, q_proj_attrs) + elif mode == "kv_proj_encoder": + set_weight_attrs(tgt_param, kv_proj_attrs) + else: + set_weight_attrs(tgt_param, missing_attrs_dict) + + def _is_same_param( + self, + src_param: torch.nn.Parameter, + map_param: torch.nn.Parameter, + ) -> bool: + """Check if two parameters are exactly pointing to same things.""" + # ignore weight_loader because it's always different + key_to_ignore = ["weight_loader", "_weight_loader"] + has_same_type_name = type(src_param) is type(map_param) + src_param_attrs = { + k: v + for k, v in src_param.__dict__.items() if k not in key_to_ignore + } + map_param_attrs = { + k: v + for k, v in map_param.__dict__.items() if k not in key_to_ignore + } + has_same_attrs = src_param_attrs == map_param_attrs + return has_same_type_name and has_same_attrs + + def select_proj_params( + self, + layer: nn.Module, + param: nn.Parameter, + ) -> nn.Parameter: + """ + Given the placeholder param, + return the corresponding param in the proj layers. + """ + target_param_list = [ + v for _, v in layer.named_parameters() + if self._is_same_param(param, v) + ] + assert len(target_param_list) == 1 + target_param = target_param_list[0] + return target_param + + def forward( # type: ignore[override] + self, + decoder_hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + ) -> tuple[torch.Tensor, ...]: + q, _ = self.q_proj_decoder(decoder_hidden_states) + if encoder_hidden_states is None: + # Encoder KV already cached. + k = None + v = None + else: + # Prefill phase, encoder KV cached here. + kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states) + # Split kv in half + k, v = kv_enc.split(self.kv_size, dim=-1) + return q, k, v + + def weight_loader(self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None): + layer = (self.q_proj_decoder + if loaded_shard_id == "q" else self.kv_proj_encoder) + target_param = self.select_proj_params(layer, param) + shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else () + layer.weight_loader(target_param, loaded_weight, *shard_id_args) + + def extra_repr(self) -> str: + s = f"in_features={self.input_size}" + s += f", q_size={self.q_proj_decoder.output_size_per_partition}" + s += f", kv_size={self.kv_size}" + s += f", bias={self.bias is not None}" + s += f", tp_size={get_tensor_model_parallel_world_size()}" + s += ", gather_output=False" + return s diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py new file mode 100644 index 0000000..2238161 --- /dev/null +++ b/vllm/model_executor/layers/logits_processor.py @@ -0,0 +1,199 @@ +# SPDX-License-Identifier: Apache-2.0 +"""A layer that compute logits from hidden_stats.""" +import inspect +from concurrent.futures import ThreadPoolExecutor +from typing import Optional + +import torch +import torch.nn as nn + +import vllm.envs as envs +from vllm.distributed import (tensor_model_parallel_all_gather, + tensor_model_parallel_gather) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.platforms import current_platform + +_logits_processor_threadpool: Optional[ThreadPoolExecutor] = None +if envs.VLLM_LOGITS_PROCESSOR_THREADS is not None: + _logits_processor_threadpool = ThreadPoolExecutor( + envs.VLLM_LOGITS_PROCESSOR_THREADS) + + +class LogitsProcessor(nn.Module): + """Process logits and apply logits processors from sampling metadata. + + This layer does the following: + 1. Gather logits from model hidden_states. + 2. Scale logits if needed. + 3. Apply logits processors (if any). + """ + + def __init__(self, + vocab_size: int, + org_vocab_size: Optional[int] = None, + scale: float = 1.0, + logits_as_input: bool = False, + soft_cap: Optional[float] = None) -> None: + """ + Args: + scale: A scaling factor to apply to the logits. + """ + super().__init__() + self.scale = scale + self.vocab_size = vocab_size + # Whether the input is logits (default is hidden states). + self.logits_as_input = logits_as_input + # original vocabulary size (without LoRA). + self.org_vocab_size = org_vocab_size or vocab_size + # Soft cap the logits. Used in Gemma 2. + self.soft_cap = soft_cap + # Whether to use gather or all-gather to gather the logits. + self.use_all_gather = current_platform.use_all_gather() + + def forward( + self, + lm_head: VocabParallelEmbedding, + hidden_states: torch.Tensor, + sampling_metadata: Optional[SamplingMetadata] = None, + embedding_bias: Optional[torch.Tensor] = None, + ) -> Optional[torch.Tensor]: + if self.logits_as_input: + logits = hidden_states + else: + if sampling_metadata is not None: + hidden_states = _prune_hidden_states(hidden_states, + sampling_metadata) + + # Get the logits for the next tokens. + if hidden_states.shape[0] > 0: + logits = self._get_logits(hidden_states, lm_head, embedding_bias) + else: + logits = torch.empty([0, lm_head.weight.shape[0]], device=hidden_states.device, dtype=hidden_states.dtype) + if logits is not None: + if self.soft_cap is not None: + logits = logits / self.soft_cap + logits = torch.tanh(logits) + logits = logits * self.soft_cap + + if self.scale != 1.0: + logits *= self.scale + + # Apply logits processors (if any). + if sampling_metadata is not None and \ + sampling_metadata.seq_groups is not None: + logits = _apply_logits_processors(logits, sampling_metadata) + + return logits + + def _gather_logits(self, logits: torch.Tensor) -> torch.Tensor: + """gather/all-gather the logits tensor across model parallel group.""" + if self.use_all_gather: + # Gather is not supported for some devices such as TPUs. + # Use all-gather instead. + # NOTE(woosuk): Here, the outputs of every device should not be None + # because XLA requires strict SPMD among all devices. Every device + # should execute the same operations after gathering the logits. + logits = tensor_model_parallel_all_gather(logits) + else: + # None may be returned for rank > 0 + logits = tensor_model_parallel_gather(logits) + return logits + + def _get_logits( + self, + hidden_states: torch.Tensor, + lm_head: VocabParallelEmbedding, + embedding_bias: Optional[torch.Tensor], + ) -> Optional[torch.Tensor]: + # Get the logits for the next tokens. + logits = lm_head.quant_method.apply(lm_head, + hidden_states, + bias=embedding_bias) + + # Gather logits for TP + logits = self._gather_logits(logits) + + # Remove paddings in vocab (if any). + if logits is not None: + logits = logits[..., :self.org_vocab_size] + return logits + + def extra_repr(self) -> str: + s = f"vocab_size={self.vocab_size}" + s += f", forg_vocab_size={self.org_vocab_size}" + s += f", scale={self.scale}, logits_as_input={self.logits_as_input}" + return s + + +def _prune_hidden_states( + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> torch.Tensor: + # NOTE(kzawora): The if guard is needed for Gaudi - in some scenarios + # (warmup, profile_run) we might not have selected_token_indices, + # so we skip pruning. + if sampling_metadata.selected_token_indices is not None: + return hidden_states.index_select( + 0, sampling_metadata.selected_token_indices) + else: + return hidden_states + + +def _apply_logits_processors( + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> torch.Tensor: + found_logits_processors = False + logits_processed = 0 + logits_row_ids_and_logits_row_futures = [] + for seq_group in sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + sampling_params = seq_group.sampling_params + logits_processors = sampling_params.logits_processors + if logits_processors: + found_logits_processors = True + + for seq_id, logits_row_idx in zip(seq_ids, + seq_group.sample_indices): + logits_row = logits[logits_row_idx] + past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids + prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids + + if _logits_processor_threadpool is not None: + logits_row_ids_and_logits_row_futures.append( + (logits_row_idx, + _logits_processor_threadpool.submit( + _apply_logits_processors_single_seq, logits_row, + logits_processors, past_tokens_ids, + prompt_tokens_ids))) + else: + logits[logits_row_idx] = \ + _apply_logits_processors_single_seq( + logits_row, logits_processors, past_tokens_ids, + prompt_tokens_ids) + + logits_processed += len(seq_group.sample_indices) + len( + seq_group.prompt_logprob_indices) + + for logits_row_idx, future in logits_row_ids_and_logits_row_futures: + logits[logits_row_idx] = future.result() + + if found_logits_processors: + # verifies that no rows in logits were missed unexpectedly + assert logits_processed == logits.shape[0] + return logits + + +def _apply_logits_processors_single_seq(logits_row, logits_processors, + past_tokens_ids, + prompt_tokens_ids) -> torch.Tensor: + for logits_processor in logits_processors: + parameters = inspect.signature(logits_processor).parameters + if len(parameters) == 3: + logits_row = logits_processor(prompt_tokens_ids, past_tokens_ids, + logits_row) + else: + logits_row = logits_processor(past_tokens_ids, logits_row) + return logits_row diff --git a/vllm/model_executor/layers/mamba/__init__.py b/vllm/model_executor/layers/mamba/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py new file mode 100644 index 0000000..156e875 --- /dev/null +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -0,0 +1,244 @@ +# SPDX-License-Identifier: Apache-2.0 + +import torch +from torch import nn +from torch.nn.parameter import Parameter + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.forward_context import get_forward_context +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + selective_scan_fn, selective_state_update) +from vllm.model_executor.models.mamba_cache import MambaCacheParams +from vllm.model_executor.utils import set_weight_attrs + + +# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer +@CustomOp.register("mamba_mixer") +class MambaMixer(CustomOp): + """ + Compute ∆, A, B, C, and D the state space parameters and compute + the `contextualized_states`. A, D are input independent + (see Mamba paper [1] Section 3.5.2 "Interpretation of A" + for why A isn't selective) ∆, B, C are input-dependent + (this is a key difference between Mamba and the linear time + invariant S4, and is why Mamba is called + **selective** state spaces) + """ + + def __init__(self, + hidden_size: int, + ssm_state_size: int, + conv_kernel_size: int, + intermediate_size: int, + time_step_rank: int, + use_conv_bias: bool, + use_bias: bool, + use_rms_norm: bool, + rms_norm_has_weight: bool = True, + rms_norm_eps: float = 1e-5, + activation="silu", + is_lora_enabled: bool = False): + super().__init__() + self.time_step_rank = time_step_rank + self.ssm_state_size = ssm_state_size + self.use_rms_norm = use_rms_norm + self.activation = activation + self.is_lora_enabled = is_lora_enabled + + self.conv1d = ColumnParallelLinear( + input_size=conv_kernel_size, + output_size=intermediate_size, + bias=use_conv_bias, + ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + self.in_proj = MergedColumnParallelLinear(hidden_size, + [intermediate_size] * 2, + bias=use_bias) + + # selective projection used to make dt, B and C input dependent + self.x_proj = RowParallelLinear( + intermediate_size, + time_step_rank + ssm_state_size * 2, + bias=False, + ) + # time step projection (discretization) - + # In the forward we need to apply dt_proj without the bias, + # as the bias is added in the selective scan kernel. + self.dt_proj = ColumnParallelLinear(time_step_rank, + intermediate_size, + bias=True, + skip_bias_add=True) + + def weight_loader(param: Parameter, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + param.data.copy_( + loaded_weight.data.split(loaded_weight.shape[0] // tp_size, + dim=0)[tp_rank]) + + def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): + weight_loader(param, -torch.exp(loaded_weight.float())) + + tp_size = get_tensor_model_parallel_world_size() + self.A = nn.Parameter( + torch.empty( + intermediate_size // tp_size, + ssm_state_size, + dtype=torch.float32, + )) + self.D = nn.Parameter(torch.ones(intermediate_size // tp_size)) + + set_weight_attrs(self.D, {"weight_loader": weight_loader}) + set_weight_attrs(self.A, {"weight_loader": A_weight_loader}) + + self.out_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=use_bias, + input_is_parallel=True, + ) + + self.dt_layernorm = RMSNorm( + time_step_rank, + eps=rms_norm_eps, + has_weight=rms_norm_has_weight, + ) if use_rms_norm else None + + self.b_layernorm = RMSNorm( + ssm_state_size, + eps=rms_norm_eps, + has_weight=rms_norm_has_weight, + ) if use_rms_norm else None + + self.c_layernorm = RMSNorm( + ssm_state_size, + eps=rms_norm_eps, + has_weight=rms_norm_has_weight, + ) if use_rms_norm else None + + def forward_native(self, hidden_states: torch.Tensor, + conv_state: torch.Tensor, ssm_state: torch.Tensor): + pass + + def forward_cuda(self, hidden_states: torch.Tensor, + mamba_cache_params: MambaCacheParams): + + attn_metadata: AttentionMetadata = get_forward_context().attn_metadata + + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) + hidden_states, gate = projected_states.chunk(2, dim=-2) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + + if attn_metadata.query_start_loc is not None \ + and attn_metadata.context_lens_tensor is not None: + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + hidden_states = causal_conv1d_fn( + hidden_states, + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=mamba_cache_params.conv_state, + has_initial_state=attn_metadata.context_lens_tensor > 0, + cache_indices=mamba_cache_params.state_indices_tensor, + query_start_loc=attn_metadata.query_start_loc) + else: + hidden_states = causal_conv1d_update( + hidden_states.transpose(0, 1), + mamba_cache_params.conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=mamba_cache_params.state_indices_tensor) + hidden_states = hidden_states.transpose(0, 1) + + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + + if self.is_lora_enabled: + # lora kernel requires contiguous tensor + ssm_parameters = self.x_proj( + hidden_states.transpose(-2, -1).contiguous())[0] + else: + ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0] + + time_step, B, C = torch.split( + ssm_parameters, + [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], + dim=-1, + ) + if self.use_rms_norm: + assert self.dt_layernorm is not None + assert self.b_layernorm is not None + assert self.c_layernorm is not None + time_step = self.dt_layernorm(time_step.contiguous()) + B = self.b_layernorm(B.contiguous()) + C = self.c_layernorm(C.contiguous()) + + discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = (self.dt_proj.bias.float() if hasattr( + self.dt_proj, "bias") else None) + + if attn_metadata.query_start_loc is not None \ + and attn_metadata.context_lens_tensor is not None: + scan_outputs = selective_scan_fn( + hidden_states, + mamba_cache_params.ssm_state, + discrete_time_step, + self.A, + B.transpose(-2, -1), + C.transpose(-2, -1), + self.D.float(), + gate, + time_proj_bias, + delta_softplus=True, + cache_indices=mamba_cache_params.state_indices_tensor, + has_initial_state=attn_metadata.context_lens_tensor > 0, + query_start_loc=attn_metadata.query_start_loc) + else: + scan_outputs = selective_state_update( + mamba_cache_params.ssm_state, + hidden_states.transpose(0, 1), + discrete_time_step.transpose(0, 1), + self.A, + B, + C, + self.D, + gate.transpose(0, 1), + time_proj_bias, + dt_softplus=True, + state_batch_indices=mamba_cache_params.state_indices_tensor) + scan_outputs = scan_outputs.transpose(0, 1) + + # 4. Final linear projection + if self.is_lora_enabled: + # lora kernel requires contiguous tensor + contextualized_states = self.out_proj( + scan_outputs.transpose(-2, -1).contiguous())[0] + else: + contextualized_states = self.out_proj( + scan_outputs.transpose(-2, -1))[0] + return contextualized_states diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py new file mode 100644 index 0000000..d7a45bc --- /dev/null +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -0,0 +1,550 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.attention.backends.placeholder_attn import ( + PlaceholderAttentionMetadata) +from vllm.attention.backends.xformers import XFormersMetadata +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce) +from vllm.forward_context import get_forward_context +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + selective_state_update) +from vllm.model_executor.layers.mamba.ops.ssd_combined import ( + mamba_chunk_scan_combined) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import ( + LoaderFunction, composed_weight_loader, sharded_weight_loader) +from vllm.model_executor.models.mamba_cache import MambaCacheParams +from vllm.model_executor.utils import set_weight_attrs + +# Added by the IBM Team, 2024 + + +# Adapted from transformers.models.mamba2.modeling_mamba2.MambaRMSNormGated +@CustomOp.register("mixer2_gated_rms_norm") +class Mixer2RMSNormGated(CustomOp): + + def __init__(self, full_hidden_size, full_n_groups, eps=1e-6): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.full_hidden_size = full_hidden_size + self.group_size = full_hidden_size // full_n_groups + self.per_rank_hidden_size = full_hidden_size // self.tp_size + self.n_groups = full_hidden_size // self.group_size + + self.variance_epsilon = eps + self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size)) + set_weight_attrs(self.weight, + {"weight_loader": sharded_weight_loader(0)}) + assert self.full_hidden_size % self.tp_size== 0,\ + "Tensor parallel world size must divide hidden size." + + def forward_native( + self, + x: torch.Tensor, + gate: torch.Tensor, + ): + # Three tensor-parallel cases: + # 1. n_groups is 1 + # In this case we parallelize along the reduction dim. + # Each rank computes a local sum of squares followed by AllReduce + # 2. tp_size divides n_groups + # Each rank only reduces within its local group(s). + # No collective ops necessary. + # 3. The general case can be pretty complicated so we AllGather + # the input and then redundantly compute the RMSNorm. + input_dtype = x.dtype + x = x * nn.functional.silu(gate.to(torch.float32)) + + if self.n_groups == 1: + if self.tp_size > 1: + # Compute local sum and then reduce to obtain global sum + local_sums = x.pow(2).sum(dim=-1, keepdim=True) + global_sums = tensor_model_parallel_all_reduce(local_sums) + # Calculate the variance + count = self.tp_size * x.shape[-1] + variance = (global_sums / count) + + else: + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + else: + redundant_tp: bool = self.n_groups % self.tp_size != 0 + if redundant_tp: + # To handle the general case, redundantly apply the variance + x = tensor_model_parallel_all_gather(x, -1) + + *prefix_dims, hidden_dim = x.shape + group_count = hidden_dim // self.group_size + x_grouped = x.view(*prefix_dims, group_count, self.group_size) + variance = x_grouped.pow(2).mean(-1, keepdim=True) + x_grouped = x_grouped * torch.rsqrt(variance + + self.variance_epsilon) + x = x_grouped.view(*prefix_dims, hidden_dim) + + if redundant_tp: + start = self.per_rank_hidden_size * self.tp_rank + end = start + self.per_rank_hidden_size + x = x[..., start:end] + + return self.weight * x.to(input_dtype) + + def forward_cuda( + self, + x: torch.Tensor, + gate: torch.Tensor, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + + if self.tp_size > 1 or self.n_groups != 1: + return self.forward_native(x, gate) + + from vllm import _custom_ops as ops + + # cast x and gate to float32 before silu + out = torch.empty_like(x) + y = x * nn.functional.silu(gate.to(torch.float32)) + ops.rms_norm( + out, + y.to(x.dtype), + self.weight.data, + self.variance_epsilon, + ) + return out + + +def extra_groups_for_head_shards(ngroups: int, tp_size: int): + """Compute the increase in group numbers to account for + replication in order to accompany the head shards.""" + + # in the case ngoups % tp_size == 0, this will be zero + if ngroups % tp_size == 0: + return 0 + + # for n_groups == 1, this is exactly tp_size - n_groups + return tp_size - ngroups + + +def mamba_v2_sharded_weight_loader( + shard_spec: List[Tuple[int, int, float]], + tp_size: int, + tp_rank: int, +) -> LoaderFunction: + """Create a weight loader for mamba v2. This ensures that the projections + are correctly sharded so that they can be split into x, B, C. It also + ensures the the all the groups corresponding to a head shard is placed + together with it. + """ + + def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + + # - track boundary of (sharded) param, and loaded_weight, respectively + boundary, loaded_boundary = 0, 0 + + # - iterate over the shard specs + for full_dim, extra, duplicate_groups in shard_spec: + # - full dim is the model dim (before TP). + # - extra > 0, means there is expected overall increase + # of dimensions. This is so because of replication. + # - ratio is used map the tp_rank to the actual shard + # rank. This is useful when there is replication of + # groups to accompany head shards. + + # - size of the loaded shard + shard_size = full_dim // tp_size + + # - compute the rank into the loaded shard. + # - if there is replication, different TP shards will + # take from the same rank. + # NOTE: currently we only support duplication + # in the case where num_groups == 1 + rank = 0 if duplicate_groups else tp_rank + + # - leftmost boundary index into loaded weight. + loaded_skip = rank * shard_size + loaded_start_idx = loaded_boundary + loaded_skip + + # - take these many dims from the loaded weight. + take = min(shard_size, full_dim - extra - loaded_skip) + + # - always shard on dim 0 + # - the ignore is for a mundane mypy error as it does not + # seem to handle slices well. + # https://github.com/python/mypy/issues/2410 + param.data[ + boundary:(boundary + take), # type: ignore[misc] + ...] = loaded_weight[loaded_start_idx:( # type: ignore[misc] + loaded_start_idx + take)] # type: ignore[misc] + + # move indexing boundaries + boundary += shard_size + loaded_boundary += (full_dim - extra) + + return loader + + +# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer +@CustomOp.register("mamba_mixer2") +class MambaMixer2(CustomOp): + """ + Compute ∆, A, B, C, and D the state space parameters and compute + the `contextualized_states`. A, D are input independent + (see Mamba paper [1] Section 3.5.2 "Interpretation of A" + for why A isn't selective) ∆, B, C are input-dependent + (this is a key difference between Mamba and the linear time + invariant S4, and is why Mamba is called + **selective** state spaces) + """ + + def __init__(self, + hidden_size: int, + ssm_state_size: int, + conv_kernel_size: int, + intermediate_size: int, + use_conv_bias: bool, + use_bias: bool, + n_groups: int = 1, + num_heads: int = 128, + head_dim: int = 64, + rms_norm_eps: float = 1e-5, + activation="silu", + chunk_size: int = 256, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + + # For TP, the sharding plan is as follows: + # - for the conv modules, since + # conv_dim = intermediate_size * 2 * n_groups * ssm_state_size, + # we shard intermediate_size and n_groups + # - since intermediate_size = n_heads * head_dim, sharding on + # intermediate_size is achieved by sharding on n_heads. + # - IF, world_size divides groups, then sharding + # (n_groups / world_size, n_heads / world_size) + # also maintains the invariant n_heads % n_groups == 0 + # - HOWEVER IF, world_size DOES NOT divide groups, then we need + # to allocate extra space in the shard, such that groups + # may be replicated to follow the head shard. + # - NOTE: currently for the world size DOES NOT divide groups + # case, we only support the case when n_groups == 1 + self.tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + assert num_heads % self.tp_size == 0, \ + "Tensor parallel world size must divide num heads." + + assert (n_groups % self.tp_size) == 0 or n_groups == 1, \ + ( + "If tensor parallel world size does not divide num_heads, " + "then num_groups must equal 1." + ) + + assert self.tp_size == 1 or quant_config is None, \ + "Tensor parallel currently not supported for quantized models." + + self.ssm_state_size = ssm_state_size + self.activation = activation + + self.chunk_size = chunk_size + self.intermediate_size = intermediate_size + self.head_dim = head_dim + self.num_heads = num_heads + + self.n_groups = n_groups + if n_groups % self.tp_size != 0: + # - for TP we shard conv_dim by sharding on n_groups, + # - but if n_groups cannot divide tp_size, we need to + # extend some extra groups + self.n_groups = n_groups + extra_groups_for_head_shards( + n_groups, self.tp_size) + + self.conv_dim = (intermediate_size + + 2 * self.n_groups * ssm_state_size) + self.conv1d = ColumnParallelLinear( + input_size=conv_kernel_size, + output_size=self.conv_dim, + bias=use_conv_bias, + quant_config=None, + ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + self.in_proj = ColumnParallelLinear(input_size=hidden_size, + output_size=intermediate_size + + self.conv_dim + self.num_heads, + bias=use_bias, + quant_config=quant_config) + + # - because in_proj is a concatenation of 3 weights, we + # need to interleave them before sharding + # - use the custom weight loader mamba_v2_sharded_weight_loader + # for conv1d.bias, covn1d.weight and in_proj.weight + # - need to set these settings, to assign the groups to the head shards + group_shard_settings = ( + self.n_groups * self.ssm_state_size, # expected model size + (self.n_groups - n_groups) * + self.ssm_state_size, # extra dims assigned + n_groups == 1, # if there was only one group + ) + intermediate_settings = (intermediate_size, 0, False) + head_setings = (self.num_heads, 0, False) + + # - the weight already has a "weight_loader" attribute + # which set_weight_attrs will raise if we do not + # delete before trying to override it + # - ditto for the otther two weights below + delattr(self.conv1d.bias, "weight_loader") + set_weight_attrs( + self.conv1d.bias, { + "weight_loader": + mamba_v2_sharded_weight_loader( + [ + intermediate_settings, + group_shard_settings, + group_shard_settings, + ], + self.tp_size, + tp_rank, + ) + }) + + delattr(self.conv1d.weight, "weight_loader") + set_weight_attrs( + self.conv1d.weight, { + "weight_loader": + mamba_v2_sharded_weight_loader([ + intermediate_settings, + group_shard_settings, + group_shard_settings, + ], self.tp_size, tp_rank) + }) + + if quant_config is None: + # - quant layers do not have a weight loader + delattr(self.in_proj.weight, "weight_loader") + set_weight_attrs( + self.in_proj.weight, + { + "weight_loader": + mamba_v2_sharded_weight_loader( + [ + intermediate_settings, # for gate + intermediate_settings, + group_shard_settings, + group_shard_settings, + head_setings, # for dt + ], + self.tp_size, + tp_rank) + }) + + # - these are TPed by heads to reduce the size of the + # temporal shape + self.A = nn.Parameter( + torch.empty( + divide(num_heads, self.tp_size), + dtype=torch.float32, + )) + self.D = nn.Parameter(torch.ones(num_heads // self.tp_size)) + self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size)) + + set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)}) + a_weight_loader = composed_weight_loader( + sharded_weight_loader(0), lambda x: -torch.exp(x.float())) + set_weight_attrs(self.A, {"weight_loader": a_weight_loader}) + set_weight_attrs(self.dt_bias, + {"weight_loader": sharded_weight_loader(0)}) + + self.out_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=use_bias, + input_is_parallel=True, + quant_config=quant_config) + + self.norm = Mixer2RMSNormGated(intermediate_size, + n_groups, + eps=rms_norm_eps) + + def forward_native(self, hidden_states: torch.Tensor, + conv_state: torch.Tensor, ssm_state: torch.Tensor): + pass + + def forward_cuda( + self, + hidden_states: torch.Tensor, + mamba_cache_params: MambaCacheParams, + sequence_idx: Optional[torch.Tensor] = None, + ): + attn_metadata: AttentionMetadata = get_forward_context().attn_metadata + + seq_len, _ = hidden_states.shape + groups_time_state_size = self.n_groups * self.ssm_state_size + + # detect if there are prefills + has_prefill = attn_metadata.num_prefills > 0 + + # - also need flags to indicate if there are initial states + # - currently we really only support the FlashAttention backend + has_initial_states = None + if (isinstance(attn_metadata, + (FlashAttentionMetadata, XFormersMetadata, + PlaceholderAttentionMetadata)) + and attn_metadata.context_lens_tensor is not None): + has_initial_states = attn_metadata.context_lens_tensor > 0 + + # 1. Gated MLP's linear projection + projected_states, _ = self.in_proj(hidden_states) + gate, hidden_states_B_C, dt = torch.split( + projected_states, + [ + self.intermediate_size // self.tp_size, + self.conv_dim // self.tp_size, + self.num_heads // self.tp_size, + ], + dim=-1, + ) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + + if has_prefill: + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + # - "cache_indices" updates the conv_state cache in positions + # pointed to by "mamba_cache_params.state_indices_tensor" + hidden_states_B_C = causal_conv1d_fn( + hidden_states_B_C.transpose(0, 1), + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=mamba_cache_params.conv_state, + has_initial_state=has_initial_states, + cache_indices=mamba_cache_params.state_indices_tensor, + query_start_loc=attn_metadata.query_start_loc).transpose( + 0, 1)[:seq_len] + + # TODO: Why is this needed? + hidden_states_B_C = hidden_states_B_C.contiguous() + else: + hidden_states_B_C = causal_conv1d_update( + hidden_states_B_C, + mamba_cache_params.conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=mamba_cache_params.state_indices_tensor) + + # - get hidden_states, B and C after depthwise convolution. + hidden_states, B, C = torch.split( + hidden_states_B_C, + [ + self.intermediate_size // self.tp_size, + groups_time_state_size // self.tp_size, + groups_time_state_size // self.tp_size, + ], + dim=-1, + ) + + # 3. State Space Model sequence transformation + if has_prefill: + + initial_states = None + if has_initial_states is not None and torch.any( + has_initial_states): + zero_init_indices = mamba_cache_params.state_indices_tensor[ + ~has_initial_states] + mamba_cache_params.ssm_state[zero_init_indices] = 0 + initial_states = mamba_cache_params.ssm_state[ + mamba_cache_params.state_indices_tensor] + + scan_output, varlen_state = mamba_chunk_scan_combined( + hidden_states.view(1, seq_len, self.num_heads // self.tp_size, + self.head_dim), + dt.unsqueeze(0), + self.A, + B.view(1, seq_len, self.n_groups // self.tp_size, -1), + C.view(1, seq_len, self.n_groups // self.tp_size, -1), + chunk_size=self.chunk_size, + D=self.D, + z=None, + dt_bias=self.dt_bias, + seq_idx=sequence_idx, + cu_seqlens=attn_metadata.query_start_loc, + initial_states=initial_states, + return_varlen_states=True, + return_final_states=False, + dt_softplus=True, + dt_limit=(0.0, float("inf")), + ) + + # update ssm states + # - varlen state is a (batch, nheads, headdim, dstate) tensor + mamba_cache_params.ssm_state[ + mamba_cache_params.state_indices_tensor] = varlen_state + + # - reshape + hidden_states = scan_output.view(seq_len, -1) + else: + + n_groups = self.n_groups // self.tp_size + A = self.A[:, None, ...][:, :, None].expand( + -1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + dt = dt[:, :, None].expand(-1, -1, self.head_dim) + dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + D = self.D[:, None, ...].expand(-1, self.head_dim) + B = B.view(-1, n_groups, B.shape[1] // n_groups) + C = C.view(-1, n_groups, C.shape[1] // n_groups) + hidden_states_reshaped = hidden_states.view( + -1, self.num_heads // self.tp_size, self.head_dim) + + # - the hidden is reshaped into number of current batches + # - in this case there is no more prefill, so the batches gen + # 1 token at a time + # - thus hidden will be (bs, num_heads, head_dim) + # - mamba_cache_params.ssm_state's slots will be selected + # using "mamba_cache_params.state_indices_tensor", just as + # above in the prefill case + + hidden_states = selective_state_update( + mamba_cache_params.ssm_state, + hidden_states_reshaped, + dt, + A, + B, + C, + D, + z=None, + dt_bias=dt_bias, + dt_softplus=True, + state_batch_indices=mamba_cache_params.state_indices_tensor, + ) + hidden_states = hidden_states.view( + -1, (self.num_heads // self.tp_size) * self.head_dim) + + # # 4. gated MLP + hidden_states = self.norm(hidden_states, gate) + + # # 5. Final linear projection + out, _ = self.out_proj(hidden_states) + return out diff --git a/vllm/model_executor/layers/mamba/ops/__init__.py b/vllm/model_executor/layers/mamba/ops/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py new file mode 100644 index 0000000..21e2716 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright (c) 2024, Tri Dao. +# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py + +from typing import Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.attention.backends.utils import PAD_SLOT_ID + + +def causal_conv1d_fn(x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + query_start_loc: Optional[torch.Tensor] = None, + cache_indices: Optional[torch.Tensor] = None, + has_initial_state: Optional[torch.Tensor] = None, + conv_states: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", + pad_slot_id: int = PAD_SLOT_ID): + """ + x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen + sequences are concatenated from left to right for varlen + weight: (dim, width) + bias: (dim,) + query_start_loc: (batch + 1) int32 + The cumulative sequence lengths of the sequences in + the batch, used to index into sequence. prepended by 0. + for example: query_start_loc = torch.Tensor([0,10,16,17]), + x.shape=(dim,17) + cache_indices: (batch) int32 + indicates the corresponding state index, + like so: conv_state = conv_states[cache_indices[batch_id]] + has_initial_state: (batch) bool + indicates whether should the kernel take the current state as initial + state for the calculations + conv_states: (...,dim,width - 1) itype + updated inplace if provided + activation: either None or "silu" or "swish" + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(-1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + + ops.causal_conv1d_fwd(x, weight, bias, conv_states, query_start_loc, + cache_indices, has_initial_state, activation + in ["silu", "swish"], pad_slot_id) + return x + + +def causal_conv1d_update(x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: Optional[str] = None, + cache_seqlens: Optional[torch.Tensor] = None, + conv_state_indices: Optional[torch.Tensor] = None, + pad_slot_id: int = PAD_SLOT_ID): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state + starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If not None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation_val = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + ops.causal_conv1d_update(x, conv_state, weight, bias, activation_val, + cache_seqlens, conv_state_indices, pad_slot_id) + if unsqueeze: + x = x.squeeze(-1) + return x diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py new file mode 100644 index 0000000..b31b980 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -0,0 +1,413 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py + +import torch +import triton +import triton.language as tl +from packaging import version + +from vllm import _custom_ops as ops +from vllm.attention.backends.utils import PAD_SLOT_ID + +TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") + +if TRITON3: + + @triton.jit + def softplus(dt): + dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt) + return dt +else: + + @triton.jit + def softplus(dt): + dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) + return dt + + +@triton.heuristics( + {"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) +@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) +@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) +@triton.heuristics({ + "HAS_STATE_BATCH_INDICES": + lambda args: args["state_batch_indices_ptr"] is not None +}) +@triton.heuristics( + {"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) +@triton.jit +def _selective_scan_update_kernel( + # Pointers to matrices + state_ptr, + x_ptr, + dt_ptr, + dt_bias_ptr, + A_ptr, + B_ptr, + C_ptr, + D_ptr, + z_ptr, + out_ptr, + state_batch_indices_ptr, + pad_slot_id, + # Matrix dimensions + batch, + nheads, + dim, + dstate, + nheads_ngroups_ratio, + # Strides + stride_state_batch, + stride_state_head, + stride_state_dim, + stride_state_dstate, + stride_x_batch, + stride_x_head, + stride_x_dim, + stride_dt_batch, + stride_dt_head, + stride_dt_dim, + stride_dt_bias_head, + stride_dt_bias_dim, + stride_A_head, + stride_A_dim, + stride_A_dstate, + stride_B_batch, + stride_B_group, + stride_B_dstate, + stride_C_batch, + stride_C_group, + stride_C_dstate, + stride_D_head, + stride_D_dim, + stride_z_batch, + stride_z_head, + stride_z_dim, + stride_out_batch, + stride_out_head, + stride_out_dim, + # Meta-parameters + DT_SOFTPLUS: tl.constexpr, + TIE_HDIM: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + HAS_D: tl.constexpr, + HAS_Z: tl.constexpr, + HAS_STATE_BATCH_INDICES: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + + # If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate + # is taken from the state_batch_indices_ptr Otherwise, the state coordinate + # is the same as the batch id. + if HAS_STATE_BATCH_INDICES: + state_batch_indices_ptr += pid_b + state_batch_idx = tl.load(state_batch_indices_ptr) + state_ptr += (state_batch_idx * stride_state_batch + + pid_h * stride_state_head) + else: + state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head + + x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head + if HAS_DT_BIAS: + dt_bias_ptr += pid_h * stride_dt_bias_head + A_ptr += pid_h * stride_A_head + B_ptr += pid_b * stride_B_batch + (pid_h // + nheads_ngroups_ratio) * stride_B_group + C_ptr += pid_b * stride_C_batch + (pid_h // + nheads_ngroups_ratio) * stride_C_group + if HAS_Z: + z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head + out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) + state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + + offs_n[None, :] * stride_state_dstate) + x_ptrs = x_ptr + offs_m * stride_x_dim + dt_ptrs = dt_ptr + offs_m * stride_dt_dim + if HAS_DT_BIAS: + dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim + if HAS_D: + D_ptr += pid_h * stride_D_head + A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + + offs_n[None, :] * stride_A_dstate) + B_ptrs = B_ptr + offs_n * stride_B_dstate + C_ptrs = C_ptr + offs_n * stride_C_dstate + if HAS_D: + D_ptrs = D_ptr + offs_m * stride_D_dim + if HAS_Z: + z_ptrs = z_ptr + offs_m * stride_z_dim + out_ptrs = out_ptr + offs_m * stride_out_dim + mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) + if HAS_STATE_BATCH_INDICES: + mask &= (state_batch_idx != pad_slot_id) + state = tl.load(state_ptrs, mask=mask, other=0.0) + + x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if not TIE_HDIM: + dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, + other=0.0).to(tl.float32) + if DT_SOFTPLUS: + dt = softplus(dt) + A = tl.load(A_ptrs, + mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), + other=0.0).to(tl.float32) + dA = tl.exp(A * dt[:, None]) + else: + dt = tl.load(dt_ptr).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptr).to(tl.float32) + if DT_SOFTPLUS: + dt = softplus(dt) + A = tl.load(A_ptr).to(tl.float32) + dA = tl.exp(A * dt) # scalar, not a matrix + + B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + if HAS_D: + D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_Z: + z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + + dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt + state = state * dA + dB * x[:, None] + + mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) + if HAS_STATE_BATCH_INDICES: + mask &= (state_batch_idx != pad_slot_id) + tl.store(state_ptrs, state, mask=mask) + out = tl.sum(state * C[None, :], axis=1) + if HAS_D: + out += x * D + if HAS_Z: + out *= z * tl.sigmoid(z) + tl.store(out_ptrs, out, mask=offs_m < dim) + + +def selective_state_update(state, + x, + dt, + A, + B, + C, + D=None, + z=None, + dt_bias=None, + dt_softplus=False, + state_batch_indices=None, + pad_slot_id=PAD_SLOT_ID): + """ + Argument: + state: (batch, dim, dstate) or (batch, nheads, dim, dstate) + x: (batch, dim) or (batch, nheads, dim) + dt: (batch, dim) or (batch, nheads, dim) + A: (dim, dstate) or (nheads, dim, dstate) + B: (batch, dstate) or (batch, ngroups, dstate) + C: (batch, dstate) or (batch, ngroups, dstate) + D: (dim,) or (nheads, dim) + z: (batch, dim) or (batch, nheads, dim) + dt_bias: (dim,) or (nheads, dim) + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + Return: + out: (batch, dim) or (batch, nheads, dim) + """ + has_heads = state.dim() > 3 + if state.dim() == 3: + state = state.unsqueeze(1) + if x.dim() == 2: + x = x.unsqueeze(1) + if dt.dim() == 2: + dt = dt.unsqueeze(1) + if A.dim() == 2: + A = A.unsqueeze(0) + if B.dim() == 2: + B = B.unsqueeze(1) + if C.dim() == 2: + C = C.unsqueeze(1) + if D is not None and D.dim() == 1: + D = D.unsqueeze(0) + if z is not None and z.dim() == 2: + z = z.unsqueeze(1) + if dt_bias is not None and dt_bias.dim() == 1: + dt_bias = dt_bias.unsqueeze(0) + + _, nheads, dim, dstate = state.shape + batch = x.shape[0] + + assert x.shape == (batch, nheads, dim) + assert dt.shape == x.shape + assert A.shape == (nheads, dim, dstate) + ngroups = B.shape[1] + assert nheads % ngroups == 0, "nheads must be divisible by ngroups" + assert B.shape == (batch, ngroups, dstate) + assert C.shape == B.shape + if D is not None: + assert D.shape == (nheads, dim) + if z is not None: + assert z.shape == x.shape + if dt_bias is not None: + assert dt_bias.shape == (nheads, dim) + if state_batch_indices is not None: + assert state_batch_indices.shape == (batch, ) + out = torch.empty_like(x) + grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads) + z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else + (0, 0, 0)) + # We don't want autotune since it will overwrite the state + # We instead tune by hand. + BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 else + ((16, 4) if dstate <= 32 else + ((8, 4) if dstate <= 64 else + ((4, 4) if dstate <= 128 else ((4, 8)))))) + tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride( + -1) == 0 and dt_bias.stride(-1) == 0 + with torch.cuda.device(x.device.index): + _selective_scan_update_kernel[grid]( + state, + x, + dt, + dt_bias, + A, + B, + C, + D, + z, + out, + state_batch_indices, + pad_slot_id, + batch, + nheads, + dim, + dstate, + nheads // ngroups, + state.stride(0), + state.stride(1), + state.stride(2), + state.stride(3), + x.stride(0), + x.stride(1), + x.stride(2), + dt.stride(0), + dt.stride(1), + dt.stride(2), + *(dt_bias.stride(0), + dt_bias.stride(1)) if dt_bias is not None else 0, + A.stride(0), + A.stride(1), + A.stride(2), + B.stride(0), + B.stride(1), + B.stride(2), + C.stride(0), + C.stride(1), + C.stride(2), + *(D.stride(0), D.stride(1)) if D is not None else 0, + z_strides[0], + z_strides[1], + z_strides[2], + out.stride(0), + out.stride(1), + out.stride(2), + dt_softplus, + tie_hdim, + BLOCK_SIZE_M, + num_warps=num_warps, + ) + if not has_heads: + out = out.squeeze(1) + return out + + +def selective_scan_fn(u, + ssm_states, + delta, + A, + B, + C, + D=None, + z=None, + delta_bias=None, + delta_softplus=False, + query_start_loc=None, + cache_indices=None, + has_initial_state=None, + pad_slot_id=PAD_SLOT_ID) -> torch.Tensor: + """ + u: (dim, total_length) for varlen or (batch, dim, seqlen) + applies changes in place. + ssm_states: (batch, dim, dstate) or (batch, nheads, dim, dstate) + applies changes in place. + delta: (dim, total_length) for varlen or (batch, dim, seqlen) + A: (dim, dstate) + B: (ngroups, dstate, total_length) for varlen or + (batch,ngroups,dstate,seqlen) + C: (ngroups, dstate, total_length) for varlen or + (batch,ngroups,dstate,seqlen) + D: (dim,) + z: (dim, total_length) for varlen or (batch, dim, seqlen) + dt_bias: (dim,) or (dim) + query_start_loc: (batch + 1) int32 + The cumulative sequence lengths of the sequences in + the batch, used to index into sequence. prepended with 0. + for example: query_start_loc = torch.Tensor([0,10,16,17]), + x.shape=(dim,17) + cache_indices: (batch) int32 + A tensor with each cell is a correspondent + input and output ssm_state index + has_initial_state: (batch) bool + A tensor populated with ones and zeros, + indicate if the ssm_state at the corresponding index should be + used as initial state. Not providing argument assumes + there's no initial state + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padding entries + that will not be processed, + for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + in this case, the kernel will not process entries at indices 0 and 3 + returns + output: (dim, total_length) for varlen or (batch, dim, seqlen) + supports inplace replacement + """ + if u.stride(-1) != 1: + u = u.contiguous() + if delta.stride(-1) != 1: + delta = delta.contiguous() + if D is not None: + D = D.contiguous() + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if z is not None and z.stride(-1) != 1: + z = z.contiguous() + if B.dim() == 3 and query_start_loc is None: + B = B.unsqueeze(1) + if B.dim() == 2 and query_start_loc is not None: + B = B.unsqueeze(0) + if C.dim() == 3 and query_start_loc is None: + C = C.unsqueeze(1) + if C.dim() == 2 and query_start_loc is not None: + C = C.unsqueeze(0) + + ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, + query_start_loc, cache_indices, has_initial_state, + ssm_states, pad_slot_id) + + if z is None: + return delta # output written inplace to delta + else: + return z # output written inplace to z diff --git a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py new file mode 100644 index 0000000..388a633 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py @@ -0,0 +1,261 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_bmm.py + +# ruff: noqa: E501,SIM102 + +import math + +import torch +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=2), + ], + key=['chunk_size', 'K', 'IS_CAUSAL'], +) +@triton.jit +def _bmm_chunk_fwd_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + out_ptr, + seq_idx_ptr, + # Matrix dimensions + seqlen, + chunk_size, + K, + ngroups, + stride_a_batch, + stride_a_seqlen, + stride_a_head, + stride_ak, + stride_b_batch, + stride_b_seqlen, + stride_b_head, + stride_bk, + stride_out_batch, + stride_out_chunk, + stride_out_head, + stride_outm, + stride_outn, + stride_seq_idx_batch, + stride_seq_idx_seqlen, + # Meta-parameters + IS_CAUSAL: tl.constexpr, + dot_dtype: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_ch = tl.program_id(axis=2).to(tl.int64) + pid_c = pid_ch // ngroups + pid_h = pid_ch - pid_c * ngroups + num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + if IS_CAUSAL: + if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M: + return + a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + + offs_n[None, :] * stride_b_seqlen) + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0).to(dot_dtype) + b = tl.load(b_ptrs, + mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & + (offs_n[None, :] < chunk_size_limit), + other=0.0).to(dot_dtype) + acc += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + if HAS_SEQ_IDX: + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, + mask=offs_m < chunk_size_limit, + other=-1) + seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, + mask=offs_n < chunk_size_limit, + other=-2) + acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0) + out = acc.to(out_ptr.dtype.element_ty) + + out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head + out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + + offs_n[None, :] * stride_outn) + tl.store(out_ptrs, + out, + mask=(offs_m[:, None] < chunk_size) & + (offs_n[None, :] < chunk_size)) + + +def _bmm_chunk_fwd(a, + b, + chunk_size, + seq_idx=None, + causal=False, + output_dtype=None): + """ + Argument: + a: (batch, seqlen, k) or (batch, seqlen, ngroups, k) + b: (batch, seqlen, k) or (batch, seqlen, ngroups, k) + seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out. + causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are + guaranteed to be correct. + Return: + out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size) + """ + # Check constraints. + has_groups = a.dim() == 4 + if not has_groups: + batch, seqlen, k = a.shape + else: + batch, seqlen, ngroups, k = a.shape + assert b.shape == a.shape + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if a.stride(-1) != 1 and a.stride(1) != 1: + a = a.contiguous() + if b.stride(-1) != 1 and b.stride(1) != 1: + b = b.contiguous() + nchunks = math.ceil(seqlen / chunk_size) + # Allocates output. + out_dtype = a.dtype if output_dtype is None else output_dtype + out = torch.empty( + (batch, nchunks, chunk_size, chunk_size) if not has_groups else + (batch, nchunks, ngroups, chunk_size, chunk_size), + device=a.device, + dtype=out_dtype) + dot_dtype = (tl.bfloat16 + if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else + (tl.float16 if a.dtype == torch.float16 + or b.dtype == torch.float16 else tl.float32)) + grid = lambda META: (triton.cdiv( + chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv( + chunk_size, META['BLOCK_SIZE_N']), batch, nchunks + if not has_groups else nchunks * ngroups) + with torch.cuda.device(a.device.index): + _bmm_chunk_fwd_kernel[grid]( + a, + b, + out, + seq_idx, + seqlen, + chunk_size, + k, + ngroups if has_groups else 1, + a.stride(0), + a.stride(1), + 0 if not has_groups else a.stride(2), + a.stride(-1), + b.stride(0), + b.stride(1), + 0 if not has_groups else b.stride(2), + b.stride(-1), + out.stride(0), + out.stride(1), + 0 if not has_groups else out.stride(2), + out.stride(-2), + out.stride(-1), + *((seq_idx.stride(0), + seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + causal, + dot_dtype, + HAS_SEQ_IDX=seq_idx is not None, + ) + return out diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py new file mode 100644 index 0000000..7ef5111 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -0,0 +1,619 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_scan.py + +# ruff: noqa: E501,SIM102 + +import math + +import torch +import triton +import triton.language as tl +from packaging import version + +TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') + + +@triton.autotune( + configs=[ + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 64 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 64 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=2), + ], + key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'], +) +@triton.jit +def _chunk_scan_fwd_kernel( + # Pointers to matrices + cb_ptr, + x_ptr, + z_ptr, + out_ptr, + out_x_ptr, + dt_ptr, + dA_cumsum_ptr, + seq_idx_ptr, + C_ptr, + states_ptr, + D_ptr, + initstates_ptr, + chunk_indices_ptr, + chunk_offsets_ptr, + chunk_meta_num, + # Matrix dimensions + chunk_size, + hdim, + dstate, + batch, + seqlen, + nheads_ngroups_ratio, + # Strides + stride_cb_batch, + stride_cb_chunk, + stride_cb_head, + stride_cb_csize_m, + stride_cb_csize_k, + stride_x_batch, + stride_x_seqlen, + stride_x_head, + stride_x_hdim, + stride_z_batch, + stride_z_seqlen, + stride_z_head, + stride_z_hdim, + stride_out_batch, + stride_out_seqlen, + stride_out_head, + stride_out_hdim, + stride_dt_batch, + stride_dt_chunk, + stride_dt_head, + stride_dt_csize, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + stride_seq_idx_batch, + stride_seq_idx_seqlen, + stride_C_batch, + stride_C_seqlen, + stride_C_head, + stride_C_dstate, + stride_states_batch, + stride_states_chunk, + stride_states_head, + stride_states_hdim, + stride_states_dstate, + stride_init_states_batch, + stride_init_states_head, + stride_init_states_hdim, + stride_init_states_dstate, + stride_D_head, + # Meta-parameters + IS_CAUSAL: tl.constexpr, + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + HAS_Z: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, + IS_TRITON_22: tl.constexpr, + HAS_INITSTATES: tl.constexpr, +): + pid_bc = tl.program_id(axis=1).to(tl.int64) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + if not HAS_INITSTATES: + c_idx = pid_c + c_off = 0 + else: + c_idx = tl.load(chunk_indices_ptr + pid_c, mask=pid_c > -1, other=0) + c_off = tl.load(chunk_offsets_ptr + pid_c, mask=pid_c > -1, other=0) + + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + cb_ptr += pid_b * stride_cb_batch + c_idx * stride_cb_chunk + ( + pid_h // nheads_ngroups_ratio) * stride_cb_head + x_ptr += pid_b * stride_x_batch + c_idx * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + c_idx * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + c_idx * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + C_ptr += pid_b * stride_C_batch + c_idx * chunk_size * stride_C_seqlen + ( + pid_h // nheads_ngroups_ratio) * stride_C_head + + # M-block offsets and prev states + # - logic in next block may override these if there is an active offset + offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M) + prev_states_ptr = states_ptr + pid_b * stride_states_batch + c_idx * stride_states_chunk + pid_h * stride_states_head + prev_states_hdim = stride_states_hdim + prev_states_dstate = stride_states_dstate + + chunk_size_limit = min(chunk_size, seqlen - c_idx * chunk_size) + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + c_idx * chunk_size * stride_seq_idx_seqlen + + # - we only need seq_idx_prev to be aligned to chunk boundary + seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, + mask=c_idx >= 1, + other=0) + + if HAS_INITSTATES: + # if there are init states, we only need seq_idx_m to point + # what is the current seq_idx + + # get current seq idx + if (pid_m * BLOCK_SIZE_M + c_off) < chunk_size_limit: + seq_idx_m = tl.load( + seq_idx_ptr + + (pid_m * BLOCK_SIZE_M + c_off) * stride_seq_idx_seqlen, ) + + # - recall that in ssd_state_passing, for the case c_off == 0 + # i.e., the very first sequence, we made states_ptr hold its initial state + # so this edge case is taken care of + if ((c_off == 0) and + (seq_idx_prev != seq_idx_m + ) # if a seq is changed exactly on boundary + or (c_off > 0) # implies a new example (pseudo chunk) + ): + + # - replace prev_states_ptr with init_states + prev_states_ptr = initstates_ptr + seq_idx_m * stride_init_states_batch + pid_h * stride_init_states_head + prev_states_hdim = stride_init_states_hdim # override strides + prev_states_dstate = stride_init_states_dstate + + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, + mask=offs_m < chunk_size, + other=0.0).to(tl.float32) + + # - handle chunk state limit + if HAS_INITSTATES: + + # have to split this if otherwise compilation will have problems + dA_cs_m_boundary = 0.0 + + # get the c_idx for the next (logica) chunk + c_idx_n = tl.load( + chunk_indices_ptr + (pid_c + 1), + mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num, + other=-1 # to trigger different chunk + ) + + # - there are things to consider + # A. if c_off > 0 then we need to move the dA_cs boundary to ensure correct + # contribution of past states + # B. if c_off_n < chunk_size_limit, then we need to adjust this so as not to + # encroach into the next sequence, where c_off_n is the offset of the next + # (logical) chunk. + # An equivalent check for B is c_idx == c_idx_n, where there is repetition in + # (logical) chunk indices. + + if (c_idx == c_idx_n) or c_off > 0: + + # get the next offset + c_off_n = tl.load(chunk_offsets_ptr + (pid_c + 1), + mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num, + other=chunk_size) + + # in this case, adjust down the chunk_size_limit + if c_idx == c_idx_n: + chunk_size_limit = min(c_off_n, chunk_size_limit) + + # get the cs at the offset boundary + # - c_off == 0 is a passthrough + dA_cs_m_boundary = tl.load( + dA_cumsum_ptr + + (pid_m * BLOCK_SIZE_M + c_off - 1) * stride_dA_cs_csize, + mask=(((pid_m * BLOCK_SIZE_M + c_off - 1) > -1) + and ((pid_m * BLOCK_SIZE_M + c_off) < chunk_size)), + other=0.0).to(tl.float32) + + if HAS_SEQ_IDX: + # - handle seq idx when HAS_INITSTATES==False + if not HAS_INITSTATES: + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, + mask=offs_m < chunk_size_limit, + other=-1) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Without the if (pid_c > -1), with Triton 2.1.0, I get + # Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed. + # With Triton 2.2.0, this works + if IS_TRITON_22 or c_idx > -1: + # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 + offs_k_dstate = tl.arange( + 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) + C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + + offs_k_dstate[None, :] * stride_C_dstate) + + prev_states_ptrs = prev_states_ptr + ( + offs_n[None, :] * prev_states_hdim + + offs_k_dstate[:, None] * prev_states_dstate) + if HAS_SEQ_IDX: + + if not HAS_INITSTATES: + # - this is for continuous batching where there is no init states + scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), + 0.0) + else: + # - if there is initstates, we will rely on prev_states, no zeroing + # required. + scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary) + else: + scale_m = tl.exp(dA_cs_m) + if BLOCK_SIZE_DSTATE <= 128: + C = tl.load(C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_k_dstate[None, :] < dstate), + other=0.0) + + prev_states = tl.load(prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate) & + (offs_n[None, :] < hdim), + other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + acc = tl.dot(C, prev_states) * scale_m[:, None] + else: + for k in range(0, dstate, BLOCK_SIZE_K): + C = tl.load(C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_k_dstate[None, :] < dstate - k), + other=0.0) + # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty) + prev_states = tl.load( + prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate - k) & + (offs_n[None, :] < hdim), + other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + acc += tl.dot(C, prev_states) + C_ptrs += BLOCK_SIZE_K + prev_states_ptrs += BLOCK_SIZE_K + acc *= scale_m[:, None] + + offs_k = tl.arange(0, BLOCK_SIZE_K) + c_off + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + + offs_k[None, :] * stride_cb_csize_k) + x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen + + offs_n[None, :] * stride_x_hdim) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + K_MAX = chunk_size_limit if not IS_CAUSAL else min( + (pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) + for k in range(0, K_MAX, BLOCK_SIZE_K): + cb = tl.load(cb_ptrs, + mask=(offs_m[:, None] < chunk_size) & + (offs_k[None, :] < chunk_size - k), + other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, + mask=offs_k < chunk_size - k, + other=0.0).to(tl.float32) + # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j]. + # So we don't need masking wrt seq_idx here. + cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :]) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, + other=0.0).to(tl.float32) + cb *= dt_k + if IS_CAUSAL: + mask = offs_m[:, None] >= k + offs_k[None, :] + cb = tl.where(mask, cb, 0.0) + cb = cb.to(x_ptr.dtype.element_ty) + x = tl.load(x_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & + (offs_n[None, :] < hdim), + other=0.0) + acc += tl.dot(cb, x) + cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + offs_out_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M) + offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + if HAS_D: + if D_HAS_HDIM: + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, + mask=offs_n < hdim, + other=0.0).to(tl.float32) + else: + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + x_residual = tl.load(x_ptr + (offs_m[:, None] * stride_x_seqlen + + offs_n[None, :] * stride_x_hdim), + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_n[None, :] < hdim), + other=0.0).to(tl.float32) + acc += x_residual * D + + if HAS_Z: + out_x_ptr += pid_b * stride_out_batch + c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head + out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + + offs_out_n[None, :]) + tl.store(out_x_ptrs, + acc, + mask=(offs_out_m[:, None] < chunk_size_limit) & + (offs_out_n[None, :] < hdim)) + + z_ptr += pid_b * stride_z_batch + c_idx * chunk_size * stride_z_seqlen + pid_h * stride_z_head + z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + + stride_z_hdim * offs_out_n[None, :]) + z = tl.load(z_ptrs, + mask=(offs_out_m[:, None] < chunk_size_limit) & + (offs_out_n[None, :] < hdim), + other=0.0).to(tl.float32) + acc *= z * tl.sigmoid(z) + + out_ptr += pid_b * stride_out_batch + c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head + out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + + offs_out_n[None, :] * stride_out_hdim) + tl.store(out_ptrs, + acc, + mask=(offs_out_m[:, None] < chunk_size_limit) & + (offs_out_n[None, :] < hdim)) + + +def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int): + + # convert seq_idx to chunk indices and offsets + # - derive the cu_seqlens + _, cu_seqlens = torch.where(seq_idx.diff()) + cu_seqlens += 1 + + # outputs will have length expansion of chunks that do not divide + # chunk_size + N = math.ceil(seq_idx.shape[-1] / chunk_size) + (cu_seqlens % chunk_size + > 0).sum() + chunk_indices = torch.arange(N, dtype=torch.int, device=seq_idx.device) + chunk_offsets = torch.zeros((N, ), dtype=torch.int, device=seq_idx.device) + + cu_seqlens = cu_seqlens.tolist() + [seq_idx.shape[-1]] + p = 0 # num of insertions + for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]): + + # if does not divide chunk_size, then there is one chunk insertion + p += (s % chunk_size > 0) + + # get the dimensions + # - the + 1 for _e is to shift the boundary by one chunk + # - this shifting is not needed if chunk_size divides e + _s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size + > 0) + + # adjust inidces and offsets + chunk_indices[_s:_e] -= p + chunk_offsets[_s] = s % chunk_size + + return chunk_indices, chunk_offsets + + +def _chunk_scan_fwd( + cb, + x, + dt, + dA_cumsum, + C, + states, + D=None, + z=None, + seq_idx=None, + initial_states=None, +): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = C.shape + assert nheads % ngroups == 0 + assert C.shape == (batch, seqlen, ngroups, dstate) + assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads, ) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert states.shape == (batch, nchunks, nheads, headdim, dstate) + + chunk_indices, chunk_offsets = None, None + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + + if initial_states is not None: + # with initial states, we need to take care of how + # seq_idx crosses the boundaries + assert batch == 1, "chunk scan only supports initial states with batch 1" + assert initial_states.shape == (seq_idx[0].max() + 1, nheads, + headdim, dstate) + + if initial_states.shape[0] == 1: + # no in this case no point to use initial states + initial_states = None + else: + chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets( + seq_idx, chunk_size) + + # Allocates output. + out = torch.empty(batch, + seqlen, + nheads, + headdim, + device=x.device, + dtype=x.dtype) + if z is not None: + out_x = torch.empty(batch, + seqlen, + nheads, + headdim, + device=x.device, + dtype=x.dtype) + assert out_x.stride() == out.stride() + else: + out_x = None + + grid = lambda META: ( + triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv( + headdim, META['BLOCK_SIZE_N']), batch * nchunks + if chunk_offsets is None else len(chunk_offsets), nheads) + z_strides = ((z.stride(0), z.stride(1), z.stride(2), + z.stride(3)) if z is not None else (0, 0, 0, 0)) + _chunk_scan_fwd_kernel[grid]( + cb, + x, + z, + out, + out_x, + dt, + dA_cumsum, + seq_idx, + C, + states, + D, + initial_states, + chunk_indices, + chunk_offsets, + len(chunk_indices) if chunk_indices is not None else 0, + chunk_size, + headdim, + dstate, + batch, + seqlen, + nheads // ngroups, + cb.stride(0), + cb.stride(1), + cb.stride(2), + cb.stride(3), + cb.stride(4), + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + z_strides[0], + z_strides[1], + z_strides[2], + z_strides[3], + out.stride(0), + out.stride(1), + out.stride(2), + out.stride(3), + dt.stride(0), + dt.stride(2), + dt.stride(1), + dt.stride(3), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else + (0, 0)), + C.stride(0), + C.stride(1), + C.stride(2), + C.stride(3), + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + states.stride(4), + *((initial_states.stride(0), initial_states.stride(1), + initial_states.stride(2), + initial_states.stride(3)) if initial_states is not None else + (0, 0, 0, 0)), + D.stride(0) if D is not None else 0, + True, + D is not None, + D.dim() == 2 if D is not None else True, + BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), + HAS_Z=z is not None, + HAS_SEQ_IDX=seq_idx is not None, + IS_TRITON_22=TRITON_22, + HAS_INITSTATES=initial_states is not None, + ) + return out, out_x diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py new file mode 100644 index 0000000..a970ac9 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -0,0 +1,750 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_state.py + +# ruff: noqa: E501 + +import math + +import torch +import triton +import triton.language as tl + +from .mamba_ssm import softplus + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_H': 1}), + triton.Config({'BLOCK_SIZE_H': 2}), + triton.Config({'BLOCK_SIZE_H': 4}), + triton.Config({'BLOCK_SIZE_H': 8}), + triton.Config({'BLOCK_SIZE_H': 16}), + triton.Config({'BLOCK_SIZE_H': 32}), + triton.Config({'BLOCK_SIZE_H': 64}), + ], + key=['chunk_size', 'nheads'], +) +@triton.jit +def _chunk_cumsum_fwd_kernel( + # Pointers to matrices + dt_ptr, + A_ptr, + dt_bias_ptr, + dt_out_ptr, + dA_cumsum_ptr, + # Matrix dimension + batch, + seqlen, + nheads, + chunk_size, + dt_min, + dt_max, + # Strides + stride_dt_batch, + stride_dt_seqlen, + stride_dt_head, + stride_A_head, + stride_dt_bias_head, + stride_dt_out_batch, + stride_dt_out_chunk, + stride_dt_out_head, + stride_dt_out_csize, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + # Meta-parameters + DT_SOFTPLUS: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_CHUNK: tl.constexpr, +): + pid_b = tl.program_id(axis=0) + + # if dt is long, may cause problems, so use 64 bit + # https://github.com/triton-lang/triton/issues/1058 + pid_c = tl.program_id(axis=1).to(tl.int64) + pid_h = tl.program_id(axis=2) + dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen + dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + + offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) + offs_c = tl.arange(0, BLOCK_SIZE_CHUNK) + dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + + offs_c[None, :] * stride_dt_seqlen) + A_ptrs = A_ptr + offs_h * stride_A_head + dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head + + offs_c[None, :] * stride_dt_out_csize) + dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + + offs_c[None, :] * stride_dA_cs_csize) + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + dt = tl.load(dt_ptrs, + mask=(offs_h[:, None] < nheads) & + (offs_c[None, :] < chunk_size_limit), + other=0.0).to(tl.float32) + if HAS_DT_BIAS: + dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, + mask=offs_h < nheads, + other=0.0).to(tl.float32) + dt += dt_bias[:, None] + if DT_SOFTPLUS: + dt = tl.where(dt <= 20.0, softplus(dt), dt) + # As of Triton 2.2.0, tl.clamp is not available yet + # dt = tl.clamp(dt, dt_min, dt_max) + dt = tl.minimum(tl.maximum(dt, dt_min), dt_max) + dt = tl.where( + (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, + 0.0) + tl.store(dt_out_ptrs, + dt, + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) + dA = dt * A[:, None] + dA_cs = tl.cumsum(dA, axis=1) + tl.store(dA_cs_ptrs, + dA_cs, + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + + +@triton.autotune( + configs=[ + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=2), + ], + key=['hdim', 'dstate', 'chunk_size'], +) +@triton.jit +def _chunk_state_fwd_kernel( + # Pointers to matrices + x_ptr, + b_ptr, + states_ptr, + dt_ptr, + dA_cumsum_ptr, + seq_idx_ptr, + # Matrix dimensions + hdim, + dstate, + chunk_size, + batch, + seqlen, + nheads_ngroups_ratio, + # Strides + stride_x_batch, + stride_x_seqlen, + stride_x_head, + stride_x_hdim, + stride_b_batch, + stride_b_seqlen, + stride_b_head, + stride_b_dstate, + stride_states_batch, + stride_states_chunk, + stride_states_head, + stride_states_hdim, + stride_states_dstate, + stride_dt_batch, + stride_dt_chunk, + stride_dt_head, + stride_dt_csize, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + stride_seq_idx_batch, + stride_seq_idx_seqlen, + # Meta-parameters + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1).to(tl.int64) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + ( + pid_h // nheads_ngroups_ratio) * stride_b_head + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + + offs_k[None, :] * stride_x_seqlen) + b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + + offs_k[:, None] * stride_b_seqlen) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cs_last = tl.load(dA_cumsum_ptr + + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + if HAS_SEQ_IDX: + seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + if HAS_SEQ_IDX: + seq_idx_last = tl.load(seq_idx_ptr + + (chunk_size_limit - 1) * stride_seq_idx_seqlen) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, chunk_size_limit, BLOCK_SIZE_K): + x = tl.load(x_ptrs, + mask=(offs_m[:, None] < hdim) & + (offs_k[None, :] < chunk_size_limit - k), + other=0.0) + b = tl.load(b_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & + (offs_n[None, :] < dstate), + other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, + mask=offs_k < chunk_size_limit - k, + other=0.0).to(tl.float32) + if HAS_SEQ_IDX: + seq_idx_k = tl.load(seq_idx_ptrs, + mask=offs_k < chunk_size_limit - k, + other=-1) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, + other=0.0).to(tl.float32) + if not HAS_SEQ_IDX: + scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k + else: + scale = tl.where(seq_idx_k == seq_idx_last, + tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0) + b *= scale[:, None] + b = b.to(x_ptr.dtype.element_ty) + acc += tl.dot(x, b) + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + b_ptrs += BLOCK_SIZE_K * stride_b_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + if HAS_SEQ_IDX: + seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen + states = acc.to(states_ptr.dtype.element_ty) + + states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + + offs_n[None, :] * stride_states_dstate) + c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) + tl.store(states_ptrs, states, mask=c_mask) + + +@triton.autotune( + configs=[ + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=2), + ], + key=['hdim', 'dstate', 'chunk_size'], +) +@triton.jit +def _chunk_state_varlen_kernel( + # Pointers to matrices + x_ptr, + b_ptr, + dt_ptr, + dA_cumsum_ptr, + chunk_states_ptr, + cu_seqlens_ptr, + states_ptr, + initstates_ptr, + # Matrix dimensions + hdim, + dstate, + chunk_size, + seqlen, + nheads_ngroups_ratio, + # Strides + stride_x_seqlen, + stride_x_head, + stride_x_hdim, + stride_b_seqlen, + stride_b_head, + stride_b_dstate, + stride_dt_chunk, + stride_dt_head, + stride_dt_csize, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + stride_chunk_states_chunk, + stride_chunk_states_head, + stride_chunk_states_hdim, + stride_chunk_states_dstate, + stride_states_batch, + stride_states_head, + stride_states_hdim, + stride_states_dstate, + stride_init_states_batch, + stride_init_states_head, + stride_init_states_hdim, + stride_init_states_dstate, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + HAS_INITSTATES: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + end_idx = tl.load(cu_seqlens_ptr + pid_b + 1) + pid_c = (end_idx - 1) // chunk_size + b_ptr += pid_c * chunk_size * stride_b_seqlen + ( + pid_h // nheads_ngroups_ratio) * stride_b_head + x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + chunk_states_ptr += pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head + + if HAS_INITSTATES: + # if there are init states provided, we differentiate between states (which + # are boundary conditions at a chunk boundary) and initstates (which are boundary + # conditions when a new example in a cont batch starts) + initstates_ptr += pid_h * stride_init_states_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + + offs_k[None, :] * stride_x_seqlen) + b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + + offs_k[:, None] * stride_b_seqlen) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cs_last = tl.load(dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * + stride_dA_cs_csize).to(tl.float32) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + + chunk_size_limit = end_idx - pid_c * chunk_size + start_idx = tl.load(cu_seqlens_ptr + pid_b) + start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, chunk_size_limit, BLOCK_SIZE_K): + x = tl.load(x_ptrs, + mask=(offs_m[:, None] < hdim) & + (offs_k[None, :] < chunk_size_limit - k) & + (offs_k[None, :] >= start_idx_cur - k), + other=0.0) + b = tl.load(b_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & + (offs_n[None, :] < dstate) & + (offs_k[:, None] >= start_idx_cur - k), + other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, + mask=offs_k < chunk_size_limit - k, + other=0.0).to(tl.float32) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, + other=0.0).to(tl.float32) + scale = tl.where( + (offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k), + tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0) + b *= scale[:, None] + b = b.to(x_ptr.dtype.element_ty) + acc += tl.dot(x, b) + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + b_ptrs += BLOCK_SIZE_K * stride_b_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + # If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk + # If HAS_INITSTATES==True need to consider two possiblties + # - if start_idx < pid_c * chunk_size, then we need to take the past_states_ptrs + # - if state_idx >= pid * chunk_size, then we need to insert initstates + if ((start_idx < pid_c * chunk_size) # first chunk + or (HAS_INITSTATES)): + + dA_cs_boundary = 0.0 # default + + if not HAS_INITSTATES: + past_states_ptrs = chunk_states_ptr + ( + offs_m[:, None] * stride_chunk_states_hdim + + offs_n[None, :] * stride_chunk_states_dstate) + else: + + # - this seems repetitve, buts its to help the compiler + if start_idx < pid_c * chunk_size: + past_states_ptrs = chunk_states_ptr + ( + offs_m[:, None] * stride_chunk_states_hdim + + offs_n[None, :] * stride_chunk_states_dstate) + else: + past_states_ptrs = initstates_ptr + ( + pid_b * stride_init_states_batch + + offs_m[:, None] * stride_init_states_hdim + + offs_n[None, :] * stride_init_states_dstate) + + # need to adjust the boundary + if start_idx > pid_c * chunk_size: + dA_cs_boundary = tl.load(dA_cumsum_ptr + + (start_idx - pid_c * chunk_size - + 1) * stride_dA_cs_csize).to( + tl.float32) + + past_states = tl.load(past_states_ptrs, + mask=(offs_m[:, None] < hdim) & + (offs_n[None, :] < dstate), + other=0.0).to(tl.float32) + + scale = tl.exp(dA_cs_last - dA_cs_boundary) + acc += past_states * scale + + states = acc.to(states_ptr.dtype.element_ty) + + states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + + offs_n[None, :] * stride_states_dstate) + c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) + tl.store(states_ptrs, states, mask=c_mask) + + +def _chunk_cumsum_fwd(dt, + A, + chunk_size, + dt_bias=None, + dt_softplus=False, + dt_limit=(0.0, float("inf"))): + batch, seqlen, nheads = dt.shape + assert A.shape == (nheads, ) + if dt_bias is not None: + assert dt_bias.shape == (nheads, ) + nchunks = math.ceil(seqlen / chunk_size) + dt_out = torch.empty(batch, + nheads, + nchunks, + chunk_size, + device=dt.device, + dtype=torch.float32) + dA_cumsum = torch.empty(batch, + nheads, + nchunks, + chunk_size, + device=dt.device, + dtype=torch.float32) + grid_chunk_cs = lambda META: (batch, nchunks, + triton.cdiv(nheads, META['BLOCK_SIZE_H'])) + with torch.cuda.device(dt.device.index): + _chunk_cumsum_fwd_kernel[grid_chunk_cs]( + dt, + A, + dt_bias, + dt_out, + dA_cumsum, + batch, + seqlen, + nheads, + chunk_size, + dt_limit[0], + dt_limit[1], + dt.stride(0), + dt.stride(1), + dt.stride(2), + A.stride(0), + dt_bias.stride(0) if dt_bias is not None else 0, + dt_out.stride(0), + dt_out.stride(2), + dt_out.stride(1), + dt_out.stride(3), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), + dt_softplus, + HAS_DT_BIAS=dt_bias is not None, + BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), + ) + return dA_cumsum, dt_out + + +def _chunk_state_fwd(B, + x, + dt, + dA_cumsum, + seq_idx=None, + states=None, + states_in_fp32=True): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if states is not None: + assert states.shape == (batch, nchunks, nheads, headdim, dstate) + else: + states_dtype = torch.float32 if states_in_fp32 else B.dtype + states = torch.empty((batch, nchunks, nheads, headdim, dstate), + device=x.device, + dtype=states_dtype) + grid = lambda META: ( + triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv( + dstate, META['BLOCK_SIZE_N']), batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_state_fwd_kernel[grid]( + x, + B, + states, + dt, + dA_cumsum, + seq_idx, + headdim, + dstate, + chunk_size, + batch, + seqlen, + nheads // ngroups, + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + B.stride(0), + B.stride(1), + B.stride(2), + B.stride(-1), + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + states.stride(4), + dt.stride(0), + dt.stride(2), + dt.stride(1), + dt.stride(3), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), + *((seq_idx.stride(0), + seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + HAS_SEQ_IDX=seq_idx is not None, + ) + return states + + +def chunk_state_varlen(B, + x, + dt, + dA_cumsum, + cu_seqlens, + chunk_states, + initial_states=None): + total_seqlen, nheads, headdim = x.shape + _, nchunks, chunk_size = dt.shape + _, ngroups, dstate = B.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + assert nheads % ngroups == 0 + assert B.shape == (total_seqlen, ngroups, dstate) + assert dt.shape == (nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert chunk_states.shape == (nchunks, nheads, headdim, dstate) + + if initial_states is not None: + assert initial_states.shape == (batch, nheads, headdim, dstate) + + states = torch.empty(batch, + nheads, + headdim, + dstate, + dtype=chunk_states.dtype, + device=chunk_states.device) + grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton. + cdiv(dstate, META['BLOCK_SIZE_N']), batch, nheads) + with torch.cuda.device(x.device.index): + _chunk_state_varlen_kernel[grid]( + x, + B, + dt, + dA_cumsum, + chunk_states, + cu_seqlens, + states, + initial_states, + headdim, + dstate, + chunk_size, + total_seqlen, + nheads // ngroups, + x.stride(0), + x.stride(1), + x.stride(2), + B.stride(0), + B.stride(1), + B.stride(2), + dt.stride(1), + dt.stride(0), + dt.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + chunk_states.stride(0), + chunk_states.stride(1), + chunk_states.stride(2), + chunk_states.stride(3), + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + *((initial_states.stride(0), initial_states.stride(1), + initial_states.stride(2), + initial_states.stride(3)) if initial_states is not None else + (0, 0, 0, 0)), + HAS_INITSTATES=initial_states is not None) + return states diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py new file mode 100644 index 0000000..97cdb70 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -0,0 +1,223 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_combined.py + +# ruff: noqa: E501 + +import torch +import triton +from einops import rearrange +from packaging import version + +from .ssd_bmm import _bmm_chunk_fwd +from .ssd_chunk_scan import _chunk_scan_fwd +from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd, + chunk_state_varlen) +from .ssd_state_passing import _state_passing_fwd + +TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') + + +def _mamba_chunk_scan_combined_fwd(x, + dt, + A, + B, + C, + chunk_size, + D=None, + z=None, + dt_bias=None, + initial_states=None, + seq_idx=None, + cu_seqlens=None, + dt_softplus=False, + dt_limit=(0.0, float("inf"))): + batch, seqlen, nheads, headdim = x.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert x.shape == (batch, seqlen, nheads, headdim) + assert dt.shape == (batch, seqlen, nheads) + assert A.shape == (nheads, ) + assert C.shape == B.shape + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads, ) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if x.stride(-1) != 1 and x.stride( + 1) != 1: # Either M or K dimension should be contiguous + x = x.contiguous() + if z is not None and z.stride(-1) != 1 and z.stride( + 1) != 1: # Either M or K dimension should be contiguous + z = z.contiguous() + if D is not None and D.stride(-1) != 1: + D = D.contiguous() + if initial_states is not None: + if cu_seqlens is None: + assert initial_states.shape == (batch, nheads, headdim, dstate) + else: + assert initial_states.shape == (len(cu_seqlens) - 1, nheads, + headdim, dstate) + + # This function executes 5 sub-functions for computing mamba + # - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/ + # which has a minimal implementation to understand the below operations + # - as explained by the blog, mamba is a special case of causal attention + # - the idea is to chunk the attention matrix and compute each + # submatrix separately using different optimizations. + # - see the blog and paper for a visualization of the submatrices + # which we refer to in the comments below + + # 1. Compute chunked cumsum of A * dt + # - here dt may go through a softplus activation + dA_cumsum, dt = _chunk_cumsum_fwd(dt, + A, + chunk_size, + dt_bias=dt_bias, + dt_softplus=dt_softplus, + dt_limit=dt_limit) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + states = _chunk_state_fwd(B, + x, + dt, + dA_cumsum, + seq_idx=seq_idx, + states_in_fp32=True) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + # - for handling chunked prefill, this requires i) initial_states + # ii) seq_idx and iii) has_cu_seqlens to be all specified. + # - When a new seq_idx is detected, we will stop passing the prev_state + # and switch accordingly to the init_state corresponding to the new seq_idx. + # - this will ensure that states will be updated with the rightmost flushed seq_idx + # of the previous chunk. This implies that the first chunk of states is either 0 + # or equal to init_states of the first example. + states, final_states = _state_passing_fwd( + rearrange(states, "... p n -> ... (p n)"), + dA_cumsum[:, :, :, -1], + initial_states=rearrange(initial_states, "... p n -> ... (p n)") + if initial_states is not None else None, + seq_idx=seq_idx, + chunk_size=chunk_size, + out_dtype=C.dtype, + is_cont_batched=cu_seqlens is not None) + states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate) + for t in [states, final_states]) + + # 4. Compute batched matrix multiply for C_j^T B_i terms + CB = _bmm_chunk_fwd(C, + B, + chunk_size, + seq_idx=seq_idx, + output_dtype=torch.float32) + + # 5. Scan and compute the diagonal blocks, taking into + # account past causal states. + # - if initial states are provided, then states information will be + # augmented with initial_states. + # - to do this properly, we need to account for example changes in + # the continuous batch, therefore we introduce pseudo chunks, which is + # a chunk that is split up each time an example changes. + # - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had + # a seq_idx change, in which case we take states information from + # init_states. + out, out_x = _chunk_scan_fwd( + CB, + x, + dt, + dA_cumsum, + C, + states, + D=D, + z=z, + seq_idx=seq_idx, + initial_states=initial_states, + ) + if cu_seqlens is None: + return out, out_x, dt, dA_cumsum, states, final_states + else: + assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" + varlen_states = chunk_state_varlen( + B.squeeze(0), + x.squeeze(0), + dt.squeeze(0), + dA_cumsum.squeeze(0), + cu_seqlens, + states.squeeze(0), + initial_states=initial_states, + ) + return out, out_x, dt, dA_cumsum, states, final_states, varlen_states + + +def mamba_chunk_scan_combined(x, + dt, + A, + B, + C, + chunk_size, + D=None, + z=None, + dt_bias=None, + initial_states=None, + seq_idx=None, + cu_seqlens=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + return_final_states=False, + return_varlen_states=False): + """ + Argument: + x: (batch, seqlen, nheads, headdim) + dt: (batch, seqlen, nheads) + A: (nheads) + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + chunk_size: int + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + dt_bias: (nheads,) + initial_states: (batch, nheads, headdim, dstate) + seq_idx: (batch, seqlen) + cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True + dt_softplus: Whether to apply softplus to dt + Return: + out: (batch, seqlen, nheads, headdim) + """ + + if not return_varlen_states: + cu_seqlens = None + else: + assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True" + out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd( + x, + dt, + A, + B, + C, + chunk_size, + D=D, + z=z, + dt_bias=dt_bias, + initial_states=initial_states, + seq_idx=seq_idx, + cu_seqlens=cu_seqlens, + dt_softplus=dt_softplus, + dt_limit=dt_limit) + if not return_varlen_states: + return out if not return_final_states else (out, final_states) + else: + varlen_states = rest[0] + return (out, + varlen_states) if not return_final_states else (out, + final_states, + varlen_states) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py new file mode 100644 index 0000000..d8f87c1 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -0,0 +1,207 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_state_passing.py + +# ruff: noqa: E501 + +import torch +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': 64}), + triton.Config({'BLOCK_SIZE': 128}), + triton.Config({'BLOCK_SIZE': 256}), + triton.Config({'BLOCK_SIZE': 512}), + triton.Config({'BLOCK_SIZE': 1024}), + triton.Config({'BLOCK_SIZE': 2048}), + ], + key=['dim'], +) +@triton.jit +def _state_passing_fwd_kernel( + # Pointers to matrices + states_ptr, + out_ptr, + final_states_ptr, + dA_cs_ptr, + initstates_ptr, + seq_idx_ptr, + # Matrix dimensions + dim, + nchunks, + seqlen, + chunk_size, + # Strides + stride_states_batch, + stride_states_chunk, + stride_states_head, + stride_states_dim, + stride_out_batch, + stride_out_chunk, + stride_out_head, + stride_out_dim, + stride_final_states_batch, + stride_final_states_head, + stride_final_states_dim, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_initstates_batch, + stride_initstates_head, + stride_initstates_dim, + stride_seq_idx_batch, + stride_seq_idx_seqlen, + # Meta-parameters + HAS_INITSTATES: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + IS_CONT_BATCHED: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + pid_m = tl.program_id(axis=0) + states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head + dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head + if HAS_INITSTATES: + initstates_ptr += pid_h * stride_initstates_head + if not IS_CONT_BATCHED: + initstates_ptr += pid_b * stride_initstates_batch + + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + states_ptrs = states_ptr + offs_m * stride_states_dim + out_ptrs = out_ptr + offs_m * stride_out_dim + final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim + + # - states will be the past state of the sequence that continues on the current check + if not HAS_INITSTATES: + states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) + else: + initstates_ptr += offs_m * stride_initstates_dim + initstates_ptrs = initstates_ptr + # - for cont batches, for the first chunk mean it will be the first batch's + # init state + states = tl.load(initstates_ptrs, mask=offs_m < dim, + other=0.0).to(tl.float32) + + tl.store(out_ptrs, states, mask=offs_m < dim) + out_ptrs += stride_out_chunk + seq_idx = 0 + for c in range(nchunks): + new_states = tl.load(states_ptrs, mask=offs_m < dim, + other=0.0).to(tl.float32) + dA_cs = tl.load(dA_cs_ptr).to(tl.float32) + scale = tl.exp(dA_cs) + if HAS_SEQ_IDX: + # - the seq to pass forward is the one that is flushed to the right + # boundary. + # - that is given by seq_idx_new below. + seq_idx_new = tl.load(seq_idx_ptr + + (min((c + 1) * chunk_size, seqlen) - 1) * + stride_seq_idx_seqlen) + if HAS_INITSTATES: + if IS_CONT_BATCHED and seq_idx != seq_idx_new: + # this means in the current chunk the rightmost flushed seq + # has changed. + # - so we do not propagate the state from previous chunk + # - but rather we load that sequence's init state + initstates_ptrs = initstates_ptr + seq_idx_new * stride_initstates_batch + + # - update state with seq_idx_new's init state + states = tl.load(initstates_ptrs, + mask=offs_m < dim, + other=0.0).to(tl.float32) + else: + scale = tl.where(seq_idx_new == seq_idx, scale, 0.0) + + seq_idx = seq_idx_new + states = scale * states + new_states + if c < nchunks - 1: + tl.store(out_ptrs, states, mask=offs_m < dim) + else: + tl.store(final_states_ptrs, states, mask=offs_m < dim) + states_ptrs += stride_states_chunk + dA_cs_ptr += stride_dA_cs_chunk + out_ptrs += stride_out_chunk + + +def _state_passing_fwd( + states, + dA_chunk_cumsum, + initial_states=None, + seq_idx=None, + chunk_size=None, + out_dtype=None, + is_cont_batched=False, +): + batch, nchunks, nheads, dim = states.shape + assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) + if initial_states is not None: + if is_cont_batched: + # - if cu_seqlens is provided, then the initial states + # are used for continuous batching. In which case we + # require seq_idx to be provided + assert seq_idx is not None, "" + assert initial_states.shape == (seq_idx.max().item() + 1, nheads, + dim) + else: + # - this is the regular batching case, where initial + # states are used are for each example of the batch. + assert initial_states.shape == (batch, nheads, dim) + + if seq_idx is not None: + assert chunk_size is not None + seqlen = seq_idx.shape[-1] + assert seq_idx.shape == (batch, seqlen) + out_dtype = states.dtype if out_dtype is None else out_dtype + out = torch.empty((batch, nchunks, nheads, dim), + device=states.device, + dtype=out_dtype) + final_states = torch.empty((batch, nheads, dim), + device=states.device, + dtype=torch.float32) + grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads) + with torch.cuda.device(states.device.index): + _state_passing_fwd_kernel[grid]( + states, + out, + final_states, + dA_chunk_cumsum, + initial_states, + seq_idx, + dim, + nchunks, + seqlen if seq_idx is not None else 0, + chunk_size if seq_idx is not None else 0, + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + out.stride(0), + out.stride(1), + out.stride(2), + out.stride(3), + final_states.stride(0), + final_states.stride(1), + final_states.stride(2), + dA_chunk_cumsum.stride(0), + dA_chunk_cumsum.stride(2), + dA_chunk_cumsum.stride(1), + *((initial_states.stride(0), initial_states.stride(1), + initial_states.stride(2)) if initial_states is not None else + (0, 0, 0)), + *((seq_idx.stride(0), + seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + HAS_INITSTATES=initial_states is not None, + HAS_SEQ_IDX=seq_idx is not None, + IS_CONT_BATCHED=is_cont_batched, + ) + return out, final_states diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py new file mode 100644 index 0000000..0012636 --- /dev/null +++ b/vllm/model_executor/layers/pooler.py @@ -0,0 +1,322 @@ +# SPDX-License-Identifier: Apache-2.0 + +from enum import IntEnum +from typing import List, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PretrainedConfig +from typing_extensions import assert_never + +from vllm.config import PoolerConfig +from vllm.model_executor.pooling_metadata import (PoolingMetadata, + PoolingTensors) +from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput +from vllm.transformers_utils.config import ( + get_cross_encoder_activation_function) + + +class PoolingType(IntEnum): + """Enumeration for different types of pooling methods.""" + LAST = 0 + ALL = 1 + CLS = 2 + STEP = 3 + MEAN = 4 + + +class SimplePooler(nn.Module): + """A layer that pools specific information from hidden states. + + This layer does the following: + 1. Extracts specific tokens or aggregates data based on pooling method. + 2. Normalizes output if specified. + 3. Returns structured results as `PoolerOutput`. + + Attributes: + pooling_type: The type of pooling to use. + normalize: Whether to normalize the pooled data. + """ + + @staticmethod + def from_pooling_type( + pooling_type: PoolingType, + *, + normalize: bool, + softmax: bool, + step_tag_id: Optional[int] = None, + returned_token_ids: Optional[List[int]] = None, + ) -> "SimplePooler": + if pooling_type == PoolingType.LAST: + assert step_tag_id is None and returned_token_ids is None + return LastPool(normalize=normalize, softmax=softmax) + if pooling_type == PoolingType.ALL: + assert step_tag_id is None and returned_token_ids is None + return AllPool(normalize=normalize, softmax=softmax) + if pooling_type == PoolingType.CLS: + assert step_tag_id is None and returned_token_ids is None + return CLSPool(normalize=normalize, softmax=softmax) + if pooling_type == PoolingType.MEAN: + assert step_tag_id is None and returned_token_ids is None + return MeanPool(normalize=normalize, softmax=softmax) + if pooling_type == PoolingType.STEP: + return StepPool(normalize=normalize, + softmax=softmax, + step_tag_id=step_tag_id, + returned_token_ids=returned_token_ids) + + assert_never(pooling_type) + + def __init__(self, *, normalize: bool, softmax: bool) -> None: + super().__init__() + + self.head = PoolerHead(normalize=normalize, softmax=softmax) + + def get_prompt_lens( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> torch.Tensor: + return PoolingTensors.from_pooling_metadata( + pooling_metadata, hidden_states.device).prompt_lens + + def extract_states( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Union[list[torch.Tensor], torch.Tensor]: + raise NotImplementedError + + def build_output(self, data: torch.Tensor) -> PoolingSequenceGroupOutput: + return PoolingSequenceGroupOutput(data) + + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + pooled_data = self.extract_states(hidden_states, pooling_metadata) + pooled_data = self.head(pooled_data) + pooled_outputs = [self.build_output(data) for data in pooled_data] + return PoolerOutput(outputs=pooled_outputs) + + +class CLSPool(SimplePooler): + + def extract_states( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Union[list[torch.Tensor], torch.Tensor]: + prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) + + first_token_flat_indices = torch.zeros_like(prompt_lens) + first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1] + return hidden_states[first_token_flat_indices] + + +class LastPool(SimplePooler): + + def extract_states( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Union[list[torch.Tensor], torch.Tensor]: + prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) + + last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1 + return hidden_states[last_token_flat_indices] + + +class AllPool(SimplePooler): + + def extract_states( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Union[list[torch.Tensor], torch.Tensor]: + prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) + + offset = 0 + pooled_data = list[torch.Tensor]() + for prompt_len in prompt_lens: + pooled_data.append(hidden_states[offset:offset + prompt_len]) + offset += prompt_len + + return pooled_data + + +class MeanPool(SimplePooler): + + def extract_states( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Union[list[torch.Tensor], torch.Tensor]: + prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) + + cumsum = torch.cumsum(hidden_states, dim=0) + start_indices = torch.cat([ + torch.tensor([0], device=hidden_states.device), + torch.cumsum(prompt_lens[:-1], dim=0) + ]) + end_indices = torch.cumsum(prompt_lens, dim=0) + return (cumsum[end_indices - 1] - cumsum[start_indices] + + hidden_states[start_indices]) / prompt_lens.unsqueeze(1) + + +class StepPool(SimplePooler): + + def __init__( + self, + *, + normalize: bool, + softmax: bool, + step_tag_id: Optional[int] = None, + returned_token_ids: Optional[List[int]] = None, + ): + super().__init__(normalize=normalize, softmax=softmax) + + self.step_tag_id = step_tag_id + self.returned_token_ids = returned_token_ids + + def extract_states( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Union[list[torch.Tensor], torch.Tensor]: + prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) + + returned_token_ids = self.returned_token_ids + if returned_token_ids is not None and len(returned_token_ids) > 0: + hidden_states = hidden_states[:, returned_token_ids] + + step_tag_id = self.step_tag_id + + offset = 0 + pooled_data = list[torch.Tensor]() + for prompt_len, seq_data_i in zip(prompt_lens, + pooling_metadata.seq_data.values()): + pooled_data_i = hidden_states[offset:offset + prompt_len] + if step_tag_id is not None: + token_ids = torch.tensor(seq_data_i.prompt_token_ids) + pooled_data_i = pooled_data_i[token_ids == step_tag_id] + + offset += prompt_len + pooled_data.append(pooled_data_i) + + return pooled_data + + +class PoolerHead(nn.Module): + + def __init__(self, *, normalize: bool, softmax: bool) -> None: + super().__init__() + + self.normalize = normalize + self.softmax = softmax + + def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor]): + if self.normalize: + if isinstance(pooled_data, list): + pooled_data = [ + F.normalize(data, p=2, dim=1) for data in pooled_data + ] + else: + pooled_data = F.normalize(pooled_data, p=2, dim=1) + + if self.softmax: + if isinstance(pooled_data, list): + pooled_data = [F.softmax(data, dim=-1) for data in pooled_data] + else: + pooled_data = F.softmax(pooled_data, dim=-1) + + return pooled_data + + +class Pooler(nn.Module): + + @classmethod + def from_config_with_defaults( + cls, + pooler_config: PoolerConfig, + pooling_type: PoolingType, + normalize: bool, + softmax: bool, + step_tag_id: Optional[int] = None, + returned_token_ids: Optional[List[int]] = None, + ) -> SimplePooler: + return SimplePooler.from_pooling_type( + pooling_type=PoolingType[pooler_config.pooling_type] + if pooler_config.pooling_type is not None else pooling_type, + normalize=pooler_config.normalize + if pooler_config.normalize is not None else normalize, + softmax=pooler_config.softmax + if pooler_config.softmax is not None else softmax, + step_tag_id=pooler_config.step_tag_id + if pooler_config.step_tag_id is not None else step_tag_id, + returned_token_ids=pooler_config.returned_token_ids + if pooler_config.returned_token_ids is not None else + returned_token_ids, + ) + + +class CrossEncodingPooler(nn.Module): + """A layer that pools specific information from hidden states. + + This layer does the following: + 1. Extracts specific tokens or aggregates data based on pooling method. + 2. Normalizes output if specified. + 3. Returns structured results as `PoolerOutput`. + + Attributes: + pooling_type: The type of pooling to use. + normalize: Whether to normalize the pooled data. + """ + + def __init__( + self, + config: PretrainedConfig, + classifier: nn.Module, + pooler: Optional[nn.Module] = None, + ): + super().__init__() + self.classifier = classifier + self.pooler = pooler + self.default_activation_function = \ + get_cross_encoder_activation_function(config) + + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + """Pools sentence pair scores from the hidden_states.""" + + prompt_lens = PoolingTensors.from_pooling_metadata( + pooling_metadata, hidden_states.device).prompt_lens + + offset = 0 + pooled_data_lst = [] + for prompt_len in prompt_lens: + pooled_data_i = hidden_states[offset:offset + prompt_len] + + if self.pooler is not None: + final_shape_tensor = self.pooler(pooled_data_i) + else: + final_shape_tensor = self.classifier(pooled_data_i) + + pooled_data_lst.append(final_shape_tensor) + offset += prompt_len + + pooled_output = torch.stack(pooled_data_lst) + + if self.pooler is not None: + # apply classifier once on the full batch if possible + pooled_output = self.classifier(pooled_output) + + scores = self.default_activation_function(pooled_output).squeeze(-1) + + pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores] + return PoolerOutput(outputs=pooled_outputs) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py new file mode 100644 index 0000000..b1b2916 --- /dev/null +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -0,0 +1,149 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, List, Type + +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + +QUANTIZATION_METHODS: List[str] = [ + "aqlm", + "awq", + "deepspeedfp", + "tpu_int8", + "fp8", + "ptpc_fp8", + "fbgemm_fp8", + "modelopt", + "nvfp4", + # The order of gptq methods is important for config.py iteration over + # override_quantization_method(..) + "marlin", + "gguf", + "gptq_marlin_24", + "gptq_marlin", + "awq_marlin", + "gptq", + "compressed-tensors", + "bitsandbytes", + "qqq", + "hqq", + "experts_int8", + "neuron_quant", + "ipex", + "quark", + "moe_wna16", + "w8a16", +] + +# The customized quantization methods which will be added to this dict. +_CUSTOMIZED_METHOD_TO_QUANT_CONFIG = {} + + +def register_quantization_config(quantization: str): + """Register a customized vllm quantization config. + + When a quantization method is not supported by vllm, you can register a customized + quantization config to support it. + + Args: + quantization (str): The quantization method name. + + Examples: + >>> from vllm.model_executor.layers.quantization import register_quantization_config + >>> from vllm.model_executor.layers.quantization import get_quantization_config + >>> from vllm.model_executor.layers.quantization.base_config import QuantizationConfig + >>> + >>> @register_quantization_config("my_quant") + ... class MyQuantConfig(QuantizationConfig): + ... pass + >>> + >>> get_quantization_config("my_quant") + + """ # noqa: E501 + + def _wrapper(quant_config_cls): + if quantization in QUANTIZATION_METHODS: + raise ValueError( + f"The quantization method `{quantization}` is already exists.") + if not issubclass(quant_config_cls, QuantizationConfig): + raise ValueError("The quantization config must be a subclass of " + "`QuantizationConfig`.") + _CUSTOMIZED_METHOD_TO_QUANT_CONFIG[quantization] = quant_config_cls + QUANTIZATION_METHODS.append(quantization) + return quant_config_cls + + return _wrapper + + +def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: + if quantization not in QUANTIZATION_METHODS: + raise ValueError(f"Invalid quantization method: {quantization}") + + # lazy import to avoid triggering `torch.compile` too early + from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig + + from .aqlm import AQLMConfig + from .awq import AWQConfig + from .awq_marlin import AWQMarlinConfig + from .bitsandbytes import BitsAndBytesConfig + from .compressed_tensors.compressed_tensors import ( # noqa: E501 + CompressedTensorsConfig) + from .deepspeedfp import DeepSpeedFPConfig + from .experts_int8 import ExpertsInt8Config + from .fbgemm_fp8 import FBGEMMFp8Config + from .fp8 import Fp8Config + from .gguf import GGUFConfig + from .gptq import GPTQConfig + from .gptq_marlin import GPTQMarlinConfig + from .gptq_marlin_24 import GPTQMarlin24Config + from .hqq_marlin import HQQMarlinConfig + from .ipex_quant import IPEXConfig + from .marlin import MarlinConfig + from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config + from .moe_wna16 import MoeWNA16Config + from .neuron_quant import NeuronQuantConfig + from .ptpc_fp8 import PTPCFp8Config + from .qqq import QQQConfig + from .tpu_int8 import Int8TpuConfig + from .w8a16 import W8a16Config + + method_to_config: Dict[str, Type[QuantizationConfig]] = { + "aqlm": AQLMConfig, + "awq": AWQConfig, + "deepspeedfp": DeepSpeedFPConfig, + "tpu_int8": Int8TpuConfig, + "fp8": Fp8Config, + "fbgemm_fp8": FBGEMMFp8Config, + "modelopt": ModelOptFp8Config, + "nvfp4": ModelOptNvFp4Config, + # The order of gptq methods is important for config.py iteration over + # override_quantization_method(..) + "marlin": MarlinConfig, + "gguf": GGUFConfig, + "gptq_marlin_24": GPTQMarlin24Config, + "gptq_marlin": GPTQMarlinConfig, + "awq_marlin": AWQMarlinConfig, + "gptq": GPTQConfig, + "compressed-tensors": CompressedTensorsConfig, + "bitsandbytes": BitsAndBytesConfig, + "ptpc_fp8": PTPCFp8Config, + "qqq": QQQConfig, + "hqq": HQQMarlinConfig, + "experts_int8": ExpertsInt8Config, + "neuron_quant": NeuronQuantConfig, + "ipex": IPEXConfig, + "quark": QuarkConfig, + "moe_wna16": MoeWNA16Config, + "w8a16": W8a16Config, + } + # Update the `method_to_config` with customized quantization methods. + method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) + + return method_to_config[quantization] + + +__all__ = [ + "QuantizationConfig", + "get_quantization_config", + "QUANTIZATION_METHODS", +] diff --git a/vllm/model_executor/layers/quantization/aqlm.py b/vllm/model_executor/layers/quantization/aqlm.py new file mode 100644 index 0000000..10f5241 --- /dev/null +++ b/vllm/model_executor/layers/quantization/aqlm.py @@ -0,0 +1,374 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Supports AQLM compression, see https://github.com/Vahe1994/AQLM +# and https://arxiv.org/pdf/2401.06118.pdf + +import math +from typing import Any, Dict, List, Optional + +import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs + + +def get_int_dtype(nbits: int) -> torch.dtype: + if nbits <= 8: + return torch.int8 + if nbits <= 16: + return torch.int16 + if nbits <= 32: + return torch.int32 + if nbits <= 64: + return torch.int64 + raise ValueError(f"No dtype available for {nbits}-bit codebooks") + + +@torch.inference_mode() +def unpack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor: + return data.to(torch.int64) % (2**nbits) + + +def dequantize_weight(codes: torch.Tensor, + codebooks: torch.Tensor, + scales: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Decode float weights from quantization codes. Differentiable. + :param codes: tensor of integer quantization codes, shape + [*dims, num_out_groups, num_in_groups, num_codebooks] + :param codebooks: tensor of vectors for each quantization code, + [num_codebooks, codebook_size, out_group_size, in_group_size] + :param scales: weight will be multiplied by this factor, must be + broadcastble with + [*dims, out_groups, num_in_groups, out_group_size, in_group_size] + :return: reconstructed weight tensor of shape + [*dims, num_in_groups*group_size] + """ + num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:] + num_codebooks, codebook_size, out_group_size, in_group_size = \ + codebooks.shape + out_features = num_out_groups * out_group_size + in_features = num_in_groups * in_group_size + codebook_offsets = torch.arange( + 0, num_codebooks * codebook_size, codebook_size, + device=codes.device) # shape: [num_codebooks] + reconstructed_weight_flat = F.embedding_bag( + codes.flatten(0, -2) + codebook_offsets, + codebooks.flatten(0, 1).flatten(-2, -1), + mode="sum" + ) # [prod(dims) * num_out_groups * num_in_groups, out_group_size + # * in_group_size] + + reconstructed_weight_groupwise = reconstructed_weight_flat.view( + list(codes.shape[:-3]) + + [num_out_groups, num_in_groups, out_group_size, in_group_size]) + if scales is not None: + reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul( + scales) + return reconstructed_weight_groupwise.swapaxes( + -3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features]) + + +def dequantize_gemm( + input: torch.Tensor, # [..., in_features] + codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] + codebooks: torch. + Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + bias: Optional[torch.Tensor], +) -> torch.Tensor: + dequantized_weight = dequantize_weight( + unpack_int_data(codes, codebooks.shape[1].bit_length() - 1), + codebooks, + scales, + ) + return F.linear(input, dequantized_weight, bias) + + +# Generic dequantization, slow but flexible. +def generic_dequantize_gemm( + input: torch.Tensor, # [..., in_features] + codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] + codebooks: torch. + Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + output_partition_sizes: List[int], + bias: Optional[torch.Tensor], +) -> torch.Tensor: + output_shape = input.shape[:-1] + (scales.shape[0], ) + output = torch.empty(output_shape, dtype=input.dtype, device=input.device) + num_outputs = len(output_partition_sizes) + + # break the inputs and codebooks apart then combine the outputs. + # Surprisingly (to me) this is faster than doing 3 de-quants and 1 big + # multiply at the end. + num_codebooks = codebooks.shape[0] // num_outputs + assert (scales.shape[0] == codes.shape[0]) + assert (sum(output_partition_sizes) == scales.shape[0]) + output_offset = 0 + codebooks_offset = 0 + for output_size in output_partition_sizes: + shard_output = dequantize_gemm( + input, codes.narrow(0, output_offset, output_size), + codebooks.narrow(0, codebooks_offset, num_codebooks), + scales.narrow(0, output_offset, output_size), None + if bias is None else bias.narrow(0, output_offset, output_size)) + + output_slice = output.narrow(-1, output_offset, output_size) + assert (output_slice.shape == shard_output.shape) + output_slice.copy_(shard_output) + output_offset += output_size + codebooks_offset += num_codebooks + return output + + +# Optimized dequnantize/decompression kernels, supports 1x16 and 2x8 +# at 6 and 9 times faster than the generic version above, respectively. +def optimized_dequantize_gemm( + input: torch.Tensor, # [..., in_features] + codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] + codebooks: torch. + Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + output_partition_sizes: List[int], + bias: Optional[torch.Tensor], +) -> torch.Tensor: + weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) + + if bias is None: + # scaling the output is fastest, so we do that when possible. + output = F.linear(input, weights, bias) + orig_shape = output.shape + flattened_output = output.view(-1, output.size(-1)) + f_scales = scales.view(-1, scales.shape[0]) + b_scales = f_scales.expand(flattened_output.shape[0], -1) + flattened_output *= b_scales + return output.view(orig_shape) + else: + b_scales = scales.view(scales.shape[:-3] + (-1, )).expand( + -1, weights.shape[1]) + weights *= b_scales + return F.linear(input, weights, bias) + + +class AQLMConfig(QuantizationConfig): + """Config class for AQLM. + + Reference: https://github.com/Vahe1994/AQLM + """ + + def __init__( + self, + in_group_size: int, + nbits_per_codebook: int, + num_codebooks: int, + out_group_size: int, + ) -> None: + super().__init__() + self.in_group_size = in_group_size + self.nbits_per_codebook = nbits_per_codebook + self.num_codebooks = num_codebooks + self.out_group_size = out_group_size + + # out_group_size > 1 is untested, and probably won't work as-is. + assert (self.out_group_size == 1) + self.pack_factor = (self.in_group_size * self.out_group_size) + + def __repr__(self) -> str: + return (f"AQLMConfig(in_group_size={self.in_group_size}, " + f"nbits_per_codebook={self.nbits_per_codebook}, " + f"num_codebooks={self.num_codebooks}, " + f"out_group_size={self.out_group_size})") + + @classmethod + def get_name(cls) -> str: + return "aqlm" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 60 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] # no extra configs. + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "AQLMConfig": + in_group_size = cls.get_from_keys(config, ["in_group_size"]) + nbits_per_codebook = cls.get_from_keys(config, ["nbits_per_codebook"]) + num_code_books = cls.get_from_keys(config, ["num_codebooks"]) + out_group_size = cls.get_from_keys(config, ["out_group_size"]) + return cls(in_group_size, nbits_per_codebook, num_code_books, + out_group_size) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["AQLMLinearMethod"]: + if isinstance(layer, LinearBase): + return AQLMLinearMethod(self) + return None + + +class AQLMLinearMethod(LinearMethodBase): + """Linear method for AQLM. + + Args: + quant_config: The AQLM quantization config. + """ + + def __init__(self, quant_config: AQLMConfig): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + del output_size # Unused. + del input_size # Unused. + + if params_dtype != torch.half: + raise ValueError("Only half is currently supported by aqlm") + if input_size_per_partition % self.quant_config.in_group_size != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + output_size_per_partition = sum(output_partition_sizes) + if output_size_per_partition % self.quant_config.out_group_size != 0: + raise ValueError( + "The output size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + codes = Parameter( + torch.empty( + # There could actually be two pack factors, one along input and + # one along output, but we don't currently support + # out_group_size, and only the one along output needs to be + # marked with "packed_dim" in order for QKVLinear to work. + output_size_per_partition, + input_size_per_partition // self.quant_config.pack_factor, + self.quant_config.num_codebooks, + dtype=get_int_dtype(self.quant_config.nbits_per_codebook), + ), + requires_grad=False, + ) + + set_weight_attrs( + codes, + { + "input_dim": 1, + "output_dim": 0, + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + }, + ) + + codebooks = Parameter( + torch.empty( + self.quant_config.num_codebooks * len(output_partition_sizes), + 2**self.quant_config.nbits_per_codebook, + self.quant_config.out_group_size, + self.quant_config.in_group_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs( + codebooks, + { + # metadata indicates fixed size concatenated along dim 0 + "is_metadata": True, + "output_partition_sizes": output_partition_sizes + }, + ) + + scales = Parameter( + torch.empty( + ( + output_size_per_partition // + self.quant_config.out_group_size, + 1, + 1, + 1, + ), + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs( + scales, + { + "output_dim": 0, + "packed_dim": 0, + "pack_factor": self.quant_config.out_group_size + }, + ) + + layer.register_parameter("codes", codes) + set_weight_attrs(codes, extra_weight_attrs) + layer.register_parameter("codebooks", codebooks) + set_weight_attrs(codebooks, extra_weight_attrs) + layer.register_parameter("scales", scales) + set_weight_attrs(scales, extra_weight_attrs) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + codebooks = layer.codebooks + codes = layer.codes + scales = layer.scales + output_partition_sizes = getattr(codebooks, "output_partition_sizes", + []) + + nbooks = codes.shape[2] + ingroups = codebooks.shape[3] + outgroups = codebooks.shape[2] + bits = codebooks.shape[1] + + # We support these formats with dedicated gemm and decompression + # kernels. + if ingroups == 8 and outgroups == 1 and ( + (bits == 256 and nbooks == 2) or (bits == 65536 and nbooks == 1)): + + # thresholds determined by timings on an A6000, one GPU + use_gemv = math.prod(x.shape[:-1]) <= 6 + + return ops.aqlm_gemm( + x, + codes, + codebooks, + scales, + output_partition_sizes, + bias, + ) if use_gemv else optimized_dequantize_gemm( + x, + codes, + codebooks, + scales, + output_partition_sizes, + bias, + ) + + # fall back all unoptimized formats + return generic_dequantize_gemm( + x, + codes, + codebooks, + scales, + output_partition_sizes, + bias, + ) diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py new file mode 100644 index 0000000..a4812e8 --- /dev/null +++ b/vllm/model_executor/layers/quantization/awq.py @@ -0,0 +1,184 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.parameter import (GroupQuantScaleParameter, + PackedvLLMParameter) + + +class AWQConfig(QuantizationConfig): + """Config class for AWQ. + + Reference: https://arxiv.org/abs/2306.00978 + """ + + def __init__( + self, + weight_bits: int, + group_size: int, + zero_point: bool, + modules_to_not_convert: Optional[List[str]] = None, + ) -> None: + super().__init__() + self.weight_bits = weight_bits + self.group_size = group_size + self.zero_point = zero_point + self.modules_to_not_convert = modules_to_not_convert or [] + + if self.weight_bits != 4: + raise ValueError( + "Currently, only 4-bit weight quantization is supported for " + f"AWQ, but got {self.weight_bits} bits.") + self.pack_factor = 32 // self.weight_bits + + def __repr__(self) -> str: + return (f"AWQConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"zero_point={self.zero_point}, " + f"modules_to_not_convert={self.modules_to_not_convert})") + + def get_name(self) -> str: + return "awq" + + def get_supported_act_dtypes(self) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + # The AWQ kernel only supports Turing or newer GPUs. + return 75 + + @staticmethod + def get_config_filenames() -> List[str]: + return [ + "quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq + # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq + "quantize_config.json", + ] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": + weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) + group_size = cls.get_from_keys(config, ["q_group_size", "group_size"]) + zero_point = cls.get_from_keys(config, ["zero_point"]) + modules_to_not_convert = cls.get_from_keys_or( + config, ["modules_to_not_convert"], None) + return cls(weight_bits, group_size, zero_point, modules_to_not_convert) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["LinearMethodBase"]: + if isinstance(layer, LinearBase): + if is_layer_skipped_awq(prefix, self.modules_to_not_convert): + return UnquantizedLinearMethod() + return AWQLinearMethod(self) + return None + + +def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]): + return any(module_name in prefix for module_name in modules_to_not_convert) + + +class AWQLinearMethod(LinearMethodBase): + """Linear method for AWQ. + + Args: + quant_config: The AWQ quantization config. + """ + + def __init__(self, quant_config: AWQConfig): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + if input_size_per_partition % self.quant_config.group_size != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + output_size_per_partition = sum(output_partition_sizes) + if output_size_per_partition % self.quant_config.pack_factor != 0: + raise ValueError( + "The output size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + weight_loader = extra_weight_attrs.get("weight_loader") + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader) + + qzeros = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.quant_config.group_size, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader) + + scales = GroupQuantScaleParameter(data=torch.empty( + input_size_per_partition // self.quant_config.group_size, + output_size_per_partition, + dtype=params_dtype, + ), + input_dim=0, + output_dim=1, + weight_loader=weight_loader) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("qzeros", qzeros) + layer.register_parameter("scales", scales) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.qweight = torch.nn.Parameter(layer.qweight.data, + requires_grad=False) + layer.qzeros = torch.nn.Parameter(layer.qzeros.data, + requires_grad=False) + layer.scales = torch.nn.Parameter(layer.scales.data, + requires_grad=False) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + qweight = layer.qweight + scales = layer.scales + qzeros = layer.qzeros + pack_factor = self.quant_config.pack_factor + out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, )) + reshaped_x = x.reshape(-1, x.shape[-1]) + + # num_tokens >= threshold + # FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256 + FP16_MATMUL_HEURISTIC_CONDITION = False + if FP16_MATMUL_HEURISTIC_CONDITION: + out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0) + out = torch.matmul(reshaped_x, out) + else: + out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, + pack_factor, group_size=self.quant_config.group_size) + if bias is not None: + out.add_(bias) + return out.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py new file mode 100644 index 0000000..185e300 --- /dev/null +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -0,0 +1,677 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Dict, List, Optional + +import torch +from torch.nn import Parameter + +import vllm.model_executor.layers.fused_moe # noqa +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod, + set_weight_attrs) +from vllm.model_executor.layers.quantization.awq import (AWQConfig, + is_layer_skipped_awq) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, + check_marlin_supports_layer, marlin_make_empty_g_idx, + marlin_make_workspace, marlin_moe_permute_scales, marlin_permute_scales, + moe_awq_to_marlin_zero_points, verify_marlin_supported, + verify_marlin_supports_shape) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.parameter import (GroupQuantScaleParameter, + PackedvLLMParameter) +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types +import ixformer.inference.functions as ixfops + +logger = init_logger(__name__) + + +class AWQMarlinConfig(QuantizationConfig): + """Config class for AWQ Marlin""" + + # num_bits -> type + TYPE_MAP = { + 4: scalar_types.uint4, + 8: scalar_types.uint8, + } + + def __init__(self, weight_bits: int, group_size: int, zero_point: bool, + lm_head_quantized: bool, + modules_to_not_convert: Optional[List[str]], + full_config: Dict[str, Any]) -> None: + super().__init__() + self.pack_factor = 32 // weight_bits # packed into int32 + self.group_size = group_size + self.zero_point = zero_point + self.lm_head_quantized = lm_head_quantized + self.weight_bits = weight_bits + self.modules_to_not_convert = modules_to_not_convert or [] + self.full_config = full_config + + if self.weight_bits not in self.TYPE_MAP: + raise ValueError(f"Unsupported num_bits = {self.weight_bits}. " + f"Supported num_bits = {self.TYPE_MAP.keys()}") + + self.quant_type = self.TYPE_MAP[self.weight_bits] + + verify_marlin_supported(self.quant_type, + group_size=self.group_size, + has_zp=self.zero_point) + + def __repr__(self) -> str: + return (f"AWQMarlinConfig(quant_type={self.quant_type}, " + f"group_size={self.group_size}, " + f"zero_point={self.zero_point}, " + f"lm_head_quantized={self.lm_head_quantized}, " + f"modules_to_not_convert={self.modules_to_not_convert})") + + @classmethod + def get_name(cls) -> str: + return "awq_marlin" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "AWQMarlinConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + zero_point = cls.get_from_keys(config, ["zero_point"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + modules_to_not_convert = cls.get_from_keys_or( + config, ["modules_to_not_convert"], None) + return cls(weight_bits, group_size, zero_point, lm_head_quantized, + modules_to_not_convert, config) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg) + is_valid_user_quant = (user_quant is None or user_quant == "marlin" + or user_quant == "awq_marlin") + + if can_convert and is_valid_user_quant: + msg = ("The model is convertible to {} during runtime." + " Using {} kernel.".format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + if can_convert and user_quant == "awq": + logger.info("Detected that the model can run with awq_marlin" + ", however you specified quantization=awq explicitly," + " so forcing awq. Use quantization=awq_marlin for" + " faster inference") + return None + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + if (isinstance(layer, LinearBase) or + (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): + if is_layer_skipped_awq(prefix, self.modules_to_not_convert): + return UnquantizedLinearMethod() + # Check if the layer is supported by AWQMarlin. + if not check_marlin_supports_layer(layer, self.group_size): + logger.warning_once( + f"Layer '{prefix}' is not supported by AWQMarlin. " + "Falling back to unoptimized AWQ kernels.") + return AWQConfig.from_config( + self.full_config).get_quant_method(layer, prefix) + return AWQMarlinLinearMethod(self) + elif isinstance(layer, FusedMoE): + # if layer.local_num_experts > 32: + # # For MoEs with many experts the moe_wna16 kernel is faster + # return MoeWNA16Config.from_config( + # self.full_config).get_quant_method(layer, prefix) + # else: + # return AWQMoEMethod(self) + return AWQMoEMethod(self) + return None + + @classmethod + def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]): + # Extract data from quant config. + quant_method = quant_config.get("quant_method", "").lower() + num_bits = quant_config.get("bits") + group_size = quant_config.get("group_size") + zero_point = quant_config.get("zero_point") + + if not current_platform.is_cuda(): + return False + + if quant_method != "awq": + return False + + # If we cannot find the info needed in the config, cannot convert. + if (num_bits is None or group_size is None or zero_point is None): + return False + + if num_bits not in cls.TYPE_MAP: + return False + + return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits], + group_size=group_size, + has_zp=zero_point) + + +class AWQMarlinLinearMethod(LinearMethodBase): + """Linear method for AWQ Marlin. + + Args: + quant_config: The AWQ Marlin quantization config. + """ + + def __init__(self, quant_config: AWQMarlinConfig) -> None: + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + del output_size + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + # Normalize group_size + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + + verify_marlin_supports_shape( + output_size_per_partition=output_size_per_partition, + input_size_per_partition=input_size_per_partition, + input_size=input_size, + group_size=group_size) + + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader) + + num_groups = input_size_per_partition // group_size + + qzeros = PackedvLLMParameter( + data=torch.empty( + num_groups, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader) + + scales = GroupQuantScaleParameter(data=torch.empty( + num_groups, + output_size_per_partition, + dtype=params_dtype, + ), + input_dim=0, + output_dim=1, + weight_loader=weight_loader) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("qzeros", qzeros) + layer.register_parameter("scales", scales) + + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.num_groups = num_groups + + # TODO: Update this docs + # Checkpoints are serialized in AutoAWQ format, which is different from the + # marlin format. This function is called after the weights are loaded. + # Here, we handle the repacking + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.output_size_per_partition = layer.qweight.data.shape[1] * self.quant_config.pack_factor + align_bits = 64 * 8 + align_size = align_bits // self.quant_config.weight_bits + if layer.output_size_per_partition % align_size != 0: + padding_output_size_per_partition = (layer.output_size_per_partition + align_size - 1) // align_size * align_size + layer.output_padding_size = padding_output_size_per_partition - layer.output_size_per_partition + device = layer.qweight.device + + pad_qweight = torch.zeros( + layer.input_size_per_partition, + padding_output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + device=device, + ) + pad_qzeros = torch.zeros( + layer.num_groups, + padding_output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + device=device, + ) + pad_scales = torch.zeros( + layer.num_groups, + padding_output_size_per_partition, + dtype=layer.scales.data.dtype, + device=device, + ) + pad_qweight[..., :layer.output_size_per_partition // self.quant_config.pack_factor] = layer.qweight.data + pad_qzeros[..., :layer.output_size_per_partition // self.quant_config.pack_factor] = layer.qzeros.data + pad_scales[..., :layer.output_size_per_partition] = layer.scales.data + replace_parameter(layer, "qweight", pad_qweight) + replace_parameter(layer, "qzeros", pad_qzeros) + replace_parameter(layer, "scales", pad_scales) + return + # TODO(gyf) Marlin format is not support for now.. + device = layer.qweight.device + layer.qweight = torch.nn.Parameter(layer.qweight.data, + requires_grad=False) + layer.qzeros = torch.nn.Parameter(layer.qzeros.data, + requires_grad=False) + layer.scales = torch.nn.Parameter(layer.scales.data, + requires_grad=False) + + # Allocate marlin workspace + layer.workspace = marlin_make_workspace( + layer.output_size_per_partition, device) + + # Repack weights from AWQ format to marlin format. + marlin_qweight = ops.awq_marlin_repack( + layer.qweight, + size_k=layer.input_size_per_partition, + size_n=layer.output_size_per_partition, + num_bits=self.quant_config.quant_type.size_bits) + replace_parameter(layer, "qweight", marlin_qweight) + + # Permute scales from AWQ format to marlin format. + marlin_scales = marlin_permute_scales( + layer.scales, + size_k=layer.input_size_per_partition, + size_n=layer.output_size_per_partition, + group_size=self.quant_config.group_size) + replace_parameter(layer, "scales", marlin_scales) + + # Permute zero-points from AWQ format to marlin format. + marlin_zp = awq_to_marlin_zero_points( + layer.qzeros, + size_k=layer.num_groups, + size_n=layer.output_size_per_partition, + num_bits=self.quant_config.quant_type.size_bits) + replace_parameter(layer, "qzeros", marlin_zp) + + # Not-used + layer.g_idx = marlin_make_empty_g_idx(device) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # TODO use awq kernel temporarily.. + qweight = layer.qweight + scales = layer.scales + qzeros = layer.qzeros + pack_factor = self.quant_config.pack_factor + out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, )) + reshaped_x = x.reshape(-1, x.shape[-1]) + + out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, + pack_factor, group_size=self.quant_config.group_size) + if bias is not None: + out.add_(bias) + return out.reshape(out_shape) + # return apply_awq_marlin_linear( + # input=x, + # weight=layer.qweight, + # weight_scale=layer.scales, + # weight_zp=layer.qzeros, + # g_idx=layer.g_idx, + # g_idx_sort_indices=layer.g_idx_sort_indices, + # workspace=layer.workspace, + # quant_type=self.quant_config.quant_type, + # output_size_per_partition=layer.output_size_per_partition, + # input_size_per_partition=layer.input_size_per_partition, + # bias=bias) + + +class AWQMoEMethod(FusedMoEMethodBase): + + def __init__(self, quant_config: AWQMarlinConfig): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + extra_weight_attrs.update({ + "is_transposed": + True, + "quant_method": + FusedMoeWeightScaleSupported.GROUP.value, + }) + + w13_qweight = Parameter( + torch.empty(num_experts, + hidden_size, + 2 * intermediate_size_per_partition // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_qweight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + + w2_qweight = Parameter(torch.empty(num_experts, + intermediate_size_per_partition, + hidden_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_qweight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + + num_groups_w13 = hidden_size // self.quant_config.group_size + num_groups_w2 = (intermediate_size_per_partition // + self.quant_config.group_size) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + w13_scales = Parameter(torch.empty(num_experts, + num_groups_w13, + intermediate_size_per_partition * 2, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_scales", w13_scales) + set_weight_attrs(w13_scales, extra_weight_attrs) + + w2_scales = Parameter(torch.empty(num_experts, + num_groups_w2, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_scales", w2_scales) + set_weight_attrs(w2_scales, extra_weight_attrs) + + # WEIGHT_ZERO_POINT + # Allocate 2 zero points for w1 and w3 respectively. + w13_qzeros = Parameter( + torch.empty(num_experts, + num_groups_w13, + 2 * intermediate_size_per_partition // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_qzeros", w13_qzeros) + set_weight_attrs(w13_qzeros, extra_weight_attrs) + + w2_qzeros = Parameter(torch.empty(num_experts, + num_groups_w2, + hidden_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_qzeros", w2_qzeros) + set_weight_attrs(w2_qzeros, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + return + # TODO(gyf) Marlin format is not support for now.. + num_experts = layer.w13_qweight.shape[0] + device = layer.w13_qweight.device + + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + + marlin_w13_qweight = ops.awq_marlin_moe_repack( + layer.w13_qweight, + layer.w13_g_idx_sort_indices, + size_k=layer.w13_qweight.shape[1], + size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits, + ) + replace_parameter(layer, "w13_qweight", marlin_w13_qweight) + + marlin_w2_qweight = ops.awq_marlin_moe_repack( + layer.w2_qweight, + layer.w2_g_idx_sort_indices, + size_k=layer.w2_qweight.shape[1], + size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits, + ) + replace_parameter(layer, "w2_qweight", marlin_w2_qweight) + + # Why does this take the intermediate size for size_k? + marlin_w13_scales = marlin_moe_permute_scales( + s=layer.w13_scales, + size_k=layer.intermediate_size_per_partition, + size_n=layer.w13_scales.shape[2], + group_size=self.quant_config.group_size, + ) + + replace_parameter(layer, "w13_scales", marlin_w13_scales) + + marlin_w2_scales = marlin_moe_permute_scales( + s=layer.w2_scales, + size_k=layer.intermediate_size_per_partition, + size_n=layer.w2_scales.shape[2], + group_size=self.quant_config.group_size, + ) + replace_parameter(layer, "w2_scales", marlin_w2_scales) + + marlin_w13_zp = moe_awq_to_marlin_zero_points( + layer.w13_qzeros, + size_k=layer.w13_qzeros.shape[1], + size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits) + replace_parameter(layer, "w13_qzeros", marlin_w13_zp) + + marlin_w2_zp = moe_awq_to_marlin_zero_points( + layer.w2_qzeros, + size_k=layer.w2_qzeros.shape[1], + size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits) + replace_parameter(layer, "w2_qzeros", marlin_w2_zp) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + extra_residual: torch.Tensor = None, + routed_scaling_factor: float = 1.0, + ) -> torch.Tensor: + assert activation == "silu", "Only SiLU activation is supported." + use_ep = expert_map is not None + if use_ep: + start_eid = layer.ep_rank * layer.local_num_experts + end_eid = min((layer.ep_rank + 1) * layer.local_num_experts, global_num_experts) + if apply_router_weight_on_input: + raise NotImplementedError( + "Apply router weight on input is not supported for" + "fused Marlin MoE method.") + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + num_tokens, num_experts = router_logits.shape + if use_ep: + hidden_size = x.shape[1] + ( + src_to_dst, + sorted_token_ids, + expert_sizes_gpu, + expert_sizes_cpu, + expand_tokens, + ) = ixfops.moe_compute_token_index_ep( + topk_ids=topk_ids, + num_experts=num_experts, + start_expert_id=start_eid, + end_expert_id=end_eid, + ) + if expert_sizes_cpu.sum() == 0: + return torch.zeros( + (num_tokens, hidden_size), + device=x.device, + dtype=x.dtype, + ) + else: + expand_tokens = num_tokens * top_k + ( + src_to_dst, + sorted_token_ids, + expert_sizes_gpu, + expert_sizes_cpu, + ) = ixfops.moe_compute_token_index( + topk_ids=topk_ids, + num_experts=num_experts, + ) + expert_sizes_cpu = expert_sizes_gpu.cpu() + + # expand + reorder + # TODO use kernel + expand_hidden_states = ixfops.moe_expand_input( + hidden_states=x, + dst_to_src=sorted_token_ids, + dst_tokens=expand_tokens, + topk=top_k, + src_to_dst=src_to_dst, + ) + + # w4a16 group gemm 1 + # pt_output_1: (expand_tokens, 2n) dtype + pt_output_1 = ixfops.moe_w4a16_group_gemm( + input=expand_hidden_states, + weight=layer.w13_qweight, + w_scales=layer.w13_scales, + quant_type="awq", + tokens_per_experts=expert_sizes_cpu, + w_zeros=layer.w13_qzeros, + group_size=self.quant_config.group_size, + dst_to_src=None, + format="NN", + tokens_per_experts_gpu=expert_sizes_gpu, + ) + + # act + pt_output_2 = ixfops.silu_and_mul(pt_output_1) + + # w4a16 group gemm 2 + reorder + # pt_output_3: (expand_tokens, k) dtype + if use_ep: + pt_output_3 = torch.empty( + (num_tokens * top_k, hidden_size), + device=x.device, + dtype=x.dtype, + ) + + ixfops.moe_w4a16_group_gemm( + input=pt_output_2, + weight=layer.w2_qweight, + w_scales=layer.w2_scales, + quant_type="awq", + tokens_per_experts=expert_sizes_cpu, + w_zeros=layer.w2_qzeros, + group_size=self.quant_config.group_size, + dst_to_src=sorted_token_ids, + format="NN", + output=pt_output_3, + ) + + reduce_mask = src_to_dst == -1 + final_hidden_states = ixfops.moe_output_reduce_sum( + input=pt_output_3.view(num_tokens, top_k, -1), + topk_weight=topk_weights, + extra_residual=extra_residual, + scaling_factor=routed_scaling_factor, + mask=reduce_mask, + ) + else: + pt_output_3 = ixfops.moe_w4a16_group_gemm( + input=pt_output_2, + weight=layer.w2_qweight, + w_scales=layer.w2_scales, + quant_type="awq", + tokens_per_experts=expert_sizes_cpu, + w_zeros=layer.w2_qzeros, + group_size=self.quant_config.group_size, + dst_to_src=sorted_token_ids, + format="NN", + tokens_per_experts_gpu=expert_sizes_gpu, + ) + + # mul + reduce_sum + # final_hidden_states: (num_tokens, k) + final_hidden_states = ixfops.moe_output_reduce_sum( + input=pt_output_3.view(num_tokens, top_k, -1), + topk_weight=topk_weights, + extra_residual=extra_residual, + scaling_factor=routed_scaling_factor + ) + return final_hidden_states + # return torch.ops.vllm.fused_marlin_moe( + # x, + # layer.w13_qweight, + # layer.w2_qweight, + # layer.w13_scales, + # layer.w2_scales, + # router_logits, + # topk_weights, + # topk_ids, + # w1_zeros=layer.w13_qzeros, + # w2_zeros=layer.w2_qzeros, + # num_bits=self.quant_config.weight_bits, + # ) diff --git a/vllm/model_executor/layers/quantization/awq_triton.py b/vllm/model_executor/layers/quantization/awq_triton.py new file mode 100644 index 0000000..09efd4d --- /dev/null +++ b/vllm/model_executor/layers/quantization/awq_triton.py @@ -0,0 +1,319 @@ +# SPDX-License-Identifier: Apache-2.0 + +import torch +import triton +import triton.language as tl + +AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + + +@triton.jit +def awq_dequantize_kernel( + qweight_ptr, # quantized matrix + scales_ptr, # scales, per group + zeros_ptr, # zeros, per group + group_size, # Should always be one of the supported group sizes + result_ptr, # Output matrix + num_cols, # input num cols in qweight + num_rows, # input num rows in qweight + BLOCK_SIZE_X: tl.constexpr, + BLOCK_SIZE_Y: tl.constexpr): + # Setup the pids. + pid_x = tl.program_id(axis=0) + pid_y = tl.program_id(axis=1) + + # Compute offsets and masks for qweight_ptr. + offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y) + offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X) + offsets = num_cols * offsets_y[:, None] + offsets_x[None, :] + + masks_y = offsets_y < num_rows + masks_x = offsets_x < num_cols + + masks = masks_y[:, None] & masks_x[None, :] + + # Compute offsets and masks for result output ptr. + result_offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y) + result_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange( + 0, BLOCK_SIZE_X * 8) + result_offsets = (8 * num_cols * result_offsets_y[:, None] + + result_offsets_x[None, :]) + + result_masks_y = result_offsets_y < num_rows + result_masks_x = result_offsets_x < num_cols * 8 + result_masks = result_masks_y[:, None] & result_masks_x[None, :] + + # Load the weights. + iweights = tl.load(qweight_ptr + offsets, masks, 0.0) + iweights = tl.interleave(iweights, iweights) + iweights = tl.interleave(iweights, iweights) + iweights = tl.interleave(iweights, iweights) + + # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7] + # that will map given indices to the correct order. + reverse_awq_order_tensor = ((tl.arange(0, 2) * 4)[None, :] + + tl.arange(0, 4)[:, None]).reshape(8) + + # Use this to compute a set of shifts that can be used to unpack and + # reorder the values in iweights and zeros. + shifts = reverse_awq_order_tensor * 4 + shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_Y * BLOCK_SIZE_X, 8)) + shifts = tl.reshape(shifts, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) + + # Unpack and reorder: shift out the correct 4-bit value and mask. + iweights = (iweights >> shifts) & 0xF + + # Compute zero offsets and masks. + zero_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1) + zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X) + zero_offsets = num_cols * zero_offsets_y[:, None] + zero_offsets_x[None, :] + + zero_masks_y = zero_offsets_y < num_rows // group_size + zero_masks_x = zero_offsets_x < num_cols + zero_masks = zero_masks_y[:, None] & zero_masks_x[None, :] + + # Load the zeros. + zeros = tl.load(zeros_ptr + zero_offsets, zero_masks, 0.0) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) + + # Unpack and reorder: shift out the correct 4-bit value and mask. + zeros = (zeros >> shifts) & 0xF + + # Compute scale offsets and masks. + scale_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1) + scale_offsets_x = (pid_x * BLOCK_SIZE_X * 8 + + tl.arange(0, BLOCK_SIZE_X * 8)) + scale_offsets = (num_cols * 8 * scale_offsets_y[:, None] + + scale_offsets_x[None, :]) + scale_masks_y = scale_offsets_y < num_rows // group_size + scale_masks_x = scale_offsets_x < num_cols * 8 + scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :] + + # Load the scales. + scales = tl.load(scales_ptr + scale_offsets, scale_masks, 0.0) + scales = tl.broadcast_to(scales, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) + + # Dequantize. + iweights = (iweights - zeros) * scales + iweights = iweights.to(result_ptr.type.element_ty) + + # Finally, store. + tl.store(result_ptr + result_offsets, iweights, result_masks) + + +@triton.jit +def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K, + group_size, BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + SPLIT_K: tl.constexpr): + pid = tl.program_id(axis=0) + pid_z = tl.program_id(1) + + # NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead. + # num_pid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + accumulator_dtype = c_ptr.type.element_ty + + # NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead. + # accumulator = tl.arange(0, BLOCK_SIZE_N) + # accumulator = tl.broadcast_to(accumulator[None, :], + # (BLOCK_SIZE_M, BLOCK_SIZE_N)) + # accumulator = accumulator & 0x0 + # accumulator = accumulator.to(accumulator_dtype) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), + dtype=accumulator_dtype) + + # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7] + # that will map given indices to the correct order. + reverse_awq_order_tensor = ((tl.arange(0, 2) * 4)[None, :] + + tl.arange(0, 4)[:, None]).reshape(8) + + # Create the necessary shifts to use to unpack. + shifts = reverse_awq_order_tensor * 4 + shifts = tl.broadcast_to(shifts[None, :], + (BLOCK_SIZE_K * (BLOCK_SIZE_N // 8), 8)) + shifts = tl.reshape(shifts, (BLOCK_SIZE_K, BLOCK_SIZE_N)) + + # Offsets and masks. + offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + masks_am = offsets_am < M + + offsets_bn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8) + masks_bn = offsets_bn < N // 8 + + offsets_zn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8) + masks_zn = offsets_zn < N // 8 + + offsets_sn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + masks_sn = offsets_sn < N + + offsets_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offsets_a = K * offsets_am[:, None] + offsets_k[None, :] + offsets_b = (N // 8) * offsets_k[:, None] + offsets_bn[None, :] + + a_ptrs = a_ptr + offsets_a + b_ptrs = b_ptr + offsets_b + + # NOTE: Use this in TRITON_INTERPRET=1 mode instead of tl.cdiv + # block_offset = BLOCK_SIZE_K * SPLIT_K + # for k in range(0, (K + block_offset - 1) // (block_offset)): + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + masks_k = offsets_k < K + masks_a = masks_am[:, None] & masks_k[None, :] + a = tl.load(a_ptrs, mask=masks_a, other=0.0) + + masks_b = masks_k[:, None] & masks_bn[None, :] + b = tl.load(b_ptrs, mask=masks_b, other=0.0) + b = tl.interleave(b, b) + b = tl.interleave(b, b) + b = tl.interleave(b, b) + + # Dequantize b. + offsets_szk = ( + (BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K) // group_size + + tl.arange(0, 1)) + offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :] + masks_zk = offsets_szk < K // group_size + masks_z = masks_zk[:, None] & masks_zn[None, :] + zeros_ptrs = zeros_ptr + offsets_z + zeros = tl.load(zeros_ptrs, mask=masks_z, other=0.0) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_K, BLOCK_SIZE_N)) + + offsets_s = N * offsets_szk[:, None] + offsets_sn[None, :] + masks_sk = offsets_szk < K // group_size + masks_s = masks_sk[:, None] & masks_sn[None, :] + scales_ptrs = scales_ptr + offsets_s + scales = tl.load(scales_ptrs, mask=masks_s, other=0.0) + scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N)) + + b = (b >> shifts) & 0xF + zeros = (zeros >> shifts) & 0xF + b = (b - zeros) * scales + b = b.to(c_ptr.type.element_ty) + + # Accumulate results. + accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype) + + offsets_k += BLOCK_SIZE_K * SPLIT_K + a_ptrs += BLOCK_SIZE_K * SPLIT_K + b_ptrs += BLOCK_SIZE_K * SPLIT_K * (N // 8) + + c = accumulator.to(c_ptr.type.element_ty) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + pid_z * N * M + N * offs_cm[:, None] + offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +# qweights - [K , M // 8], int32 +# scales - [K // G, M ], float16 +# zeros - [K // G, M // 8], int32 +def awq_dequantize_triton(qweight: torch.Tensor, + scales: torch.Tensor, + zeros: torch.Tensor, + block_size_x: int = 32, + block_size_y: int = 32) -> torch.Tensor: + K = qweight.shape[0] + M = scales.shape[1] + group_size = qweight.shape[0] // scales.shape[0] + + assert K > 0 and M > 0 + assert scales.shape[0] == K // group_size and scales.shape[1] == M + assert zeros.shape[0] == K // group_size and zeros.shape[1] == M // 8 + assert group_size <= K + assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K + + # Result tensor: + # number of rows = same as input tensor + # number of cols = 8 x input tensor num cols + result = torch.empty(qweight.shape[0], + qweight.shape[1] * 8, + device=qweight.device, + dtype=scales.dtype) + + Y = qweight.shape[0] # num rows + X = qweight.shape[1] # num cols + + grid = lambda META: ( + triton.cdiv(X, META['BLOCK_SIZE_X']), + triton.cdiv(Y, META['BLOCK_SIZE_Y']), + ) + awq_dequantize_kernel[grid](qweight, + scales, + zeros, + group_size, + result, + X, + Y, + BLOCK_SIZE_X=block_size_x, + BLOCK_SIZE_Y=block_size_y) + + return result + + +# input - [M, K] +# qweight - [K, N // 8] +# qzeros - [K // G, N // 8] +# scales - [K // G, N] +# split_k_iters - parallelism along K-dimension, int, power of 2. +def awq_gemm_triton(input: torch.Tensor, + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + split_k_iters: int, + block_size_m: int = 32, + block_size_n: int = 32, + block_size_k: int = 32) -> torch.Tensor: + M, K = input.shape + N = qweight.shape[1] * 8 + group_size = qweight.shape[0] // qzeros.shape[0] + + assert N > 0 and K > 0 and M > 0 + assert qweight.shape[0] == K and qweight.shape[1] == N // 8 + assert qzeros.shape[0] == K // group_size and qzeros.shape[1] == N // 8 + assert scales.shape[0] == K // group_size and scales.shape[1] == N + assert split_k_iters & (split_k_iters - 1) == 0 and split_k_iters != 0 + assert split_k_iters <= 32 + assert group_size <= K + assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( + N, META['BLOCK_SIZE_N']), + split_k_iters, + ) + + result = torch.zeros((split_k_iters, M, N), + dtype=scales.dtype, + device=input.device) + + # A = input, B = qweight, C = result + # A = M x K, B = K x N, C = M x N + awq_gemm_kernel[grid](input, + qweight, + result, + qzeros, + scales, + M, + N, + K, + group_size, + BLOCK_SIZE_M=block_size_m, + BLOCK_SIZE_N=block_size_n, + BLOCK_SIZE_K=block_size_k, + SPLIT_K=split_k_iters) + + result = result.sum(0) + + return result diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py new file mode 100644 index 0000000..5ef1154 --- /dev/null +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 + +import inspect +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Type + +import torch +from torch import nn + + +class QuantizeMethodBase(ABC): + """Base class for different quantized methods.""" + + @abstractmethod + def create_weights(self, layer: torch.nn.Module, *weight_args, + **extra_weight_attrs): + """Create weights for a layer. + + The weights will be set as attributes of the layer.""" + raise NotImplementedError + + @abstractmethod + def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor: + """Apply the weights in layer to the input tensor. + + Expects create_weights to have been called before on the layer.""" + raise NotImplementedError + + # Not required functions + def embedding(self, layer: torch.nn.Module, *args, + **kwargs) -> torch.Tensor: + """Gather embeddings in the layer based on indices in the input tensor. + + Expects create_weights to have been called before on the layer.""" + raise NotImplementedError + + def process_weights_after_loading(self, layer: nn.Module) -> None: + """Process the weight after loading. + + This can be used for example, to transpose weights for computation. + """ + return + + +def method_has_implemented_embedding( + method_class: Type[QuantizeMethodBase]) -> bool: + """ + Not all quant methods have embedding implemented, so we need to check that + it exists for our given method. We check this by making sure the function + has been changed from the base implementation. + """ + base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding", + None) + class_embedding = inspect.getattr_static(method_class, "embedding", None) + + return (class_embedding is not None + and class_embedding is not base_embedding) + + +class QuantizationConfig(ABC): + """Base class for quantization configs.""" + + def __init__(self): + super().__init__() + # mapping is updated by models as they initialize + self.packed_modules_mapping: Dict[str, List[str]] = dict() + + @abstractmethod + def get_name(self) -> str: + """Name of the quantization method.""" + raise NotImplementedError + + @abstractmethod + def get_supported_act_dtypes(self) -> List[torch.dtype]: + """List of supported activation dtypes.""" + raise NotImplementedError + + @classmethod + @abstractmethod + def get_min_capability(cls) -> int: + """Minimum GPU capability to support the quantization method. + + E.g., 70 for Volta, 75 for Turing, 80 for Ampere. + This requirement is due to the custom CUDA kernels used by the + quantization method. + """ + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_config_filenames() -> List[str]: + """List of filenames to search for in the model directory.""" + raise NotImplementedError + + @classmethod + @abstractmethod + def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig": + """Create a config class from the model's quantization config.""" + raise NotImplementedError + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + """ + Detects if this quantization method can support a given checkpoint + format by overriding the user specified quantization method -- + this method should only be overwritten by subclasses in exceptional + circumstances + """ + return None + + @staticmethod + def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any: + """Get a value from the model's quantization config.""" + for key in keys: + if key in config: + return config[key] + raise ValueError(f"Cannot find any of {keys} in the model's " + "quantization config.") + + @staticmethod + def get_from_keys_or(config: Dict[str, Any], keys: List[str], + default: Any) -> Any: + """Get a optional value from the model's quantization config.""" + try: + return QuantizationConfig.get_from_keys(config, keys) + except ValueError: + return default + + @abstractmethod + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional[QuantizeMethodBase]: + """Get the quantize method to use for the quantized layer. + + Args: + layer: The layer for the quant method. + prefix: The full name of the layer in the state dict + Returns: + The quantize method. None if the given layer doesn't support quant + method. + """ + raise NotImplementedError + + def get_cache_scale(self, name: str) -> Optional[str]: + return None diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py new file mode 100644 index 0000000..704531b --- /dev/null +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -0,0 +1,396 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional + +import torch + +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod, + set_weight_attrs) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.utils import direct_register_custom_op + + +class BitsAndBytesConfig(QuantizationConfig): + """Config class for BitsAndBytes Quantization. + + Reference: https://arxiv.org/abs/2305.14314 + """ + + def __init__( + self, + load_in_8bit: bool = False, + load_in_4bit: bool = True, + bnb_4bit_compute_dtype: str = "float32", + bnb_4bit_quant_storage: str = "uint8", + bnb_4bit_quant_type: str = "fp4", + bnb_4bit_use_double_quant: bool = False, + llm_int8_enable_fp32_cpu_offload: bool = False, + llm_int8_has_fp16_weight: bool = False, + llm_int8_skip_modules: Optional[List[str]] = None, + llm_int8_threshold: float = 6.0, + ) -> None: + super().__init__() + self.load_in_8bit = load_in_8bit + self.load_in_4bit = load_in_4bit + self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype + self.bnb_4bit_quant_storage = bnb_4bit_quant_storage + self.bnb_4bit_quant_type = bnb_4bit_quant_type + self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant + self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload + self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight + self.llm_int8_skip_modules = llm_int8_skip_modules or [] + self.llm_int8_threshold = llm_int8_threshold + + if self.bnb_4bit_quant_storage not in ["uint8"]: + raise ValueError("Unsupported bnb_4bit_quant_storage: " + f"{self.bnb_4bit_quant_storage}") + + def __repr__(self) -> str: + return (f"BitsAndBytesConfig(load_in_8bit={self.load_in_8bit}, " + f"load_in_4bit={self.load_in_4bit}, " + f"bnb_4bit_compute_dtype={self.bnb_4bit_compute_dtype}, " + f"bnb_4bit_quant_storage={self.bnb_4bit_quant_storage}, " + f"bnb_4bit_quant_type={self.bnb_4bit_quant_type}, " + f"llm_int8_skip_modules={self.llm_int8_skip_modules})") + + @classmethod + def get_name(self) -> str: + return "bitsandbytes" + + @classmethod + def get_supported_act_dtypes(self) -> List[torch.dtype]: + return [torch.float32, torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + @staticmethod + def get_config_filenames() -> List[str]: + return [ + "adapter_config.json", + ] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig": + + def get_safe_value(config, keys, default_value=None): + try: + value = cls.get_from_keys(config, keys) + return value if value is not None else default_value + except ValueError: + return default_value + + load_in_8bit = get_safe_value(config, ["load_in_8bit"], + default_value=False) + load_in_4bit = get_safe_value(config, ["load_in_4bit"], + default_value=True) + bnb_4bit_compute_dtype = get_safe_value(config, + ["bnb_4bit_compute_dtype"], + default_value="float32") + bnb_4bit_quant_storage = get_safe_value(config, + ["bnb_4bit_quant_storage"], + default_value="uint8") + bnb_4bit_quant_type = get_safe_value(config, ["bnb_4bit_quant_type"], + default_value="fp4") + bnb_4bit_use_double_quant = get_safe_value( + config, ["bnb_4bit_use_double_quant"], default_value=False) + llm_int8_enable_fp32_cpu_offload = get_safe_value( + config, ["llm_int8_enable_fp32_cpu_offload"], default_value=False) + llm_int8_has_fp16_weight = get_safe_value(config, + ["llm_int8_has_fp16_weight"], + default_value=False) + llm_int8_skip_modules = get_safe_value(config, + ["llm_int8_skip_modules"], + default_value=[]) + llm_int8_threshold = get_safe_value(config, ["llm_int8_threshold"], + default_value=6.0) + + return cls( + load_in_8bit=load_in_8bit, + load_in_4bit=load_in_4bit, + bnb_4bit_compute_dtype=bnb_4bit_compute_dtype, + bnb_4bit_quant_storage=bnb_4bit_quant_storage, + bnb_4bit_quant_type=bnb_4bit_quant_type, + bnb_4bit_use_double_quant=bnb_4bit_use_double_quant, + llm_int8_enable_fp32_cpu_offload=llm_int8_enable_fp32_cpu_offload, + llm_int8_has_fp16_weight=llm_int8_has_fp16_weight, + llm_int8_skip_modules=llm_int8_skip_modules, + llm_int8_threshold=llm_int8_threshold) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["LinearMethodBase"]: + if isinstance(layer, LinearBase): + if is_layer_skipped_bnb(prefix, self.llm_int8_skip_modules): + return UnquantizedLinearMethod() + return BitsAndBytesLinearMethod(self) + return None + + +def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: List[str]): + # Split the prefix into its dot-separated components + components = prefix.split('.') + + # Check if any of the skip modules exactly matches any component + substr_check = any(module_name in components + for module_name in llm_int8_skip_modules) + + # Allow certain layers to not be quantized + set_components = set(".".join(components[:i + 1]) + for i in range(len(components))) + set_llm_int8_skip_modules = set(llm_int8_skip_modules) + prefix_check = len(set_llm_int8_skip_modules & set_components) != 0 + + return substr_check or prefix_check + + +class BitsAndBytesLinearMethod(LinearMethodBase): + """Linear method for BitsAndBytes. + + Args: + quant_config: The BitsAndBytes quantization config. + """ + + def __init__(self, quant_config: BitsAndBytesConfig): + try: + import bitsandbytes + if bitsandbytes.__version__ < "0.45.3": + raise ImportError("bitsandbytes version is wrong. Please " + "install bitsandbytes>=0.45.3.") + except ImportError as err: + raise ImportError("Please install bitsandbytes>=0.45.3 via " + "`pip install bitsandbytes>=0.45.3` to use " + "bitsandbytes quantizer.") from err + + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + from bitsandbytes.nn import Int8Params + + def calculate_quant_ratio(dtype): + if dtype.is_floating_point: + return torch.finfo(dtype).bits // torch.iinfo(torch.uint8).bits + else: + return torch.iinfo(dtype).bits // torch.iinfo(torch.uint8).bits + + def create_qweight_for_8bit(): + qweight = Int8Params( + data=torch.empty(sum(output_partition_sizes), + input_size_per_partition, + dtype=torch.int8), + has_fp16_weights=self.quant_config.llm_int8_has_fp16_weight, + requires_grad=False) + set_weight_attrs( + qweight, { + "input_dim": 0, + "output_dim": 0, + "pack_factor": 1, + "use_bitsandbytes_8bit": True, + "generation": 0 + }) + return qweight + + def create_qweight_for_4bit(): + quant_ratio = calculate_quant_ratio(params_dtype) + + total_size = input_size_per_partition * sum(output_partition_sizes) + if total_size % quant_ratio != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape.") + + qweight = torch.nn.Parameter(torch.empty(total_size // quant_ratio, + 1, + dtype=torch.uint8), + requires_grad=False) + set_weight_attrs( + qweight, { + "input_dim": 0, + "output_dim": 0, + "pack_factor": quant_ratio, + "use_bitsandbytes_4bit": True + }) + return qweight + + if self.quant_config.load_in_8bit: + qweight = create_qweight_for_8bit() + else: + qweight = create_qweight_for_4bit() + # Enable parameters to have the same name as in the BNB + # checkpoint format. + layer.register_parameter("weight", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + if self.quant_config.load_in_8bit: + return self._apply_8bit_weight(layer, x, bias) + else: + return self._apply_4bit_weight(layer, x, bias) + + def _apply_8bit_weight( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + # only load the bitsandbytes module when needed + from bitsandbytes import MatmulLtState, matmul + + original_type = x.dtype + original_shape = x.shape + reshape_after_matmul = False + if x.ndim > 2: + x = x.reshape(-1, x.size(-1)) + reshape_after_matmul = True + bf_x = x.to(torch.bfloat16) + + qweight = layer.weight + offsets = qweight.bnb_shard_offsets + quant_states = qweight.bnb_quant_state + matmul_states = qweight.matmul_state + generation = qweight.generation + + out_dim_0 = x.shape[0] + out_dim_1 = sum( + [quant_state[1].shape[0] for quant_state in quant_states.items()]) + out = torch.empty(out_dim_0, + out_dim_1, + dtype=torch.float16, + device=x.device) + + current_index = 0 + for i in range(len(quant_states)): + output_size = quant_states[i].shape[0] + + # in profile_run or the first generation of inference, + # create new matmul_states + if generation == 0 or generation == 1: + matmul_states[i] = MatmulLtState() + matmul_states[i].CB = qweight[offsets[i]:offsets[i + 1]] + matmul_states[i].SCB = quant_states[i].to(x.device) + matmul_states[i].threshold = ( + self.quant_config.llm_int8_threshold) + matmul_states[i].has_fp16_weights = ( + self.quant_config.llm_int8_has_fp16_weight) + matmul_states[i].is_training = False + if matmul_states[i].threshold > 0.0 and not matmul_states[ + i].has_fp16_weights: + matmul_states[i].use_pool = True + + new_x = bf_x.unsqueeze(0) + + out[:, current_index:current_index + output_size] = matmul( + new_x, + qweight[offsets[i]:offsets[i + 1]], + state=matmul_states[i]) + + current_index += output_size + + # only update the matmul_states if it is not profile_run + if (generation > 0 + and not self.quant_config.llm_int8_has_fp16_weight + and matmul_states[i].CB is not None + and matmul_states[i].CxB is not None): + del matmul_states[i].CB + qweight[offsets[i]:offsets[i + 1]] = matmul_states[i].CxB + + out = out.to(original_type) + + if reshape_after_matmul: + out = out.view(*original_shape[:-1], out.size(-1)) + + if bias is not None: + out += bias + + qweight.generation += 1 + + return out + + def _apply_4bit_weight( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + original_type = x.dtype + original_shape = x.shape + reshape_after_matmul = False + if x.ndim > 2: + x = x.reshape(-1, x.size(-1)) + reshape_after_matmul = True + bf_x = x.to(torch.bfloat16) + + qweight = layer.weight + quant_states = qweight.bnb_quant_state + offsets = qweight.bnb_shard_offsets + + out_dim_0 = x.shape[0] + out_dim_1 = sum( + [quant_state[1].shape[0] for quant_state in quant_states.items()]) + out = torch.empty(out_dim_0, + out_dim_1, + dtype=torch.bfloat16, + device=x.device) + apply_bnb_4bit(bf_x, qweight, offsets, out) + out = out.to(original_type) + + if reshape_after_matmul: + out = out.view(*original_shape[:-1], out.size(-1)) + + if bias is not None: + out += bias + + return out + + +def _apply_bnb_4bit( + x: torch.Tensor, + weight: torch.Tensor, + offsets: torch.Tensor, + out: torch.Tensor, +) -> None: + # only load the bitsandbytes module when needed + from bitsandbytes import matmul_4bit + quant_states = weight.bnb_quant_state + current_index = 0 + for i in range(len(quant_states)): + output_size = quant_states[i].shape[0] + # It is more efficient to use out kwarg like + # matmul_4bit(..., out = ...). Infeasible now due to the bug + # https://github.com/TimDettmers/bitsandbytes/issues/1235. + # Need to change after the bug is fixed. + out[:, current_index:current_index + output_size] = matmul_4bit( + x, weight[offsets[i]:offsets[i + 1]].t(), quant_states[i]) + current_index += output_size + + +def _apply_bnb_4bit_fake( + x: torch.Tensor, + weight: torch.Tensor, + offsets: torch.Tensor, + out: torch.Tensor, +) -> None: + return + + +try: + direct_register_custom_op( + op_name="apply_bnb_4bit", + op_func=_apply_bnb_4bit, + mutates_args=["out"], + fake_impl=_apply_bnb_4bit_fake, + ) + apply_bnb_4bit = None + +except AttributeError as error: + raise error diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py new file mode 100644 index 0000000..2efaf5a --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -0,0 +1,625 @@ +# SPDX-License-Identifier: Apache-2.0 + +from contextlib import suppress +from typing import Any, Dict, List, Literal, Optional, Tuple, cast + +import torch +from compressed_tensors.config import (CompressionFormat, + SparsityCompressionConfig, + SparsityStructure) +from compressed_tensors.quantization import (QuantizationArgs, + QuantizationStrategy, + QuantizationType) +from pydantic import BaseModel + +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501 + CompressedTensorsMoEMethod) +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24, + CompressedTensorsScheme, CompressedTensorsW4A16Sparse24, + CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, + CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + find_matched_target, is_activation_quantization_format, + should_ignore_layer) +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.platforms import current_platform + +logger = init_logger(__name__) + +__all__ = ["CompressedTensorsLinearMethod"] + +SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config" +QUANTIZATION_SCHEME_MAP_TYPE = Dict[str, Optional[Dict[str, QuantizationArgs]]] + + +class CompressedTensorsConfig(QuantizationConfig): + + def __init__( + self, + target_scheme_map: Dict[str, Any], + ignore: List[str], + quant_format: str, + sparsity_scheme_map: Dict[str, SparsityCompressionConfig], + sparsity_ignore_list: List[str], + kv_cache_scheme: Optional[Dict[str, Any]] = None, + config: Optional[Dict[str, Any]] = None, + ): + super().__init__() + self.ignore = ignore + self.quant_format = quant_format + # Map from [target -> scheme] + self.target_scheme_map = target_scheme_map + self.kv_cache_scheme = kv_cache_scheme + self.sparsity_scheme_map = sparsity_scheme_map + self.sparsity_ignore_list = sparsity_ignore_list + self.config = config + + def get_linear_method(self) -> "CompressedTensorsLinearMethod": + return CompressedTensorsLinearMethod(self) + + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + def get_name(self) -> str: + return "compressed_tensors" + + def get_quant_method( + self, + layer: torch.nn.Module, + prefix: str, + ) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + + # Check if the layer is skipped for quantization. + # TODO (@robertgshaw2): support module names + if should_ignore_layer(prefix, + ignore=self.ignore, + fused_mapping=self.packed_modules_mapping): + return UnquantizedLinearMethod() + if isinstance(layer, LinearBase): + scheme = self.get_scheme(layer=layer, layer_name=prefix) + if scheme is None: + return UnquantizedLinearMethod() + layer.scheme = scheme + return CompressedTensorsLinearMethod(self) + if isinstance(layer, Attention): + return CompressedTensorsKVCacheMethod(self) + if isinstance(layer, FusedMoE): + return CompressedTensorsMoEMethod.get_moe_method( + self, layer.activation, layer.expert_map) + return None + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": + ignore: List[str] = cast(List[str], config.get("ignore", [])) + quant_format = cast(str, config.get("format")) + target_scheme_map = cls._quantization_scheme_map_from_config( + config=config) + sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config( + config=config) + + return cls( + target_scheme_map=target_scheme_map, + ignore=ignore, + quant_format=quant_format, + sparsity_scheme_map=sparsity_scheme_map, + sparsity_ignore_list=sparsity_ignore_list, + config=config, + ) + + @classmethod + def _parse_sparsity_config( + cls, config: Dict[str, Any] + ) -> Tuple[Dict[str, SparsityCompressionConfig], List[str]]: + """ + :param config: The `quantization_config` dictionary from config.json + :return: A tuple with two elements + 1. A dictionary mapping target layer names to their corresponding + sparsity_config + 2. A list of layer names to ignore for sparsity + """ + if not (sparsity_config := config.get(SPARSITY_CONFIG_NAME)): + return dict(), [] + + sparsity_config = SparsityCompressionConfig.model_validate( + sparsity_config) + sparse_scheme_map: Dict[str, SparsityCompressionConfig] = { + target: sparsity_config + for target in sparsity_config.targets or list() + } + sparsity_ignore_list = sparsity_config.ignore or list() + return sparse_scheme_map, sparsity_ignore_list + + @classmethod + def _quantization_scheme_map_from_config( + cls, config: Dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE: + """ + :param config: The `quantization_config` dictionary from config.json + :return: A dictionary mapping target layer names to their corresponding + quantization_args for weights and input activations + """ + target_scheme_map: Dict[str, Any] = dict() + quant_format = cast(str, config.get("format")) + + # The quant_config has multiple config_groups, each containing + # an input_activations key with details about how the activations are + # quantized, a weights key indicating how the weights are quantized, + # and a list of targets under the `targets` key, dictating which + # layers are impacted by the quantization details. The quantization + # details follow the structure defined by the QuantizationArgs + # pydantic model, which is used to verify the structure of the + # quant_config and also store the details for later use. + + config_groups = config.get("config_groups", dict()) + for _, quant_config in config_groups.items(): + targets = quant_config.get("targets") + for target in targets: + target_scheme_map[target] = {} + target_scheme_map[target][ + "weights"] = QuantizationArgs.model_validate( + quant_config.get("weights")) + + target_scheme_map[target]["input_activations"] = None + if is_activation_quantization_format(quant_format): + input_activations = quant_config.get("input_activations") + # The only case where we have activation quant supported + # but no input_activations provided in the config + # should be w8a16fp8 w8a16fp8 can also run for cases where + # there is an input_quant but it is ignored + if not input_activations: + assert target_scheme_map[target][ + "weights"].type == QuantizationType.FLOAT + else: + target_scheme_map[target][ + "input_activations"] = QuantizationArgs.model_validate( # noqa: E501 + quant_config.get("input_activations")) + return target_scheme_map + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + def _check_scheme_supported(self, + min_capability: int, + error: bool = True, + match_exact: bool = False) -> bool: + capability_tuple = current_platform.get_device_capability() + + if capability_tuple is not None: + capability = capability_tuple.to_int() + if match_exact: + supported = capability == min_capability + if error and not supported: + raise RuntimeError( + "Quantization scheme is not supported for ", + "the current GPU. Required capability: ", + f"{min_capability}. Current capability: {capability}.") + else: + supported = capability >= min_capability + if error and not supported: + raise RuntimeError( + "Quantization scheme is not supported for ", + f"the current GPU. Min capability: {min_capability}. ", + f"Current capability: {capability}.") + return supported + else: + return False + + def _is_static_tensor_w8a8(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 + weight_strategy = ( + weight_quant.strategy == QuantizationStrategy.CHANNEL.value + or weight_quant.strategy == QuantizationStrategy.GROUP.value) + is_tensor = (weight_strategy and input_quant.strategy + == QuantizationStrategy.TENSOR.value) + is_static = not weight_quant.dynamic and not input_quant.dynamic + + # Both symmetric and asymmetric input quantization supported. + # Only symmetric weight quantization supported. + return is_8_bits and is_tensor and weight_quant.symmetric and is_static + + def _is_dynamic_token_w8a8(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 + weight_strategy = ( + weight_quant.strategy == QuantizationStrategy.CHANNEL.value + or weight_quant.strategy == QuantizationStrategy.GROUP.value) + is_token = (weight_strategy and input_quant.strategy + == QuantizationStrategy.TOKEN.value) + is_dynamic = not weight_quant.dynamic and input_quant.dynamic + + # Both symmetric and asymmetric input quantization supported. + # Only symmetric weight quantization supported. + return is_8_bits and is_token and weight_quant.symmetric and is_dynamic + + def _is_fp8_w8a8(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + # Confirm weights and activations quantized. + if weight_quant is None or input_quant is None: + return False + + # Confirm weight scheme is supported. + is_floating_point = (weight_quant.type == QuantizationType.FLOAT + and input_quant.type == QuantizationType.FLOAT) + is_symmetric_weight = weight_quant.symmetric + is_static_weight = not weight_quant.dynamic + is_per_tensor_or_channel_weight = (weight_quant.strategy in [ + QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL + ]) + if not (is_floating_point and is_symmetric_weight and is_static_weight + and is_per_tensor_or_channel_weight): + return False + + # Dynamic quantization is always supported if weights supported. + if input_quant.dynamic: + return True + + # Confirm activation scheme is supported. + is_symmetric_activation = input_quant.symmetric + is_per_tensor_activation = ( + input_quant.strategy == QuantizationStrategy.TENSOR) + return is_symmetric_activation and is_per_tensor_activation + + def _is_fp8_w8a8_sm90(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + return (self._check_scheme_supported(90, error=False, match_exact=True) + and self._is_fp8_w8a8(weight_quant, input_quant)) + + def _is_fp8_w8a16(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + # Confirm weights quantized. + if weight_quant is None: + return False + + # Confirm we have floating points. + if weight_quant.type != QuantizationType.FLOAT: + return False + + # Confirm weight scheme is supported. + is_symmetric_weight = weight_quant.symmetric + is_static_weight = not weight_quant.dynamic + is_per_tensor_or_channel_weight = (weight_quant.strategy in [ + QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL + ]) + if not (is_symmetric_weight and is_static_weight # noqa: SIM103 + and is_per_tensor_or_channel_weight): + return False + + # All conditions satisfied. + return True + + def _is_wNa16_group_channel(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + input_quant_none = input_quant is None + is_symmetric = weight_quant.symmetric + is_channel_group = ( + weight_quant.strategy == QuantizationStrategy.CHANNEL.value + or weight_quant.strategy == QuantizationStrategy.GROUP.value) + is_static = not weight_quant.dynamic + + return (is_channel_group and input_quant_none and is_symmetric + and is_static) + + def _get_scheme_from_parts( + self, weight_quant: BaseModel, + input_quant: BaseModel) -> "CompressedTensorsScheme": + + # Detect If Mixed Precision + if self._is_wNa16_group_channel(weight_quant, input_quant): + if (self.quant_format == CompressionFormat.marlin_24.value + and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS): + return CompressedTensorsW4A16Sparse24( + strategy=weight_quant.strategy, + num_bits=weight_quant.num_bits, + group_size=weight_quant.group_size) + if (self.quant_format == CompressionFormat.pack_quantized.value + and weight_quant.num_bits in WNA16_SUPPORTED_BITS): + return CompressedTensorsWNA16( + num_bits=weight_quant.num_bits, + strategy=weight_quant.strategy, + group_size=weight_quant.group_size, + actorder=weight_quant.actorder) + + if is_activation_quantization_format(self.quant_format): + if self._is_fp8_w8a8(weight_quant, input_quant): + is_fp8_w8a8_supported = self._check_scheme_supported( + CompressedTensorsW8A8Fp8.get_min_capability(), error=False) + if is_fp8_w8a8_supported: + return CompressedTensorsW8A8Fp8( + strategy=weight_quant.strategy, + is_static_input_scheme=(input_quant + and not input_quant.dynamic)) + else: + # note: input_quant will be present for converted models; + # will be ignored during inference post loading + return CompressedTensorsW8A16Fp8( + strategy=weight_quant.strategy, + is_static_input_scheme=not input_quant.dynamic) + + # note: input_quant can be None + if self._is_fp8_w8a16(weight_quant, input_quant): + is_static_input_scheme = (input_quant + and not input_quant.dynamic) + return CompressedTensorsW8A16Fp8( + strategy=weight_quant.strategy, + is_static_input_scheme=is_static_input_scheme) + + if self._is_static_tensor_w8a8(weight_quant, input_quant): + return CompressedTensorsW8A8Int8( + strategy=weight_quant.strategy, + is_static_input_scheme=True, + input_symmetric=input_quant.symmetric) + + if self._is_dynamic_token_w8a8(weight_quant, input_quant): + return CompressedTensorsW8A8Int8( + strategy=weight_quant.strategy, + is_static_input_scheme=False, + input_symmetric=input_quant.symmetric) + + raise NotImplementedError( + "No compressed-tensors compatible scheme was found.") + + def get_scheme(self, + layer: torch.nn.Module, + layer_name: Optional[str] = None + ) -> Optional["CompressedTensorsScheme"]: + """ + compressed-tensors supports non uniform in the following way: + + targets of config_groups: There can be N config_groups which each + have a quantization scheme. Each config_group has a list of targets + which can be a full layer_name, a regex for a layer_name, or + an nn.Module name. + + Detect whether a layer_name is found in any target and + use the quantization scheme corresponding to the matched target + to select the CompressedTensorsScheme used for infernece. + """ + + # Find the "target" in the compressed-tensors config + # that our layer conforms to. + # TODO (@robertgshaw): add compressed-tensors as dep + # so we do not have to re-write these functions + # need to make accelerate optional in ct to do this + + # Will be empty for models with only sparsity + weight_quant = input_quant = None + if self.target_scheme_map: + matched_target = find_matched_target( + layer_name=layer_name, + module=layer, + targets=self.target_scheme_map.keys(), + fused_mapping=self.packed_modules_mapping) + + scheme_dict = self.target_scheme_map[matched_target] + weight_quant = scheme_dict.get("weights") + input_quant = scheme_dict.get("input_activations") + + # Find the sparsity scheme of the layer + # assume that fused layers inerhit first component's sparsity scheme + sparsity_targets = (self.sparsity_scheme_map.keys() - + set(self.sparsity_ignore_list)) + sparsity_scheme: Optional[SparsityCompressionConfig] = None + with suppress(ValueError): + matched_target = find_matched_target( + layer_name=layer_name, + module=layer, + targets=sparsity_targets, + fused_mapping=self.packed_modules_mapping) + sparsity_scheme = self.sparsity_scheme_map[matched_target] + + if self.supports_cutlass_24(weight_quant=weight_quant, + input_quant=input_quant, + sparsity_scheme=sparsity_scheme): + # Have a valid sparsity scheme + # Validate layer is supported by Cutlass 2:4 Kernel + model_compression_config = (None if sparsity_scheme is None + or sparsity_scheme.format == "dense" + else self.config) + + scheme = CompressedTensors24( + quantized=weight_quant is not None or input_quant is not None, + weight_quant=weight_quant, + input_quant=input_quant, + model_compression_config=model_compression_config, + ) + elif weight_quant is None: + logger.warning_once("Acceleration for non-quantized schemes is " + "not supported by Compressed Tensors. " + "Falling back to UnquantizedLinearMethod") + return None + + else: + # Find the quant_scheme + scheme = self._get_scheme_from_parts( # type: ignore + weight_quant=weight_quant, + input_quant=input_quant, + ) + + # Raise error if device does not support the scheme + # (e.g. fp8 needs ada lovelace) + self._check_scheme_supported(scheme.get_min_capability()) + logger.debug("Using scheme: %s for %s", scheme.__class__.__name__, + layer_name) + return scheme + + def get_cache_scale(self, name: str) -> Optional[str]: + """ + Check whether the param name matches the format for k/v cache scales + in compressed-tensors. If this is the case, return its equivalent + param name expected by vLLM + + :param name: param name + :return: matching param name for KV cache scale in vLLM + """ + if name.endswith(".output_scale") and ".k_proj" in name: + return name.replace(".k_proj.output_scale", ".attn.k_scale") + if name.endswith(".output_scale") and ".v_proj" in name: + return name.replace(".v_proj.output_scale", ".attn.v_scale") + # If no matches, return None + return None + + @staticmethod + def supports_cutlass_24( + weight_quant: Optional[QuantizationArgs], + input_quant: Optional[QuantizationArgs], + sparsity_scheme: Optional[SparsityCompressionConfig] = None + ) -> bool: + """ + Check if the layer is supported by the Cutlass 2:4 Kernel + Conditions: + - Overarching condition: Sparsity Structure is 2:4 + - Unquantized cases are supported + - Weight only quantization is not-supported + - Supported weight quantization strategies are TENSOR and CHANNEL + - Supported input quantization strategies are TENSOR and TOKEN + - Only 8 bit quantization is supported + + :return: True if the layer is supported by the Cutlass 2:4 Kernel + False otherwise + """ + if sparsity_scheme is None: + return False + + is_valid_sparsity_structure: bool = ( + sparsity_scheme.sparsity_structure == + SparsityStructure.TWO_FOUR.value) + + valid_compressors = { + CompressionFormat.dense.value, + CompressionFormat.sparse_24_bitmask.value + } + + is_valid_sparsity = (is_valid_sparsity_structure + and sparsity_scheme.format in valid_compressors) + + if not is_valid_sparsity: + return False + + # Unquantized cases are supported + if weight_quant is None and input_quant is None: + return True + + # Weight only quantization is not-supported + if weight_quant is not None and input_quant is None: + return False + + supported_weight_quant_strategies = [ + QuantizationStrategy.TENSOR.value, + QuantizationStrategy.CHANNEL.value + ] + + assert weight_quant is not None + assert input_quant is not None + if weight_quant.strategy not in supported_weight_quant_strategies: + return False + + supported_input_quant_strategies = [ + QuantizationStrategy.TENSOR.value, QuantizationStrategy.TOKEN.value + ] + + if input_quant.strategy not in supported_input_quant_strategies: + return False + + return weight_quant.num_bits == input_quant.num_bits == 8 + + +class CompressedTensorsLinearMethod(LinearMethodBase): + + def __init__(self, quantization_config: CompressedTensorsConfig): + self.quantization_config = quantization_config + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.scheme.process_weights_after_loading(layer) + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + """ + Use the CompressedTensorsScheme associated with each layer to create + the necessary parameters for the layer. See LinearMethodBase for param + details + """ + weight_loader = extra_weight_attrs.get("weight_loader") + layer.scheme.create_weights( + layer=layer, + input_size=input_size, + input_size_per_partition=input_size_per_partition, + output_partition_sizes=output_partition_sizes, + output_size=output_size, + params_dtype=params_dtype, + weight_loader=weight_loader) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None): + """ + Use the output of create_weights and the CompressedTensorsScheme + associated with the layer to apply the forward pass with the + layer input. See LinearMethodBase for param details + + """ + + scheme = layer.scheme + if scheme is None: + raise ValueError("A scheme must be defined for each layer") + return scheme.apply_weights(layer, x, bias=bias) + + +class CompressedTensorsKVCacheMethod(BaseKVCacheMethod): + """ + Supports loading kv-cache scaling factors from compressed-tensors + checkpoints. + """ + + def __init__(self, quant_config: CompressedTensorsConfig): + self.validate_kv_cache_scheme(quant_config.kv_cache_scheme) + super().__init__(quant_config) + + @staticmethod + def validate_kv_cache_scheme(kv_cache_scheme: Optional[Dict[str, Any]]): + """ + Validator for the kv cache scheme. Useful for controlling the + kv cache quantization schemes, that are being supported in vLLM + :param kv_cache_scheme: the compressed-tensors kv cache scheme + """ + if kv_cache_scheme is None: + return + + type_ = kv_cache_scheme.get("type") + num_bits = kv_cache_scheme.get("num_bits") + + if type_ != "float" and num_bits != 8: + raise NotImplementedError( + "Currently supported kv cache quantization is " + "num_bits=8, type=float, however " + f"received num_bits={num_bits}, type={type_}") + + strategy = kv_cache_scheme.get("strategy") + if strategy != "tensor": + raise NotImplementedError( + "Only support per-tensor scaling factor " + "for compressed-tensors KV cache. " + f"Expected strategy: tensor, found strategy: {strategy}") + + is_symmetric = kv_cache_scheme.get("symmetric") + if not is_symmetric: + raise NotImplementedError( + "Only support symmetric scaling factor " + "for compressed-tensors KV cache. " + f"However found symmetric: {is_symmetric}") diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py new file mode 100644 index 0000000..bb50862 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -0,0 +1,1461 @@ +# SPDX-License-Identifier: Apache-2.0 + +import enum +from enum import Enum +from typing import Callable, List, Optional + +import torch +from compressed_tensors import CompressionFormat +from compressed_tensors.quantization import QuantizationStrategy + +import vllm.model_executor.layers.fused_moe # noqa +from vllm import _custom_ops as ops +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, + FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + WNA16_SUPPORTED_BITS) +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform + +logger = init_logger(__name__) +import ixformer.inference.functions as ixfops +import vllm.envs as envs + + +class GPTQMarlinState(Enum): + REPACK = enum.auto() + READY = enum.auto() + + +__all__ = [ + "CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod", + "CompressedTensorsW8A8Fp8MoECutlassMethod", + "CompressedTensorsWNA16MoEMethod", + "CompressedTensorsW8A8Int8MoEMethod" +] + + +class CompressedTensorsMoEMethod(FusedMoEMethodBase): + + @staticmethod + def get_moe_method( + quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + activation: str, + expert_map: Optional[torch.Tensor], + ) -> "CompressedTensorsMoEMethod": + # TODO: @dsikka: refactor this to use schemes as other kernels + # are supported + check if the layer is being ignored. + weight_quant = quant_config.target_scheme_map["Linear"].get("weights") + input_quant = quant_config.target_scheme_map["Linear"].get( + "input_activations") + + if quant_config._is_wNa16_group_channel(weight_quant, input_quant): + return CompressedTensorsWNA16MoEMethod(quant_config) + elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) + and activation == "silu" and expert_map is None): + return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config) + elif quant_config._is_fp8_w8a8(weight_quant, input_quant): + return CompressedTensorsW8A8Fp8MoEMethod(quant_config) + elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant) or quant_config._is_static_tensor_w8a8(weight_quant, input_quant): + if envs.VLLM_W8A8_MOE_USE_W4A8: + return CompressedTensorsW4A8MoEMethod(quant_config) + else: + return CompressedTensorsW8A8Int8MoEMethod(quant_config) + else: + raise RuntimeError( + f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}") + + +class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): + + def __init__( + self, + quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + ): + self.quant_config = quant_config + self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( + "weights") + self.input_quant = self.quant_config.target_scheme_map["Linear"].get( + "input_activations") + + if not (self.weight_quant.strategy == QuantizationStrategy.TENSOR + and self.input_quant.strategy == QuantizationStrategy.TENSOR): + raise ValueError( + "For FP8 Fused MoE layers, only per-tensor scales " + "for weights and activations are supported. Found " + f"{self.weight_quant}, {self.input_quant}") + + self.static_input_scales = not self.input_quant.dynamic + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + params_dtype = torch.float8_e4m3fn + + # WEIGHTS + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, + 2, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + + w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.static_input_scales: + w13_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Fp8 moe kernels require a single activation scale. + # We take the max of all the scales in case they differ. + if self.static_input_scales: + if (layer.w13_input_scale is None or layer.w2_input_scale is None): + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None.") + if (not all_close_1d(layer.w13_input_scale) + or not all_close_1d(layer.w2_input_scale)): + logger.warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer.") + layer.w13_input_scale = torch.nn.Parameter( + layer.w13_input_scale.max(), requires_grad=False) + layer.w2_input_scale = torch.nn.Parameter( + layer.w2_input_scale.max(), requires_grad=False) + + if current_platform.is_fp8_fnuz(): + # Normalize the weights and scales + w13_weight, w13_weight_scale, w13_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w13_weight, layer.w13_weight_scale, + layer.w13_input_scale) + w2_weight, w2_weight_scale, w2_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w2_weight, layer.w2_weight_scale, + layer.w2_input_scale) + # Reset the parameter + layer.w13_weight = torch.nn.Parameter(w13_weight, + requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale, + requires_grad=False) + if w13_input_scale is not None: + layer.w13_input_scale = torch.nn.Parameter(w13_input_scale, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, + requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, + requires_grad=False) + if w2_input_scale is not None: + layer.w2_input_scale = torch.nn.Parameter(w2_input_scale, + requires_grad=False) + + # Fp8 moe kernel needs single weight scale for w13 per expert. + # We take the max then dequant and requant each expert. + assert layer.w13_weight_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + for expert_id in range(layer.local_num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start:start + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id]) + layer.w13_weight[expert_id][ + start:start + shard_size, :], _ = ops.scaled_fp8_quant( + dq_weight, max_w13_scales[expert_id]) + start += shard_size + + layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, + requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + extra_residual: torch.Tensor = None, + routed_scaling_factor: float = 1.0, + ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe import fused_experts + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + use_fp8_w8a8=True, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale) + + +class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod): + + def __init__( + self, + quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + ): + self.quant_config = quant_config + self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( + "weights") + self.input_quant = self.quant_config.target_scheme_map["Linear"].get( + "input_activations") + + per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR + and self.input_quant.strategy + == QuantizationStrategy.TENSOR) + per_channel = ( + self.weight_quant.strategy == QuantizationStrategy.CHANNEL + and self.input_quant.strategy == QuantizationStrategy.TOKEN) + if not (per_tensor or per_channel): + raise ValueError( + "For FP8 Fused MoE layers, we require per tensor " + "or channelwise, dynamic per token quantization. Found " + f"{self.weight_quant}, {self.input_quant}") + + self.static_input_scales = not self.input_quant.dynamic + if self.static_input_scales and per_channel: + raise ValueError( + "For FP8 Fused MoE layer, we require either per tensor or " + "channelwise, dynamic per token quantization.") + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + params_dtype = torch.float8_e4m3fn + + # WEIGHTS + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + if self.weight_quant.strategy == QuantizationStrategy.TENSOR: + # Allocate 2 scales for w1 and w3 respectively. + # They are combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, 2, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-TENSOR quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL: + w13_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-CHANNEL quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.static_input_scales: + w13_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + + device = w13_weight.device + # TODO strides can be shared across multiple layers + self.ab_strides1 = torch.full((num_experts, ), + hidden_size, + device=device, + dtype=torch.int64) + self.c_strides1 = torch.full((num_experts, ), + 2 * intermediate_size_per_partition, + device=device, + dtype=torch.int64) + self.ab_strides2 = torch.full((num_experts, ), + intermediate_size_per_partition, + device=device, + dtype=torch.int64) + self.c_strides2 = torch.full((num_experts, ), + hidden_size, + device=device, + dtype=torch.int64) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Fp8 moe kernels require a single activation scale. + # We take the max of all the scales in case they differ. + if self.static_input_scales: + assert self.input_quant.strategy == QuantizationStrategy.TENSOR + if (layer.w13_input_scale is None or layer.w2_input_scale is None): + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None.") + if (not all_close_1d(layer.w13_input_scale) + or not all_close_1d(layer.w2_input_scale)): + logger.warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer.") + layer.w13_input_scale = torch.nn.Parameter( + layer.w13_input_scale.max(), requires_grad=False) + layer.w2_input_scale = torch.nn.Parameter( + layer.w2_input_scale.max(), requires_grad=False) + + # For Per-TENSOR case, Fp8 moe kernel needs single weight scale + # for w13 per expert. Use max then dequant and requant each expert. + if self.weight_quant.strategy == QuantizationStrategy.TENSOR: + assert layer.w13_weight_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + for expert_id in range(layer.local_num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start:start + + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id]) + layer.w13_weight[expert_id][ + start:start + shard_size, :], _ = ops.scaled_fp8_quant( + dq_weight, max_w13_scales[expert_id]) + start += shard_size + layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, + requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + extra_residual: torch.Tensor = None, + routed_scaling_factor: float = 1.0, + ) -> torch.Tensor: + + assert activation == "silu" + assert global_num_experts == layer.w13_weight.shape[0] + assert expert_map is None + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + from vllm.model_executor.layers.fused_moe import cutlass_moe_fp8 + + return cutlass_moe_fp8( + x, + layer.w13_weight.transpose(1, 2), + layer.w2_weight.transpose(1, 2), + layer.w13_weight_scale, + layer.w2_weight_scale, + topk_weights, + topk_ids, + self.ab_strides1, + self.c_strides1, + self.ab_strides2, + self.c_strides2, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + out_dtype=x.dtype, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + + +class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): + + def __init__( + self, + quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + ): + self.quant_config = quant_config + self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( + "weights") + self.input_quant = self.quant_config.target_scheme_map["Linear"].get( + "input_activations") + + if not (self.weight_quant.strategy == QuantizationStrategy.CHANNEL + and self.input_quant.strategy == QuantizationStrategy.TOKEN): + raise ValueError( + "For INT8 Fused MoE layers, only per-channel scales" + "for weights and per-token scales for activations are supported. Found " + f"{self.weight_quant}, {self.input_quant}") + + self.static_input_scales = not self.input_quant.dynamic + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + params_dtype = torch.int8 + + # WEIGHTS + w13_weight = torch.nn.Parameter(torch.empty(num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty(num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + + w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, + hidden_size, + 1, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.static_input_scales: + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) + w13_input_scale = torch.nn.Parameter(torch.ones( + num_experts, hidden_size, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter(torch.ones( + num_experts, intermediate_size_per_partition, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + extra_residual: torch.Tensor = None, + routed_scaling_factor: float = 1.0, + ) -> torch.Tensor: + use_ep = expert_map is not None + if use_ep: + start_eid = layer.ep_rank * layer.local_num_experts + end_eid = min((layer.ep_rank + 1) * layer.local_num_experts, global_num_experts) + topk_weight, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + dtype = x.dtype + num_tokens, num_experts = router_logits.shape + + if use_ep: + hidden_size = x.shape[1] + ( + src_to_dst, + sorted_token_ids, + expert_sizes_gpu, + expert_sizes_cpu, + expand_tokens, + ) = ixfops.moe_compute_token_index_ep( + topk_ids=topk_ids, + num_experts=num_experts, + start_expert_id=start_eid, + end_expert_id=end_eid, + ) + if expert_sizes_cpu.sum() == 0: + return torch.zeros( + (num_tokens, hidden_size), + device=x.device, + dtype=x.dtype, + ) + else: + expand_tokens = num_tokens * top_k + ( + src_to_dst, + sorted_token_ids, + expert_sizes_gpu, + expert_sizes_cpu, + ) = ixfops.moe_compute_token_index( + topk_ids=topk_ids, + num_experts=num_experts, + ) + expert_sizes_cpu = expert_sizes_gpu.cpu() + + # expand + reorder + quant + i8_hidden_states, a_scale = ixfops.moe_expand_input_dynamic_scaled_int8( + hidden_states=x, + dst_to_src=sorted_token_ids, + dst_tokens=expand_tokens, + topk=top_k, + src_to_dst=src_to_dst, + topk_ids=None, + smooth_scales=layer.w13_input_scale, + ) + + # w8a8 group gemm 1 + pt_output_1 = ixfops.moe_w8a8_group_gemm( + input=i8_hidden_states, + weight=layer.w13_weight, + i_scales=a_scale, + w_scales=layer.w13_weight_scale, + output_dtype=dtype, + tokens_per_experts=expert_sizes_cpu, + dst_to_src=None, + format="TN", + ) + + # act + quant + pt_output_2, a2_scale = ixfops.activation_dynamic_scaled_int8( + input=pt_output_1, + bias=None, + smooth_scales=layer.w2_input_scale, + dst_to_src=sorted_token_ids, + topk_ids=None, + act_type="swiglu", + ) + + # w8a8 group gemm 2 + reorder + if use_ep: + pt_output_3 = torch.empty( + (num_tokens * top_k, hidden_size), + device=x.device, + dtype=x.dtype, + ) + + ixfops.moe_w8a8_group_gemm( + input=pt_output_2, + weight=layer.w2_weight, + i_scales=a2_scale, + w_scales=layer.w2_weight_scale, + output_dtype=dtype, + tokens_per_experts=expert_sizes_cpu, + dst_to_src=sorted_token_ids, + format="TN", + output=pt_output_3, + ) + + reduce_mask = src_to_dst == -1 + final_hidden_states = ixfops.moe_output_reduce_sum( + input=pt_output_3.view(num_tokens, top_k, -1), + topk_weight=topk_weight, + extra_residual=extra_residual, + scaling_factor=routed_scaling_factor, + mask=reduce_mask, + ) + else: + pt_output_3 = ixfops.moe_w8a8_group_gemm( + input=pt_output_2, + weight=layer.w2_weight, + i_scales=a2_scale, + w_scales=layer.w2_weight_scale, + output_dtype=dtype, + tokens_per_experts=expert_sizes_cpu, + dst_to_src=sorted_token_ids, + format="TN", + ) + + # mul + reduce_sum + final_hidden_states = ixfops.moe_output_reduce_sum( + input=pt_output_3.view(num_tokens, top_k, -1), + topk_weight=topk_weight, + extra_residual=extra_residual, + scaling_factor=routed_scaling_factor + ) + return final_hidden_states + + +class CompressedTensorsW4A8MoEMethod(CompressedTensorsMoEMethod): + + def __init__( + self, + quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + ): + self.quant_config = quant_config + self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( + "weights") + self.input_quant = self.quant_config.target_scheme_map["Linear"].get( + "input_activations") + self.pack_factor = 2 + self.group_size = -1 if self.weight_quant.group_size is None else self.weight_quant.group_size + self.weight_symmetric = self.weight_quant.symmetric + self.gemm_format = envs.VLLM_W4A8_FORMAT + self.format_mapping = {"NN":0,"NT":1,"TN":2} + self.version = envs.VLLM_W4A8_VERSION + assert self.gemm_format in ["TN","NN"] + + if not ((self.weight_quant.strategy == QuantizationStrategy.CHANNEL + or self.weight_quant.strategy == QuantizationStrategy.GROUP) + and self.input_quant.strategy == QuantizationStrategy.TOKEN): + raise ValueError( + "For INT4 pack2 Fused MoE layers, only per-channel or group scales" + "for weights and per-token scales for activations are supported. Found " + f"{self.weight_quant}, {self.input_quant}") + + self.static_input_scales = not self.input_quant.dynamic + + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + params_dtype = torch.int8 + if self.gemm_format == "TN": + w13_shape = (num_experts, 2 * intermediate_size_per_partition, hidden_size // self.pack_factor) + w2_shape = (num_experts, hidden_size, intermediate_size_per_partition // self.pack_factor) + else: + w13_shape = (num_experts, hidden_size, 2 * intermediate_size_per_partition // self.pack_factor) + w2_shape = (num_experts, intermediate_size_per_partition, hidden_size // self.pack_factor) + + # WEIGHTS + # use process_weights_after_loading to get get right layout if gemm_format is NN + w13_weight = torch.nn.Parameter(torch.empty(w13_shape, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + if self.gemm_format == "NN": + setattr(w13_weight, "shard_dim", 1) + + w2_weight = torch.nn.Parameter(torch.empty(w2_shape, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + if self.gemm_format == "NN": + setattr(w2_weight, "shard_dim", 0) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + # The following scale or zero will use permute(0,2,1) to get right layout, init here to avoid rewrite data_loader + w13_weight_scale = torch.nn.Parameter(torch.empty(num_experts, + 1 if self.version == 2 else 1 if self.group_size == -1 else hidden_size // self.group_size, + 2 * intermediate_size_per_partition, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + setattr(w13_weight_scale, "shard_dim", 1) + + w2_weight_scale = torch.nn.Parameter(torch.empty(num_experts, + 1 if self.version == 2 else 1 if self.group_size == -1 else intermediate_size_per_partition // self.group_size, + hidden_size, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + setattr(w2_weight_scale, "shard_dim", 0) + # setattr(w2_weight_scale, "load_full_w2", True) + + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value if self.version == 2 or self.group_size == -1 else FusedMoeWeightScaleSupported.GROUP.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + if self.version == 2: + # INT8 -> INT4 weight scales/zeros + if self.group_size != -1: + w13_i8_weight_scale = torch.nn.Parameter(torch.empty(num_experts, + hidden_size // self.group_size, + 2 * intermediate_size_per_partition, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_i8_weight_scale", w13_i8_weight_scale) + setattr(w13_i8_weight_scale, "shard_dim", 1) + if not self.weight_symmetric: + w13_i8_weight_zero = torch.nn.Parameter(torch.empty(num_experts, + 1 if self.group_size == -1 else hidden_size // self.group_size, + 2 * intermediate_size_per_partition, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_i8_weight_zero", w13_i8_weight_zero) + setattr(w13_i8_weight_zero, "shard_dim", 1) + + if self.group_size != -1: + w2_i8_weight_scale = torch.nn.Parameter(torch.empty(num_experts, + intermediate_size_per_partition // self.group_size, + hidden_size, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_i8_weight_scale", w2_i8_weight_scale) + setattr(w2_i8_weight_scale, "shard_dim", 0) + if not self.weight_symmetric: + w2_i8_weight_zero = torch.nn.Parameter(torch.empty(num_experts, + 1 if self.group_size == -1 else intermediate_size_per_partition // self.group_size, + hidden_size, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_i8_weight_zero", w2_i8_weight_zero) + setattr(w2_i8_weight_zero, "shard_dim", 0) + + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value if self.group_size == -1 else FusedMoeWeightScaleSupported.GROUP.value}) + + if self.version == 2 and self.group_size != -1: + set_weight_attrs(w13_i8_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_i8_weight_scale, extra_weight_attrs) + else: + setattr(layer, "w13_i8_weight_scale", None) + setattr(layer, "w2_i8_weight_scale", None) + if self.version == 2 and not self.weight_symmetric: + set_weight_attrs(w13_i8_weight_zero, extra_weight_attrs) + set_weight_attrs(w2_i8_weight_zero, extra_weight_attrs) + else: + setattr(layer, "w13_i8_weight_zero", None) + setattr(layer, "w2_i8_weight_zero", None) + + # DO NOT SUPPORT INPUT_SCALES + if self.static_input_scales: + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) + w13_input_scale = torch.nn.Parameter(torch.ones( + num_experts, hidden_size, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter(torch.ones( + num_experts, intermediate_size_per_partition, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + + self.gemm_format = self.format_mapping[self.gemm_format] + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + extra_residual: torch.Tensor = None, + routed_scaling_factor: float = 1.0, + ) -> torch.Tensor: + attn_metadata = get_forward_context().attn_metadata + use_ep = expert_map is not None + # unsupported ep now + only_decode = use_ep == False and attn_metadata.decode_metadata is not None and attn_metadata.prefill_metadata is None + if use_ep: + start_eid = layer.ep_rank * layer.local_num_experts + end_eid = min((layer.ep_rank + 1) * layer.local_num_experts, global_num_experts) + topk_weight, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + dtype = x.dtype + num_tokens, num_experts = router_logits.shape + if use_ep: + hidden_size = x.shape[1] + ( + src_to_dst, + sorted_token_ids, + expert_sizes_gpu, + expert_sizes_cpu, + expand_tokens, + ) = ixfops.moe_compute_token_index_ep( + topk_ids=topk_ids, + num_experts=num_experts, + start_expert_id=start_eid, + end_expert_id=end_eid, + ) + if expert_sizes_cpu.sum() == 0: + return torch.zeros( + (num_tokens, hidden_size), + device=x.device, + dtype=x.dtype, + ) + else: + expand_tokens = num_tokens * top_k + ( + src_to_dst, + sorted_token_ids, + expert_sizes_gpu, + expert_sizes_cpu, + ) = ixfops.moe_compute_token_index( + topk_ids=topk_ids, + num_experts=num_experts, + ) + + if only_decode: + i8_hidden_states, a_scale = ixfops.moe_expand_input_dynamic_scaled_int8( + hidden_states=x, + dst_to_src=sorted_token_ids, + dst_tokens=expand_tokens, + topk=top_k, + src_to_dst=src_to_dst, + topk_ids=None, + smooth_scales=layer.w13_input_scale, + output_format = 1, + ) + + pt_output_1 = ixfops.moe_w4a8_group_gemv( + input=i8_hidden_states, + weight=layer.w13_weight, + i_scales=a_scale, + w_scales=layer.w13_weight_scale, + output_dtype=dtype, + tokens_per_experts=expert_sizes_gpu, + w_i8scales=layer.w13_i8_weight_scale, + w_i8zeros=layer.w13_i8_weight_zero, + dst_to_src=None, + format=self.gemm_format, + group_size=self.group_size, + ) + + pt_output_2, a2_scale = ixfops.activation_dynamic_scaled_int8( + input=pt_output_1, + bias=None, + smooth_scales=layer.w2_input_scale, + dst_to_src=sorted_token_ids, + topk_ids=None, + act_type="swiglu", + output_format = 1, + ) + + pt_output_3 = ixfops.moe_w4a8_group_gemv( + input=pt_output_2, + weight=layer.w2_weight, + i_scales=a2_scale, + w_scales=layer.w2_weight_scale, + output_dtype=dtype, + tokens_per_experts=expert_sizes_gpu, + w_i8scales=layer.w2_i8_weight_scale, + w_i8zeros=layer.w2_i8_weight_zero, + dst_to_src=sorted_token_ids, + format=self.gemm_format, + group_size=self.group_size, + ) + else: + expert_sizes_cpu = expert_sizes_gpu.cpu() + # expand + reorder + quant + i8_hidden_states, a_scale = ixfops.moe_expand_input_dynamic_scaled_int8( + hidden_states=x, + dst_to_src=sorted_token_ids, + dst_tokens=expand_tokens, + topk=top_k, + src_to_dst=src_to_dst, + topk_ids=None, + smooth_scales=layer.w13_input_scale, + ) + + # w4a8 group gemm 1 + pt_output_1 = ixfops.moe_w4a8_group_gemm( + input=i8_hidden_states, + weight=layer.w13_weight, + i_scales=a_scale, + w_scales=layer.w13_weight_scale, + output_dtype=dtype, + tokens_per_experts=expert_sizes_cpu, + w_i8scales=layer.w13_i8_weight_scale, + w_i8zeros=layer.w13_i8_weight_zero, + dst_to_src=None, + format=self.gemm_format, + group_size=self.group_size, + version=self.version + ) + + # act + quant + pt_output_2, a2_scale = ixfops.activation_dynamic_scaled_int8( + input=pt_output_1, + bias=None, + smooth_scales=layer.w2_input_scale, + dst_to_src=sorted_token_ids, + topk_ids=None, + act_type="swiglu", + ) + + # w4a8 group gemm 2 + reorder + if use_ep: + pt_output_3 = torch.empty( + (num_tokens * top_k, hidden_size), + device=x.device, + dtype=x.dtype, + ) + + ixfops.moe_w4a8_group_gemm( + input=pt_output_2, + weight=layer.w2_weight, + i_scales=a2_scale, + w_scales=layer.w2_weight_scale, + output_dtype=dtype, + tokens_per_experts=expert_sizes_cpu, + w_i8scales=layer.w2_i8_weight_scale, + w_i8zeros=layer.w2_i8_weight_zero, + dst_to_src=sorted_token_ids, + format=self.gemm_format, + group_size=self.group_size, + version=self.version, + output=pt_output_3, + ) + + reduce_mask = src_to_dst == -1 + final_hidden_states = ixfops.moe_output_reduce_sum( + input=pt_output_3.view(num_tokens, top_k, -1), + topk_weight=topk_weight, + extra_residual=extra_residual, + scaling_factor=routed_scaling_factor, + mask=reduce_mask, + ) + else: + pt_output_3 = ixfops.moe_w4a8_group_gemm( + input=pt_output_2, + weight=layer.w2_weight, + i_scales=a2_scale, + w_scales=layer.w2_weight_scale, + output_dtype=dtype, + tokens_per_experts=expert_sizes_cpu, + w_i8scales=layer.w2_i8_weight_scale, + w_i8zeros=layer.w2_i8_weight_zero, + dst_to_src=sorted_token_ids, + format=self.gemm_format, + group_size=self.group_size, + version=self.version + ) + + # mul + reduce_sum + final_hidden_states = ixfops.moe_output_reduce_sum( + input=pt_output_3.view(num_tokens, top_k, -1), + topk_weight=topk_weight, + extra_residual=extra_residual, + scaling_factor=routed_scaling_factor + ) + return final_hidden_states + + +class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): + + def __init__( + self, + quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + ): + self.quant_config = quant_config + # TODO: @dsikka: refactor this to use schemes as other kernels + # are supported + check if the layer is being ignored. + config = self.quant_config.target_scheme_map["Linear"].get("weights") + self.num_bits = config.num_bits + self.packed_factor = 32 // config.num_bits + self.strategy = config.strategy + self.group_size = config.group_size + self.actorder = config.actorder + assert config.symmetric, ( + "Only symmetric quantization is supported for MoE") + + if not (self.quant_config.quant_format + == CompressionFormat.pack_quantized.value + and self.num_bits in WNA16_SUPPORTED_BITS): + raise ValueError("For Fused MoE layers, only ", + f"{CompressionFormat.pack_quantized.value} ", + "is supported for the following bits: ", + f"{WNA16_SUPPORTED_BITS}") + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + assert params_dtype == torch.float16, ( + "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501 + ) + + intermediate_size_full = extra_weight_attrs.pop( + "intermediate_size_full") + + # Will transpose the loaded weight along the + # intermediate and hidden dim sizes. Will + # shard for TP along the transposed dims + extra_weight_attrs.update({ + "is_transposed": True, + "quant_method": self.strategy + }) + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size // self.packed_factor, + 2 * intermediate_size_per_partition, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_weight_packed", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + intermediate_size_per_partition // self.packed_factor, + hidden_size, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_weight_packed", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # In the case where we have actorder/g_idx, + # we do not partition the w2 scales + load_full_w2 = self.actorder and self.group_size != -1 + w2_scales_size = (intermediate_size_full + if load_full_w2 else intermediate_size_per_partition) + + self.is_k_full = (not self.actorder) or ( + intermediate_size_per_partition == intermediate_size_full) + + if self.strategy == "channel": + num_groups_w2 = num_groups_w13 = 1 + self.group_size = -1 + else: + num_groups_w2 = w2_scales_size // self.group_size + num_groups_w13 = hidden_size // self.group_size + + w13_scale = torch.nn.Parameter(torch.ones( + num_experts, + num_groups_w13, + 2 * intermediate_size_per_partition, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_scale) + set_weight_attrs(w13_scale, extra_weight_attrs) + + w2_scale = torch.nn.Parameter(torch.ones(num_experts, + num_groups_w2, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_scale) + set_weight_attrs(w2_scale, extra_weight_attrs) + set_weight_attrs(w2_scale, {"load_full_w2": load_full_w2}) + + w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2), + requires_grad=False) + layer.register_parameter("w2_weight_shape", w2_weight_shape) + set_weight_attrs(w2_weight_shape, extra_weight_attrs) + w13_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2), + requires_grad=False) + + layer.register_parameter("w13_weight_shape", w13_weight_shape) + set_weight_attrs(w13_weight_shape, extra_weight_attrs) + + w13_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_g_idx", w13_g_idx) + set_weight_attrs(w13_g_idx, extra_weight_attrs) + + w2_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_g_idx", w2_g_idx) + set_weight_attrs(w2_g_idx, extra_weight_attrs) + + w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) + + w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx_sort_indices", + w2_g_idx_sort_indices) + set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) + + layer.a13_scale = None + layer.a2_scale = None + layer.marlin_state = GPTQMarlinState.REPACK + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + + def replace_tensor(name, new_t): + # It is important to use resize_() here since it ensures + # the same buffer is reused + getattr(layer, name).resize_(new_t.shape) + getattr(layer, name).copy_(new_t) + del new_t + + def get_scale_perms(num_bits: int): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, + group_size: int, num_bits: int): + scale_perm, scale_perm_single = get_scale_perms(num_bits) + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, + scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + return s + + def marlin_moe_permute_scales(s: torch.Tensor, size_k: int, + size_n: int, group_size: int, + num_bits: int): + num_experts = s.shape[0] + output = torch.empty((num_experts, s.shape[1], s.shape[2]), + device=s.device, + dtype=s.dtype) + for e in range(num_experts): + output[e] = marlin_permute_scales(s[e], size_k, size_n, + group_size, num_bits) + return output + + size_k2 = layer.w2_weight_packed.shape[2] + size_k13 = layer.w13_weight_packed.shape[2] + + num_experts = layer.w13_weight_g_idx.shape[0] + device = layer.w13_weight_g_idx.device + + # when running models with grouped act order, + # resort to g_idx values provided in checkpoint + if self.actorder == "group": + w13_g_idx_sort_indices = torch.empty_like(layer.w13_weight_g_idx) + w2_g_idx_sort_indices = torch.empty_like(layer.w2_weight_g_idx) + w13_sorted_g_idx = torch.empty_like(layer.w13_weight_g_idx) + w2_sorted_g_idx = torch.empty_like(layer.w2_weight_g_idx) + + for e in range(num_experts): + w13_g_idx_sort_indices[e] = torch.argsort( + layer.w13_weight_g_idx[e]).to(torch.int32) + w2_g_idx_sort_indices[e] = torch.argsort( + layer.w2_weight_g_idx[e]).to(torch.int32) + w13_sorted_g_idx[e] = layer.w13_weight_g_idx[e][ + w13_g_idx_sort_indices[e]] + w2_sorted_g_idx[e] = layer.w2_weight_g_idx[e][ + w2_g_idx_sort_indices[e]] + + replace_parameter(layer, "w13_weight_g_idx", w13_sorted_g_idx) + replace_parameter(layer, "w2_weight_g_idx", w2_sorted_g_idx) + replace_parameter(layer, "w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + replace_parameter(layer, "w2_g_idx_sort_indices", + w2_g_idx_sort_indices) + + else: + layer.w13_weight_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_weight_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + + marlin_w13_qweight = ops.gptq_marlin_moe_repack( + layer.w13_weight_packed, + layer.w13_g_idx_sort_indices, + layer.w13_weight_packed.shape[1] * self.packed_factor, + layer.w13_weight_packed.shape[2], + self.num_bits, + ) + replace_tensor("w13_weight_packed", marlin_w13_qweight) + marlin_w2_qweight = ops.gptq_marlin_moe_repack( + layer.w2_weight_packed, + layer.w2_g_idx_sort_indices, + layer.w2_weight_packed.shape[1] * self.packed_factor, + layer.w2_weight_packed.shape[2], + self.num_bits, + ) + replace_tensor("w2_weight_packed", marlin_w2_qweight) + # Repack scales + marlin_w13_scales = marlin_moe_permute_scales( + layer.w13_weight_scale, + size_k13, + layer.w13_weight_scale.shape[2], + self.group_size, + self.num_bits, + ) + replace_tensor("w13_weight_scale", marlin_w13_scales) + marlin_w2_scales = marlin_moe_permute_scales( + layer.w2_weight_scale, + layer.w2_weight_scale.shape[1] * + (self.group_size if self.group_size != -1 else self.packed_factor), + size_k2, + self.group_size, + self.num_bits, + ) + replace_tensor("w2_weight_scale", marlin_w2_scales) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + extra_residual: torch.Tensor = None, + routed_scaling_factor: float = 1.0, + ) -> torch.Tensor: + assert activation == "silu", "Only SiLU activation is supported." + if expert_map is not None: + raise NotImplementedError( + "Expert Parallelism is not supported for " + "fused Marlin MoE method.") + if apply_router_weight_on_input: + raise NotImplementedError( + "Apply router weight on input is not supported for " + "fused Marlin MoE method.") + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + return torch.ops.vllm.fused_marlin_moe( + x, + layer.w13_weight_packed, + layer.w2_weight_packed, + layer.w13_weight_scale, + layer.w2_weight_scale, + router_logits, + topk_weights, + topk_ids, + g_idx1=layer.w13_weight_g_idx, + g_idx2=layer.w2_weight_g_idx, + sort_indices1=layer.w13_g_idx_sort_indices, + sort_indices2=layer.w2_g_idx_sort_indices, + num_bits=self.num_bits, + is_k_full=self.is_k_full) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py new file mode 100644 index 0000000..b26c74f --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 + +from .compressed_tensors_scheme import CompressedTensorsScheme +from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS, + CompressedTensorsW4A16Sparse24) +from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8 +from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8 +from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8 +from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS, + CompressedTensorsWNA16) + +from .compressed_tensors_24 import CompressedTensors24 # isort: skip + +__all__ = [ + "CompressedTensorsScheme", "CompressedTensorsWNA16", + "CompressedTensorsW8A16Fp8", "CompressedTensorsW4A16Sparse24", + "CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8", + "WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS", + "CompressedTensors24" +] diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py new file mode 100644 index 0000000..ec805c9 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -0,0 +1,357 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +from compressed_tensors import CompressionFormat, ModelCompressor +from compressed_tensors.quantization import (QuantizationArgs, + QuantizationStrategy, + QuantizationType) +from compressed_tensors.utils import combine_shards + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear) +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + convert_to_channelwise, sparse_cutlass_supported) +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) + +__all__ = ["CompressedTensors24"] + + +class CompressedTensors24(CompressedTensorsScheme): + + def __init__( + self, + quantized: bool = False, + weight_quant: Optional[QuantizationArgs] = None, + input_quant: Optional[QuantizationArgs] = None, + model_compression_config: Optional[Dict[str, Any]] = None, + ): + self.quantized = quantized + self.weight_quant = weight_quant + self.input_quant = input_quant + self.model_compressor = ( + ModelCompressor.from_compression_config(model_compression_config) + if model_compression_config is not None else None) + self.do_sparse_decompress = ( + self.model_compressor is not None + and self.model_compressor.sparsity_config.format + == CompressionFormat.sparse_24_bitmask.value) + + @classmethod + def get_min_capability(cls) -> int: + # Only cutlass 3.x kernels are implemented so far + return 90 + + def create_weights( + self, + layer: torch.nn.Module, + input_size: int, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): + if not sparse_cutlass_supported(): + raise ValueError( + "Sparse CUTLASS not supported. vLLM must be built with " + "CUDA 12.2 or later to use this feature") + + layer.logical_widths = output_partition_sizes + layer.input_size = input_size + layer.input_size_per_partition = input_size_per_partition + self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype) + + # parameter to store uncompressed weight + weight = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=self.weights_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + if self.do_sparse_decompress: + assert all(partition_size % 8 == 0 + for partition_size in output_partition_sizes + ), "All partitions must be divisible by 8 for " + "2:4 sparse compressed models" + + shape = BasevLLMParameter( + data=torch.empty(2, 1, dtype=torch.int64), + weight_loader=weight_loader, + ) + compressed_weight = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // 2, + dtype=self.weights_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + + bitmask = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // 8, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + + layer.register_parameter("shape", shape) + layer.register_parameter("compressed", compressed_weight) + layer.register_parameter("bitmask", bitmask) + + # Check if quantized, not just 2:4 Sparse + if self.quantized: + if (self.weight_quant and self.weight_quant.strategy + == QuantizationStrategy.CHANNEL.value): + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), + dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader, + ) + else: + assert (self.weight_quant and self.weight_quant.strategy + == QuantizationStrategy.TENSOR.value) + weight_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), + dtype=torch.float32), + weight_loader=weight_loader, + ) + + layer.register_parameter("weight_scale", weight_scale) + + # input quant will be non-none + if self.input_quant and not self.input_quant.dynamic: + # register input quant scale + assert (self.input_quant.strategy == + QuantizationStrategy.TENSOR.value) + input_scale = BasevLLMParameter( + data=torch.empty(1, dtype=torch.float32), + weight_loader=weight_loader, + ) + + layer.register_parameter("input_scale", input_scale) + + else: + # for sparse-only, pass in 1 for weight/input scales + weight_scale = torch.nn.Parameter(data=torch.ones( + 1, dtype=torch.float32), + requires_grad=False) + input_scale = torch.nn.Parameter(data=torch.ones( + 1, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("input_scale", input_scale) + layer.register_parameter("weight_scale", weight_scale) + + layer.register_parameter("weight", weight) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + """ + Compress weights after loading. Store compressed weight and meta + tensor + + :post-condition: layer.w_compressed and layer.meta are + set to the compressed weight and meta tensor in the + format expected by the Cutlass kernels + :param layer: The layer with the weights to be processed + + """ + if self.do_sparse_decompress: + layer.weight.data = self._decompress_bitmask_compressed_weight( + compressed=layer.compressed, + bitmask=layer.bitmask, + layer=layer, + ) + + # compressed and bitmask tensors + # are no longer needed after decompression + del layer.compressed + del layer.bitmask + + # torch.compile workaround + if hasattr(layer, "input_scale"): + layer.input_scale = torch.nn.Parameter(layer.input_scale.data, + requires_grad=False) + + if self.weight_quant: + if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value: + layer.weight_scale = torch.nn.Parameter( + convert_to_channelwise( + weight_scale=layer.weight_scale, + logical_widths=layer.logical_widths, + ), + requires_grad=False, + ) + else: + # torch.compile workaround + layer.weight_scale = torch.nn.Parameter( + layer.weight_scale.data, requires_grad=False) + + # Set all negative zero values to 0 prior to compression + if (layer.weight.dtype.is_floating_point + and layer.weight.dtype.itemsize >= 2): + layer.weight.data[layer.weight.data == -0.0] = 0.0 + + w_compressed, meta = ops.cutlass_sparse_compress(layer.weight.data) + layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False) + layer.meta = torch.nn.Parameter(meta, requires_grad=False) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Returns the output tensor for the layer with 2:4 + sparse compressed weights, given the input tensor + and bias + + :param layer: The layer with 2:4 sparse compressed + weights to be used for the computation + :param x: The input tensor to the layer + :param bias: The bias to be added to the output tensor + :return: The output tensor of the layer + """ + if self.quantized: + scale = None + if hasattr(layer, "input_scale"): + scale = layer.input_scale + + if self.weights_dtype == torch.int8: + ops_output = ops.scaled_int8_quant(x, scale=scale) + q_input = ops_output[0] + input_scale = ops_output[1] + else: + assert self.weights_dtype == torch.float8_e4m3fn + if scale is not None: + q_input, input_scale = ops.scaled_fp8_quant(x, scale=scale) + else: + q_input, input_scale = ops.scaled_fp8_quant( + x, use_per_token_if_dynamic=True) + + else: + # Not quantized, nothing to do with the input_scales, use as is + input_scale = layer.input_scale + q_input = x + + out = ops.cutlass_scaled_sparse_mm( + a=q_input, + bt_nzs=layer.weight, + bt_meta=layer.meta, + scale_a=input_scale, + scale_b=layer.weight_scale, + out_dtype=x.dtype, + bias=bias, + ) + + assert out.is_contiguous() + return out + + def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype: + if not self.quantized: + return params_dtype + + assert self.weight_quant is not None + assert self.input_quant is not None + + is_8_bits = self.weight_quant.num_bits == self.input_quant.num_bits == 8 + + if not is_8_bits: + raise ValueError("Cutlass only supports 8-bit quantization") + + if (self.weight_quant.type == QuantizationType.FLOAT + and self.input_quant.type == QuantizationType.FLOAT): + return torch.float8_e4m3fn + + if (self.weight_quant.type == QuantizationType.INT + and self.input_quant.type == QuantizationType.INT): + return torch.int8 + + raise ValueError("Quantization type not supported by Cutlass") + + def _decompress_bitmask_compressed_weight( + self, + compressed: torch.Tensor, + bitmask: torch.Tensor, + layer: torch.nn.Module, + ) -> torch.Tensor: + """ + Decompress a compressed 2:4 sparse weight tensor using the bitmask and + return the result. + + This function also supports sharded decompression. + + :param compressed: The 2:4 sparse weight tensor compressed using the + sparse-24-bitmask compressor. This is different from + `cutlass_sparse_compress` which uses a different scheme (2 bits for + every nonzero element that represent the coordinate within the block + of 4). The bitmask compression here uses a bitmask to indicate the + positions of non-zero elements. + :param bitmask: The 2:4 bitmask associated with the compressed weights, + representing the positions of non-zero elements in the compressed + tensor. + :param layer: The layer whose weights need to be processed after + loading. + :return: The decompressed 2:4 sparse weight tensor. + """ + + sparsity_compressor = self.model_compressor.sparsity_compressor + + def _process_split( + bitmask_compressed_weight: torch.Tensor, + shape, + bitmask: torch.Tensor, + ) -> torch.Tensor: + weight_data = dict( + compressed=bitmask_compressed_weight, + shape=shape, + bitmask=bitmask, + ) + return sparsity_compressor.decompress_weight(weight_data) + + split_weights: List[torch.Tensor] = [] + split_bitmask: List[torch.Tensor] = [] + split_shape: List[Tuple[int, int]] = [] + + if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)): + split_weights = torch.split(compressed, layer.logical_widths) + split_bitmask = torch.split(bitmask, layer.logical_widths) + split_shape = [(out, layer.input_size_per_partition) + for out in layer.logical_widths] + + if split_weights: + decompressed_shards = [ + _process_split(compressed_weight, shape, bitmask) + for compressed_weight, shape, bitmask in zip( + split_weights, split_shape, split_bitmask) + ] + decompressed = combine_shards(decompressed_shards) + else: + decompressed = sparsity_compressor.decompress_weight( + dict( + compressed=compressed, + shape=( + layer.logical_widths[0], + layer.input_size_per_partition, + ), + bitmask=bitmask, + )) + return decompressed diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py new file mode 100644 index 0000000..daa25d2 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from typing import Optional + +import torch + +__all__ = ["CompressedTensorsScheme"] + + +class CompressedTensorsScheme(ABC): + """ + Abstract class used to describe the weight creation and forward pass + of different quantization schemes supported by CompressedTensors. + """ + + @classmethod + @abstractmethod + def get_min_capability(cls) -> int: + """ + Get minimum device capability. + """ + raise NotImplementedError + + @abstractmethod + def create_weights(self, *args, **kwargs): + """ + Weight creation for the particular scheme. Inputs to this function + + """ + raise NotImplementedError + + @abstractmethod + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]): + """ + Run the forward pass for the particular scheme. This is where + scheme-specific dequant/quant steps/kernels should be applied. + + :param layer: torch.nn.Module with the registered weights and + other parameters relevant to the particular scheme. + :param x: input to the layer + :param bias: bias parameter + + """ + raise NotImplementedError + + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module): + """ + Called after weight loading is complete for any cleanup that + needs to occur. + """ + raise NotImplementedError diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py new file mode 100644 index 0000000..535ea6b --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py @@ -0,0 +1,159 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable, List, Optional + +import torch +from torch.nn import Parameter + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( + GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N) +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter) +from vllm.scalar_type import scalar_types + +__all__ = ["CompressedTensorsW4A16Sparse24"] +W4A16SPARSE24_SUPPORTED_TYPES_MAP = { + 4: scalar_types.uint4b8, +} +W4A16SPARSE24_SUPPORTED_BITS = list(W4A16SPARSE24_SUPPORTED_TYPES_MAP.keys()) + + +class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): + + def __init__(self, + strategy: str, + num_bits: int, + group_size: Optional[int] = None): + self.strategy = strategy + self.group_size = group_size + self.tile_size = 16 + + if num_bits not in W4A16SPARSE24_SUPPORTED_TYPES_MAP: + raise ValueError( + f"Unsupported num_bits = {num_bits}. " + f"Supported num_bits = {W4A16SPARSE24_SUPPORTED_BITS}") + + self.quant_type = W4A16SPARSE24_SUPPORTED_TYPES_MAP[num_bits] + + if self.strategy == "group" and self.group_size is None: + raise ValueError( + "group_size must be given when using strategy group") + + @classmethod + def get_min_capability(cls) -> int: + # ampere + up + return 80 + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # required by torch.compile to be torch.nn.Parameter + layer.weight_packed = Parameter(layer.weight_packed.data, + requires_grad=False) + layer.scale_packed = Parameter(layer.scale_packed.data, + requires_grad=False) + layer.meta = Parameter(layer.meta.data, requires_grad=False) + + def create_weights(self, layer: torch.nn.Module, input_size: int, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + assert params_dtype == torch.float16, ( + "float16 is required for marlin24 compressed models. Set dtype=torch.float16" # noqa: E501 + ) + + pack_factor = 32 // self.quant_type.size_bits + output_size_per_partition = sum(output_partition_sizes) + + qweight = PackedvLLMParameter(data=torch.empty( + input_size_per_partition // self.tile_size // 2, + output_size_per_partition * self.tile_size // pack_factor, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=pack_factor, + marlin_tile_size=self.tile_size, + weight_loader=weight_loader) + + input_groups = (1 if self.group_size is None else + input_size_per_partition // self.group_size) + + weight_scale_args = { + "data": + torch.empty( + input_groups, + output_size_per_partition, + dtype=params_dtype, + ), + "weight_loader": + weight_loader + } + + if self.group_size is not None: + scales = GroupQuantScaleParameter(output_dim=1, + input_dim=0, + **weight_scale_args) + else: + scales = ChannelQuantScaleParameter(output_dim=1, + **weight_scale_args) + + weight_shape = BasevLLMParameter(data=torch.empty(2, + dtype=torch.int64), + weight_loader=weight_loader) + + meta = PackedvLLMParameter(data=torch.empty( + input_size_per_partition // 8 // 2 // 2, + output_size_per_partition * 2, + dtype=torch.int16, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=1, + marlin_tile_size=2, + weight_loader=weight_loader) + + layer.register_parameter("weight_packed", qweight) + layer.register_parameter("weight_shape", weight_shape) + layer.register_parameter("scale_packed", scales) + layer.register_parameter("meta", meta) + + max_workspace_size = ( + output_size_per_partition // + GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL + + workspace = Parameter(torch.zeros(max_workspace_size, dtype=torch.int), + requires_grad=False) + layer.workspace = workspace + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + + qweight = layer.weight_packed + meta = layer.meta + scales = layer.scale_packed + workspace = layer.workspace + + x_2d = x.view(-1, x.shape[-1]) + + size_m = x_2d.shape[0] + size_k = x_2d.shape[1] + size_n = scales.shape[1] + + output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales, + workspace, self.quant_type, size_m, + size_n, size_k) + + output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) + + if bias is not None: + output.add_(bias) # In-place add + + return output diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py new file mode 100644 index 0000000..5c82619 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py @@ -0,0 +1,119 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable, List, Optional + +import torch +from compressed_tensors.quantization import QuantizationStrategy + +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + convert_to_channelwise) +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) + +__all__ = ["CompressedTensorsW8A16Fp8"] + +SUPPORTED_STRATEGIES = [ + QuantizationStrategy.CHANNEL, QuantizationStrategy.TENSOR +] + + +class CompressedTensorsW8A16Fp8(CompressedTensorsScheme): + + def __init__(self, strategy: str, is_static_input_scheme: bool): + self.strategy = strategy + self.is_static_input_scheme = is_static_input_scheme + + @classmethod + def get_min_capability(cls) -> int: + # ampere and up + return 80 + + # W8A8-Fp8 kernels support only per-tensor and per-channel cases. + # So if we have a fused module (QKV, MLP) with per tensor scales, + # we expand each scale to its shard's channels. + def process_weights_after_loading(self, layer) -> None: + if self.strategy == QuantizationStrategy.TENSOR: + ws_channelwise = convert_to_channelwise(layer.weight_scale, + layer.logical_widths) + layer.weight_scale = torch.nn.Parameter(ws_channelwise, + requires_grad=False) + else: + # required by torch.compile to be torch.nn.Parameter + layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data, + requires_grad=False) + + # Weights must be transposed for marlin + layer.weight = torch.nn.Parameter(layer.weight.t(), + requires_grad=False) + + if self.is_static_input_scheme: + # required by torch.compile to be torch.nn.Parameter + layer.input_scale = torch.nn.Parameter(layer.input_scale.data, + requires_grad=False) + prepare_fp8_layer_for_marlin(layer, strategy="channel") + + def create_weights(self, layer: torch.nn.Module, input_size: int, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + # WEIGHT + weight = ModelWeightParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + if self.strategy == QuantizationStrategy.CHANNEL: + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), + dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader) + elif self.strategy == QuantizationStrategy.TENSOR: + weight_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + else: + raise ValueError( + f"Unsupported weight strategy={self.strategy}, " + f"supported strategies are {SUPPORTED_STRATEGIES}") + + weight_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", weight_scale) + + # INPUT SCALE (to deal with converted checkpoints) + if self.is_static_input_scheme: + input_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("input_scale", input_scale) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + return apply_fp8_marlin_linear(input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py new file mode 100644 index 0000000..e99a452 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -0,0 +1,149 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable, List, Optional + +import torch +from compressed_tensors.quantization import QuantizationStrategy +from torch.nn import Parameter + +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz, + requantize_with_max_scale) +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) +from vllm.platforms import current_platform + +__all__ = ["CompressedTensorsW8A8Fp8"] + + +class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): + + def __init__(self, strategy: str, is_static_input_scheme: bool): + self.strategy = strategy + self.out_dtype = torch.get_default_dtype() + self.is_static_input_scheme = is_static_input_scheme + self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True) + + @classmethod + def get_min_capability(cls) -> int: + # lovelace and up + return 89 + + def process_weights_after_loading(self, layer) -> None: + # If per tensor, when we have a fused module (e.g. QKV) with per + # tensor scales (thus N scales being passed to the kernel), + # requantize so we can always run per tensor + if self.strategy == QuantizationStrategy.TENSOR: + max_w_scale, weight = requantize_with_max_scale( + weight=layer.weight, + weight_scale=layer.weight_scale, + logical_widths=layer.logical_widths, + ) + + if current_platform.is_fp8_fnuz(): + input_scale = getattr(layer, 'input_scale', None) + + weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=max_w_scale, + input_scale=input_scale) + if input_scale is not None: + layer.input_scale = Parameter(input_scale, + requires_grad=False) + + layer.weight = Parameter(weight.t(), requires_grad=False) + layer.weight_scale = Parameter(max_w_scale, requires_grad=False) + + # If channelwise, scales are already lined up, so just transpose. + elif self.strategy == QuantizationStrategy.CHANNEL: + weight = layer.weight + + if current_platform.is_fp8_fnuz(): + input_scale = getattr(layer, 'input_scale', None) + + weight, weight_scale, input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=layer.weight_scale, + input_scale=input_scale) + if input_scale is not None: + layer.input_scale = Parameter(input_scale, + requires_grad=False) + else: + weight_scale = layer.weight_scale.data + + layer.weight = Parameter(weight.t(), requires_grad=False) + # required by torch.compile to be torch.nn.Parameter + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + + else: + raise ValueError(f"Unknown quantization strategy {self.strategy}") + + # INPUT SCALE + if self.is_static_input_scheme and hasattr(layer, 'input_scale'): + layer.input_scale = Parameter(layer.input_scale.max(), + requires_grad=False) + else: + layer.input_scale = None + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + maybe_create_device_identity() + + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + + # WEIGHT + weight = ModelWeightParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + # TODO: update create_xxx_parameter functions to return + # the newly added parameters + if self.strategy == QuantizationStrategy.CHANNEL: + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), + dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader) + else: + assert self.strategy == QuantizationStrategy.TENSOR + weight_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + + # min requirement for fp8 kernels + weight_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", weight_scale) + + # INPUT SCALE + if self.is_static_input_scheme: + input_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + input_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("input_scale", input_scale) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + return self.fp8_linear.apply(input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + out_dtype=self.out_dtype, + input_scale=layer.input_scale, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py new file mode 100644 index 0000000..898c8e7 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable, List, Optional, Set + +import torch +from compressed_tensors.quantization import QuantizationStrategy + +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + ScaledMMLinearLayerConfig, choose_scaled_mm_linear_kernel) +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) + +logger = init_logger(__name__) + + +class CompressedTensorsW8A8Int8(CompressedTensorsScheme): + _kernel_backends_being_used: Set[str] = set() + + def __init__(self, strategy: str, is_static_input_scheme: bool, + input_symmetric: bool): + import vllm.envs as env + if env.VLLM_MIX_QUANTIZATION_TYPE == "TENSOR": + self.strategy = QuantizationStrategy.TENSOR + elif env.VLLM_MIX_QUANTIZATION_TYPE == "CHANNEL": + self.strategy = QuantizationStrategy.CHANNEL + else: + self.strategy = strategy + self.is_static_input_scheme = is_static_input_scheme + self.input_symmetric = input_symmetric + + @classmethod + def get_min_capability(cls) -> int: + # turing and up + return 75 + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + layer.logical_widths = output_partition_sizes + + scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig( + is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL), + is_static_input_scheme=self.is_static_input_scheme, + input_symmetric=self.input_symmetric) + + kernel_type = choose_scaled_mm_linear_kernel( + scaled_mm_linear_kernel_config) + + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for CompressedTensorsW8A8Int8", + kernel_type.__name__) + self._kernel_backends_being_used.add(kernel_type.__name__) + + # WEIGHT + weight = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=torch.int8), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + if self.strategy == QuantizationStrategy.CHANNEL: + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), + dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader) + else: + assert self.strategy == QuantizationStrategy.TENSOR + weight_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("weight_scale", weight_scale) + + # INPUT SCALE + if self.is_static_input_scheme: + input_scale = BasevLLMParameter(data=torch.empty( + 1, dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("input_scale", input_scale) + + if not self.input_symmetric: + # Note: compressed-tensors stores the zp using the same dtype + # as the weights + # AZP loaded as int8 but used as int32 + input_zero_point = BasevLLMParameter( + data=torch.empty(1, dtype=torch.int8), + weight_loader=weight_loader) + layer.register_parameter("input_zero_point", input_zero_point) + + self.kernel = kernel_type(c=scaled_mm_linear_kernel_config, + w_q_param_name="weight", + w_s_param_name="weight_scale", + i_s_param_name="input_scale", + i_zp_param_name="input_zero_point", + azp_adj_param_name="azp_adj") + + # Checkpoints are serialized in compressed-tensors format, which is + # different from the format the kernel may want. Handle repacking here. + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py new file mode 100644 index 0000000..38df09f --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -0,0 +1,164 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable, List, Optional, Set + +import torch +from compressed_tensors.quantization import ActivationOrdering + +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( + MPLinearLayerConfig, choose_mp_linear_kernel) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + marlin_repeat_scales_on_all_ranks) +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter, + RowvLLMParameter) +from vllm.scalar_type import scalar_types + +logger = init_logger(__name__) + +__all__ = ["CompressedTensorsWNA16"] +WNA16_SUPPORTED_TYPES_MAP = { + 4: scalar_types.uint4b8, + 8: scalar_types.uint8b128 +} +WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys()) + + +class CompressedTensorsWNA16(CompressedTensorsScheme): + _kernel_backends_being_used: Set[str] = set() + + def __init__(self, + strategy: str, + num_bits: int, + group_size: Optional[int] = None, + actorder: Optional[ActivationOrdering] = None): + + self.pack_factor = 32 // num_bits + self.strategy = strategy + self.group_size = -1 if group_size is None else group_size + self.has_g_idx = actorder == ActivationOrdering.GROUP + + if self.group_size == -1 and self.strategy != "channel": + raise ValueError("Marlin kernels require group quantization or " + "channelwise quantization, but found no group " + "size and strategy is not channelwise.") + + if num_bits not in WNA16_SUPPORTED_TYPES_MAP: + raise ValueError( + f"Unsupported num_bits = {num_bits}. " + f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}") + + self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits] + + @classmethod + def get_min_capability(cls) -> int: + # ampere and up + return 80 + + def create_weights(self, layer: torch.nn.Module, output_size: int, + input_size: int, output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + output_size_per_partition = sum(output_partition_sizes) + + mp_linear_kernel_config = MPLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=\ + (input_size_per_partition, output_size_per_partition), + weight_type=self.quant_type, + act_type=params_dtype, + group_size=self.group_size, + zero_points=False, + has_g_idx=self.has_g_idx + ) + + kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) + + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for CompressedTensorsWNA16", + kernel_type.__name__) + self._kernel_backends_being_used.add(kernel_type.__name__) + + # If group_size is -1, we are in channelwise case. + group_size = self.group_size if self.group_size != -1 else input_size + row_parallel = (input_size != input_size_per_partition) + partition_scales = not marlin_repeat_scales_on_all_ranks( + self.has_g_idx, self.group_size, row_parallel) + + scales_and_zp_size = input_size // group_size + + if partition_scales: + assert input_size_per_partition % group_size == 0 + scales_and_zp_size = input_size_per_partition // group_size + + weight = PackedvLLMParameter(input_dim=1, + output_dim=0, + weight_loader=weight_loader, + packed_factor=self.pack_factor, + packed_dim=1, + data=torch.empty( + output_size_per_partition, + input_size_per_partition // + self.pack_factor, + dtype=torch.int32, + )) + + weight_scale_args = { + "weight_loader": + weight_loader, + "data": + torch.empty( + output_size_per_partition, + scales_and_zp_size, + dtype=params_dtype, + ) + } + if not partition_scales: + weight_scale = ChannelQuantScaleParameter(output_dim=0, + **weight_scale_args) + else: + weight_scale = GroupQuantScaleParameter(output_dim=0, + input_dim=1, + **weight_scale_args) + + # A 2D array defining the original shape of the weights + # before packing + weight_shape = BasevLLMParameter(data=torch.empty(2, + dtype=torch.int64), + weight_loader=weight_loader) + + layer.register_parameter("weight_packed", weight) + layer.register_parameter("weight_scale", weight_scale) + layer.register_parameter("weight_shape", weight_shape) + + # group index (for activation reordering) + if self.has_g_idx: + weight_g_idx = RowvLLMParameter(data=torch.empty( + input_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight_g_idx", weight_g_idx) + + self.kernel = kernel_type(mp_linear_kernel_config, + w_q_param_name="weight_packed", + w_s_param_name="weight_scale", + w_zp_param_name=None, + w_gidx_param_name="weight_g_idx") + + # Checkpoints are serialized in compressed-tensors format, which is + # different from the format the kernel may want. Handle repacking here. + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py b/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py new file mode 100644 index 0000000..b69c5e7 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py @@ -0,0 +1,205 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Type + +import torch +import triton +import triton.language as tl + + +def is_weak_contiguous(x: torch.Tensor): + strides = x.stride() + sizes = x.shape + is_not_transpose = strides[0] == 1 and (strides[1] >= max(1, sizes[0])) + is_transpose = strides[1] == 1 and (strides[0] >= max(1, sizes[1])) + return is_transpose or is_not_transpose + + +@triton.jit +def scaled_mm_kernel(a_ptr, b_ptr, scale_a_ptr, scale_b_ptr, c_ptr, bias_ptr, + M, N, K, stride_am, stride_ak, stride_bk, stride_bn, + stride_cm, stride_cn, ACCUMULATOR_DTYPE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_SCALE_A: tl.constexpr, + BLOCK_SIZE_SCALE_B: tl.constexpr): + pid = tl.program_id(axis=0) + + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + accumulator_dtype = ACCUMULATOR_DTYPE + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), + dtype=accumulator_dtype) + + # NOTE: Some tensor inputs are so large, they will cause int32 overflow + # so it is necessary to use tl.int64 for all the offsets, else SEGV will + # eventually occur. + + # Offsets and masks. + offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + masks_am = offsets_am < M + + offsets_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) + masks_bn = offsets_bn < N + + offsets_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64) + offsets_a = (stride_am * offsets_am[:, None] + + stride_ak * offsets_k[None, :]) + offsets_b = (stride_bk * offsets_k[:, None] + + stride_bn * offsets_bn[None, :]) + + # NOTE: BLOCK_SIZE_SCALE_A could be 1 or BLOCK_SIZE_M, so need to create + # appropriate offsets and masks for each case. Same goes for + # BLOCK_SIZE_SCALE_B. + offsets_scale_am = (tl.arange(0, BLOCK_SIZE_SCALE_A) + + (BLOCK_SIZE_SCALE_A > 1) * pid_m * BLOCK_SIZE_M) + masks_scale_am = offsets_scale_am < M + + offsets_scale_bn = (tl.arange(0, BLOCK_SIZE_SCALE_B) + + (BLOCK_SIZE_SCALE_B > 1) * pid_n * BLOCK_SIZE_N) + masks_scale_bn = offsets_scale_bn < N + + a_ptrs = a_ptr + offsets_a + b_ptrs = b_ptr + offsets_b + + scale_a_ptrs = scale_a_ptr + offsets_scale_am + scale_b_ptrs = scale_b_ptr + offsets_scale_bn + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + masks_k = offsets_k < K + masks_a = masks_am[:, None] & masks_k[None, :] + a = tl.load(a_ptrs, mask=masks_a) + + masks_b = masks_k[:, None] & masks_bn[None, :] + b = tl.load(b_ptrs, mask=masks_b) + + # Accumulate results. + accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype) + + offsets_k += BLOCK_SIZE_K + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + # Apply scale at end. + masks_scale_a = masks_scale_am[:, None] & (tl.arange(0, 1) < 1)[:, None] + scale_a = tl.load(scale_a_ptrs[:, None], masks_scale_a) + # Need to broadcast to the appropriate size, if scale_a is already + # (BLOCK_SIZE_M, 1) then it will broadcast to its own shape. Same goes + # for scale_b below. + scale_a = scale_a.broadcast_to((BLOCK_SIZE_M, 1)) + accumulator = scale_a * accumulator.to(tl.float32) + + masks_scale_b = masks_scale_bn[:, None] & (tl.arange(0, 1) < 1)[None, :] + scale_b = tl.load(scale_b_ptrs[:, None], masks_scale_b) + scale_b = scale_b.broadcast_to((BLOCK_SIZE_N, 1)) + accumulator = scale_b.T * accumulator.to(tl.float32) + + # Convert to output format. + c = accumulator.to(c_ptr.type.element_ty) + + # Add bias, it's already in output format, so add it after conversion. + if bias_ptr: + offsets_bias = offsets_bn + bias_ptrs = bias_ptr + offsets_bias + bias_mask = offsets_bias < N + bias = tl.load(bias_ptrs, bias_mask) + c += bias + + # Save output + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) + offs_cm = offs_cm.to(tl.int64) + offs_cn = offs_cn.to(tl.int64) + c_ptrs = (c_ptr + stride_cm * offs_cm[:, None] + + stride_cn * offs_cn[None, :]) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + + tl.store(c_ptrs, c, mask=c_mask) + + +# input - [M, K] +# weight - [K, N] +def triton_scaled_mm(input: torch.Tensor, + weight: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: Type[torch.dtype], + bias: Optional[torch.Tensor] = None, + block_size_m: int = 32, + block_size_n: int = 32, + block_size_k: int = 32, + use_heuristic=True) -> torch.Tensor: + M, K = input.shape + N = weight.shape[1] + + assert N > 0 and K > 0 and M > 0 + assert weight.shape[0] == K + assert input.dtype == weight.dtype + + scale_a = scale_a.reshape(-1, 1) if scale_a.dim() <= 1 else scale_a + scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b + + assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point() + assert scale_a.shape == torch.Size([1, 1]) or scale_a.shape == torch.Size( + [M, 1]) + assert scale_b.shape == torch.Size([1, 1]) or scale_b.shape == torch.Size( + [N, 1]) + assert out_dtype.is_floating_point + assert bias is None or bias.is_floating_point() + assert is_weak_contiguous(input) + assert is_weak_contiguous(weight) + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( + N, META['BLOCK_SIZE_N']), ) + + result = torch.empty((M, N), dtype=out_dtype, device=input.device) + + has_scalar = lambda x: x.shape[0] == 1 and x.shape[1] == 1 + + if use_heuristic: + is_small_N = N < 8192 + next_power_of_2_M = max(32, triton.next_power_of_2(M)) + if next_power_of_2_M <= 32: + tile_shape = (64, 64, 256) if is_small_N else (64, 128, 256) + elif next_power_of_2_M <= 64: + tile_shape = (64, 64, 256) + elif next_power_of_2_M <= 128: + tile_shape = (64, 128, 128) + else: + tile_shape = (128, 128, 128) + + block_size_m, block_size_n, block_size_k = tile_shape + + block_size_sa = 1 if has_scalar(scale_a) else block_size_m + block_size_sb = 1 if has_scalar(scale_b) else block_size_n + + accumulator_dtype = tl.float32 if input.is_floating_point() else tl.int32 + + # A = input, B = weight, C = result + # A = M x K, B = K x N, C = M x N + scaled_mm_kernel[grid](input, + weight, + scale_a, + scale_b, + result, + bias, + M, + N, + K, + input.stride(0), + input.stride(1), + weight.stride(0), + weight.stride(1), + result.stride(0), + result.stride(1), + accumulator_dtype, + BLOCK_SIZE_M=block_size_m, + BLOCK_SIZE_N=block_size_n, + BLOCK_SIZE_K=block_size_k, + BLOCK_SIZE_SCALE_A=block_size_sa, + BLOCK_SIZE_SCALE_B=block_size_sb) + + return result.to(out_dtype) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py new file mode 100644 index 0000000..85ae1d5 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -0,0 +1,213 @@ +# SPDX-License-Identifier: Apache-2.0 + +import re +from types import MappingProxyType +from typing import Iterable, List, Mapping, Optional + +from compressed_tensors import CompressionFormat +from torch.nn import Module + + +def is_activation_quantization_format(format: str) -> bool: + _ACTIVATION_QUANTIZATION_FORMATS = [ + CompressionFormat.naive_quantized.value, + CompressionFormat.int_quantized.value, + CompressionFormat.float_quantized.value, + ] + return format in _ACTIVATION_QUANTIZATION_FORMATS + + +def should_ignore_layer( + layer_name: Optional[str], + ignore: Iterable[str] = tuple(), + fused_mapping: Mapping[str, List[str]] = MappingProxyType({}) +) -> bool: + if layer_name is None: + return False + + # layer_name = model.layers.0.self_attn.qkv_proj + # proj_name = qkv_proj + proj_name = layer_name.split(".")[-1] + + # Fused layers like gate_up_proj or qkv_proj will not be fused + # in the safetensors checkpoint. So, we convert the name + # from the fused version to unfused + check to make sure that + # each shard of the fused layer has the same scheme. + if proj_name in fused_mapping and layer_name not in ignore: + shard_proj_names = fused_mapping[proj_name] + + # Convert fused_name --> [shard_names] + shard_names = [ + layer_name.replace(proj_name, shard_proj_name) + for shard_proj_name in shard_proj_names + ] + + # Layer should be ignored if shards are ignored. + should_ignore_layer = None + for shard_name in shard_names: + should_ignore_shard = check_equal_or_regex_match( + layer_name=shard_name, targets=ignore) + + # If shard_idx=0, set layer ignore to match shard. + if should_ignore_layer is None: + should_ignore_layer = should_ignore_shard + + # If shard_idx=1+ confirm scheme matches prior shards. + elif should_ignore_shard != should_ignore_layer: + raise ValueError(f"Found a different quantization schemes for " + f"{shard_proj_names} in {layer_name}. vLLM " + "requires all to use the same scheme.") + + # Unfused layers like down_proj and o_proj will match + # the safetensors checkpoint already. + else: + should_ignore_layer = check_equal_or_regex_match(layer_name=layer_name, + targets=ignore) + + assert should_ignore_layer is not None + return should_ignore_layer + + +def check_equal_or_regex_match(layer_name: str, + targets: Iterable[str]) -> bool: + """ + Checks whether a layer_name is exactly equal or a regex match for + if target starts with 're:' to any target in list. + """ + for target in targets: + if _is_equal_or_regex_match(layer_name, target): + return True + return False + + +def find_matched_target( + layer_name: Optional[str], + module: Module, + targets: Iterable[str], + fused_mapping: Mapping[str, List[str]] = MappingProxyType({}) +) -> str: + """ + Helper function to look up which "target" in the compressed-tensors + config that a layer corresponds to. + + Recall that a compressed-tensors configs has a concept of + config_groups, where each layer can be quantized with with a different + scheme. + + targets in each config_group will be a list of either layer names + (or regexes corresponding to layer names) or names of torch Modules. + + First, we try to match the layer_name with a target + Second, we try to match the module's name with a target + Third, we try to map the layer_name to a list of fused module names. + *All* component module names must match in order for a match to be + successful. A successful match returns the first component target + + :param layer_name: layer name + :param module: torch.nn.Module + :param targets: list of targets to match the layer against + :param fused_mapping: map from fused layer names to its components + :param fused_strategy: either "all" or "any". If using "all", fused + layers match if "all" of its components match + """ + + if layer_name is None: + layer_name = "" + + matched_target = ( + _find_first_match(layer_name, targets) + or _find_first_match(module.__class__.__name__, targets, True) + or _match_fused_layer(layer_name, targets, fused_mapping)) + + if matched_target is None: + raise ValueError( + f"Unable to find matching target for {layer_name} in the " + "compressed-tensors config.") + + return matched_target + + +def _find_first_match(value: str, + targets: Iterable[str], + check_contains: bool = False) -> Optional[str]: + """ + Returns first element of target that matches value either + exactly or as a regex after 're:'. If check_contains is set to True, + additionally checks if the target string is contained within the value. + + :param value: string to compare the list of targets against + :param targets: list of targets to match the layer against + :param check_contains: whether or not to do a substring match + """ + + for target in targets: + if _is_equal_or_regex_match(value, + target, + check_contains=check_contains): + return target + return None + + +def _is_equal_or_regex_match(value: str, + target: str, + check_contains: bool = False) -> bool: + """ + Checks whether a value is exactly equal or a regex match for target + if target starts with 're:'. If check_contains is set to True, + additionally checks if the target string is contained within the value. + """ + + if target.startswith("re:"): + pattern = target[3:] + if re.match(pattern, value): + return True + elif check_contains: + if target.lower() in value.lower(): + return True + elif target == value: + return True + return False + + +def _match_fused_layer( + layer_name: str, target_layers: Iterable[str], + fused_mapping: Mapping[str, List[str]]) -> Optional[str]: + """ + Match a fused layer name to its corresponding individual layer in + target_layers. Returns first value in fused_mapping which matches targets + + Implements an "all" matching strategy where a fused layer matches iff + "all" of its components match + + :param layer_name: layer name + :param target_layers: list of targets to match the layer against + :param fused_mapping: map from fused layer names to its components + + Examples: + layer_name = "model.layers.0.self_attn.qkv_proj" + target_layers = ["model.layers.0.self_attn.q_proj", + "model.layers.0.self_attn.k_proj", + "model.layers.0.self_attn.v_proj"] + """ + # find layer_name in mapping + fused = next((key for key in fused_mapping if layer_name.endswith(key)), + None) + if fused is None: + return None + + # expand path of unfused components + unfused_paths = [ + layer_name.replace(fused, unfused) for unfused in fused_mapping[fused] + ] + + # for each unfused component, find a match in targets + unfused_matches: List[Optional[str]] = [] + for unfused in unfused_paths: + for target in target_layers: + if _is_equal_or_regex_match(unfused, target): + unfused_matches.append(target) + break + else: + unfused_matches.append(None) + + return unfused_matches[0] if all(unfused_matches) else None diff --git a/vllm/model_executor/layers/quantization/deepspeedfp.py b/vllm/model_executor/layers/quantization/deepspeedfp.py new file mode 100644 index 0000000..67934d3 --- /dev/null +++ b/vllm/model_executor/layers/quantization/deepspeedfp.py @@ -0,0 +1,193 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs + + +class DeepSpeedFPConfig(QuantizationConfig): + """Config for DeepSpeed FP quantizer. It supports fp6 and fp8. + + Args: + weight_bits: the target quantization bits, 6 or 8. + group_size: group size for quantizaiton, default to 128. + """ + + def __init__( + self, + weight_bits: int = 8, + group_size: int = 512, + ) -> None: + super().__init__() + self.weight_bits = weight_bits + self.group_size = group_size + self.valid_types = [torch.bfloat16, torch.float16] + + if self.weight_bits not in (6, 8): + raise ValueError( + "Currently, only 6-bit or 8-bit weight quantization are " + f"supported for DeepSpeed FP quantizaiton, but got " + f"{self.weight_bits} bits.") + + def __repr__(self) -> str: + return (f"DeepSpeedFPConfig(weight_bits={self.weight_bits}), " + f"group_size={self.group_size}") + + @classmethod + def get_name(cls) -> str: + return "DeepSpeedFP" + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "DeepSpeedFPConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + return cls(weight_bits=weight_bits, group_size=group_size) + + def get_linear_method(self) -> "DeepSpeedFPLinearMethod": + return DeepSpeedFPLinearMethod(self) + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 60 + + @staticmethod + def get_config_filenames() -> List[str]: + return [ + "quant_config.json", + "quantize_config.json", + ] + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["DeepSpeedFPLinearMethod"]: + if isinstance(layer, LinearBase): + return DeepSpeedFPLinearMethod(self) + return None + + +class DeepSpeedFPLinearMethod(LinearMethodBase): + """Linear method for DeepSpeedFP quantizer. + + Args: + quant_config: the DeepSpeedFP quantization config. + """ + + def __init__(self, quant_config: DeepSpeedFPConfig): + self.quant_config = quant_config + self.weight = None + + def create_weights(self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + weight_loader=None, + **extra_weight_attrs): + del output_size + del input_size + output_size_per_partition = sum(output_partition_sizes) + weight = DeepSpeedFPParameter( + torch.Size((output_size_per_partition, input_size_per_partition)), + params_dtype=params_dtype, + quant_config=self.quant_config, + ) + set_weight_attrs(weight, { + "input_dim": 1, + "output_dim": 0, + }) + layer.register_parameter("weight", weight) + + def quant_weight_loader(param, loaded_weight, *args, **kwargs): + # Calls the original weight loader (if any), quantizes the result, + # and then loads the quantized parameter. + if weight_loader is not None: + orig_param_data = param.data + param.data = param.ds_dequantize() + weight_loader(param, loaded_weight, *args, **kwargs) + param.data, loaded_weight = orig_param_data, param.data + param.ds_quantize_(loaded_weight.cuda()) + + extra_weight_attrs["weight_loader"] = quant_weight_loader + set_weight_attrs(weight, extra_weight_attrs) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + weight = layer.weight + y = weight.ds_dequantize() + return F.linear(x, y, bias) + + +class DeepSpeedFPParameter(nn.Parameter): + """ + DeepSpeedFP quantized parameter class that implements fp8/fp6 + quantization deepspeed. Weights are stored in quantized form on + GPUs, and can be dequantized on-the-fly when needed by the model. + """ + + def __new__(cls, orig_shape: torch.Size, params_dtype: torch.dtype, + quant_config: DeepSpeedFPConfig): + try: + import deepspeed + if deepspeed.__version__ < "0.14.2": + raise ImportError("deepspeed version is wrong. Please " + "install deepspeed>=0.14.2.") + from deepspeed.ops.fp_quantizer import FP_Quantize + except ImportError as err: + raise ImportError("Please install deepspeed>=0.14.2 via " + "`pip install deepspeed>=0.14.2` to use " + "deepspeedfp quantizer.") from err + data = torch.empty(( + orig_shape.numel() // quant_config.group_size, + quant_config.group_size * quant_config.weight_bits // 8 + 4, + ), + dtype=torch.int8) + self = torch.Tensor._make_subclass(cls, data, data.requires_grad) + self.orig_shape = orig_shape + self.quant_config = quant_config + self.fp_quantizer = FP_Quantize(group_size=quant_config.group_size) + self.fp_quantizer.orig_shape = orig_shape + self.fp_quantizer.orig_dtype = params_dtype + return self + + def ds_quantize_(self, tensor: torch.Tensor): + assert tensor.device.type == "cuda" and tensor.dtype != torch.int8 + return self.data.copy_( + self.fp_quantizer.quantize( + tensor.data, + q_bits=self.quant_config.weight_bits, + )) + + def ds_dequantize(self, fp_out=None) -> torch.Tensor: + """ + Return a tensor containing the dequantized weights of this parameter. + """ + assert self.data.device.type == "cuda" and self.data.dtype == torch.int8 + return self.fp_quantizer.dequantize( + self.data, fp_out=fp_out, q_bits=self.quant_config.weight_bits) + + def ds_selective_dequantize(self, indices, fp_out=None) -> torch.Tensor: + """ + Return a tensor where only the weights at `indices` are dequantized + (to save HBM -> SRAM bandwidth). + """ + assert self.data.device.type == "cuda" and self.data.dtype == torch.int8 + return self.fp_quantizer.selective_dequantize( + self.data, + indices, + fp_out=fp_out, + q_bits=self.quant_config.weight_bits) diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py new file mode 100644 index 0000000..be19b80 --- /dev/null +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -0,0 +1,194 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Dict, List, Optional + +import torch + +from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group +from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.utils import set_weight_attrs + + +class ExpertsInt8Config(QuantizationConfig): + """Config class for Int8 experts quantization.""" + + def __init__(self) -> None: + super().__init__() + + @classmethod + def get_name(cls) -> str: + return "experts_int8" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "ExpertsInt8Config": + return cls() + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, LinearBase): + return UnquantizedLinearMethod() + elif isinstance(layer, FusedMoE): + return ExpertsInt8MoEMethod(self) + return None + + +class ExpertsInt8MoEMethod(FusedMoEMethodBase): + + def __init__(self, quant_config: ExpertsInt8Config): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + int8_dtype = torch.int8 + + assert 'weight_loader' in extra_weight_attrs + weight_loader = extra_weight_attrs['weight_loader'] + wrapped_weight_loader = ExpertsInt8MoEMethod.quantizing_weight_loader( + layer, weight_loader) + extra_weight_attrs['weight_loader'] = wrapped_weight_loader + + # Fused gate_up_proj (column parallel) + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=int8_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + # down_proj (row parallel) + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=int8_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + w13_scale = torch.nn.Parameter(torch.zeros( + num_experts, + 2 * intermediate_size_per_partition, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_scale", w13_scale) + + w2_scale = torch.nn.Parameter(torch.zeros(num_experts, + hidden_size, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_scale", w2_scale) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe import fused_experts + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + use_int8_w8a16=True, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + w1_scale=layer.w13_scale, + w2_scale=layer.w2_scale) + + @staticmethod + def quantizing_weight_loader(layer, weight_loader): + + def quantize_and_call_weight_loader(param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, shard_id: int, + expert_id: int): + tp_rank = get_tensor_model_parallel_rank() + shard_size = layer.intermediate_size_per_partition + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + device = get_tp_group().device + loaded_weight = loaded_weight.to(device) + # w1, gate_proj case: Load into first shard of w13. + if shard_id == "w1": + scales = quantize_in_place_and_get_scales( + loaded_weight[shard, :]) + layer.w13_scale.data[expert_id, 0:shard_size].copy_(scales[:, + 0]) + # w3, up_proj case: Load into second shard of w13. + elif shard_id == "w3": + scales = quantize_in_place_and_get_scales( + loaded_weight[shard, :]) + layer.w13_scale.data[expert_id, shard_size:2 * + shard_size].copy_(scales[:, 0]) + # w2, down_proj case: Load into only shard of w2. + elif shard_id == "w2": + scales = quantize_in_place_and_get_scales(loaded_weight[:, + shard]) + layer.w2_scale.data[expert_id, :].copy_(scales[:, 0]) + else: + raise ValueError( + f"Shard id must be in [0,1,2] but got {shard_id}") + weight_loader(param, loaded_weight, weight_name, shard_id, + expert_id) + + return quantize_and_call_weight_loader + + +def quantize_in_place_and_get_scales(weight: torch.Tensor) -> torch.Tensor: + vmax = torch.iinfo(torch.int8).max + scales = (torch.max(torch.abs(weight), dim=1, keepdim=True)[0] / vmax) + + weight.div_(scales) + weight.round_() + weight.clamp_(-vmax, vmax) + + return scales diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py new file mode 100644 index 0000000..7dddc40 --- /dev/null +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -0,0 +1,168 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional + +import torch +from torch.nn import Module +from torch.nn.parameter import Parameter + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz) +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + ModelWeightParameter) +from vllm.platforms import current_platform + +logger = init_logger(__name__) + + +class FBGEMMFp8Config(QuantizationConfig): + """Config class for FBGEMM Fp8.""" + + def __init__(self, ignore_list: List[str], input_scale_ub: float): + super().__init__() + self.ignore_list = ignore_list if ignore_list else [] + self.input_scale_ub = input_scale_ub + + # For GPUs that lack FP8 hardware support, we can leverage the Marlin + # kernel for fast weight-only FP8 quantization + self.use_marlin = not current_platform.has_device_capability(89) + self.fp8_linear = Fp8LinearOp() + + @classmethod + def get_name(cls) -> str: + return "fbgemm_fp8" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.float16] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "FBGEMMFp8Config": + ignore_list = cls.get_from_keys(config, ["modules_to_not_convert"]) + input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"]) + return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, LinearBase): + if is_layer_skipped(prefix, self.ignore_list): + return UnquantizedLinearMethod() + return FBGEMMFp8LinearMethod(self) + return None + + +class FBGEMMFp8LinearMethod(LinearMethodBase): + + def __init__(self, quant_config: FBGEMMFp8Config): + self.quant_config = quant_config + self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True) + self.out_dtype = torch.get_default_dtype() + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + maybe_create_device_identity() + weight_loader = extra_weight_attrs.get("weight_loader") + del input_size, output_size + output_size_per_partition = sum(output_partition_sizes) + + layer.logical_widths = output_partition_sizes + + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + # WEIGHT + weight = ModelWeightParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + weight_scale = ChannelQuantScaleParameter(data=torch.empty( + (sum(output_partition_sizes), 1), dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader) + weight_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", weight_scale) + + # INPUT SCALE UPPER BOUND + input_scale_ub = torch.nn.Parameter(torch.tensor( + (self.quant_config.input_scale_ub), dtype=torch.float32), + requires_grad=False) + layer.input_scale_ub = input_scale_ub + + def process_weights_after_loading(self, layer: Module) -> None: + # required by torch.compile + layer.weight_scale = Parameter(layer.weight_scale.data, + requires_grad=False) + layer.weight = Parameter(layer.weight.data, requires_grad=False) + + weight = layer.weight + + if current_platform.is_fp8_fnuz(): + weight, weight_scale, input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=layer.weight_scale, + input_scale=None) + if input_scale is not None: + layer.input_scale = Parameter(input_scale, requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + + layer.weight = Parameter(weight.t(), requires_grad=False) + if self.quant_config.use_marlin: + prepare_fp8_layer_for_marlin(layer) + # Activations not quantized for marlin. + del layer.input_scale_ub + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + if self.quant_config.use_marlin: + return apply_fp8_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias) + + return self.fp8_linear.apply(input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + out_dtype=self.out_dtype, + input_scale=None, + input_scale_ub=layer.input_scale_ub, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py new file mode 100644 index 0000000..4435644 --- /dev/null +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -0,0 +1,823 @@ +# SPDX-License-Identifier: Apache-2.0 + +import importlib.util +from typing import Any, Callable, Dict, List, Optional + +import torch +import torch.nn.functional as F +from torch.nn import Module +from torch.nn.parameter import Parameter + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, + FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + Fp8LinearOp, all_close_1d, convert_to_channelwise, + cutlass_block_fp8_supported, cutlass_fp8_supported, + maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz, + per_tensor_dequantize, requantize_with_max_scale) +from vllm.model_executor.parameter import (BlockQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform + +ACTIVATION_SCHEMES = ["static", "dynamic"] + +logger = init_logger(__name__) + +has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None + + +def _is_col_major(x: torch.Tensor) -> bool: + assert x.dim() == 3 + b, m, n = x.shape + return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m + + +class Fp8Config(QuantizationConfig): + """Config class for FP8.""" + + def __init__( + self, + is_checkpoint_fp8_serialized: bool = False, + activation_scheme: str = "dynamic", + ignored_layers: Optional[List[str]] = None, + weight_block_size: Optional[List[int]] = None, + ) -> None: + super().__init__() + self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized + if is_checkpoint_fp8_serialized: + logger.warning("Detected fp8 checkpoint. Please note that the " + "format is experimental and subject to change.") + if activation_scheme not in ACTIVATION_SCHEMES: + raise ValueError( + f"Unsupported activation scheme {activation_scheme}") + self.activation_scheme = activation_scheme + self.ignored_layers = ignored_layers or [] + if weight_block_size is not None: + if not is_checkpoint_fp8_serialized: + raise ValueError( + "The block-wise quantization only supports fp8-serialized " + "checkpoint for now.") + if len(weight_block_size) != 2: + raise ValueError( + "The quantization block size of weight must have 2 " + f"dimensions, but got {len(weight_block_size)} dimensions") + if activation_scheme != "dynamic": + raise ValueError("The block-wise quantization only supports " + "dynamic activation scheme for now, but got " + f"{activation_scheme} activation scheme.") + self.weight_block_size = weight_block_size + + @classmethod + def get_name(cls) -> str: + return "fp8" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": + quant_method = cls.get_from_keys(config, ["quant_method"]) + is_checkpoint_fp8_serialized = ("fp8" in quant_method) + activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) + weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], + None) + return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, + activation_scheme=activation_scheme, + ignored_layers=ignored_layers, + weight_block_size=weight_block_size) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + + if isinstance(layer, LinearBase): + if is_layer_skipped(prefix, self.ignored_layers): + return UnquantizedLinearMethod() + return Fp8LinearMethod(self) + elif isinstance(layer, FusedMoE): + return Fp8MoEMethod(self) + elif isinstance(layer, Attention): + return Fp8KVCacheMethod(self) + return None + + def get_cache_scale(self, name: str) -> Optional[str]: + """ + Check whether the param name matches the format for k/v cache scales + in compressed-tensors. If this is the case, return its equivalent + param name expected by vLLM + + :param name: param name + :return: matching param name for KV cache scale in vLLM + """ + if name.endswith(".output_scale") and ".k_proj" in name: + return name.replace(".k_proj.output_scale", ".attn.k_scale") + if name.endswith(".output_scale") and ".v_proj" in name: + return name.replace(".v_proj.output_scale", ".attn.v_scale") + return None + + +class Fp8LinearMethod(LinearMethodBase): + """Linear method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + + Limitations: + 1. Only support per-tensor quantization due to torch._scaled_mm support. + 2. Only support float8_e4m3fn data type due to the limitation of + torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856) + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: Fp8Config): + self.quant_config = quant_config + self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() + self.out_dtype = torch.get_default_dtype() + + # For GPUs that lack FP8 hardware support, we can leverage the Marlin + # kernel for fast weight-only FP8 quantization + self.use_marlin = (not current_platform.has_device_capability(89) + or envs.VLLM_TEST_FORCE_FP8_MARLIN) + # Disable marlin for rocm + if current_platform.is_rocm(): + self.use_marlin = False + + self.block_quant = self.quant_config.weight_block_size is not None + if self.block_quant: + # Marlin doesn't support block-wise fp8 + self.use_marlin = False + + self.fp8_linear = Fp8LinearOp( + # Default to using per_token quantization if cutlass is supported + use_per_token_if_dynamic=cutlass_fp8_supported()) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + maybe_create_device_identity() + + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + if self.block_quant: + tp_size = get_tensor_model_parallel_world_size() + assert self.quant_config.weight_block_size is not None + block_n, block_k = ( + self.quant_config.weight_block_size[0], + self.quant_config.weight_block_size[1], + ) + # Required by row parallel + if (tp_size > 1 + and input_size // input_size_per_partition == tp_size + and input_size_per_partition % block_k != 0): + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"weight quantization block_k = {block_k}.") + # Required by column parallel or enabling merged weights + if (tp_size > 1 and output_size // output_size_per_partition + == tp_size) or len(output_partition_sizes) > 1: + for output_partition_size in output_partition_sizes: + if output_partition_size % block_n != 0: + raise ValueError( + f"Weight output_partition_size = " + f"{output_partition_size} is not divisible by " + f"weight quantization block_n = {block_n}.") + + layer.logical_widths = output_partition_sizes + + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + # WEIGHT + weight_dtype = (torch.float8_e4m3fn + if self.quant_config.is_checkpoint_fp8_serialized else + params_dtype) + + weight = ModelWeightParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=weight_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight", weight) + + # If checkpoint is serialized fp8, load them. + # Otherwise, wait until process_weights_after_loading. + if self.quant_config.is_checkpoint_fp8_serialized: + # WEIGHT SCALE + if not self.block_quant: + scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), + dtype=torch.float32), + weight_loader=weight_loader, + ) + scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", scale) + else: + assert self.quant_config.activation_scheme == "dynamic" + scale = BlockQuantScaleParameter( + data=torch.empty( + (output_size_per_partition + block_n - 1) // block_n, + (input_size_per_partition + block_k - 1) // block_k, + dtype=torch.float32, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + scale[:] = torch.finfo(torch.float32).min + # The weight_scale_inv name is intentional for deepseekv3 + layer.register_parameter("weight_scale_inv", scale) + + # INPUT ACTIVATION SCALE + if self.quant_config.activation_scheme == "static": + scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + + scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("input_scale", scale) + else: + layer.register_parameter("input_scale", None) + + def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor: + # Pad the weight tensor. This is an optimization on ROCm platform, which + # can benefit from tensors located far enough from one another in memory + if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm() + and weight.stride(-1) == 1 + and (weight.stride(-2) * weight.element_size()) % 512 == 0): + num_pad = 256 // weight.element_size() + weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] + torch.cuda.empty_cache() + return weight + + def process_weights_after_loading(self, layer: Module) -> None: + # TODO(rob): refactor block quant into separate class. + if self.block_quant: + assert self.quant_config.activation_scheme == "dynamic" + if current_platform.is_fp8_fnuz(): + weight, weight_scale_inv, _ = \ + normalize_e4m3fn_to_e4m3fnuz( + weight=layer.weight, + weight_scale=layer.weight_scale_inv) + else: + weight = layer.weight.data + weight_scale_inv = layer.weight_scale_inv.data + + weight = self._maybe_pad_weight(weight) + + # Torch.compile cannot use Parameter subclasses. + layer.weight = Parameter(weight, requires_grad=False) + layer.weight_scale_inv = Parameter(weight_scale_inv, + requires_grad=False) + return + + # If checkpoint not serialized fp8, quantize the weights. + if not self.quant_config.is_checkpoint_fp8_serialized: + qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, + scale=None) + + # If using marlin (w8a16), kernel uses channelwise weights, + # so extend the weight scales to be channelwise. + if self.use_marlin: + assert weight_scale.numel() == 1 + weight_scale = convert_to_channelwise( + weight_scale.expand(len(layer.logical_widths)), + layer.logical_widths) + + # Update the layer with the new values. + layer.weight = Parameter(qweight.t(), requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + layer.input_scale = None + + # If checkpoint is fp8, handle that there are N scales for N + # shards in a fused module + else: + layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data, + requires_grad=False) + if self.quant_config.activation_scheme == "static": + layer.input_scale = torch.nn.Parameter(layer.input_scale.data, + requires_grad=False) + # If using marlin (w8a16), kernel uses channelwise weights, + # so extend the weight scales to be channelwise. + if self.use_marlin: + weight = layer.weight + weight_scale = convert_to_channelwise(layer.weight_scale, + layer.logical_widths) + + # If using w8a8, torch._scaled_mm needs per tensor, so + # requantize the logical shards as a single weight. + else: + # Dequant -> Quant with max scale so we can run per tensor. + weight = layer.weight + weight_scale = layer.weight_scale + + if current_platform.is_fp8_fnuz(): + weight, weight_scale, input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=weight_scale, + input_scale=layer.input_scale) + if input_scale is not None: + layer.input_scale = Parameter(input_scale, + requires_grad=False) + + weight_scale, weight = requantize_with_max_scale( + weight=weight, + weight_scale=weight_scale, + logical_widths=layer.logical_widths, + ) + + weight = self._maybe_pad_weight(weight) + # Update layer with new values. + layer.weight = Parameter(weight.t(), requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + if self.quant_config.activation_scheme == "static": + layer.input_scale = Parameter(layer.input_scale.max(), + requires_grad=False) + + if self.use_marlin: + prepare_fp8_layer_for_marlin(layer) + # Activations not quantized for marlin. + del layer.input_scale + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + if self.use_marlin: + return apply_fp8_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias) + + if self.block_quant: + assert self.quant_config.weight_block_size is not None + return torch.ops.vllm.apply_w8a8_block_fp8_linear( + input=x, + weight=layer.weight, + block_size=self.quant_config.weight_block_size, + weight_scale=layer.weight_scale_inv, + input_scale=layer.input_scale, + bias=bias, + cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, + ) + + return self.fp8_linear.apply(input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + out_dtype=self.out_dtype, + input_scale=layer.input_scale, + bias=bias) + + +class Fp8MoEMethod(FusedMoEMethodBase): + """MoE method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: Fp8Config): + self.quant_config = quant_config + self.block_quant = self.quant_config.weight_block_size is not None + + # Check for DeepGemm support. + self.allow_deep_gemm = False + if envs.VLLM_USE_DEEP_GEMM: + if not has_deep_gemm: + logger.warning_once("Failed to import DeepGemm kernels.") + elif (current_platform.is_cuda() + and current_platform.has_device_capability(90)): + logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.") + self.allow_deep_gemm = True + else: + logger.warning_once( + "DeepGemm not supported on the current platform.") + + def create_weights(self, layer: Module, num_experts: int, hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + if self.quant_config.is_checkpoint_fp8_serialized: + params_dtype = torch.float8_e4m3fn + if self.block_quant: + assert self.quant_config.weight_block_size is not None + tp_size = get_tensor_model_parallel_world_size() + block_n, block_k = ( + self.quant_config.weight_block_size[0], + self.quant_config.weight_block_size[1], + ) + # NOTE: To ensure proper alignment of the block-wise quantization + # scales, the output_size of the weights for both the gate and up + # layers must be divisible by block_n. + # Required by column parallel or enabling merged weights + if intermediate_size_per_partition % block_n != 0: + raise ValueError( + f"The output_size of gate's and up's weight = " + f"{intermediate_size_per_partition} is not divisible by " + f"weight quantization block_n = {block_n}.") + if (tp_size > 1 + and intermediate_size_per_partition % block_k != 0): + # Required by row parallel + raise ValueError( + f"The input_size of down's weight = " + f"{intermediate_size_per_partition} is not divisible by " + f"weight quantization block_k = {block_k}.") + + # WEIGHTS + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + if not self.block_quant: + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, 2, dtype=torch.float32), + requires_grad=False) + w2_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + else: + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * ((intermediate_size_per_partition + block_n - 1) // + block_n), + (hidden_size + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + (hidden_size + block_n - 1) // block_n, + (intermediate_size_per_partition + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) + layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) + assert self.quant_config.activation_scheme == "dynamic" + + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.BLOCK. + value} if self.block_quant else + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + # If loading fp8 checkpoint, pass the weight loaders. + # If loading an fp16 checkpoint, do not (we will quantize in + # process_weights_after_loading() + if self.quant_config.is_checkpoint_fp8_serialized: + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.quant_config.activation_scheme == "static": + if not self.quant_config.is_checkpoint_fp8_serialized: + raise ValueError( + "Found static activation scheme for checkpoint that " + "was not serialized fp8.") + + w13_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: Module) -> None: + # Lazy import to avoid importing triton too early. + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + expand_weights, is_rocm_aiter_block_scaled_moe_enabled, + is_rocm_aiter_moe_enabled, shuffle_weights) + + # TODO (rob): refactor block quant into separate class. + if self.block_quant: + assert self.quant_config.activation_scheme == "dynamic" + if current_platform.is_fp8_fnuz(): + w13_weight, w13_weight_scale_inv, w13_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w13_weight, layer.w13_weight_scale_inv, + layer.w13_input_scale) + w2_weight, w2_weight_scale_inv, w2_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w2_weight, layer.w2_weight_scale_inv, + layer.w2_input_scale) + else: + w13_weight = layer.w13_weight.data + w13_weight_scale_inv = layer.w13_weight_scale_inv.data + w2_weight = layer.w2_weight + w2_weight_scale_inv = layer.w2_weight_scale_inv + + # torch.compile() cannot use Parameter subclasses. + layer.w13_weight = Parameter(w13_weight, requires_grad=False) + layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv, + requires_grad=False) + layer.w2_weight = Parameter(w2_weight, requires_grad=False) + layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv, + requires_grad=False) + if is_rocm_aiter_block_scaled_moe_enabled(): + # reshaping weights is required for aiter moe kernel. + shuffled_w13, shuffled_w2 = shuffle_weights( + layer.w13_weight.data, layer.w2_weight.data) + + layer.w13_weight = torch.nn.Parameter(shuffled_w13, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(shuffled_w2, + requires_grad=False) + + # DeepGemm scales need to be transposed and aligned. We try to do + # it ahead of time for performance reasons. + if self.allow_deep_gemm: + # Lazy import to avoid CUDA initialization problems. + import deep_gemm as dg + if _is_col_major(layer.w13_weight_scale_inv): + layer.w13_weight_scale_inv = \ + dg.get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous() + if _is_col_major(layer.w2_weight_scale_inv): + layer.w2_weight_scale_inv = \ + dg.get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous() + + return + + # If checkpoint is fp16, quantize in place. + if not self.quant_config.is_checkpoint_fp8_serialized: + fp8_dtype = current_platform.fp8_dtype() + w13_weight = torch.empty_like(layer.w13_weight.data, + dtype=fp8_dtype) + w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) + + # Re-initialize w13_scale because we directly quantize + # merged w13 weights and generate a single scaling factor. + layer.w13_weight_scale = torch.nn.Parameter(torch.ones( + layer.local_num_experts, + dtype=torch.float32, + device=w13_weight.device), + requires_grad=False) + for expert in range(layer.local_num_experts): + w13_weight[expert, :, :], layer.w13_weight_scale[ + expert] = ops.scaled_fp8_quant( + layer.w13_weight.data[expert, :, :]) + w2_weight[expert, :, :], layer.w2_weight_scale[ + expert] = ops.scaled_fp8_quant( + layer.w2_weight.data[expert, :, :]) + layer.w13_weight = torch.nn.Parameter(w13_weight, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, + requires_grad=False) + if is_rocm_aiter_moe_enabled(): + # reshaping weights is required for aiter moe kernel. + w13_scales, w2_scales = expand_weights( + layer.w13_weight_scale.data, + layer.w2_weight_scale.data, + expansion_dims=[ + layer.w13_weight.shape[1], layer.w2_weight.shape[1] + ]) + layer.w13_weight_scale = torch.nn.Parameter( + w13_scales.contiguous(), requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter( + w2_scales.contiguous(), requires_grad=False) + + shuffled_w13, shuffled_w2 = shuffle_weights( + layer.w13_weight, layer.w2_weight) + + layer.w13_weight = torch.nn.Parameter(shuffled_w13, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(shuffled_w2, + requires_grad=False) + return + + # If checkpoint is fp8, we need to handle that the + # MoE kernels require single activation scale and single weight + # scale for w13 per expert. + else: + # Fp8 moe kernels require a single activation scale. + # We take the max of all the scales in case they differ. + if self.quant_config.activation_scheme == "static": + if (layer.w13_input_scale is None + or layer.w2_input_scale is None): + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None.") + if (not all_close_1d(layer.w13_input_scale) + or not all_close_1d(layer.w2_input_scale)): + logger.warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer.") + layer.w13_input_scale = torch.nn.Parameter( + layer.w13_input_scale.max(), requires_grad=False) + layer.w2_input_scale = torch.nn.Parameter( + layer.w2_input_scale.max(), requires_grad=False) + if current_platform.is_fp8_fnuz(): + # Normalize the weights and scales + w13_weight, w13_weight_scale, w13_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w13_weight, layer.w13_weight_scale, + layer.w13_input_scale) + w2_weight, w2_weight_scale, w2_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w2_weight, layer.w2_weight_scale, + layer.w2_input_scale) + # Reset the parameter + layer.w13_weight = torch.nn.Parameter(w13_weight, + requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + w13_weight_scale, requires_grad=False) + if w13_input_scale is not None: + layer.w13_input_scale = torch.nn.Parameter( + w13_input_scale, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, + requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, + requires_grad=False) + if w2_input_scale is not None: + layer.w2_input_scale = torch.nn.Parameter( + w2_input_scale, requires_grad=False) + + # Fp8 moe kernel needs single weight scale for w13 per expert. + # We take the max then dequant and requant each expert. + assert layer.w13_weight_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + for expert_id in range(layer.local_num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start:start + + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id]) + layer.w13_weight[expert_id][ + start:start + shard_size, :], _ = ops.scaled_fp8_quant( + dq_weight, max_w13_scales[expert_id]) + start += shard_size + + if is_rocm_aiter_moe_enabled(): + # reshaping weights is required for aiter moe kernel. + expansion_dims = [ + layer.w13_weight.shape[1], layer.w2_weight.shape[1] + ] + max_w13_scales, w2_scales = expand_weights( + max_w13_scales, + layer.w2_weight_scale.data, + expansion_dims=expansion_dims) + layer.w2_weight_scale = torch.nn.Parameter( + w2_scales.contiguous(), requires_grad=False) + + shuffled_w13, shuffled_w2 = shuffle_weights( + layer.w13_weight, layer.w2_weight) + + layer.w13_weight = torch.nn.Parameter(shuffled_w13, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(shuffled_w2, + requires_grad=False) + + layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, + requires_grad=False) + return + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe import fused_experts + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + use_fp8_w8a8=True, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + w1_scale=(layer.w13_weight_scale_inv + if self.block_quant else layer.w13_weight_scale), + w2_scale=(layer.w2_weight_scale_inv + if self.block_quant else layer.w2_weight_scale), + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + block_shape=self.quant_config.weight_block_size, + allow_deep_gemm=self.allow_deep_gemm, + ) + + +class Fp8KVCacheMethod(BaseKVCacheMethod): + """ + Supports loading kv-cache scaling factors from FP8 checkpoints. + """ + + def __init__(self, quant_config: Fp8Config): + super().__init__(quant_config) diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py new file mode 100644 index 0000000..3d32d69 --- /dev/null +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -0,0 +1,685 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional, Generator, Tuple, Callable + +import gguf +import torch +import torch.nn.functional as F +from gguf import GGMLQuantizationType as WeightType +from torch.nn.parameter import Parameter, UninitializedParameter + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, + FusedMoEMethodBase) +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.quantization.utils.gguf_utils import get_awq_format, dequant_gguf +import gguf +import ctypes +import ixformer.inference.functions as IXF +try: + import ktransformers.cpuinfer_ext as cpuinfer_ext + from ktransformers.cpuinfer_ext.moe import MOEConfig, MOE +except: + try: + print("ktransformers may not installed, try import cpuinfer_ext from vllm.") + import vllm.cpuinfer_ext as cpuinfer_ext + from vllm.cpuinfer_ext.moe import MOEConfig, MOE + except: + print("ktransformers cpu inference component cpuinfer_ext.so not founded. deepseek will not work.") + cpuinfer_ext = None + MOEConfig = None + MOE = None + +logger = init_logger(__name__) + + +class GGUFConfig(QuantizationConfig): + """Config class for GGUF.""" + + def __init__(self, ) -> None: + super().__init__() + + def __repr__(self) -> str: + return ("GGUFConfig()") + + def get_name(self) -> str: + return "gguf" + + def get_supported_act_dtypes(self) -> List[torch.dtype]: + return [torch.half, torch.bfloat16, torch.float32] + + @classmethod + def get_min_capability(cls) -> int: + return 60 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] # no extra configs. + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "GGUFConfig": + return cls() + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, LinearBase): + return GGUFLinearMethod(self) + elif isinstance(layer, VocabParallelEmbedding): + return GGUFEmbeddingMethod(self) + elif isinstance(layer, FusedMoE): + return GGUFMoEMethod(self, prefix) + return None + + +UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16} +STANDARD_QUANT_TYPES = { + WeightType.Q4_0, + WeightType.Q4_1, + WeightType.Q5_0, + WeightType.Q5_1, + WeightType.Q8_0, + WeightType.Q8_1, +} +KQUANT_TYPES = { + WeightType.Q2_K, + WeightType.Q3_K, + WeightType.Q4_K, + WeightType.Q5_K, + WeightType.Q6_K, +} +IMATRIX_QUANT_TYPES = { + WeightType.IQ1_M, + WeightType.IQ1_S, + WeightType.IQ2_XXS, + WeightType.IQ2_XS, + WeightType.IQ2_S, + WeightType.IQ3_XXS, + WeightType.IQ3_S, + WeightType.IQ4_XS, + WeightType.IQ4_NL, +} +# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization. +# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add +# MMQ kernel for I-Matrix quantization. +DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES +MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES +MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES + + +def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor, + qweight_type: int) -> torch.Tensor: + # HACK: when doing chunked prefill we don't generate output tokens + # so input to logits generator is empty which causes invalid parameter + if x.shape[0] == 0: + return torch.empty(x.shape[0], + qweight.shape[0], + dtype=x.dtype, + device=x.device) + # there is no need to call any kernel for fp16/bf16 + if qweight_type in UNQUANTIZED_TYPES: + return x @ qweight.T + # enable MMVQ in contiguous batching with batch_size=1 + if x.shape[0] == 1 and qweight_type in MMVQ_QUANT_TYPES: + y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0]) + # Use MMQ Kernel if it's available (standard + k-quants) + elif qweight_type in MMQ_QUANT_TYPES: + y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0]) + # If there is no available MMQ kernel, fallback to dequantize + elif qweight_type in DEQUANT_TYPES: + block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] + shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size) + weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype) + y = x @ weight.T + else: + # Raise an error if the quantization type is not supported. + # Might be useful if llama.cpp adds a new quantization type. + # Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type. + qweight_type = WeightType(qweight_type) + raise NotImplementedError( + f"Unsupported GGUF quantization type: {qweight_type}") + return y + + +def _fused_moe_gguf( + x: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + qweight_type: int, + qweight_type2: int, + act, +) -> torch.Tensor: + # lazy import to avoid triggering triton import in CPU backend + from vllm.model_executor.layers.fused_moe.fused_moe import ( + moe_align_block_size) + + out_hidden_states = torch.empty_like(x) + if qweight_type2 in MMQ_QUANT_TYPES and qweight_type in MMQ_QUANT_TYPES: + num_tokens, _ = x.shape + E, N, _ = w1.shape + top_k = topk_ids.shape[1] + BLOCK_SIZE = ops.ggml_moe_get_block_size(qweight_type) + + sorted_token_ids, expert_ids, num_tokens_post_padded = \ + moe_align_block_size(topk_ids, BLOCK_SIZE, E) + out = ops.ggml_moe_a8(x, w1, sorted_token_ids, expert_ids, + num_tokens_post_padded, qweight_type, N, top_k, + num_tokens) + out = act(out) + out = ops.ggml_moe_a8(out, w2, sorted_token_ids, expert_ids, + num_tokens_post_padded, qweight_type2, + w2.shape[1], 1, num_tokens * top_k) + out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_( + topk_weights.view(num_tokens, top_k, 1)) + ops.moe_sum(out, out_hidden_states) + else: + logger.warning_once("There is no support for fast MoE kernel " + "for current quantization method. " + "Falling back to slow implementation. ") + for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)): + inp = x[tok].reshape((1, ) + x.shape[1:]) + current_hidden_state = None + for ww, ii in zip(w, idx): + expert_up = w1[ii] + + out = _fuse_mul_mat(inp, expert_up, qweight_type) + out = act(out) + + expert_down = w2[ii] + current_state = _fuse_mul_mat(out, expert_down, + qweight_type2).mul_(ww) + if current_hidden_state is None: + current_hidden_state = current_state + else: + current_hidden_state.add_(current_state) + out_hidden_states[tok] = current_hidden_state + return out_hidden_states + + +class GGUFLinearMethod(LinearMethodBase): + """Linear method for GGUF. + + Args: + quant_config: The GGUF quantization config. + """ + + def __init__(self, quant_config: GGUFConfig): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + self.params_dtype = params_dtype + output_size_per_partition = sum(output_partition_sizes) + + tensor_shape = (output_size_per_partition, input_size_per_partition) + qweight = GGUFUninitializedParameter(requires_grad=False) + set_weight_attrs( + qweight, { + "input_dim": 1, + "output_dim": 0, + "tensor_shape": tensor_shape, + "is_gguf_weight": True, + "data_container": [], + "shard_id": [], + "shard_id_map": {}, + "params_dtype": params_dtype, + "input_size_per_partition" :input_size_per_partition, # restore shape for qkv and merge + "output_partition_sizes" :output_partition_sizes, + }) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("qweight", qweight) + + qweight_type = Parameter(torch.empty(len(output_partition_sizes), + dtype=torch.uint8), + requires_grad=False) + set_weight_attrs( + qweight_type, { + "is_gguf_weight_type": True, + "weight_type": 0, + "shard_weight_type": {}, + "ignore_warning": True + }) + set_weight_attrs(qweight_type, extra_weight_attrs) + layer.register_parameter("qweight_type", qweight_type) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + shard_id = getattr(layer.qweight, "shard_id", None) + keep_half_weight = getattr(self.quant_config,"keep_half_weight",None) # use for MLA fused.. + params_dtype = getattr(layer.qweight, "params_dtype", torch.get_default_dtype()) + if shard_id: + # dequantize shard weights respectively + shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id + qweight = layer.qweight.unbind(0) + weights = [] + input_size_per_partition = layer.qweight.input_size_per_partition + output_partition_sizes = layer.qweight.output_partition_sizes + orig_type_weights = [] + for idx in shard_id: + q_idx = layer.qweight.shard_id_map[idx] + qweight_type = layer.qweight_type.shard_weight_type[idx] + weight = dequant_gguf(qweight[q_idx].data.cpu().numpy().reshape(-1), qweight_type, (output_partition_sizes[q_idx], input_size_per_partition)) + weights.append(weight.to(qweight[q_idx].data.device)) + + weights = torch.cat(weights, axis=0) + qweight, qzeros, scales = get_awq_format(weights) + if keep_half_weight: + layer.weight = torch.nn.Parameter(weights.to(params_dtype), + requires_grad=False) + assert layer.weight.is_cuda + else: + qweight = layer.qweight + qweight_type = layer.qweight_type.weight_type + weight = dequant_gguf(qweight.data.cpu().numpy().reshape(-1), qweight_type, qweight.tensor_shape) + qweight, qzeros, scales = get_awq_format(weight.to(qweight.data.device)) + if keep_half_weight: + layer.weight = torch.nn.Parameter(weight.to(qweight.data.device).to(params_dtype), + requires_grad=False) + assert layer.weight.is_cuda + del layer.qweight + del layer.qweight_type + layer.qweight = torch.nn.Parameter(qweight.data, + requires_grad=False) + layer.qzeros = torch.nn.Parameter(qzeros.data, + requires_grad=False) + layer.scales = torch.nn.Parameter(scales.data, + requires_grad=False) + assert layer.qweight.is_cuda + assert layer.qzeros.is_cuda + assert layer.scales.is_cuda + # if layer.weight is not None, we will use it on triton_mla.process_weights_after_loading and delete it. + # Mainly used to avoid accuracy loss in quantization and dequantization on MLA martix absorption. + # PS: layer.weight is dequantizing from gguf format + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + """ gguf format use awq kernel for now.. + shard_id = getattr(layer.qweight, "shard_id", None) + + if shard_id: + # dequantize shard weights respectively + shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id + qweight = layer.qweight.unbind(0) + result = [] + for idx in shard_id: + q_idx = layer.qweight.shard_id_map[idx] + qweight_type = layer.qweight_type.shard_weight_type[idx] + result.append(_fuse_mul_mat(x, qweight[q_idx], qweight_type)) + out = torch.cat(result, axis=1) + else: + qweight = layer.qweight + qweight_type = layer.qweight_type.weight_type + out = _fuse_mul_mat(x, qweight, qweight_type) + if bias is not None: + out.add_(bias) + return out + """ + pack_factor = 8 + qweight = layer.qweight + scales = layer.scales + qzeros = layer.qzeros + out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, )) + reshaped_x = x.reshape(-1, x.shape[-1]) + + out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor, group_size=128) + if bias is not None: + out.add_(bias) + return out.reshape(out_shape) + + +class GGUFMoEMethod(FusedMoEMethodBase): + """MoE method for GGUF. + + Args: + quant_config: The GGUF quantization config. + """ + + def __init__(self, quant_config: GGUFConfig): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + tensor_shape = (num_experts, 2 * intermediate_size_per_partition, + hidden_size) + #gate up proj + w13_qweight = GGUFUninitializedParameter(requires_grad=False) + set_weight_attrs( + w13_qweight, { + "input_dim": 1, + "output_dim": 0, + "tensor_shape": tensor_shape, + "is_gguf_weight": True, + "data_container": [], + }) + set_weight_attrs(w13_qweight, extra_weight_attrs) + layer.register_parameter("w13_qweight", w13_qweight) + + w13_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8), + requires_grad=False) + set_weight_attrs(w13_qweight_type, { + "is_gguf_weight_type": True, + "weight_type": 0, + "ignore_warning": True + }) + set_weight_attrs(w13_qweight_type, extra_weight_attrs) + layer.register_parameter("w13_qweight_type", w13_qweight_type) + + tensor_shape = (num_experts, intermediate_size_per_partition, + hidden_size) + #gate down proj + w2_qweight = GGUFUninitializedParameter(requires_grad=False) + set_weight_attrs( + w2_qweight, { + "input_dim": 1, + "output_dim": 0, + "tensor_shape": tensor_shape, + "is_gguf_weight": True, + "data_container": [], + }) + set_weight_attrs(w2_qweight, extra_weight_attrs) + layer.register_parameter("w2_qweight", w2_qweight) + + w2_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8), + requires_grad=False) + set_weight_attrs(w2_qweight_type, { + "is_gguf_weight_type": True, + "weight_type": 0, + "ignore_warning": True + }) + + set_weight_attrs(w2_qweight_type, extra_weight_attrs) + layer.register_parameter("w2_qweight_type", w2_qweight_type) + self.act = SiluAndMul() + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ): + assert activation == "silu", "Only SiLU activation is supported." + if apply_router_weight_on_input: + raise NotImplementedError( + "Apply router weight on input is not supported for" + "fused GGUF MoE method.") + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + return _fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight, + topk_weights, topk_ids, + layer.w13_qweight_type.weight_type, + layer.w2_qweight_type.weight_type, self.act) + + +class GGUFEmbeddingMethod(GGUFLinearMethod): + """Embedding method for GGUF. + + Args: + quant_config: The GGUF quantization config. + """ + + def embedding(self, layer: torch.nn.Module, + x: torch.Tensor) -> torch.Tensor: + weight = layer.weight + return F.embedding(x, weight) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + qweight = layer.qweight + qweight_type = layer.qweight_type + + params_dtype = getattr(qweight, "params_dtype", torch.get_default_dtype()) + weight = dequant_gguf(qweight.data.cpu().numpy().reshape(-1), qweight_type.weight_type, qweight.tensor_shape) + + del layer.qweight + del layer.qweight_type + layer.weight = torch.nn.Parameter(weight.to(params_dtype).to(qweight.device), + requires_grad=False) + assert layer.weight.is_cuda + assert layer.weight.dtype == params_dtype + if layer.__class__.__name__ == "ParallelLMHead": + from vllm.model_executor.layers.vocab_parallel_embedding import UnquantizedEmbeddingMethod + layer.quant_method = UnquantizedEmbeddingMethod() + + +class GGUFCPUInfer: + cpuinfer = None + + def __init__(self) -> None: + import os + thread_num = int(os.environ.get("VLLM_GGUF_THREAD_NUMS", 33)) + if GGUFCPUInfer.cpuinfer is None: + GGUFCPUInfer.cpuinfer = cpuinfer_ext.CPUInfer(thread_num) + + +class GGUFMoEMethod(FusedMoEMethodBase): + CPUINfer = None + CUR_STREAM = None + INPUT_TENSOR_CPU = None + TOPK_IDS_CPU = None + TOPK_WEIGHT_CPU = None + OUTPUT_CPU = None + OUTPUT_GPU = None + MAX_B = 0 + + def __init__(self, quant_config: GGUFConfig, prefix: str): + if GGUFMoEMethod.CPUINfer is None: + if cpuinfer_ext is None: + raise RuntimeError("ktransformers not installed!!") + else: + GGUFMoEMethod.CPUINfer = GGUFCPUInfer() + self.quant_config = quant_config + self.prefix = prefix + + def create_weights(self, layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs): + + gate_up_tensor_shape = (intermediate_size_per_partition, hidden_size, num_experts) + down_tensor_shape = (hidden_size, intermediate_size_per_partition, num_experts) + gate_weight = GGUFUninitializedParameter(requires_grad=False) + up_weight = GGUFUninitializedParameter(requires_grad=False) + down_weight = GGUFUninitializedParameter(requires_grad=False) + set_weight_attrs( + gate_weight, { + "tensor_shape": gate_up_tensor_shape, + "is_gguf_weight": True, + "data_container": [] + }) + set_weight_attrs( + up_weight, { + "tensor_shape": gate_up_tensor_shape, + "is_gguf_weight": True, + "data_container": [] + }) + set_weight_attrs( + down_weight, { + "tensor_shape": down_tensor_shape, + "is_gguf_weight": True, + "data_container": [] + }) + set_weight_attrs(gate_weight, extra_weight_attrs) + set_weight_attrs(up_weight, extra_weight_attrs) + set_weight_attrs(down_weight, extra_weight_attrs) + layer.register_parameter("gate_weight", gate_weight) + layer.register_parameter("up_weight", up_weight) + layer.register_parameter("down_weight", down_weight) + setattr(layer, "gate_type", None) + setattr(layer, "up_type", None) + setattr(layer, "down_type", None) + + + def process_weights_after_loading(self, layer: torch.nn.Module, + all_tensors: gguf.ReaderTensor, + gguf_to_hf_name_map: Dict[str, str]) -> None: + all_done = [None] * 3 + ptrs = [None] * 3 + types = [None] * 3 + prefix = self.prefix.replace("model.layers","blk").replace(".mlp.experts","") + for tensor in all_tensors: + # if tensor.name in gguf_to_hf_name_map: + if tensor.name == prefix + ".ffn_gate_exps.weight": + data = tensor.data + index = 0 + if all(all_done): + assert False, "repeat weights!!!" + elif tensor.name == prefix + ".ffn_up_exps.weight": + data = tensor.data + index = 1 + if all(all_done): + assert False, "repeat weights!!!" + elif tensor.name == prefix + ".ffn_down_exps.weight": + data = tensor.data + index = 2 + if all(all_done): + assert False, "repeat weights!!!" + else: + index = -1 + if index != -1: + types[index] = tensor.tensor_type + ptrs[index] = ctypes.addressof( + ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents + ) + all_done[index] = True + assert all(all_done) + + moe_config = MOEConfig( + layer.num_experts, + layer.top_k, + layer.hidden_size, + layer.intermediate_size_per_partition, + 64, + 10, + 1024, + ptrs[0], + ptrs[1], + ptrs[2], + types[0], + types[1], + types[2], + 30 if torch.get_default_dtype() == torch.bfloat16 else 1, + ) + cpu_moe = MOE(moe_config) + self.CPUINfer.cpuinfer.submit(cpu_moe.warm_up()) + self.CPUINfer.cpuinfer.sync() + setattr(layer, "cpu_moe", cpu_moe) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + extra_residual: Optional[torch.Tensor] = None, + routed_scaling_factor: float = 1.0, + ) -> torch.Tensor: + B = x.shape[0] + topk_ids = torch.empty([B, top_k], dtype=torch.long, device=x.device) + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + topk_ids=topk_ids) + + if GGUFMoEMethod.CUR_STREAM is None: + GGUFMoEMethod.CUR_STREAM = torch.cuda.default_stream() + if GGUFMoEMethod.MAX_B < B: + GGUFMoEMethod.MAX_B = B + GGUFMoEMethod.INPUT_TENSOR_CPU = torch.empty_like(x, dtype=x.dtype, device="cpu", pin_memory=True) + GGUFMoEMethod.TOPK_IDS_CPU = torch.empty_like(topk_ids, dtype=torch.long, device="cpu",pin_memory=True) + GGUFMoEMethod.TOPK_WEIGHT_CPU = torch.empty_like(topk_weights, dtype=topk_weights.dtype, device="cpu",pin_memory=True) + GGUFMoEMethod.OUTPUT_CPU = torch.zeros_like(x, dtype=x.dtype, device="cpu",pin_memory=True) + GGUFMoEMethod.OUTPUT_GPU = torch.zeros_like(x, dtype=x.dtype, device=x.device) + + input_tensor_cpu = GGUFMoEMethod.INPUT_TENSOR_CPU[:B] + topk_ids_cpu = GGUFMoEMethod.TOPK_IDS_CPU[:B] + topk_weights_cpu = GGUFMoEMethod.TOPK_WEIGHT_CPU[:B] + output_cpu = GGUFMoEMethod.OUTPUT_CPU[:B] + output_gpu = GGUFMoEMethod.OUTPUT_GPU[:B] + + input_tensor_cpu.copy_(x, non_blocking=True) + topk_ids_cpu.copy_(topk_ids, non_blocking=True) + topk_weights_cpu.copy_(topk_weights, non_blocking=True) + GGUFMoEMethod.CUR_STREAM.synchronize() + cpu_moe = layer.cpu_moe + self.CPUINfer.cpuinfer.submit( + cpu_moe.forward(topk_ids.size(0), topk_ids.size(1), topk_ids_cpu.data_ptr(), topk_weights_cpu.data_ptr(), input_tensor_cpu.data_ptr(), output_cpu.data_ptr())) + self.CPUINfer.cpuinfer.sync() + output_gpu.copy_(output_cpu, non_blocking=True) + torch.cuda.default_stream().synchronize() + return GGUFMoEMethod.OUTPUT_GPU[:B] + + +class GGUFUninitializedParameter(UninitializedParameter): + cls_to_become = Parameter + data_container: List[torch.Tensor] + + def materialize_nested(self) -> Parameter: + dtype = {data.dtype for data in self.data_container} + assert len(dtype) == 1, ValueError( + f"Data container has mixed dtypes: {dtype}") + dtype = next(iter(dtype)) + nested_data = torch.nested.nested_tensor(self.data_container, + device=self.device, + dtype=dtype) + self.data_container.clear() + param = torch.Tensor._make_subclass(self.cls_to_become, + nested_data, + require_grad=False) + for k, v in self.__dict__.items(): + setattr(param, k, v) + return param diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py new file mode 100644 index 0000000..59f07ba --- /dev/null +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -0,0 +1,276 @@ +# SPDX-License-Identifier: Apache-2.0 + +import enum +from enum import Enum +from fractions import Fraction +from typing import Any, Dict, List, Optional, Union + +import torch +from torch.nn.parameter import Parameter + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.linear import LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.utils.gptq_utils import ( + get_linear_quant_method) +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + RowvLLMParameter) + + +class GPTQConfig(QuantizationConfig): + """Config class for GPTQ. + + Reference: https://arxiv.org/abs/2210.17323 + """ + + def __init__( + self, + weight_bits: int, + group_size: int, + desc_act: bool, + lm_head_quantized: bool, + dynamic: Dict[str, Dict[str, Union[int, bool]]], + ) -> None: + # GPTQModel use `dynamic` config property to allow per module + # quantization config so each module can be individually optimized. + # Format is Dict[str, Dict] where key is a regex string that can + # perform both positive ("+:" prefixed) or negative ("-:" prefixed) + # matching of a module. + # Default to positive match, override base quant config mode, if no + # prefix is used. Value is in dict format of field key and override + # value. + # Negative matching will skip quantization init for this module + # entirely: + # non-quantized inference. More details and quantization examples can be + # found at: https://github.com/ModelCloud/GPTQModel + # Example: + # # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9 + # # last 1/4 of the layers 16-21 has 8bit and group_size 64 + # dynamic = { + # #`.*\.` matches the layers_node prefix + # # positive match layer 10-15 + # r"+:.*\.(?:1[0-5])\..*": {"bits": 8,}, + # # positive match layer 16-21 + # r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,}, + # r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers + # } + super().__init__() + self.dynamic = dynamic + + self.weight_bits = weight_bits + self.group_size = group_size + self.desc_act = desc_act + self.lm_head_quantized = lm_head_quantized + self.pack_factor = Fraction(32, self.weight_bits) + if self.weight_bits not in [2, 3, 4, 8]: + raise ValueError( + "Currently, only 2/3/4/8-bit weight quantization is " + f"supported for GPTQ, but got {self.weight_bits} bits.") + + def __repr__(self) -> str: + return (f"GPTQConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act}), " + f"lm_head_quantized={self.lm_head_quantized}), " + f"dynamic={self.dynamic}") + + @classmethod + def get_name(cls) -> str: + return "gptq" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 60 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig": + dynamic = cls.get_from_keys_or(config, ["dynamic"], default={}) + dynamic = {} if dynamic is None else dynamic + + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + desc_act = cls.get_from_keys(config, ["desc_act"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + return cls(weight_bits, group_size, desc_act, lm_head_quantized, + dynamic) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["GPTQLinearMethod"]: + return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod) + + +class ExllamaState(Enum): + + UNUSED = enum.auto() + UNINITIALIZED = enum.auto() + READY = enum.auto() + + +class GPTQLinearMethod(LinearMethodBase): + """Linear method for GPTQ. + + Args: + quant_config: The GPTQ quantization config. + """ + + def __init__(self, quant_config: GPTQConfig): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del output_size # Unused. + weight_loader = extra_weight_attrs.get("weight_loader") + if input_size_per_partition % self.quant_config.group_size != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + output_size_per_partition = sum(output_partition_sizes) + if (output_size_per_partition % self.quant_config.pack_factor.numerator + != 0): + raise ValueError( + "The output size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + exllama_state = ExllamaState.UNINITIALIZED + scale_and_zero_size = input_size // group_size + scale_and_zero_input_dim = None + if (input_size != input_size_per_partition + and self.quant_config.group_size != -1): + # For act-order models, we cannot use Exllama for row parallel layer + if self.quant_config.desc_act: + exllama_state = ExllamaState.UNUSED + else: + # we need to partition qzeros and scales for exllama kernel + scale_and_zero_size = input_size_per_partition // group_size + scale_and_zero_input_dim = 0 + + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.quant_config.pack_factor, + output_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=0, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader) + + g_idx = RowvLLMParameter(data=torch.tensor( + [ + i // self.quant_config.group_size + for i in range(input_size_per_partition) + ], + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader) + qzeros_args = { + "data": + torch.empty( + scale_and_zero_size, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + "weight_loader": + weight_loader + } + weight_scale_args = { + "data": + torch.empty( + scale_and_zero_size, + output_size_per_partition, + dtype=params_dtype, + ), + "weight_loader": + weight_loader + } + if scale_and_zero_input_dim is None: + scales = ChannelQuantScaleParameter(output_dim=1, + **weight_scale_args) + qzeros = PackedColumnParameter( + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args) + + else: + scales = GroupQuantScaleParameter(output_dim=1, + input_dim=0, + **weight_scale_args) + qzeros = PackedvLLMParameter( + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("g_idx", g_idx) + layer.register_parameter("qzeros", qzeros) + layer.register_parameter("scales", scales) + + layer.exllama_state = exllama_state + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # for torch.compile + layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False) + layer.qweight = Parameter(layer.qweight.data, requires_grad=False) + layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False) + layer.scales = Parameter(layer.scales.data, requires_grad=False) + + # exllama needs to shuffle the weight after the weight is loaded + # here we do the shuffle on first forward pass + if layer.exllama_state == ExllamaState.UNINITIALIZED: + if self.quant_config.desc_act: + layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int) + else: + layer.g_idx.data = torch.empty((0, ), + dtype=torch.int, + device=layer.g_idx.device) + layer.exllama_state = ExllamaState.READY + ops.gptq_shuffle(layer.qweight, layer.g_idx, + self.quant_config.weight_bits) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + out_shape = x.shape[:-1] + (layer.qweight.shape[-1], ) + reshaped_x = x.reshape(-1, x.shape[-1]) + + output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros, + layer.scales, layer.g_idx, + layer.exllama_state == ExllamaState.READY, + self.quant_config.weight_bits) + if bias is not None: + output.add_(bias) + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py new file mode 100644 index 0000000..0615bb4 --- /dev/null +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -0,0 +1,634 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Dict, List, Optional, Set, Union + +import torch + +import vllm.model_executor.layers.fused_moe # noqa +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.linear import (LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( + MPLinearLayerConfig, choose_mp_linear_kernel) +from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.gptq_utils import ( + get_linear_quant_method) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + check_marlin_supported, marlin_moe_permute_scales, + marlin_repeat_scales_on_all_ranks, verify_marlin_supported) +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + RowvLLMParameter) +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +logger = init_logger(__name__) + + +class GPTQMarlinConfig(QuantizationConfig): + """Config class for GPTQ Marlin""" + + # (num_bits, is_sym) -> quant_type + TYPE_MAP = { + (4, True): scalar_types.uint4b8, + (8, True): scalar_types.uint8b128, + } + + def __init__(self, weight_bits: int, group_size: int, desc_act: bool, + is_sym: bool, lm_head_quantized: bool, + dynamic: Dict[str, Dict[str, Union[int, bool]]], + full_config: Dict[str, Any]) -> None: + super().__init__() + if desc_act and group_size == -1: + # In this case, act_order == True is the same as act_order == False + # (since we have only one group per output channel) + desc_act = False + + # GPTQModel use `dynamic` config property to allow per module + # quantization config so each module can be individually optimized. + # Format is Dict[str, Dict] where key is a regex string that can + # perform both positive ("+:" prefixed) or negative ("-:" prefixed) + # matching of a module. + # Default to positive match, override base quant config mode, if no + # prefix is used. Value is in dict format of field key and override + # value. + # Negative matching will skip quantization init for this module + # entirely: + # non-quantized inference. More details and quantization examples can be + # found at: https://github.com/ModelCloud/GPTQModel + # Example: + # # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9 + # # last 1/4 of the layers 16-21 has 8bit and group_size 64 + # dynamic = { + # #`.*\.` matches the layers_node prefix + # # positive match layer 10-15 + # r"+:.*\.(?:1[0-5])\..*": {"bits": 8,}, + # # positive match layer 16-21 + # r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,}, + # r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers + # } + self.dynamic = dynamic + + self.weight_bits = weight_bits + self.is_sym = is_sym + + self.pack_factor = 32 // weight_bits # packed into int32 + self.group_size = group_size + self.desc_act = desc_act + self.lm_head_quantized = lm_head_quantized + self.full_config = full_config + + if (weight_bits, is_sym) not in self.TYPE_MAP: + raise ValueError("Unsupported quantization config: " + f"bits={weight_bits}, sym={is_sym}") + + self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)] + + def __repr__(self) -> str: + return (f"GPTQMarlinConfig(quant_type={self.quant_type}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act}, " + f"lm_head_quantized={self.lm_head_quantized}), " + f"dynamic={self.dynamic}") + + @classmethod + def get_name(cls) -> str: + return "gptq_marlin" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig": + dynamic = cls.get_from_keys_or(config, ["dynamic"], default={}) + dynamic = {} if dynamic is None else dynamic + + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + desc_act = cls.get_from_keys(config, ["desc_act"]) + is_sym = cls.get_from_keys(config, ["sym"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + return cls(weight_bits, group_size, desc_act, is_sym, + lm_head_quantized, dynamic, config) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg) + + is_valid_user_quant = (user_quant is None or user_quant == "marlin" + or user_quant == "gptq_marlin") + + if can_convert and is_valid_user_quant: + msg = ("The model is convertible to {} during runtime." + " Using {} kernel.".format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + if can_convert and user_quant == "gptq": + logger.info("Detected that the model can run with gptq_marlin" + ", however you specified quantization=gptq explicitly," + " so forcing gptq. Use quantization=gptq_marlin for" + " faster inference") + return None + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, FusedMoE): + if layer.local_num_experts > 32: + # For MoEs with many experts the moe_wna16 kernel is faster + return MoeWNA16Config.from_config( + self.full_config).get_quant_method(layer, prefix) + else: + return GPTQMarlinMoEMethod(self) + return get_linear_quant_method(self, layer, prefix, + GPTQMarlinLinearMethod) + + @classmethod + def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]): + quant_method = quant_config.get("quant_method", "").lower() + num_bits = quant_config.get("bits") + group_size = quant_config.get("group_size") + sym = quant_config.get("sym") + desc_act = quant_config.get("desc_act") + + if not current_platform.is_cuda(): + return False + + if quant_method != "gptq": + return False + + # Marlin conversion is only valid if required properties are found + if (num_bits is None or group_size is None or sym is None + or desc_act is None): + return False + + if (num_bits, sym) not in cls.TYPE_MAP: + return False + + return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)], + group_size=group_size) + + +class GPTQMarlinLinearMethod(LinearMethodBase): + """Linear method for GPTQ Marlin. + + Args: + quant_config: The GPTQ Marlin quantization config. + """ + + _kernel_backends_being_used: Set[str] = set() + + def __init__(self, quant_config: GPTQMarlinConfig) -> None: + self.quant_config = quant_config + + # Verify supported on platform. + verify_marlin_supported(quant_type=self.quant_config.quant_type, + group_size=self.quant_config.group_size) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + output_size_per_partition = sum(output_partition_sizes) + is_row_parallel = input_size != input_size_per_partition + weight_loader = extra_weight_attrs.get("weight_loader") + + mp_linear_kernel_config = MPLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=\ + (input_size_per_partition, output_size_per_partition), + weight_type=self.quant_config.quant_type, + act_type=params_dtype, + group_size=self.quant_config.group_size, + zero_points=False, + has_g_idx=self.quant_config.desc_act + ) + + kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) + + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for GPTQMarlinLinearMethod", + kernel_type.__name__) + self._kernel_backends_being_used.add(kernel_type.__name__) + + # Normalize group_size + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + + # Determine sharding + if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act, + self.quant_config.group_size, + is_row_parallel): + # By setting scale_dim == None, weight_loader will + # repeat the scales on each GPU in TP>1 case. + scales_and_zp_input_dim = None + scales_and_zp_size = input_size // group_size + else: + # By setting scale_dim == 0, weight_loader will + # shard the scales in TP>1 case. + scales_and_zp_input_dim = 0 + scales_and_zp_size = input_size_per_partition // group_size + + # Quantized weights + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.quant_config.pack_factor, + output_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=0, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader) + + # Activation order + g_idx = RowvLLMParameter(data=torch.empty( + input_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader) + + qzeros_args = { + "data": + torch.empty( + scales_and_zp_size, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + "weight_loader": + weight_loader + } + weight_scale_args = { + "data": + torch.empty( + scales_and_zp_size, + output_size_per_partition, + dtype=params_dtype, + ), + "weight_loader": + weight_loader + } + + if scales_and_zp_input_dim is None: + scales = ChannelQuantScaleParameter(output_dim=1, + **weight_scale_args) + qzeros = PackedColumnParameter( + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args) + + else: + scales = GroupQuantScaleParameter(output_dim=1, + input_dim=0, + **weight_scale_args) + qzeros = PackedvLLMParameter( + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("g_idx", g_idx) + layer.register_parameter("scales", scales) + layer.register_parameter("qzeros", qzeros) + + self.kernel = kernel_type(mp_linear_kernel_config, + w_q_param_name="qweight", + w_s_param_name="scales", + w_zp_param_name="qzeros", + w_gidx_param_name="g_idx") + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.kernel.apply_weights(layer, x, bias) + + +class GPTQMarlinMoEMethod(FusedMoEMethodBase): + """MoE Marlin method with quantization.""" + + def __init__(self, quant_config: GPTQMarlinConfig) -> None: + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + intermediate_size_full = extra_weight_attrs.pop( + "intermediate_size_full") + + self.is_k_full = (not self.quant_config.desc_act) or ( + intermediate_size_per_partition == intermediate_size_full) + + if self.quant_config.group_size != -1: + scales_size13 = hidden_size // self.quant_config.group_size + w2_scales_size = (intermediate_size_full + if self.quant_config.desc_act else + intermediate_size_per_partition) + scales_size2 = (w2_scales_size // self.quant_config.group_size) + strategy = FusedMoeWeightScaleSupported.GROUP.value + else: + scales_size13 = 1 + scales_size2 = 1 + strategy = FusedMoeWeightScaleSupported.CHANNEL.value + + extra_weight_attrs.update({ + "quant_method": strategy, + "is_transposed": True + }) + # Fused gate_up_proj (column parallel) + w13_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size // self.quant_config.pack_factor, + 2 * intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_qweight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + # down_proj (row parallel) + w2_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition // + self.quant_config.pack_factor, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_qweight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + # up_proj scales + w13_scales = torch.nn.Parameter( + torch.empty(num_experts, + scales_size13, + 2 * intermediate_size_per_partition, + dtype=torch.half), + requires_grad=False, + ) + layer.register_parameter("w13_scales", w13_scales) + set_weight_attrs(w13_scales, extra_weight_attrs) + # down_proj scales + w2_scales = torch.nn.Parameter( + torch.empty(num_experts, + scales_size2, + hidden_size, + dtype=torch.half), + requires_grad=False, + ) + layer.register_parameter("w2_scales", w2_scales) + set_weight_attrs(w2_scales, extra_weight_attrs) + # dont shard the w2 scales when running act order + set_weight_attrs(w2_scales, + {"load_full_w2": self.quant_config.desc_act}) + # up_proj scales + w13_qzeros = torch.nn.Parameter( + torch.empty(num_experts, + scales_size13, + 2 * intermediate_size_per_partition // + self.quant_config.pack_factor, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w13_qzeros", w13_qzeros) + set_weight_attrs(w13_qzeros, extra_weight_attrs) + # down_proj scales + w2_qzeros = torch.nn.Parameter( + torch.empty(num_experts, + scales_size2, + hidden_size // self.quant_config.pack_factor, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_qzeros", w2_qzeros) + set_weight_attrs(w2_qzeros, extra_weight_attrs) + # dont shard the w2 scales when running act order + set_weight_attrs(w2_qzeros, + {"load_full_w2": self.quant_config.desc_act}) + w13_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx", w13_g_idx) + set_weight_attrs(w13_g_idx, extra_weight_attrs) + w2_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx", w2_g_idx) + set_weight_attrs(w2_g_idx, extra_weight_attrs) + w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) + w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx_sort_indices", + w2_g_idx_sort_indices) + set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + + # Process act_order + if self.quant_config.desc_act: + # Get sorting based on g_idx + num_experts = layer.w13_g_idx.shape[0] + w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx) + w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx) + w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx) + w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx) + for e in range(num_experts): + w13_g_idx_sort_indices[e] = torch.argsort( + layer.w13_g_idx[e]).to(torch.int32) + w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to( + torch.int32) + w13_sorted_g_idx[e] = layer.w13_g_idx[e][ + w13_g_idx_sort_indices[e]] + w2_sorted_g_idx[e] = layer.w2_g_idx[e][ + w2_g_idx_sort_indices[e]] + replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx) + replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx) + replace_parameter(layer, "w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + replace_parameter(layer, "w2_g_idx_sort_indices", + w2_g_idx_sort_indices) + else: + # Reset g_idx related tensors + num_experts = layer.w13_g_idx.shape[0] + device = layer.w13_g_idx.device + layer.w13_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + # Repack weights + marlin_w13_qweight = ops.gptq_marlin_moe_repack( + layer.w13_qweight, + layer.w13_g_idx_sort_indices, + layer.w13_qweight.shape[1] * self.quant_config.pack_factor, + layer.w13_qweight.shape[2], + self.quant_config.quant_type.size_bits, + ) + replace_parameter(layer, "w13_qweight", marlin_w13_qweight) + marlin_w2_qweight = ops.gptq_marlin_moe_repack( + layer.w2_qweight, + layer.w2_g_idx_sort_indices, + layer.w2_qweight.shape[1] * self.quant_config.pack_factor, + layer.w2_qweight.shape[2], + self.quant_config.quant_type.size_bits, + ) + replace_parameter(layer, "w2_qweight", marlin_w2_qweight) + # Repack scales + marlin_w13_scales = marlin_moe_permute_scales( + s=layer.w13_scales, + size_k=layer.intermediate_size_per_partition, + size_n=layer.w13_scales.shape[2], + group_size=self.quant_config.group_size, + ) + replace_parameter(layer, "w13_scales", marlin_w13_scales) + marlin_w2_scales = marlin_moe_permute_scales( + s=layer.w2_scales, + size_k=layer.w2_scales.shape[1] * + (self.quant_config.group_size if self.quant_config.group_size != -1 + else self.quant_config.pack_factor), + size_n=layer.w2_scales.shape[2], + group_size=self.quant_config.group_size, + ) + replace_parameter(layer, "w2_scales", marlin_w2_scales) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + assert activation == "silu", "Only SiLU activation is supported." + if apply_router_weight_on_input is not None: + raise NotImplementedError( + "Apply router weight on input is not supported for" + "fused Marlin MoE method.") + + # The input must currently be float16 + orig_dtype = x.dtype + x = x.half() + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + return torch.ops.vllm.fused_marlin_moe( + x, + layer.w13_qweight, + layer.w2_qweight, + layer.w13_scales, + layer.w2_scales, + router_logits, + topk_weights, + topk_ids, + g_idx1=layer.w13_g_idx, + g_idx2=layer.w2_g_idx, + sort_indices1=layer.w13_g_idx_sort_indices, + sort_indices2=layer.w2_g_idx_sort_indices, + num_bits=self.quant_config.quant_type.size_bits, + is_k_full=self.is_k_full).to(orig_dtype) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin_24.py b/vllm/model_executor/layers/quantization/gptq_marlin_24.py new file mode 100644 index 0000000..dd747e1 --- /dev/null +++ b/vllm/model_executor/layers/quantization/gptq_marlin_24.py @@ -0,0 +1,295 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter) +from vllm.scalar_type import scalar_types + +logger = init_logger(__name__) + +GPTQ_MARLIN_24_TILE = 16 +GPTQ_MARLIN_24_MIN_THREAD_N = 128 +GPTQ_MARLIN_24_MIN_THREAD_K = 128 +GPTQ_MARLIN_24_MAX_PARALLEL = 64 + +GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [ + scalar_types.uint4b8, scalar_types.uint8b128 +] +GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] + + +class GPTQMarlin24Config(QuantizationConfig): + """Config class for Marlin24. + """ + + def __init__( + self, + weight_bits: int, + group_size: int, + ) -> None: + super().__init__() + quant_type = { + 4: scalar_types.uint4b8, + 8: scalar_types.uint8b128, + }.get(weight_bits) + + self.group_size = group_size + + # Verify + if quant_type is None or \ + quant_type not in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES: + raise ValueError( + f"Marlin_24 does not support quant_type = {quant_type}. " + f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES} " + "are supported.") + if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES: + raise ValueError( + f"Marlin_24 does not support group_size = {self.group_size}. " + f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} " + "are supported.") + + self.quant_type = quant_type + + # 4 Bits packed into 32 bit datatype. + self.pack_factor = 32 // self.quant_type.size_bits + + # Tile size used by marlin kernels. + self.tile_size = 16 + + # Min out_features dim + self.min_n_threads = GPTQ_MARLIN_24_MIN_THREAD_N + + # Min in_features dim + self.min_k_threads = GPTQ_MARLIN_24_MIN_THREAD_K + + # Max parallel problems to solve at once (improves large + # batch performance) + self.max_parallel = GPTQ_MARLIN_24_MAX_PARALLEL + + # Permutation length used by the marlin kernels. + self.perm_len = 1024 + + def __repr__(self) -> str: + return "Marlin24Config(quant_type={}, group_size={})".format( + self.quant_type, self.group_size) + + @classmethod + def get_name(cls) -> str: + return "gptq_marlin_24" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlin24Config": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + return cls(weight_bits, group_size) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + is_marlin_24_format = ( + hf_quant_cfg.get("checkpoint_format") == "marlin_24") + + is_valid_user_quant = (user_quant is None or user_quant == "gptq" + or user_quant == "gptq_marlin_24") + + if is_marlin_24_format and is_valid_user_quant: + msg = ("The model is serialized in {} format. " + "Using {} kernel.".format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + return None + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["GPTQMarlin24LinearMethod"]: + if isinstance(layer, LinearBase): + return GPTQMarlin24LinearMethod(self) + return None + + +class GPTQMarlin24LinearMethod(LinearMethodBase): + """Linear method for Marlin24. + + Args: + quant_config: The Marlin24 quantization config. + """ + + def __init__(self, quant_config: GPTQMarlin24Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del output_size # Unused. + weight_loader = extra_weight_attrs["weight_loader"] + if params_dtype != torch.float16: + raise ValueError( + f"The params dtype must be float16, but got {params_dtype}") + + # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) + if output_size_per_partition % self.quant_config.min_n_threads != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f"min_n_threads = {self.quant_config.min_n_threads}.") + if output_size_per_partition % self.quant_config.pack_factor != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f"pack_factor = {self.quant_config.pack_factor}.") + + # Validate input_size_per_partition + if input_size_per_partition % self.quant_config.min_k_threads != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"min_k_threads = {self.quant_config.min_k_threads}.") + if (self.quant_config.group_size != -1 and + input_size_per_partition % self.quant_config.group_size != 0): + raise ValueError(f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"group_size = {self.quant_config.group_size}.") + + # Check that we have at least 4 tiles horizontally in the shard + num_tiles_per_perm = self.quant_config.perm_len // ( + self.quant_config.tile_size**2) + if output_size_per_partition % num_tiles_per_perm != 0: + raise ValueError( + "Each permutation group must reside on the same gpu") + + # Quantized 4Bit weights packed into Int32. + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.quant_config.tile_size // 2, + output_size_per_partition * self.quant_config.tile_size // + self.quant_config.pack_factor, + device="cuda", + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + marlin_tile_size=self.quant_config.tile_size, + weight_loader=weight_loader) + + # Meta + meta = PackedvLLMParameter(data=torch.empty( + input_size_per_partition // 8 // 2 // 2, + output_size_per_partition * 2, + device="cuda", + dtype=torch.int16, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=1, + marlin_tile_size=2, + weight_loader=weight_loader) + + # Determine if channelwise or not + input_groups = (1 if self.quant_config.group_size == -1 else + input_size_per_partition // + self.quant_config.group_size) + + weight_scale_args = { + "data": + torch.empty( + input_groups, + output_size_per_partition, + device="cuda", + dtype=params_dtype, + ), + "weight_loader": + weight_loader + } + if input_groups == 1: + scales = ChannelQuantScaleParameter(output_dim=1, + **weight_scale_args) + else: + scales = GroupQuantScaleParameter(output_dim=1, + input_dim=0, + **weight_scale_args) + + # Allocate workspace (Used for internal locking mechanism) + max_workspace_size = ( + output_size_per_partition // + self.quant_config.min_n_threads) * self.quant_config.max_parallel + + workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size, + device="cuda", + dtype=torch.int), + weight_loader=weight_loader) + + layer.register_parameter("B_24", qweight) + layer.register_parameter("B_meta", meta) + layer.register_parameter("s", scales) + layer.register_parameter("workspace", workspace) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # required by torch.compile + layer.B_24 = Parameter(layer.B_24.data, requires_grad=False) + layer.s = Parameter(layer.s.data, requires_grad=False) + layer.B_meta = Parameter(layer.B_meta.data, requires_grad=False) + layer.workspace = Parameter(layer.workspace.data, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qweight = layer.B_24 + meta = layer.B_meta + scales = layer.s + workspace = layer.workspace + + x_2d = x.view(-1, x.shape[-1]) + + size_m = x_2d.shape[0] + size_k = x_2d.shape[1] + size_n = scales.shape[1] + + output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales, + workspace, + self.quant_config.quant_type, + size_m, size_n, size_k) + + output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) + + if bias is not None: + output.add_(bias) # In-place add + + return output diff --git a/vllm/model_executor/layers/quantization/hqq_marlin.py b/vllm/model_executor/layers/quantization/hqq_marlin.py new file mode 100644 index 0000000..4edc9aa --- /dev/null +++ b/vllm/model_executor/layers/quantization/hqq_marlin.py @@ -0,0 +1,328 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, + marlin_make_empty_g_idx, marlin_permute_scales) +from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( + MarlinWorkspace) +from vllm.model_executor.layers.quantization.utils.quant_utils import gptq_pack +from vllm.model_executor.parameter import (BasevLLMParameter, + GroupQuantScaleParameter, + PackedvLLMParameter) +from vllm.scalar_type import scalar_types + +logger = init_logger(__name__) + + +class HQQMarlinConfig(QuantizationConfig): + """Config class for HQQ Marlin""" + + def __init__( + self, + weight_bits: int, + group_size: int, + skip_modules: Optional[List[str]] = None, + ) -> None: + super().__init__() + assert group_size == 64, ("The only supported HQQ group size is " + "currently 64.") + assert weight_bits == 4, ("The only supported HQQ quantization " + "bitsize is currently 4.") + + self.weight_bits = weight_bits + self.group_size = group_size + self.pack_factor = 32 // weight_bits # packed into int32 in GPTQ format + self.quant_type = scalar_types.uint4 + self.skip_modules = skip_modules + + def __repr__(self) -> str: + return (f"HQQMarlinConfig(quant_type={self.quant_type}, " + f"group_size={self.group_size})") + + @classmethod + def get_name(cls) -> str: + return "hqq" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "HQQMarlinConfig": + wq_params = (config["quant_config"]["weight_quant_params"]) + weight_bits = cls.get_from_keys(wq_params, ["nbits"]) + group_size = cls.get_from_keys(wq_params, ["group_size"]) + skip_modules = config["skip_modules"] + return cls(weight_bits, group_size, skip_modules) + + def is_layer_skipped(self, prefix: str) -> bool: + # Split the prefix into its dot-separated components + components = prefix.split('.') + + # Check if any of the skip modules exactly matches any component + return self.skip_modules is not None and any( + module_name in components for module_name in self.skip_modules) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, LinearBase): + if self.is_layer_skipped(prefix): + return UnquantizedLinearMethod() + return HQQMarlinMethod(self) + return None + + +# Empty HQQ parameter, will be ignored during loading +class HQQEmptyParameter(BasevLLMParameter): + + def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): + pass + + def load_row_parallel_weight(self, loaded_weight: torch.Tensor): + pass + + def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): + pass + + +def error_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + raise ValueError("No loader provided for HQQ parameter!") + + +# HQQ packing creates issues with sharding - therefore, prior to loading, we +# repack to GPTQ. We also reshape the weights to their proper GPTQ shape. +class HQQweightParameter(PackedvLLMParameter): + + # unpack function from https://github.com/mobiusml/hqq + def unpack_4bit_u8(self, + W_q: torch.Tensor) -> torch.Tensor: # uint8/2 > uint8 + assert self.weight_bits == 4, "Unsupported quant bitsize (must be 4)" + + dtype = torch.uint8 + step = W_q.shape[0] + tmp = torch.empty([2 * step, W_q.shape[1]], + dtype=dtype, + device=W_q.device) + tmp[:step] = (W_q & 0b11110000) >> 4 + tmp[step:] = W_q & 0b00001111 + return tmp + + def __init__(self, packed_factor: int, packed_dim: int, weight_bits: int, + **kwargs): + super().__init__(packed_factor, packed_dim, None, **kwargs) + self.weight_bits = weight_bits + self.input_shape = self.shape[self.input_dim] * self.packed_factor + self.output_shape = self.shape[self.output_dim] + + def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): + loaded_weight = self.unpack_4bit_u8(loaded_weight) + loaded_weight = loaded_weight.reshape(-1, self.input_shape).transpose( + 1, 0) + loaded_weight = gptq_pack(loaded_weight, self.weight_bits, + loaded_weight.shape[0], + loaded_weight.shape[1]) + super().load_merged_column_weight(loaded_weight, **kwargs) + + def load_row_parallel_weight(self, loaded_weight: torch.Tensor): + loaded_weight = self.unpack_4bit_u8(loaded_weight) + loaded_weight = loaded_weight.reshape(self.output_shape, + -1).transpose(1, 0) + loaded_weight = gptq_pack(loaded_weight, self.weight_bits, + loaded_weight.shape[0], + loaded_weight.shape[1]) + super().load_row_parallel_weight(loaded_weight) + + def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): + loaded_weight = self.unpack_4bit_u8(loaded_weight) + loaded_weight = loaded_weight.reshape(-1, self.input_shape).transpose( + 1, 0) + loaded_weight = gptq_pack(loaded_weight, self.weight_bits, + loaded_weight.shape[0], + loaded_weight.shape[1]) + super().load_qkv_weight(loaded_weight, **kwargs) + + +# Zero points and scales in HQQ must also be reshaped to correspond to W_q's +# GPTQ shape (transposed - we transpose them too when processing weights). +class HQQZeroScaleParameter(GroupQuantScaleParameter): + + def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): + loaded_weight = loaded_weight.reshape(-1, self.shape[1]) + super().load_merged_column_weight(loaded_weight, **kwargs) + + def load_row_parallel_weight(self, loaded_weight: torch.Tensor): + loaded_weight = loaded_weight.reshape(self.shape[0], -1) + super().load_row_parallel_weight(loaded_weight) + + def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): + loaded_weight = loaded_weight.reshape(-1, self.shape[1]) + super().load_qkv_weight(loaded_weight, **kwargs) + + +class HQQMarlinMethod(LinearMethodBase): + """Linear method for HQQ Marlin. + """ + + def __init__( + self, + quant_config: HQQMarlinConfig, + ): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + self.output_size_per_partition = sum(output_partition_sizes) + self.input_size_per_partition = input_size_per_partition + + weight_loader = extra_weight_attrs.get("weight_loader", error_loader) + + self.scales_and_zp_size = (input_size_per_partition // + self.quant_config.group_size) + + qweight = HQQweightParameter( + data=torch.empty( + self.input_size_per_partition // self.quant_config.pack_factor, + self.output_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=0, + packed_factor=self.quant_config.pack_factor, + weight_bits=self.quant_config.weight_bits, + weight_loader=weight_loader) + + zeros = HQQZeroScaleParameter(data=torch.empty( + self.output_size_per_partition, + self.scales_and_zp_size, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + scales = HQQZeroScaleParameter(data=torch.empty( + self.output_size_per_partition, + self.scales_and_zp_size, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + layer.register_parameter("W_q", qweight) + layer.register_parameter("zero", zeros) + layer.register_parameter("scale", scales) + + # Ignore extra parameters in the HQQ model. + # To be added as needed. + ignore_parameters = ("axis", "channel_wise", "compute_dtype", + "encoded_state_dict", "group_size", "nbits", + "offload_meta", "optimize", "packing", + "quant_scale", "quant_zero", "round_zero", + "shape", "stores_quant_config", + "unpack_view_dtype", "view_as_float") + for name in ignore_parameters: + layer.register_parameter( + name, + HQQEmptyParameter(data=torch.empty(0), + weight_loader=weight_loader)) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + dev = layer.W_q.device + + # Repack to Marlin + sort_indices = torch.empty(0, dtype=torch.int, device=dev) + marlin_w_q = ops.gptq_marlin_repack( + layer.W_q, + sort_indices, + self.input_size_per_partition, + self.output_size_per_partition, + self.quant_config.weight_bits, + ).to(dev) + marlin_s = marlin_permute_scales(layer.scale.transpose(1, 0), + self.input_size_per_partition, + self.output_size_per_partition, + self.quant_config.group_size).to(dev) + marlin_zp = marlin_permute_scales(layer.zero.transpose(1, 0), + self.input_size_per_partition, + self.output_size_per_partition, + self.quant_config.group_size).to(dev) + + layer.g_idx = marlin_make_empty_g_idx(dev) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(dev) + + layer.marlin_qweight = marlin_w_q + layer.marlin_zeros = marlin_zp + layer.marlin_scales = marlin_s + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + workspace = MarlinWorkspace(self.output_size_per_partition, + GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MAX_PARALLEL) + + scales = layer.marlin_scales + zeros = layer.marlin_zeros + orig_type = x.dtype + + if orig_type != torch.float16: + x = x.to(torch.float16) + scales = scales.to(torch.float16) + zeros = zeros.to(torch.float16) + + marlin_out = ops.gptq_marlin_gemm( + x, + layer.marlin_qweight, + scales, + zeros, + layer.g_idx, + layer.g_idx_sort_indices, + workspace.scratch, + scalar_types.uint4, + x.shape[0], + self.output_size_per_partition, + self.input_size_per_partition, + True, # is_k_full + True, # has_zp + True, # use 32-bit reduce + True, # use float zp + ) + + if orig_type != torch.float16: + marlin_out = marlin_out.to(orig_type) + + if bias is not None: + marlin_out.add_(bias) + + return marlin_out diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py new file mode 100644 index 0000000..c09cc13 --- /dev/null +++ b/vllm/model_executor/layers/quantization/ipex_quant.py @@ -0,0 +1,250 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional + +import torch + +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod, + is_layer_skipped_awq) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod +from vllm.platforms import current_platform + +MIN_IPEX_VERSION = "2.5.0" + + +class IPEXConfig(QuantizationConfig): + """INT8 quantization config class using IPEX for the CPU/XPU backend, + including AWQ, GPTQ. + """ + + IPEX_QUANT_METHOD_MAP = { + "awq": 1, + "gptq": 0, + } + + def __init__( + self, + method: str, + weight_bits: int, + group_size: int, + modules_to_not_convert: Optional[List[str]] = None, + desc_act: Optional[bool] = None, + lm_head_quantized: Optional[bool] = None, + ) -> None: + super().__init__() + self.method = method + self.weight_bits = weight_bits + self.group_size = group_size + self.modules_to_not_convert = modules_to_not_convert or [] + self.desc_act = desc_act + self.lm_head_quantized = lm_head_quantized + self.pack_factor = 32 // self.weight_bits + + if self.weight_bits not in [4]: + raise ValueError(f"IPEX quantization supports weight bits [4], " + f"but got {self.weight_bits}.") + + if self.method not in ["awq", "gptq"]: + raise ValueError(f"IPEX quantization supports [awq, gptq], " + f"but got {self.method}.") + + def __repr__(self) -> str: + return (f"IPEXConfig(method={self.method}," + f"weight_bits={self.weight_bits}, " + f"group_size={self.group_size})") + + @classmethod + def get_name(cls) -> str: + return "ipex" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.float16] + + @classmethod + def get_min_capability(cls) -> int: + return -1 + + @staticmethod + def get_config_filenames() -> List[str]: + return [ + "quant_config.json", + "quantize_config.json", + ] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "IPEXConfig": + method = cls.get_from_keys(config, ["quant_method"]).lower() + if method == "awq": + weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) + group_size = cls.get_from_keys(config, + ["q_group_size", "group_size"]) + modules_to_not_convert = cls.get_from_keys_or( + config, ["modules_to_not_convert"], None) + return cls(method, weight_bits, group_size, modules_to_not_convert, + False, False) + # otherwise for gptq + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + desc_act = cls.get_from_keys_or(config, ["desc_act"], default=False) + return cls(method, weight_bits, group_size, [], desc_act, + lm_head_quantized) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + if not current_platform.is_cpu() and not current_platform.is_xpu(): + return None + + quant_method = hf_quant_cfg.get("quant_method", "").lower() + + if quant_method in ["awq", "gptq"]: + return cls.get_name() + + return None + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["LinearMethodBase"]: + if isinstance(layer, LinearBase): + if self.method == "awq": + if is_layer_skipped_awq(prefix, self.modules_to_not_convert): + return UnquantizedLinearMethod() + return IPEXAWQLinearMethod(self) + if self.method == "gptq": + return IPEXGPTQLinearMethod(self) + return None + + +class IPEXGPTQLinearMethod(GPTQLinearMethod): + """GPTQ linear method using IPEX for the CPU/XPU backend. + """ + + def __init__(self, quant_config: IPEXConfig): + self.quant_config = quant_config # type: ignore + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + bias = layer.bias if not layer.skip_bias_add else None + + try: + import intel_extension_for_pytorch as ipex + if ipex.__version__ < MIN_IPEX_VERSION: + raise ImportError( + "intel_extension_for_pytorch version is " + "wrong. Please install " + f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}.") + except ImportError as err: + raise ImportError( + "Please install " + f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via " + f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`" + " to use IPEX-AWQ linear method.") from err + # Using the compute dtype (lowp_mode) as INT8 to leverage instructions + # with better performance. + lowp_mode = ipex.quantization.WoqLowpMode.INT8 + # The weight will be de-packed from INT4 to INT8. + weight_dtype = ipex.quantization.WoqWeightDtype.INT4 + # The float activation will be quantized (dynamic, per-token) to INT8. + act_quant_mode = ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK + + qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping( + weight_dtype=weight_dtype, + lowp_mode=lowp_mode, + act_quant_mode=act_quant_mode, + group_size=self.quant_config.group_size, + ) + layer.ipex_output_size = layer.qweight.shape[-1] + g_idx = layer.g_idx if self.quant_config.desc_act else None + layer.ipex_qlinear = ipex.llm.quantization.woq_linear. \ + IPEXWeightOnlyQuantizedLinear.from_weight( + layer.qweight, + layer.scales, + layer.qzeros, + layer.qweight.size(0), + layer.ipex_output_size, + qconfig=qconfig, + g_idx=g_idx, + bias=bias, + group_size=self.quant_config.group_size, + quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["gptq"] + ) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + reshaped_x = x.reshape(-1, x.shape[-1]) + out = layer.ipex_qlinear(reshaped_x) + if bias is not None: + out.add_(bias) + return out.reshape(x.shape[:-1] + (layer.ipex_output_size, )) + + +class IPEXAWQLinearMethod(AWQLinearMethod): + """AWQ linear method using IPEX for the CPU/XPU backend. + """ + + def __init__(self, quant_config: IPEXConfig): + self.quant_config = quant_config # type: ignore + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + super().process_weights_after_loading(layer=layer) + + bias = layer.bias if not layer.skip_bias_add else None + + try: + import intel_extension_for_pytorch as ipex + if ipex.__version__ < MIN_IPEX_VERSION: + raise ImportError( + "intel_extension_for_pytorch version is " + "wrong. Please install " + f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}.") + except ImportError as err: + raise ImportError( + "Please install " + f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via " + f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`" + " to use IPEX-AWQ linear method.") from err + + # Using the compute dtype (lowp_mode) as INT8 to leverage instructions + # with better performance. + lowp_mode = ipex.quantization.WoqLowpMode.INT8 + # The weight will be de-packed from INT4 to INT8. + weight_dtype = ipex.quantization.WoqWeightDtype.INT4 + # The float activation will be quantized (dynamic, per-token) to INT8. + act_quant_mode = ipex.quantization.WoqActQuantMode.PER_BATCH + + qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping( + weight_dtype=weight_dtype, + lowp_mode=lowp_mode, + act_quant_mode=act_quant_mode, + group_size=self.quant_config.group_size, + ) + + layer.ipex_output_size = layer.qweight.size( + 1) * self.quant_config.pack_factor + layer.ipex_qlinear = ipex.llm.quantization.woq_linear. \ + IPEXWeightOnlyQuantizedLinear.from_weight( + layer.qweight, + layer.scales, + layer.qzeros, + layer.qweight.size(0), + layer.ipex_output_size, + qconfig=qconfig, + bias=bias, + group_size=self.quant_config.group_size, + quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["awq"] # type: ignore + ) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + reshaped_x = x.reshape(-1, x.shape[-1]) + out = layer.ipex_qlinear(reshaped_x) + return out.reshape(x.shape[:-1] + (layer.ipex_output_size, )) diff --git a/vllm/model_executor/layers/quantization/kernels/__init__.py b/vllm/model_executor/layers/quantization/kernels/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py new file mode 100644 index 0000000..c06befa --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py @@ -0,0 +1,89 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Callable, Optional, Tuple + +import torch + +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.scalar_type import ScalarType + + +@dataclass +class MPLinearLayerConfig: + full_weight_shape: Tuple[int, int] # [in, out] + partition_weight_shape: Tuple[int, int] + weight_type: ScalarType + act_type: torch.dtype + group_size: int + zero_points: bool + has_g_idx: bool + + +class MPLinearKernel(ABC): + + @classmethod + @abstractmethod + def get_min_capability(cls) -> int: + raise NotImplementedError + + @classmethod + @abstractmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + raise NotImplementedError + + def __init__(self, + c: MPLinearLayerConfig, + w_q_param_name: str, + w_s_param_name: str, + w_zp_param_name: Optional[str] = None, + w_gidx_param_name: Optional[str] = None) -> None: + assert self.can_implement(c) + self.config = c + self.w_q_name = w_q_param_name + self.w_s_name = w_s_param_name + if c.zero_points: + assert w_zp_param_name is not None + if c.has_g_idx: + assert w_gidx_param_name is not None + self.w_zp_name = w_zp_param_name + self.w_gidx_name = w_gidx_param_name + + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + raise NotImplementedError + + @abstractmethod + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + raise NotImplementedError + + def _transform_param(self, layer: torch.nn.Module, name: Optional[str], + fn: Callable) -> None: + if name is not None and getattr(layer, name, None) is not None: + + old_param = getattr(layer, name) + new_param = fn(old_param) + # replace the parameter with torch.nn.Parameter for TorchDynamo + # compatibility + replace_parameter( + layer, name, + torch.nn.Parameter(new_param.data, requires_grad=False)) + + def _get_weight_params( + self, layer: torch.nn.Module) -> Tuple[ + torch.Tensor, # w_q + torch.Tensor, # w_s + Optional[torch.Tensor], # w_zp, + Optional[torch.Tensor] # w_gidx + ]: + return ( + getattr(layer, self.w_q_name), + getattr(layer, self.w_s_name), + getattr(layer, self.w_zp_name or "", None), + getattr(layer, self.w_gidx_name or "", None), + ) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py new file mode 100644 index 0000000..520e1bc --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -0,0 +1,79 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional, Type + +import vllm.envs as envs +from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501 + AllSparkLinearKernel) +from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501 + ExllamaLinearKernel) +from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501 + MacheteLinearKernel) +from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import ( # noqa: E501 + MarlinLinearKernel) +from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKernel import ( # noqa: E501 + MPLinearKernel, MPLinearLayerConfig) +from vllm.platforms import current_platform + +# in priority/performance order (when available) +_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ + MacheteLinearKernel, + AllSparkLinearKernel, + MarlinLinearKernel, + ExllamaLinearKernel, +] + + +def choose_mp_linear_kernel( + config: MPLinearLayerConfig, + compute_capability: Optional[int] = None) -> Type[MPLinearKernel]: + """ + Choose an MPLinearKernel that can implement the given config for the given + compute capability. Attempts to choose the best kernel in terms of + performance. + + Args: + config (MPLinearLayerConfig): Description of the linear layer to be + implemented. + compute_capability (Optional[int], optional): The compute capability of + the target device, if None uses `current_platform` to get the compute + capability. Defaults to None. + + Raises: + ValueError: If no kernel can implement the given config. + + Returns: + Type[MPLinearKernel]: Chosen kernel. + """ + if compute_capability is None: + if current_platform is None: + raise ValueError("Cannot determine compute capability") + _cc = current_platform.get_device_capability() + compute_capability = _cc[0] * 10 + _cc[1] + + failure_reasons = [] + for kernel in _POSSIBLE_KERNELS: + if kernel.__name__ in envs.VLLM_DISABLED_KERNELS: + failure_reasons.append( + f' {kernel.__name__} disabled by environment variable') + continue + + if kernel.get_min_capability() > compute_capability: + failure_reasons.append( + f"{kernel.__name__} requires capability " + f"{kernel.get_min_capability()}, current compute capability " + f"is {compute_capability}") + continue + + can_implement, failure_reason = kernel.can_implement(config) + if can_implement: + return kernel + else: + failure_reasons.append( + f' {kernel.__name__} cannot implement due to: {failure_reason}' + ) + + raise ValueError( + "Failed to find a kernel that can implement the "\ + "WNA16 linear layer. Reasons: \n" + + '\n'.join(failure_reasons)) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py new file mode 100644 index 0000000..56fdd6a --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.allspark_utils import ( + ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, check_allspark_supported_dtype_shape) +from vllm.model_executor.parameter import (BasevLLMParameter, + permute_param_layout_) + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + + +class AllSparkLinearKernel(MPLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if c.has_g_idx: + return False, "Act reordering currently not supported by AllSpark" + + if c.zero_points: + return False, "Zero points currently not supported by AllSpark" + + return check_allspark_supported_dtype_shape( + c.partition_weight_shape[0], # in_features + c.partition_weight_shape[1], # out_features + c.group_size, + c.weight_type, + c.act_type) + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + device = getattr(layer, self.w_q_name).device + c = self.config + + # prepare the parameters required for the kernel + properties = torch.cuda.get_device_properties(device.index) + sm_count = properties.multi_processor_count + sm_version = properties.major * 10 + properties.minor + gemm_args = {} + gemm_args['sm_count'] = sm_count + gemm_args['sm_version'] = sm_version + + self.gemm_args = gemm_args + + # transform param weight, scale + old_weight_param = getattr(layer, self.w_q_name) + old_scale_param = getattr(layer, self.w_s_name) + + assert isinstance(old_weight_param, BasevLLMParameter) + permute_param_layout_(old_weight_param, + input_dim=0, + output_dim=1, + packed_dim=0) + + assert isinstance(old_scale_param, BasevLLMParameter) + permute_param_layout_(old_scale_param, input_dim=0, output_dim=1) + + # unpack weight from K / 4 x N int32 to K x N uint8 + new_weight_param = torch.nn.Parameter(old_weight_param.data, + requires_grad=False) + new_weight_param.data = new_weight_param.data.t().contiguous().view( + dtype=torch.uint8) + new_weight_param.data = new_weight_param.data.t().contiguous() + + new_scale_param = torch.nn.Parameter(old_scale_param.data, + requires_grad=False) + + # reorder K x N weight as N32K16 format for Ampere W8A16 + new_weight_param.data, new_scale_param.data, _ = \ + ops.allspark_repack_weight( + new_weight_param.data, new_scale_param.data, None, + c.zero_points) + + replace_parameter(layer, self.w_q_name, new_weight_param.data) + replace_parameter(layer, self.w_s_name, new_scale_param.data) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + gemm_args = self.gemm_args + w_q, w_s, _, _ = self._get_weight_params(layer) + + reshaped_x = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + + output = ops.allspark_w8a16_gemm( + a=reshaped_x, + b_qweight=w_q, + b_scales=w_s, + b_qzeros=None, + n=c.partition_weight_shape[1], + group_size=c.group_size, + sm_count=gemm_args['sm_count'], + sm_version=gemm_args['sm_version'], + CUBLAS_M_THRESHOLD=ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, + has_zp=c.zero_points, + n32k16_reorder=True) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py new file mode 100644 index 0000000..2706fbb --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py @@ -0,0 +1,142 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + pack_quantized_values_into_int32) +from vllm.model_executor.parameter import (BasevLLMParameter, + permute_param_layout_) +from vllm.scalar_type import scalar_types + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + + +class ExllamaLinearKernel(MPLinearKernel): + SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] + # In theory supports `scalar_types.uint2b2, scalar_types.uint3b4` too but + # currently untested so not added to the list + + @classmethod + def get_min_capability(cls) -> int: + return 60 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if c.has_g_idx and\ + c.partition_weight_shape[0] != c.full_weight_shape[0]: + return False, "Act reordering currently not supported by Exllama, "\ + "when the input features are partitioned across "\ + "devices" + + if c.partition_weight_shape[1] % (32 // c.weight_type.size_bits) != 0: + return False, "Output features must be a multiple of the pack " \ + "factor (32 / num_bits) so that we can correctly " \ + "pack the zero points" + + if c.act_type != torch.float16: + return False, "Exllama only supports float16 activations" + + if c.weight_type not in cls.SUPPORTED_QUANT_TYPES: + return False, f"Quant type ({c.weight_type}) not supported by "\ + "Exllama, supported types are: "\ + f"{cls.SUPPORTED_QUANT_TYPES}" + + if c.full_weight_shape[0] % c.group_size != 0: + return False, f"Group size ({c.group_size}) does not evenly divide"\ + " the number of input features "\ + f"({c.full_weight_shape[0]})" + + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module): + c = self.config + + # For Exllama, we need to set a zero-point tensor if there is not one + if not c.zero_points: + self.w_zp_name = "qzeros" + device = getattr(layer, self.w_q_name).device + groups = c.partition_weight_shape[0] // c.group_size + out_features = c.partition_weight_shape[1] + + if c.weight_type.has_bias(): + # if the type has a bias we have to create a zeros tensor that + # contains the bias values repeated for each group (-1 due to + # a bug in the original GPTQ checkpoint format leading to + # exllama kernel adding 1 to the zero points during inference) + # Documentation of the bug can be found here: + # https://garden.danieldk.eu/GPTQ-Checkpoint-Format + zeros = torch.full((groups, out_features), + c.weight_type.bias - 1, + dtype=torch.int32, + device=device) + else: + raise NotImplementedError( + "A 0 zero-point is not supported by Exllama due to " + "a bug in the original GPTQ checkpoint format leading to " + "exllama kernel adding 1 to the zero points during " + "inference") + zeros = pack_quantized_values_into_int32(zeros, + c.weight_type, + packed_dim=1) + setattr(layer, self.w_zp_name, + torch.nn.Parameter(zeros, requires_grad=False)) + + if c.has_g_idx: + + def transform_w_g_idx(x): + # Exllama wants the permutation array instead of the group + # indices + return torch.argsort(x).to(torch.int) + + self._transform_param(layer, self.w_gidx_name, transform_w_g_idx) + else: + self.w_gidx_name = "g_idx" + empty_g_idx = torch.nn.Parameter(torch.empty((0, ), + dtype=torch.int, + device=device), + requires_grad=False) + setattr(layer, self.w_gidx_name, empty_g_idx) + + def transform_w_q(x): + assert isinstance(x, BasevLLMParameter) + assert self.w_gidx_name is not None + g_idx = getattr(layer, self.w_gidx_name) + + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + x_cont = x.data.contiguous() + ops.gptq_shuffle(x_cont, g_idx, c.weight_type.size_bits) + return x_cont + + def transform_w_s(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = x.data.contiguous() + return x.to(dtype=c.act_type) + + # Repack weights and scales for Machete + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + + x_2d = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + + w_q, w_s, w_zp, w_g_idx = self._get_weight_params(layer) + + assert w_zp is not None, "Zero points are required by Exllama" + assert w_g_idx is not None, "Group index is required by Exllama" + output = ops.gptq_gemm(x_2d, w_q, w_zp, w_s, w_g_idx, True, + c.weight_type.size_bits) + + if bias is not None: + output.add_(bias) + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py new file mode 100644 index 0000000..3f0586f --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: Apache-2.0 + +from functools import partial +from typing import Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.machete_utils import ( + MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape, + query_machete_supported_quant_types) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + pack_quantized_values_into_int32, unpack_quantized_values_into_int32) +from vllm.model_executor.parameter import (BasevLLMParameter, + permute_param_layout_) + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + + +class MacheteLinearKernel(MPLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 90 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if c.has_g_idx and\ + c.partition_weight_shape[0] != c.full_weight_shape[0]: + return False, "Act reordering currently not supported by Machete, "\ + "when the input features are partitioned across "\ + "devices" + + if c.zero_points: + return False, "Zero points currently not supported by "\ + " Compressed Tensors + Machete. (Kernel supports it"\ + " but CompressedTensorsWNA16 does not so support has"\ + " not been added to MacheteWNA16Kernel yet" + + if c.weight_type not in query_machete_supported_quant_types( + c.zero_points): + return False, f"Quant type ({c.weight_type}) not supported by "\ + "Machete, supported types are: "\ + f"{query_machete_supported_quant_types(c.zero_points)}" + + if c.group_size not in MACHETE_SUPPORTED_GROUP_SIZES: + return False, f"Group size ({c.group_size}) not supported by "\ + "Machete, supported group sizes are: "\ + f"{MACHETE_SUPPORTED_GROUP_SIZES}" + + return check_machete_supports_shape(c.partition_weight_shape[0], + c.partition_weight_shape[1]) + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module): + c = self.config + + if c.has_g_idx: + assert self.w_gidx_name is not None + perm = torch.argsort(getattr(layer, self.w_gidx_name))\ + .to(torch.int) + + self.act_perm = lambda x: x[:, perm] + # use `ops.permute_cols` if possible + if c.act_type in [torch.float16, torch.bfloat16] \ + and c.partition_weight_shape[0] % 8 == 0: + self.act_perm = partial(ops.permute_cols, perm=perm) + + def transform_w_q(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + if c.has_g_idx: + x_unpacked = unpack_quantized_values_into_int32(x.data, + c.weight_type, + packed_dim=0) + x_perm = x_unpacked[perm, :] + x.data = pack_quantized_values_into_int32(x_perm, + c.weight_type, + packed_dim=0) + x.data = ops.machete_prepack_B(x.data.t().contiguous().t(), + a_type=c.act_type, + b_type=c.weight_type, + group_scales_type=c.act_type) + return x + + def transform_w_s(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = x.data.contiguous() + return x + + # Repack weights and scales for Machete + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + w_q, w_s, _, _ = self._get_weight_params(layer) + + x_2d = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + + if c.has_g_idx: + x_2d = self.act_perm(x_2d) + + output = ops.machete_mm(a=x_2d, + b_q=w_q, + b_type=c.weight_type, + b_group_zeros=None, + b_group_scales=w_s, + b_group_size=c.group_size) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py new file mode 100644 index 0000000..e21801c --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear, + check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx, + marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx, + query_marlin_supported_quant_types) +from vllm.model_executor.parameter import (BasevLLMParameter, + permute_param_layout_) + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + + +class MarlinLinearKernel(MPLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if c.zero_points: + return False, "Zero points currently not supported by "\ + " MarlinLinearKernel. Will be added when AWQMarlin "\ + "is migrated over to using MPLinearKernel backend" + + quant_types = query_marlin_supported_quant_types(c.zero_points) + if c.weight_type not in quant_types: + return False, f"Quant type ({c.weight_type}) not supported by"\ + f" Marlin, supported types are: {quant_types}" + + if c.group_size not in MARLIN_SUPPORTED_GROUP_SIZES: + return False, f"Group size ({c.group_size}) not supported by "\ + "Marlin, supported group sizes are: "\ + f"{MARLIN_SUPPORTED_GROUP_SIZES}" + + return check_marlin_supports_shape( + c.partition_weight_shape[1], # out_features + c.partition_weight_shape[0], # in_features + c.full_weight_shape[0], # in_features + c.group_size) + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + device = getattr(layer, self.w_q_name).device + c = self.config + + row_parallel = (c.partition_weight_shape[0] != c.full_weight_shape[0]) + self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel) + + # Allocate marlin workspace. + self.workspace = marlin_make_workspace(c.partition_weight_shape[1], + device) + + # Default names since marlin requires empty parameters for these, + # TODO: remove this requirement from marlin (allow optional tensors) + if self.w_gidx_name is None: + self.w_gidx_name = "g_idx" + if self.w_zp_name is None: + self.w_zp_name = "w_zp" + + if c.has_g_idx: + g_idx, g_idx_sort_indices = marlin_sort_g_idx( + getattr(layer, self.w_gidx_name)) + self._transform_param(layer, self.w_gidx_name, lambda _: g_idx) + layer.g_idx_sort_indices = g_idx_sort_indices + else: + setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device)) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + + if c.zero_points: + pass + # TODO (lucas): add the following when AWQMarlin is migrated over to + # using MPLinearKernel backend + # self._transform_param(layer, self.w_zp_name, lambda x: \ + # marlin_zero_points( + # x, + # size_k=c.partition_weight_shape[0], + # size_n=c.partition_weight_shape[1], + # num_bits=c.weight_type.size_bits)) + else: + setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device)) + + def transform_w_q(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + x.data = ops.gptq_marlin_repack(x.data.contiguous(), + perm=layer.g_idx_sort_indices, + size_k=c.partition_weight_shape[0], + size_n=c.partition_weight_shape[1], + num_bits=c.weight_type.size_bits) + return x + + def transform_w_s(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = marlin_permute_scales(x.data.contiguous(), + size_k=c.partition_weight_shape[0], + size_n=c.partition_weight_shape[1], + group_size=c.group_size) + return x + + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer) + + # `process_weights_after_loading` will ensure w_zp and w_gidx are not + # None for marlin + return apply_gptq_marlin_linear( + input=x, + weight=w_q, + weight_scale=w_s, + weight_zp=w_zp, # type: ignore + g_idx=w_gidx, # type: ignore + g_idx_sort_indices=layer.g_idx_sort_indices, + workspace=self.workspace, + wtype=c.weight_type, + input_size_per_partition=c.partition_weight_shape[0], + output_size_per_partition=c.partition_weight_shape[1], + is_k_full=self.is_k_full, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py new file mode 100644 index 0000000..91e7654 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch + + +@dataclass +class ScaledMMLinearLayerConfig: + is_channelwise: bool + is_static_input_scheme: bool + input_symmetric: bool + + +class ScaledMMLinearKernel(ABC): + + @classmethod + @abstractmethod + def get_min_capability(cls) -> int: + raise NotImplementedError + + @classmethod + @abstractmethod + def can_implement( + cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: + raise NotImplementedError + + def __init__(self, c: ScaledMMLinearLayerConfig, w_q_param_name: str, + w_s_param_name: str, i_s_param_name: str, + i_zp_param_name: str, azp_adj_param_name: str) -> None: + assert self.can_implement(c) + self.config = c + self.w_q_name = w_q_param_name + self.w_s_name = w_s_param_name + self.i_s_name = i_s_param_name + self.i_zp_name = i_zp_param_name + self.azp_adj_name = azp_adj_param_name + + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + raise NotImplementedError + + @abstractmethod + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + raise NotImplementedError + + def _get_weight_params( + self, layer: torch.nn.Module) -> Tuple[ + torch.Tensor, # weight + torch.Tensor, # weight_scale + Optional[torch.Tensor], # input_scale, + Optional[torch.Tensor], # input_zp + Optional[torch.Tensor], # azp_adj + ]: + return ( + getattr(layer, self.w_q_name), + getattr(layer, self.w_s_name), + getattr(layer, self.i_s_name), + getattr(layer, self.i_zp_name), + getattr(layer, self.azp_adj_name), + ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py new file mode 100644 index 0000000..bedda4c --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +from typing import Dict, List, Optional, Type + +from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import ( + AiterScaledMMLinearKernel) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( + CutlassScaledMMLinearKernel) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + ScaledMMLinearKernel, ScaledMMLinearLayerConfig) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import ( + TritonScaledMMLinearKernel) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import ( + XLAScaledMMLinearKernel) +from vllm.platforms import PlatformEnum, current_platform + +# in priority/performance order (when available) +_POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = { + PlatformEnum.CPU: [CutlassScaledMMLinearKernel], + PlatformEnum.CUDA: [CutlassScaledMMLinearKernel], + PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel], + PlatformEnum.TPU: [XLAScaledMMLinearKernel], +} + + +def choose_scaled_mm_linear_kernel( + config: ScaledMMLinearLayerConfig, + compute_capability: Optional[int] = None +) -> Type[ScaledMMLinearKernel]: + """ + Choose an ScalledMMLinearKernel that can implement the given config for the + given compute capability. Attempts to choose the best kernel in terms of + performance. + + Args: + config (ScaledMMLinearLayerConfig): Description of the linear layer + to be implemented. + compute_capability (Optional[int], optional): The compute capability of + the target device, if None uses `current_platform` to get the + compute capability. Defaults to None. + + Raises: + ValueError: If no kernel can implement the given config. + + Returns: + Type[ScaledMMLinearKernel]: Chosen kernel. + """ + + if compute_capability is None: + _cc = current_platform.get_device_capability() + if _cc is not None: + compute_capability = _cc[0] * 10 + _cc[1] + + failure_reasons = [] + for kernel in _POSSIBLE_KERNELS[current_platform._enum]: + if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\ + .split(","): + failure_reasons.append( + f' {kernel.__name__} disabled by environment variable') + continue + + # If the current platform uses compute_capability, + # make sure the kernel supports the compute cability. + if compute_capability is not None: + kernel_min_capability = kernel.get_min_capability() + if (kernel_min_capability is not None + and kernel_min_capability > compute_capability): + failure_reasons.append( + f"{kernel.__name__} requires capability " + f"{kernel_min_capability}, current compute capability " + f"is {compute_capability}") + continue + + can_implement, failure_reason = kernel.can_implement(config) + if can_implement: + return kernel + else: + failure_reasons.append( + f' {kernel.__name__} cannot implement due to: {failure_reason}' + ) + + raise ValueError( + "Failed to find a kernel that can implement the "\ + "ScaledMM linear layer. Reasons: \n" + + '\n'.join(failure_reasons)) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py new file mode 100644 index 0000000..582b12f --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -0,0 +1,119 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple + +import torch + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.platforms import current_platform + +from .cutlass import CutlassScaledMMLinearKernel +from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig + + +class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 90 + + @classmethod + def can_implement( + cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if not current_platform.is_rocm(): + return ( + False, + "AiterScaledMMLinearKernel requires `aiter` which is not " + + "currently supported on non-ROCm platform.") + + try: + import aiter # noqa: F401 # deliberately attempt to import aiter + except Exception: + return ( + False, + "AiterScaledMMLinearKernel requires `aiter` which is not " + + "installed on ROCm.") + # Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled + if not ( + envs.VLLM_ROCM_USE_AITER_LINEAR \ + and envs.VLLM_ROCM_USE_AITER + ): + return (False, "AiterScaledMMLinearKernel is disabled. " + + "Enable by setting `VLLM_ROCM_USE_AITER=1` " + + "and `VLLM_ROCM_USE_AITER_LINEAR=1`. " + + "`VLLM_ROCM_USE_AITER_LINEAR` default is True.") + + if not c.input_symmetric: + return (False, + "AiterScaledMMLinearKernel only supports symmetric " + + "quantization.") + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + super().process_weights_after_loading(layer) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + `AiterScaledMMLinearKernel` implements a fused version of + `output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)` + where scale_a * a and scale_b * b are implemented using numpy-style + broadcasting. + Currently only support per-tensor-per-tensor GEMM + and per-token-per-channel GEMM through AITER + w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support + ATIER block scaled GEMM and mix-precision GEMM. + """ + w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) + + # ops.scaled_int8_quant supports both dynamic and static quant: + # * dynamic, i_s is None and x_s computed from x. + # * static, i_s is scalar and x_s is i_s. + symmetric = azp_adj is None + assert symmetric, ("AiterScaledMMLinearKernel only supports" + " symmetric quantization.") + x_q, x_s, x_zp = ops.scaled_int8_quant(x, + i_s, + i_zp, + symmetric=symmetric) + + assert x_zp is None, ("AiterScaledMMLinearKernel only supports" + " symmetric quantization.") + out_dtype = x.dtype + + assert (w_q.shape[0] % 16 == 0 and w_q.shape[1] % 16 == 0) + assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) + assert bias is None or bias.shape[0] == w_q.shape[ + 1] and bias.dtype == out_dtype + + m = x_q.shape[0] # a + n = w_q.shape[1] # b + + per_tensor_scale_a = (x_s.numel() == 1) + per_tensor_scale_b = (w_s.numel() == 1) + per_token_scale_a = (x_s.numel() == m) + per_channel_scale_b = (w_s.numel() == n) + + # @TODO: + # Maybe broadcast the per-tensor-scale into per-channel-scale + # if one of the scale is a per-channel-scale. + # For now, it only supports: + # - per-tensor-per-tensor a8w8 scaled GEMM, and + # - per-token-per-channel a8w8 scaled GEMM + assert ((per_tensor_scale_a and per_tensor_scale_b) + or (per_token_scale_a and per_channel_scale_b)), ( + "Currently only support per-tensor-per-tensor GEMM " + + " and per-token-per-channel GEMM through AITER" + " w8a8 scaled gemm. `AiterScaledMMLinearKernel` " + + "does not support AITER block scaled GEMM.") + + from aiter import gemm_a8w8_CK + + # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects + # a to be [M, K] + # b to be [N, K] + # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format + return gemm_a8w8_CK(x_q, w_q.t(), x_s, w_s, bias).to(out_dtype) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py new file mode 100644 index 0000000..14afa3d --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -0,0 +1,153 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + convert_to_channelwise) +from vllm.platforms import current_platform + +from .ScaledMMLinearKernel import (ScaledMMLinearKernel, + ScaledMMLinearLayerConfig) + + +class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 75 + + @classmethod + def can_implement( + cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: + + if (not current_platform.is_cuda() and not current_platform.is_cpu()): + return False, "CutlassScaledMM requires running on CUDA or CPU." + + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # WEIGHT + # Cutlass kernels need transposed weight. + weight = getattr(layer, self.w_q_name) + self.format = "TN" #默认weight都是按T排布 + m, k = weight.shape + if(m % 64 == 0 and k % 64 == 0): + self.format= "NN" + replace_parameter( + layer, self.w_q_name, + torch.nn.Parameter(weight.t().data.contiguous(), requires_grad=False))#原始排布是T[m,k] 处理完后是N[k, m] + else: + if k % 64 != 0: + pad_k = (k // 64 + 1) * 64 + weight_pad = torch.empty((m, pad_k), dtype=weight.dtype, device=weight.device) + _weight = weight_pad[:, :k] + _weight.copy_(weight) + weight = _weight + replace_parameter( + layer, self.w_q_name, + torch.nn.Parameter(weight.t(), requires_grad=False)) + + + + # WEIGHT SCALE + # Cutlass kernels support only per-tensor and per-channel. + # If we have a fused module (QKV, MLP) with per tensor scales (thus N + # scales being passed to the kernel), convert to the per-channel case. + is_fused_module = len(layer.logical_widths) > 1 + weight_scale = getattr(layer, self.w_s_name) + if is_fused_module and not self.config.is_channelwise: + weight_scale = convert_to_channelwise(weight_scale, + layer.logical_widths) + replace_parameter( + layer, self.w_s_name, + torch.nn.Parameter(weight_scale.data, requires_grad=False)) + + # INPUT SCALE + if self.config.is_static_input_scheme: + input_scale = getattr(layer, self.i_s_name) + + if self.config.input_symmetric: + replace_parameter( + layer, self.i_s_name, + torch.nn.Parameter(input_scale.max(), requires_grad=False)) + setattr(layer, self.i_zp_name, None) + else: + input_zero_point = getattr(layer, self.i_zp_name) + + # reconstruct the ranges + int8_traits = torch.iinfo(torch.int8) + azps = input_zero_point.to(dtype=torch.int32) + range_max = (input_scale * (int8_traits.max - azps)).max() + range_min = (input_scale * (int8_traits.min - azps)).min() + + scale = (range_max - range_min) / (int8_traits.max - + int8_traits.min) + replace_parameter( + layer, self.i_s_name, + torch.nn.Parameter(scale, requires_grad=False)) + + # AZP loaded as int8 but used as int32 + azp = (int8_traits.min - + range_min / scale).to(dtype=torch.int32) + replace_parameter(layer, self.i_zp_name, + torch.nn.Parameter(azp, requires_grad=False)) + + else: + setattr(layer, self.i_s_name, None) + setattr(layer, self.i_zp_name, None) + + # azp_adj is the AZP adjustment term, used to account for weights. + # It does not depend on scales or azp, so it is the same for + # static and dynamic quantization. + # For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md + # https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md + if not self.config.input_symmetric: + weight = getattr(layer, self.w_q_name) + azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32) + if self.config.is_static_input_scheme: + # cutlass_w8a8 requires azp to be folded into azp_adj + # in the per-tensor case + azp_adj = getattr(layer, self.i_zp_name) * azp_adj + setattr(layer, self.azp_adj_name, + torch.nn.Parameter(azp_adj, requires_grad=False)) + else: + setattr(layer, self.azp_adj_name, None) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) + + # ops.scaled_int8_quant supports both dynamic and static quant: + # * dynamic, i_s is None and x_s computed from x. + # * static, i_s is scalar and x_s is i_s. + symmetric = azp_adj is None + x_q, x_s, x_zp = ops.scaled_int8_quant(x, + i_s, + i_zp, + symmetric=symmetric) + + if x_zp is not None: + # Currently, static is always per-tensor and dynamic is per-token + static = i_zp is not None + azp = None if static else x_zp + return ops.cutlass_scaled_mm_azp(x_q, + w_q, + scale_a=x_s, + scale_b=w_s, + out_dtype=x.dtype, + azp_adj=azp_adj, + azp=azp, + bias=bias) + return ops.cutlass_scaled_mm(x_q, + w_q, + scale_a=x_s, + scale_b=w_s, + out_dtype=x.dtype, + bias=bias, + format=self.format) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py new file mode 100644 index 0000000..5da5df8 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple + +import torch + +from vllm.platforms import current_platform + +from .cutlass import CutlassScaledMMLinearKernel +from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig + + +class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 75 + + @classmethod + def can_implement( + cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if current_platform.is_cpu(): + return ( + False, + "TritonScaledMMLinearKernel requires Triton which is not " + + "currently supported on CPU.") + if not c.input_symmetric: + return (False, + "TritonScaledMMLinearKernel only supports symmetric " + + "quantization.") + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + super().process_weights_after_loading(layer) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return super().apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py new file mode 100644 index 0000000..0893140 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 + +import warnings +from typing import Optional, Tuple + +import torch +from functorch.experimental.control_flow import cond # noqa: F401 + +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + convert_to_channelwise) +from vllm.platforms import current_platform + +from .ScaledMMLinearKernel import (ScaledMMLinearKernel, + ScaledMMLinearLayerConfig) + + +class XLAScaledMMLinearKernel(ScaledMMLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + raise NotImplementedError( + "TPU platform does have a concept of compute capability, " + "this method should not be called.") + + @classmethod + def can_implement( + cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: + + if not current_platform.is_tpu(): + return False, "ScaledMMXLA requires running on TPU." + + if c.is_static_input_scheme: + return False, "ScaledMMXLA requires dynamic activation scales." + + if not c.input_symmetric: + return False, "ScaledMMXLA requires symmetric activation scales." + + if not c.is_channelwise: + return False, "ScaledMMXLA requires channelwise weight scales" + + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # WEIGHT + # [out, in] (different than cutlass_scaled_mm) + weight = getattr(layer, self.w_q_name) + replace_parameter(layer, self.w_q_name, + torch.nn.Parameter(weight.data, requires_grad=False)) + + # WEIGHT SCALE + # XLA kernels support only per-tensor and per-channel. + # If we have a fused module (QKV, MLP) with per tensor scales (thus N + # scales being passed to the kernel), convert to the per-channel case. + is_fused_module = len(layer.logical_widths) > 1 + weight_scale = getattr(layer, self.w_s_name) + if is_fused_module and not self.config.is_channelwise: + weight_scale = convert_to_channelwise(weight_scale, + layer.logical_widths) + + # [out_channel,] (different than cutlass_scaled_mm) + weight_scale = weight_scale.squeeze(-1) + replace_parameter( + layer, self.w_s_name, + torch.nn.Parameter(weight_scale.data, requires_grad=False)) + + # Only support symmetric dynamic activation quantization. + setattr(layer, self.i_s_name, None) + setattr(layer, self.i_zp_name, None) + setattr(layer, self.azp_adj_name, None) + + # Filter warning for cond usage in apply_weights. It is okay + # to specialize the graph since bias is not dynamic. + warnings.filterwarnings( + "ignore", + message= + "Pred is a Python constant. When used with torch.cond, it specializes on one of the branches." # noqa: E501 + ) + + def no_add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]): + return x + + def add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]): + return x + bias + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + w_q, w_s, _, _, _ = self._get_weight_params(layer) + + import torch_xla.experimental.xla_quantized_matmul # noqa: F401 + out = torch.ops.xla.quantized_matmul(x, + w_q, + w_s, + zero_point=None, + block_size=-1, + int4_weight=False, + quantize_activation=True) + # `quantized_matmul` output is fp32, cast it down to bf16 for perf + out = out.to(x.dtype) + # Explicitly capture control flow to make dynamo happy. + # https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501 + return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias]) diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py new file mode 100644 index 0000000..5d766c2 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.platforms import current_platform + +logger = init_logger(__name__) + + +class BaseKVCacheMethod(QuantizeMethodBase): + """ + Quant method that adds `_k_scale` and `_v_scale` attributes to the + Attention layer to support loading those scaling factors from checkpoints. + The k/v_scale will be used to: + - quantize k/v_cache entries before saving them to the cache + - dequantize k/v_cache entries before fetching them from the cache + + :param quant_config: the appropriate QuantizationConfig + """ + + def __init__(self, quant_config: QuantizationConfig): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module): + """ + Create "weight" (aka q_scale, k_scale and v_scale) + for an attention layer. + """ + # Initialize the Q and KV cache scales to -1.0, an invalid value. + # If the q and k/v_scales appear in the checkpoint, it will be + # overwritten when loading weights. + layer.q_scale = torch.nn.Parameter(torch.tensor(-1.0), + requires_grad=False) + layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), + requires_grad=False) + layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), + requires_grad=False) + + def apply(self, layer: torch.nn.Module) -> torch.Tensor: + raise RuntimeError( + f"{self.__class__.__name__}.apply should not be called.") + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 + # regardless whether the kv-scale is available in the checkpoint. + # No need to process kv scales after loading if we are going to + # calculate them on the fly. + if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales: + if layer.k_scale > 0.0 and layer.v_scale > 0.0: + # We prefer to use separate k_scale and v_scale if present + k_scale = layer.k_scale.to("cpu").tolist() + v_scale = layer.v_scale.to("cpu").tolist() + if current_platform.is_fp8_fnuz(): + k_scale *= 2 + v_scale *= 2 + elif layer.k_scale < 0.0 and layer.v_scale < 0.0: + # If no scales were loaded (both scales are invalid negative + # values), use the default value of 1.0 + k_scale = 1.0 + v_scale = 1.0 + else: + # If we find a single kv_scale in the checkpoint, we remap + # kv_scale to k_scale during weight loading, and duplicate + # k_scale to v_scale here + assert layer.k_scale > 0.0 + scale_to_duplicate = max(layer.k_scale, layer.v_scale) + k_scale = scale_to_duplicate.to("cpu").tolist() + v_scale = scale_to_duplicate.to("cpu").tolist() + if current_platform.is_fp8_fnuz(): + k_scale *= 2 + v_scale *= 2 + + if not isinstance(k_scale, float) or not isinstance( + v_scale, float): + raise ValueError("Only support per-tensor scaling factor " + "for fp8 KV cache") + + if layer.q_scale < 0.0: + logger.warning_once( + "Checkpoint does not provide a q scaling factor. " + "Setting it to k_scale. This only matters for " + "the flash-attn backend.") + layer._q_scale.copy_(k_scale) + + # These are used in the final Attention.forward() + layer._k_scale.copy_(k_scale) + layer._v_scale.copy_(v_scale) + layer._k_scale_float = k_scale + layer._v_scale_float = v_scale + if (k_scale == 1.0 and v_scale == 1.0 + and "e5m2" not in layer.kv_cache_dtype): + logger.warning_once( + "Using KV cache scaling factor 1.0 for fp8_e4m3. This " + "may cause accuracy issues. Please make sure k/v_scale " + "scaling factors are available in the fp8 checkpoint.") + + del layer.k_scale + del layer.v_scale diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py new file mode 100644 index 0000000..4cf0c67 --- /dev/null +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -0,0 +1,259 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter) + +logger = init_logger(__name__) + + +class MarlinConfig(QuantizationConfig): + """Config class for Marlin. + + Reference: https://github.com/IST-DASLab/marlin/tree/master + """ + + def __init__( + self, + group_size: int, + lm_head_quantized: bool, + ) -> None: + # Group size for the quantization. + self.group_size = group_size + self.lm_head_quantized = lm_head_quantized + if self.group_size != 128 and self.group_size != -1: + raise ValueError( + "Currently, only group size 128 and -1 (channelwise) " + "is supported for Marlin, but got group_size of " + f"{self.group_size}") + + # 4 Bits packed into 32 bit datatype. + self.pack_factor = 32 // 4 + + # Tile size used by marlin kernels. + self.tile_size = 16 + + # Min out_features dim + self.min_n_threads = 64 + + # Min in_features dim + self.min_k_threads = 128 + + # Max parallel problems to solve at once (improves large + # batch performance) + self.max_parallel = 16 + + # Permutation length used by the marlin kernels. + self.perm_len = 1024 + + def __repr__(self) -> str: + return (f"MarlinConfig(group_size={self.group_size}, " + f"lm_head_quantized={self.lm_head_quantized})") + + @classmethod + def get_name(cls) -> str: + return "marlin" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig": + group_size = cls.get_from_keys(config, ["group_size"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + return cls(group_size, lm_head_quantized) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + # compat: autogptq >=0.8.0 use checkpoint_format: str + # compat: autogptq <=0.7.1 is_marlin_format: bool + is_marlin_format = (hf_quant_cfg.get("checkpoint_format") == "marlin" + or hf_quant_cfg.get("is_marlin_format", False)) + + is_valid_user_quant = (user_quant is None or user_quant == "gptq" + or user_quant == "marlin") + + if is_marlin_format and is_valid_user_quant: + msg = ("The model is serialized in {} format. Using {} kernel.". + format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + return None + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["MarlinLinearMethod"]: + if (isinstance(layer, LinearBase) or + (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): + return MarlinLinearMethod(self) + return None + + +class MarlinLinearMethod(LinearMethodBase): + """Linear method for Marlin. + + Args: + quant_config: The Marlin quantization config. + """ + + def __init__(self, quant_config: MarlinConfig): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del output_size # Unused. + weight_loader = extra_weight_attrs["weight_loader"] + + if params_dtype != torch.float16: + raise ValueError( + f"The params dtype must be float16, but got {params_dtype}") + + # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) + if output_size_per_partition % self.quant_config.min_n_threads != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f"min_n_threads = {self.quant_config.min_n_threads}.") + if output_size_per_partition % self.quant_config.pack_factor != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f"pack_factor = {self.quant_config.pack_factor}.") + + # Validate input_size_per_partition + if input_size_per_partition % self.quant_config.min_k_threads != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"min_k_threads = {self.quant_config.min_k_threads}.") + if (self.quant_config.group_size != -1 and + input_size_per_partition % self.quant_config.group_size != 0): + raise ValueError(f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"group_size = {self.quant_config.group_size}.") + + # Check that we have at least 4 tiles horizontally in the shard + num_tiles_per_perm = self.quant_config.perm_len // ( + self.quant_config.tile_size**2) + if output_size_per_partition % num_tiles_per_perm != 0: + raise ValueError( + "Each permutation group must reside on the same gpu") + + # Quantized 4Bit weights packed into Int32. + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.quant_config.tile_size, + output_size_per_partition * self.quant_config.tile_size // + self.quant_config.pack_factor, + device="cuda", + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + marlin_tile_size=self.quant_config.tile_size, + weight_loader=weight_loader) + + # Determine if channelwise or not + input_groups = (1 if self.quant_config.group_size == -1 else + input_size_per_partition // + self.quant_config.group_size) + + weight_scale_args = { + "data": + torch.empty( + input_groups, + output_size_per_partition, + device="cuda", + dtype=params_dtype, + ), + "weight_loader": + weight_loader + } + if input_groups == 1: + scales = ChannelQuantScaleParameter(output_dim=1, + **weight_scale_args) + else: + scales = GroupQuantScaleParameter(output_dim=1, + input_dim=0, + **weight_scale_args) + + # Allocate workspace (Used for internal locking mechanism) + max_workspace_size = ( + output_size_per_partition // + self.quant_config.min_n_threads) * self.quant_config.max_parallel + + workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size, + device="cuda", + dtype=torch.int), + weight_loader=weight_loader) + + layer.register_parameter("B", qweight) + layer.register_parameter("s", scales) + layer.register_parameter("workspace", workspace) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # required by torch.compile + layer.B = Parameter(layer.B.data, requires_grad=False) + layer.s = Parameter(layer.s.data, requires_grad=False) + layer.workspace = Parameter(layer.workspace.data, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qweight = layer.B + scales = layer.s + workspace = layer.workspace + + x_2d = x.view(-1, x.shape[-1]) + + size_m = x_2d.shape[0] + size_k = x_2d.shape[1] + size_n = scales.shape[1] + + output_2d = ops.marlin_gemm(x_2d, qweight, scales, workspace, size_m, + size_n, size_k) + + output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) + + if bias is not None: + output.add_(bias) # In-place add + + return output diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py new file mode 100644 index 0000000..3de1536 --- /dev/null +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -0,0 +1,410 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional, Union + +import torch +from torch.nn import Module +from torch.nn.parameter import Parameter + +from vllm._custom_ops import (cutlass_scaled_fp4_mm, + cutlass_scaled_mm_supports_fp4, scaled_fp4_quant) +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + Fp8LinearOp, requantize_with_max_scale) +from vllm.model_executor.parameter import (ModelWeightParameter, + PerTensorScaleParameter) +from vllm.platforms import current_platform + +logger = init_logger(__name__) + +QUANT_ALGOS = ["FP8", "NVFP4"] +KV_CACHE_QUANT_ALGOS = ["FP8"] + + +class ModelOptFp8Config(QuantizationConfig): + """Config class for ModelOpt FP8.""" + + def __init__( + self, + is_checkpoint_fp8_serialized: bool = False, + ) -> None: + super().__init__() + self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized + if is_checkpoint_fp8_serialized: + logger.warning("Detected ModelOpt fp8 checkpoint. Please note that" + " the format is experimental and could change.") + + @classmethod + def get_name(cls) -> str: + return "modelopt" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 89 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["hf_quant_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config": + quant_config = cls.get_from_keys(config, ["quantization"]) + quant_method = quant_config["quant_algo"] + if quant_method not in QUANT_ALGOS: + raise ValueError(f"ModelOpt currently only supports: {QUANT_ALGOS}" + " quantizations in vLLM. Please check the " + "`hf_quant_config.json` file for your model's " + "quant configuration.") + is_checkpoint_fp8_serialized = ("FP8" in quant_method) + + return cls(is_checkpoint_fp8_serialized) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + if isinstance(layer, LinearBase): + return ModelOptFp8LinearMethod(self) + elif isinstance(layer, Attention): + return ModelOptFp8KVCacheMethod(self) + return None + + +class ModelOptFp8LinearMethod(LinearMethodBase): + """Linear method for Model Optimizer static quantization. + Supports loading FP8 checkpoints with static weight scale and + activation scale. Future support might be added for dynamic + scales. + + Limitations: + 1. Only support per-tensor quantization due to torch._scaled_mm support. + 2. Only support float8_e4m3fn datatype + Args: quant_config: The ModelOpt quantization config. + """ + + def __init__(self, quant_config: ModelOptFp8Config): + self.quant_config = quant_config + self.fp8_linear = Fp8LinearOp() + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + weight_dtype = (torch.float8_e4m3fn + if self.quant_config.is_checkpoint_fp8_serialized else + params_dtype) + weight = ModelWeightParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=weight_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight", weight) + + if self.quant_config.is_checkpoint_fp8_serialized: + # WEIGHT SCALE + weight_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + weight_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", weight_scale) + # INPUT SCALE + scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + + scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("input_scale", scale) + + def process_weights_after_loading(self, layer: Module) -> None: + weight = layer.weight + max_w_scale = layer.weight_scale.max() + if not (layer.weight_scale == layer.weight_scale[0]).all(): + max_w_scale, weight = requantize_with_max_scale( + layer.weight, layer.weight_scale, layer.logical_widths) + layer.weight = Parameter(weight.t(), requires_grad=False) + layer.weight_scale = Parameter(max_w_scale, requires_grad=False) + layer.input_scale = Parameter(layer.input_scale.max(), + requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.fp8_linear.apply(input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + bias=bias) + + +class ModelOptNvFp4Config(QuantizationConfig): + """Config class for ModelOpt FP4.""" + + def __init__( + self, + is_checkpoint_nvfp4_serialized: bool, + kv_cache_quant_algo: str, + exclude_modules: List[str], + group_size: int = 16, + ) -> None: + self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized + if is_checkpoint_nvfp4_serialized: + logger.warning( + "Detected ModelOpt NVFP4 checkpoint. Please note that" + " the format is experimental and could change in future.") + + self.group_size = group_size + self.kv_cache_quant_algo = kv_cache_quant_algo + self.exclude_modules = exclude_modules + + @classmethod + def get_name(cls) -> str: + return "modelopt_nvfp4" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half, torch.float8_e4m3fn] + + @classmethod + def get_min_capability(cls) -> int: + return 100 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["hf_quant_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "ModelOptNvFp4Config": + quant_config = cls.get_from_keys(config, ["quantization"]) + quant_method = quant_config["quant_algo"] + if quant_method not in QUANT_ALGOS: + raise ValueError(f"ModelOpt currently only supports: {QUANT_ALGOS}" + " quantizations in vLLM. Please check the " + "`hf_quant_config.json` file for your model's " + "quant configuration.") + is_checkpoint_nvfp4_serialized = ("NVFP4" in quant_method) + kv_cache_quant_algo = quant_config["kv_cache_quant_algo"] + group_size = quant_config["group_size"] + exclude_modules = quant_config["exclude_modules"] + if not (group_size and kv_cache_quant_algo and exclude_modules): + raise ValueError("NVFP4 quantization requires group size and " + "kv_cache_quant_algo specified in " + "hf_quant_config.json") + return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo, + exclude_modules, group_size) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + if isinstance(layer, LinearBase): + if is_layer_skipped(prefix, self.exclude_modules): + return UnquantizedLinearMethod() + return ModelOptNvFp4LinearMethod(self) + elif isinstance(layer, Attention): + return ModelOptFp8KVCacheMethod(self) + return None + + +def cutlass_fp4_supported() -> bool: + if not current_platform.is_cuda(): + return False + capability_tuple = current_platform.get_device_capability() + capability = -1 if capability_tuple is None else capability_tuple.to_int() + return cutlass_scaled_mm_supports_fp4(capability) + + +class ModelOptFp8KVCacheMethod(BaseKVCacheMethod): + """ + Supports loading kv-cache scaling factors from FP8 checkpoints. + """ + + def __init__(self, quant_config: Union[ModelOptFp8Config, + ModelOptNvFp4Config]): + super().__init__(quant_config) + + +class ModelOptNvFp4LinearMethod(LinearMethodBase): + """Linear method for Model Optimizer NVFP4. + Supports loading NVFP4 checkpoints with the following structure: + + input_scale: torch.float32, scalar , + weight: NVFP4(represented as byte) Shape: [1, X, y/2] + weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale, + weight_scale_2: torch.float32, scalar, + Args: quant_config: The ModelOpt quantization config. + """ + + def __init__(self, quant_config: ModelOptNvFp4Config): + self.quant_config = quant_config + self.cutlass_nvfp4_supported = cutlass_fp4_supported() + if not self.cutlass_nvfp4_supported: + raise ValueError("Current platform does not support NVFP4" + " quantization. Please use Blackwell and above.") + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + if not self.quant_config.is_checkpoint_nvfp4_serialized: + raise ValueError("NVFP4 quantization was selected, " + " dynamic quantization is not supported.") + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + + if (input_size_per_partition % 16 != 0): + raise ValueError("Unsupported model when in features size is " + "not multiple of 16") + # The nvfp4 weight is still represented as + weight_dtype = (torch.float8_e4m3fn + if self.quant_config.is_checkpoint_nvfp4_serialized + else params_dtype) + # Weight + weight = ModelWeightParameter( + data=torch.empty( + # 2 fp4 items are packed in the input dimension + layer.output_size_per_partition, + layer.input_size_per_partition // 2, + dtype=torch.uint8), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight", weight) + + # Input Weight Scale + input_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("input_scale", input_scale) + + # Global Weight Scale + weight_scale_2 = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("weight_scale_2", weight_scale_2) + + # Per Block Weight Scale + weight_scale = ModelWeightParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition // self.quant_config.group_size, + dtype=weight_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + layer.register_parameter("weight_scale", weight_scale) + + def swizzle_blockscale(self, scale: torch.tensor): + assert (scale.dtype == torch.float8_e4m3fn) + # Pad and blockwise interleave weight_scale + scale_ndim = scale.ndim + if scale.ndim == 2: + scale = scale.unsqueeze(0) + assert scale.ndim == 3 + B, M, K = scale.shape + round_up_multiple = lambda x, m: (x + m - 1) // m * m + M_padded = round_up_multiple(M, 128) + K_padded = round_up_multiple(K, 4) + padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype) + padded_scale[:B, :M, :K] = scale + batches, rows, cols = padded_scale.shape + assert rows % 128 == 0 + assert cols % 4 == 0 + padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, + cols // 4, 4) + swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5)) + swizzled_scale = swizzled_scale.contiguous().cuda() + return (swizzled_scale.reshape(M, K) + if scale_ndim == 2 else swizzled_scale.reshape(B, M, K)) + + def process_weights_after_loading(self, layer: Module) -> None: + + # global scales: + input_scale_2 = layer.input_scale.max().to(torch.float32) + layer.input_scale = Parameter(input_scale_2, requires_grad=False) + + weight_scale_2 = layer.weight_scale_2.max().to(torch.float32) + layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False) + + layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2, + requires_grad=False) + + # Swizzle the weight blockscale. + # contracting dimension is input dimension + # block_size = 16; + assert (layer.weight_scale.shape[1] % 16 == 0), ( + "Expected weight_scale.dim(1) to be divisible by 16") + assert (layer.weight_scale.dtype == torch.float8_e4m3fn), ( + "Weight Block scale must be represented as FP8-E4M3") + swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale) + + layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, + requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + output_dtype = x.dtype + + # for input only the contracting dimension has a constraint. + x_m, _ = x.shape + w_n, _ = layer.weight.shape + output_shape = [x_m, w_n] + + # quantize BF16 or FP16 to (FP4 and interleaved block scale) + s_quant = 1 / layer.input_scale + x_fp4, x_blockscale = scaled_fp4_quant(x, s_quant) + + # validate dtypes of quantized input, input block scale, + # weight and weight_blockscale + assert (x_fp4.dtype == torch.uint8) + assert (layer.weight.dtype == torch.uint8) + assert (x_blockscale.dtype == torch.float8_e4m3fn) + assert (layer.weight_scale_swizzled.dtype == torch.float8_e4m3fn) + assert (layer.alpha.dtype == torch.float32) + + out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale, + layer.weight_scale_swizzled, layer.alpha, + output_dtype) + if bias is not None: + out = out + bias + return out.view(*output_shape) diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py new file mode 100644 index 0000000..00c4b66 --- /dev/null +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -0,0 +1,447 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Dict, List, Optional + +import torch + +from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + check_marlin_supports_layer) +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform + + +class MoeWNA16Config(QuantizationConfig): + """Config class for MOE WNA16 (W8A16/W4A16) quantization.""" + + def __init__(self, linear_quant_method: str, weight_bits: int, + group_size: int, has_zp: bool, lm_head_quantized: bool, + modules_to_not_convert: Optional[List[str]], + full_config: Dict[str, Any]) -> None: + super().__init__() + self.weight_bits = weight_bits + self.group_size = group_size + self.has_zp = has_zp + self.bit8_pack_factor = 8 // self.weight_bits + self.lm_head_quantized = lm_head_quantized + self.linear_quant_method = linear_quant_method + self.full_config = full_config + self.use_marlin = False + # Avoid circular import + from vllm.model_executor.layers.quantization.awq import AWQConfig + from vllm.model_executor.layers.quantization.awq_marlin import ( + AWQMarlinConfig) + from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig) + if self.linear_quant_method == "gptq": + self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible( + full_config) + elif self.linear_quant_method == "awq": + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) + awq_min_capability = AWQConfig.get_min_capability() + if device_capability < awq_min_capability: + raise ValueError( + "The quantization method moe_wna16 + awq is not supported " + "for the current GPU. " + f"Minimum capability: {awq_min_capability}. " + f"Current capability: {device_capability}.") + self.use_marlin = AWQMarlinConfig.is_awq_marlin_compatible( + full_config) + else: + raise ValueError("moe_wna16 only support gptq and awq.") + + if modules_to_not_convert is None: + self.modules_to_not_convert = [] + else: + self.modules_to_not_convert = modules_to_not_convert + + @classmethod + def get_name(cls) -> str: + return "moe_wna16" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "MoeWNA16Config": + linear_quant_method = cls.get_from_keys(config, ["quant_method"]) + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + if linear_quant_method == "gptq": + has_zp = not cls.get_from_keys(config, ["sym"]) + modules_to_not_convert = [] + elif linear_quant_method == "awq": + has_zp = cls.get_from_keys(config, ["zero_point"]) + modules_to_not_convert = cls.get_from_keys_or( + config, ["modules_to_not_convert"], None) + else: + raise ValueError("moe_wna16 only support gptq and awq.") + + return cls(linear_quant_method, weight_bits, group_size, has_zp, + lm_head_quantized, modules_to_not_convert, config) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg) + if can_convert and user_quant == "moe_wna16": + return cls.get_name() + return None + + @classmethod + def is_moe_wna16_compatible(cls, quant_config: Dict[str, Any]): + # Extract data from quant config. + quant_method = quant_config.get("quant_method", "").lower() + num_bits = quant_config.get("bits") + desc_act = quant_config.get("desc_act") + + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) + # Avoid circular import + from vllm.model_executor.layers.quantization.awq import AWQConfig + awq_min_capability = AWQConfig.get_min_capability() + + gptq_compatible = quant_method == "gptq" and \ + not desc_act and num_bits in [4, 8] + awq_compatible = quant_method == "awq" and num_bits == 4 and \ + device_capability >= awq_min_capability + + return gptq_compatible or awq_compatible + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + if is_layer_skipped_quant(prefix, self.modules_to_not_convert): + return UnquantizedLinearMethod() + elif isinstance(layer, LinearBase): + # Avoid circular import + from vllm.model_executor.layers.quantization.awq import AWQConfig + from vllm.model_executor.layers.quantization.awq_marlin import ( + AWQMarlinConfig) + from vllm.model_executor.layers.quantization.gptq import GPTQConfig + from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig) + if self.linear_quant_method == "gptq": + if self.use_marlin: + return GPTQMarlinConfig.from_config( + self.full_config).get_quant_method(layer, prefix) + else: + return GPTQConfig.from_config( + self.full_config).get_quant_method(layer, prefix) + elif self.linear_quant_method == "awq": + if self.use_marlin and check_marlin_supports_layer( + layer, self.group_size): + return AWQMarlinConfig.from_config( + self.full_config).get_quant_method(layer, prefix) + else: + return AWQConfig.from_config( + self.full_config).get_quant_method(layer, prefix) + else: + raise ValueError("moe_wna16 only support gptq and awq.") + elif isinstance(layer, FusedMoE): + return MoeWNA16Method(self) + return None + + +def is_layer_skipped_quant(prefix: str, modules_to_not_convert: List[str]): + return any(module_name in prefix for module_name in modules_to_not_convert) + + +class MoeWNA16Method(FusedMoEMethodBase): + """Linear method for MOE WNA16 (W8A16/W4A16) quantization. + + Args: + quant_config: The MOE WNA16 (W8A16/W4A16) quantization config. + """ + + def __init__(self, quant_config: MoeWNA16Config): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + layer.quant_config = self.quant_config + bit8_pack_factor = self.quant_config.bit8_pack_factor + group_size = self.quant_config.group_size + group_size_div_factor = 1 + + # make intermediate_size and hidden_size diviable by group_size + # we reduce the group size to ensure that + # and we would repeat the loaded_weight later + while intermediate_size_per_partition % group_size or \ + hidden_size % group_size: + group_size = group_size // 2 + group_size_div_factor *= 2 + assert group_size >= 32 + layer.group_size = group_size + layer.group_size_div_factor = group_size_div_factor + + strategy = FusedMoeWeightScaleSupported.GROUP.value + extra_weight_attrs.update({ + "quant_method": strategy, + "is_transposed": False + }) + + assert 'weight_loader' in extra_weight_attrs + weight_loader = extra_weight_attrs['weight_loader'] + wrapped_weight_loader = MoeWNA16Method.get_weight_loader( + layer, weight_loader) + extra_weight_attrs['weight_loader'] = wrapped_weight_loader + + # Fused gate_up_proj (column parallel) + w13_qweight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // bit8_pack_factor, + dtype=torch.uint8), + requires_grad=False) + layer.register_parameter("w13_qweight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + + # down_proj (row parallel) + w2_qweight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // bit8_pack_factor, + dtype=torch.uint8), + requires_grad=False) + layer.register_parameter("w2_qweight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + + w13_scales = torch.nn.Parameter(torch.zeros( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // group_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_scales", w13_scales) + set_weight_attrs(w13_scales, extra_weight_attrs) + + w2_scales = torch.nn.Parameter(torch.zeros( + num_experts, + hidden_size, + intermediate_size_per_partition // group_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_scales", w2_scales) + set_weight_attrs(w2_scales, extra_weight_attrs) + + if self.quant_config.has_zp: + w13_qzeros = torch.nn.Parameter(torch.zeros( + num_experts, + 2 * intermediate_size_per_partition // bit8_pack_factor, + hidden_size // group_size, + dtype=torch.uint8), + requires_grad=False) + layer.register_parameter("w13_qzeros", w13_qzeros) + set_weight_attrs(w13_qzeros, extra_weight_attrs) + + w2_qzeros = torch.nn.Parameter(torch.zeros( + num_experts, + hidden_size // bit8_pack_factor, + intermediate_size_per_partition // group_size, + dtype=torch.uint8), + requires_grad=False) + layer.register_parameter("w2_qzeros", w2_qzeros) + set_weight_attrs(w2_qzeros, extra_weight_attrs) + + if self.quant_config.linear_quant_method == "gptq": + # some param are unused, but we need to init them in order to + # load weights + invalid_param_keys = ["w13_g_idx", "w2_g_idx"] + if not self.quant_config.has_zp: + invalid_param_keys += ["w13_qzeros", "w2_qzeros"] + for key in invalid_param_keys: + param = torch.nn.Parameter(torch.empty((0, ), + dtype=torch.int32), + requires_grad=False) + layer.register_parameter(key, param) + set_weight_attrs(param, extra_weight_attrs) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe import fused_experts + assert activation == "silu", "Only SiLU activation is supported." + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + weight_bits = self.quant_config.weight_bits + has_zp = self.quant_config.has_zp + + return fused_experts( + x, + layer.w13_qweight, + layer.w2_qweight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_int4_w4a16=weight_bits == 4, + use_int8_w8a16=weight_bits == 8, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + w1_scale=layer.w13_scales, + w2_scale=layer.w2_scales, + w1_zp=layer.w13_qzeros if has_zp else None, + w2_zp=layer.w2_qzeros if has_zp else None, + block_shape=[0, layer.group_size]) + + @staticmethod + def get_weight_loader(layer, weight_loader): + + def convert_awq_tensor(tensor, tensor_type): + # convert awq qweight/qzeros to a standard format (assume int4) + # qweight: (k, n // pack_factor_bit32) -> (n, k // pack_factor_bit8) + # qzeros: (k // group_size, n // pack_factor_bit32) -> + # (n // pack_factor_bit8, k // group_size) + # pack_factor_bit32 = 32 // weight_bits + # pack_factor_bit8 = 8 // weight_bits + + # 0. suppose origin shape (a, b), dtype int32 + # 1. convert to uint8, shape (a, b) -> (a, 4 * b) + size0 = tensor.size(0) + tensor = tensor.view(torch.uint8) + + # 2. unpack to uint4 (only when weight_bits == 4) + # shape (a, 4 * b) -> (a, 4 * b, 2) + shifter = torch.tensor([0, 4], + dtype=torch.uint8, + device=tensor.device) + tensor = (tensor[:, :, None] >> shifter) & 0xF + + # 3. change order, see + # https://github.com/casper-hansen/AutoAWQ/blob/v0.2.8/awq/utils/quant_utils.py + # shape -> (a, 4 * b * pack_factor_bit8) + reverse_awq_pack_order = [0, 4, 1, 5, 2, 6, 3, 7] + tensor = tensor.view(-1, 8)[:, reverse_awq_pack_order] + tensor = tensor.view(size0, -1) + + # 4. transpose, shape -> (4 * b * pack_factor_bit8, a) + tensor = tensor.T.contiguous() + + # 5. repack (only when weight_bits == 4) + # qweight shape -> (4 * b * pack_factor_bit8, a // pack_factor_bit8) + # qzeros shape -> (4 * b, a) + + if tensor_type == "qweight": + tensor = tensor[:, 1::2] * 16 + tensor[:, ::2] + elif tensor_type == "qzeros": + tensor = tensor[1::2, :] * 16 + tensor[::2, :] + return tensor + + def convert_gptq_int4_qzeros(tensor): + tensor = tensor.view(torch.uint8) + shifter = torch.tensor([0, 4], + dtype=torch.uint8, + device=tensor.device) + tensor = (tensor[:, :, None] >> shifter) & 0xF + tensor = tensor + 1 + tensor = tensor[:, :, 0] + tensor[:, :, 1] * 16 + return tensor + + def moe_wna16_weight_loader(param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, shard_id: str, + expert_id: int): + if "g_idx" in weight_name: + return + if not layer.quant_config.has_zp and "qzeros" in weight_name: + return + + device = get_tp_group().device + tp_rank = get_tensor_model_parallel_rank() + loaded_weight = loaded_weight.to(device) + shard_size = layer.intermediate_size_per_partition + + # convert gptq and awq weight to a standard format + if layer.quant_config.linear_quant_method == "awq": + assert layer.quant_config.weight_bits == 4 + if "weight" in weight_name: + loaded_weight = convert_awq_tensor(loaded_weight, + "qweight") + elif "zeros" in weight_name: + loaded_weight = convert_awq_tensor(loaded_weight, "qzeros") + else: + loaded_weight = loaded_weight.T + elif layer.quant_config.linear_quant_method == "gptq": + assert layer.quant_config.weight_bits in [4, 8] + if "weight" in weight_name: + loaded_weight = loaded_weight.T.contiguous().view( + torch.uint8) + elif "zeros" in weight_name: + # add 1 to gptq qzeros to align with awq + loaded_weight = loaded_weight.view(torch.uint8) + if layer.quant_config.weight_bits == 4: + loaded_weight = convert_gptq_int4_qzeros( + loaded_weight).T + else: + loaded_weight = loaded_weight.T + 1 + else: + loaded_weight = loaded_weight.T + + # repeat the qzeros/scales to fit new group size + if layer.group_size_div_factor > 1 and \ + "qzeros" in weight_name or "scales" in weight_name: + loaded_weight = loaded_weight.repeat_interleave( + layer.group_size_div_factor, 1) + + if "w13_qzeros" in weight_name: + tensor = loaded_weight.view(layer.tp_size, -1, + loaded_weight.size(1))[tp_rank] + if shard_id == "w1": + param.data[expert_id, :shard_size // 2] = tensor + else: + param.data[expert_id, shard_size // 2:] = tensor + elif "w2_qzeros" in weight_name: + param.data[expert_id] = loaded_weight.view( + loaded_weight.size(0), layer.tp_size, -1)[:, tp_rank] + else: + weight_loader(param, loaded_weight, weight_name, shard_id, + expert_id) + + return moe_wna16_weight_loader diff --git a/vllm/model_executor/layers/quantization/neuron_quant.py b/vllm/model_executor/layers/quantization/neuron_quant.py new file mode 100644 index 0000000..f6f6680 --- /dev/null +++ b/vllm/model_executor/layers/quantization/neuron_quant.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +from importlib.util import find_spec +from typing import Any, Dict, List, Optional + +from torch.nn import Module + +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + +SUPPORTED_QUANT_DTYPE_LIST = ['s8', 'f8e4m3fn'] + + +class NeuronQuantConfig(QuantizationConfig): + """Int8 Quantization Config class for Neuron Backend.""" + + def __init__( + self, + dequant_dtype: str = "f16", + quantize_method: str = "vector_dynamic", + ) -> None: + super().__init__() + self.quant_dtype = os.getenv("NEURON_QUANT_DTYPE", "s8") + if self.quant_dtype not in SUPPORTED_QUANT_DTYPE_LIST: + raise ValueError( + f"Neuron quantization datatype {self.quant_dtype} is not valid," + f" the quantization datatype should match one of the below " + f"types {SUPPORTED_QUANT_DTYPE_LIST}") + self.dequant_dtype = dequant_dtype + self.quantize_method = quantize_method + + def get_name(self) -> str: + return "neuron_quant" + + def get_supported_act_dtypes(self) -> List[str]: + return SUPPORTED_QUANT_DTYPE_LIST + + @classmethod + def get_min_capability(cls) -> int: + raise NotImplementedError( + "This function should not be called with Neuron Backend") + + @staticmethod + def get_config_filenames() -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "NeuronQuantConfig": + quantize_method = cls.get_from_keys(config, ["quantize_method"]) + dequant_dtype = cls.get_from_keys(config, ["dequant_dtype"]) + return cls(dequant_dtype=dequant_dtype, + quantize_method=quantize_method) + + def get_quant_method(self, layer: Module, prefix: str) -> Optional[Any]: + if find_spec("transformers_neuronx") is not None: + return self.get_quantization_config() + else: + raise NotImplementedError( + "Neuron Quantization is only supported through" + " transformers_neuronx.") + + def get_quantization_config(self): + from transformers_neuronx.config import QuantizationConfig + return QuantizationConfig(quant_dtype=self.quant_dtype, + dequant_dtype=self.dequant_dtype, + quantize_method=self.quantize_method) diff --git a/vllm/model_executor/layers/quantization/ptpc_fp8.py b/vllm/model_executor/layers/quantization/ptpc_fp8.py new file mode 100644 index 0000000..592ffc5 --- /dev/null +++ b/vllm/model_executor/layers/quantization/ptpc_fp8.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizeMethodBase) +from vllm.model_executor.layers.quantization.fp8 import (Fp8Config, + Fp8KVCacheMethod, + Fp8LinearMethod) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + Fp8LinearOp) +from vllm.platforms import current_platform + +ACTIVATION_SCHEMES = ["static", "dynamic"] + +logger = init_logger(__name__) + + +class PTPCFp8Config(Fp8Config): + """Config class for Per-Token-Per-Channel Dynamic Quantization Fp8.""" + + def __init__( + self, + activation_scheme: str = "dynamic", + ignored_layers: Optional[List[str]] = None, + ) -> None: + if not current_platform.is_rocm(): + raise ValueError( + "ptpc_fp8 quantization is supported only on ROCm.") + + if not current_platform.has_device_capability(94): + raise ValueError( + "ptpc_fp8 quantization is supported only on AMD Instinct MI300 GPUs and newer." # noqa: E501 + ) + if activation_scheme == "static": + raise ValueError( + "ptpc_fp8 as of now only support dynamic quantization.") + + super().__init__(is_checkpoint_fp8_serialized=False, + activation_scheme=activation_scheme, + ignored_layers=ignored_layers) + + @classmethod + def get_name(cls) -> str: + return "ptpc_fp8" + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "PTPCFp8Config": + activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) + return cls(activation_scheme=activation_scheme, + ignored_layers=ignored_layers) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + + if isinstance(layer, LinearBase): + if is_layer_skipped(prefix, self.ignored_layers): + return UnquantizedLinearMethod() + return PTPCFp8LinearMethod(self) + elif isinstance(layer, Attention): + return Fp8KVCacheMethod(self) + return None + + +class PTPCFp8LinearMethod(Fp8LinearMethod): + """Linear method for Per-Token and Per-Channel FP8 Quantization. + Only supports loading quantized BF16 model checkpoints with dynamic + activation scaling. To load FP16 model checkpoints, user must specify + to convert the FP16 model weight loading into BF16. + The weight scaling factor will be initialized after + the model weights are loaded. + + Limitations: + 1. Only support float8_e4m3fnuz data type due to the limitation of + torch._scaled_mm (https://github.com/ROCm/pytorch/blob/8c0504d7f3fb0ee4c278c096a5c3caedb01129fa/aten/src/ATen/native/cuda/Blas.cpp#L1041) + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: PTPCFp8Config): + super().__init__(quant_config=quant_config) + # Force weight quantization + self.quant_config.is_checkpoint_fp8_serialized = False + self.fp8_linear = Fp8LinearOp(cutlass_fp8_supported=False, + use_per_token_if_dynamic=True) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.weight = torch.nn.Parameter(layer.weight.data, + requires_grad=False) + + assert layer.weight.data.dtype == torch.bfloat16, \ + f"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. {str(layer.weight.data.dtype)} is specified." # noqa: E501 + # Quantize the weights. + qweight, weight_scale = ops.scaled_fp8_quant( + layer.weight, scale=None, use_per_token_if_dynamic=True) + + # Update the layer with the new values. + layer.weight = Parameter( + qweight.t(), requires_grad=False) # Pretranspose the weight + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + layer.input_scale = None + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + return self.fp8_linear.apply(input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=None, + input_scale_ub=None, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/qqq.py b/vllm/model_executor/layers/quantization/qqq.py new file mode 100644 index 0000000..1e05917 --- /dev/null +++ b/vllm/model_executor/layers/quantization/qqq.py @@ -0,0 +1,273 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter) + +logger = init_logger(__name__) + +MARLIN_QQQ_TILE = 16 +MARLIN_QQQ_MIN_THREAD_N = 64 +MARLIN_QQQ_MIN_THREAD_K = 128 +MARLIN_QQQ_MAX_PARALLEL = 16 + +MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] +MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] +MARLIN_QQQ_SUPPORTED_SYM = [True] + + +class QQQConfig(QuantizationConfig): + """Config class for QQQ + + Reference: https://arxiv.org/pdf/2406.09904 + """ + + def __init__( + self, + weight_bits: int, + group_size: int, + is_sym: bool = True, + ) -> None: + super().__init__() + self.weight_bits = weight_bits + self.group_size = group_size + self.is_sym = is_sym + + # Verify + if self.weight_bits not in MARLIN_QQQ_SUPPORTED_NUM_BITS: + raise ValueError( + f"QQQ does not support weight_bits = {self.weight_bits}. " + f"Only weight_bits = {MARLIN_QQQ_SUPPORTED_NUM_BITS} " + "are supported.") + if self.group_size not in MARLIN_QQQ_SUPPORTED_GROUP_SIZES: + raise ValueError( + f"QQQ does not support group_size = {self.group_size}. " + f"Only group_sizes = {MARLIN_QQQ_SUPPORTED_GROUP_SIZES} " + "are supported.") + if self.is_sym not in MARLIN_QQQ_SUPPORTED_SYM: + raise ValueError( + f"QQQ does not support is_sym = {self.is_sym}. " + f"Only sym = {MARLIN_QQQ_SUPPORTED_SYM} are supported.") + + # 4 Bits packed into 32 bit datatype. + self.pack_factor = 32 // self.weight_bits + + # Tile size used by QQQ kernels. + self.tile_size = MARLIN_QQQ_TILE + + # Min out_features dim + self.min_n_threads = MARLIN_QQQ_MIN_THREAD_N + + # Min in_features dim + self.min_k_threads = MARLIN_QQQ_MIN_THREAD_K + + # Max parallel problems to solve at once (improves large + # batch performance) + self.max_parallel = MARLIN_QQQ_MAX_PARALLEL + + # Permutation length used by the QQQ kernels. + self.perm_len = 1024 + + def __repr__(self) -> str: + return "QQQConfig(weight_bits={}, group_size={})".format( + self.weight_bits, self.group_size) + + @classmethod + def get_name(cls) -> str: + return "qqq" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + """List of filenames to search for in the model directory.""" + return [ + "quant_config.json", + "quantize_config.json", + ] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "QQQConfig": + weight_bits = cls.get_from_keys(config, ["wbits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + return cls(weight_bits, group_size) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QQQLinearMethod"]: + if isinstance(layer, LinearBase): + return QQQLinearMethod(self) + return None + + +class QQQLinearMethod(LinearMethodBase): + """Linear method for QQQ. + + Args: + quant_config: The QQQ quantization config. + """ + + def __init__(self, quant_config: QQQConfig): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + weight_loader = extra_weight_attrs["weight_loader"] + if params_dtype != torch.float16: + raise ValueError( + f"The params dtype must be float16, but got {params_dtype}") + + # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) + if output_size_per_partition % self.quant_config.min_n_threads != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f"min_n_threads = {self.quant_config.min_n_threads}.") + if output_size_per_partition % self.quant_config.pack_factor != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f"pack_factor = {self.quant_config.pack_factor}.") + + # Validate input_size_per_partition + if input_size_per_partition % self.quant_config.min_k_threads != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"min_k_threads = {self.quant_config.min_k_threads}.") + if (self.quant_config.group_size != -1 and + input_size_per_partition % self.quant_config.group_size != 0): + raise ValueError(f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"group_size = {self.quant_config.group_size}.") + + # Check that we have at least 4 tiles horizontally in the shard + num_tiles_per_perm = self.quant_config.perm_len // ( + self.quant_config.tile_size**2) + if output_size_per_partition % num_tiles_per_perm != 0: + raise ValueError( + "Each permutation group must reside on the same gpu") + + # Quantized 4Bit weights packed into Int32. + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.quant_config.tile_size, + output_size_per_partition * self.quant_config.tile_size // + self.quant_config.pack_factor, + device="cuda", + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + marlin_tile_size=self.quant_config.tile_size, + weight_loader=weight_loader) + + s_channel = ChannelQuantScaleParameter(data=torch.empty( + 1, + output_size_per_partition, + device="cuda", + dtype=torch.float, + ), + weight_loader=weight_loader, + output_dim=1) + + if self.quant_config.group_size == -1: + s_group_data = torch.tensor( + [], + device="cuda", + dtype=torch.half, + ) + else: + s_group_data = torch.empty( + input_size_per_partition // self.quant_config.group_size, + output_size_per_partition, + device="cuda", + dtype=torch.half, + ) + + s_group_attr = {"data": s_group_data, "weight_loader": weight_loader} + + if self.quant_config.group_size == -1: + s_group = BasevLLMParameter(**s_group_attr) + else: + s_group = GroupQuantScaleParameter(output_dim=1, + input_dim=0, + **s_group_attr) + + # Allocate workspace (Used for internal locking mechanism) + max_workspace_size = ( + output_size_per_partition // + self.quant_config.min_n_threads) * self.quant_config.max_parallel + + workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size, + device="cuda", + dtype=torch.int), + weight_loader=weight_loader) + + layer.register_parameter("B", qweight) + layer.register_parameter("s_channel", s_channel) + layer.register_parameter("s_group", s_group) + layer.register_parameter("workspace", workspace) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # required by torch.compile + layer.B = Parameter(layer.B.data, requires_grad=False) + layer.s_channel = Parameter(layer.s_channel.data, requires_grad=False) + layer.s_group = Parameter(layer.s_group.data, requires_grad=False) + layer.workspace = Parameter(layer.workspace.data, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qweight = layer.B + s_ch = layer.s_channel + s_group = layer.s_group + workspace = layer.workspace + + x_2d = x.view(-1, x.shape[-1]) + + size_m = x_2d.shape[0] + size_k = x_2d.shape[1] + size_n = s_ch.shape[1] + + x_int8, s_tok, _ = ops.scaled_int8_quant(x_2d) + + output_2d = ops.marlin_qqq_gemm(x_int8, qweight, s_tok, s_ch, s_group, + workspace, size_m, size_n, size_k) + + output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) + + if bias is not None: + output.add_(bias) # In-place add + + return output diff --git a/vllm/model_executor/layers/quantization/quark/__init__.py b/vllm/model_executor/layers/quantization/quark/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py new file mode 100644 index 0000000..ca71da8 --- /dev/null +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -0,0 +1,390 @@ +# SPDX-License-Identifier: Apache-2.0 + +import fnmatch +import re +from typing import Any, Dict, List, Optional, cast + +import torch + +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501 + QuarkMoEMethod) +from vllm.model_executor.layers.quantization.quark.schemes import ( + QuarkScheme, QuarkW8A8Fp8, QuarkW8A8Int8) +from vllm.model_executor.layers.quantization.quark.utils import ( + deep_compare, should_ignore_layer) +from vllm.platforms import current_platform + +__all__ = ["QuarkLinearMethod"] + + +class QuarkConfig(QuantizationConfig): + + def __init__(self, + quant_config: Dict[str, Any], + kv_cache_group: Optional[List[str]] = None, + kv_cache_config: Optional[Dict[str, Any]] = None, + pack_method: str = "reorder"): + super().__init__() + if kv_cache_group is None: + kv_cache_group = [] + self.quant_config = quant_config + self.kv_cache_group = kv_cache_group + self.kv_cache_config = kv_cache_config + self.pack_method = pack_method + + def get_linear_method(self) -> "QuarkLinearMethod": + return QuarkLinearMethod(self) + + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + def get_name(self) -> str: + return "quark" + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + + # Check if the layer is skipped for quantization. + exclude_layers = cast(List[str], self.quant_config.get("exclude")) + if should_ignore_layer(prefix, + ignore=exclude_layers, + fused_mapping=self.packed_modules_mapping): + return UnquantizedLinearMethod() + if isinstance(layer, LinearBase): + scheme = self.get_scheme(layer=layer, layer_name=prefix) + layer.scheme = scheme + return QuarkLinearMethod(self) + if isinstance(layer, Attention): + return QuarkKVCacheMethod(self) + if isinstance(layer, FusedMoE): + return QuarkMoEMethod.get_moe_method(self, + module=layer, + layer_name=prefix) + return None + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "QuarkConfig": + export_config = config.get("export") + if export_config is None: + raise ValueError("The export key should be included in " + "the configurations of Quark quantized model") + kv_cache_group = cast(List[str], export_config.get("kv_cache_group")) + pack_method = cast(str, export_config.get("pack_method")) + + # In the export model of quark, the quantization configuration + # of kv_cache is stored in layer_quant_config. First, it is + # judged whether kv_cache_group exists, and then it is judged + # whether layer_quant_config has a quantization configuration + # that matches kv_cache. + if len(kv_cache_group) == 0: + kv_cache_config = None + else: + kv_cache_set = set(kv_cache_group) + layer_quant_config = cast(Dict[str, Any], + config.get("layer_quant_config")) + layer_quant_names = list(layer_quant_config.keys()) + layer_quant_set = set(layer_quant_names) + + if not kv_cache_set.issubset(layer_quant_set): + raise ValueError("The Quark quantized model has the " + "kv_cache_group parameter setting, " + "but no kv_cache quantization settings " + "were found in the quantization " + "configuration.") + + q_configs = [ + cast(Dict[str, Any], layer_quant_config.get(name)) + for name in kv_cache_group + ] + if not all( + deep_compare(q_config, q_configs[0]) + for q_config in q_configs): + raise ValueError( + "The quantization method used for kv_cache should " + "be the same, but the quantization method for the " + "kv_cache layer in the config is different.") + kv_cache_config = q_configs[0].get("output_tensors") + if kv_cache_config is None: + raise ValueError( + "The kv_cache quantization configuration is empty.") + + # Since we have already set kv_cache quantization configurations, + # we will remove the quantization configuration for the + # output_tensors corresponding to the kv_cache layer. + for q_config in q_configs: + q_config["output_tensors"] = None + + return cls(quant_config=config, + kv_cache_group=kv_cache_group, + kv_cache_config=kv_cache_config, + pack_method=pack_method) + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + def _check_scheme_supported(self, + min_capability: int, + error: bool = True) -> bool: + capability_tuple = current_platform.get_device_capability() + + if capability_tuple is not None: + capability = capability_tuple.to_int() + supported = capability >= min_capability + if error and not supported: + raise RuntimeError( + "Quantization scheme is not supported for ", + f"the current GPU. Min capability: {min_capability}. ", + f"Current capability: {capability}.") + return supported + else: + return False + + def _is_fp8_w8a8(self, weight_quant: Optional[Dict[str, Any]], + input_quant: Optional[Dict[str, Any]]) -> bool: + # Confirm weights and input quantized. + if weight_quant is None or input_quant is None: + return False + + # Confirm weight scheme is supported + is_fp8_dtype = (weight_quant.get("dtype") == "fp8_e4m3" + and input_quant.get("dtype") == "fp8_e4m3") + is_static_weight = not weight_quant.get("is_dynamic") + is_per_tensor_or_channel_weight = (weight_quant.get("qscheme") + in ["per_tensor", "per_channel"]) + + if not (is_fp8_dtype and is_static_weight + and is_per_tensor_or_channel_weight): + return False + + # Dynamic quantization is always supported if weights supported. + if input_quant.get("is_dynamic"): + return True + + # Confirm activation scheme is supported. + is_per_tensor_activation = (input_quant.get("qscheme") == "per_tensor") + return is_per_tensor_activation + + def _is_static_tensor_w8a8(self, weight_quant: Optional[Dict[str, Any]], + input_quant: Optional[Dict[str, Any]]) -> bool: + # Confirm weights and input quantized. + if weight_quant is None or input_quant is None: + return False + + is_int8_dtype = (weight_quant.get("dtype") == "int8" + and input_quant.get("dtype") == "int8") + + is_tensor = (weight_quant.get("qscheme") + in ["per_tensor", "per_channel"] + and input_quant.get("qscheme") == "per_tensor") + + is_static = (not weight_quant.get("is_dynamic") + and not input_quant.get("is_dynamic")) + + is_weight_symmetric = (weight_quant.get("symmetric") is True) + + # Both symmetric and asymmetric input quantization supported. + # Only symmetric weight quantization supported. + return is_int8_dtype and is_tensor and is_weight_symmetric and is_static + + def _find_matched_config(self, layer_name: str, + module: torch.nn.Module) -> Dict[str, Any]: + + proj_name = layer_name.split(".")[-1] + if proj_name in self.packed_modules_mapping: + shard_proj_names = self.packed_modules_mapping[proj_name] + + # Convert fused_name --> [shard_names] + shard_names = [ + layer_name.replace(proj_name, shard_proj_name) + for shard_proj_name in shard_proj_names + ] + shard_configs = [ + self._find_matched_config(shard_name, module) + for shard_name in shard_names + ] + if not all( + deep_compare(q_config, shard_configs[0]) + for q_config in shard_configs): + raise ValueError( + f"Found a different quantization configuration for " + f"{shard_proj_names} in {layer_name}. vLLM " + "requires all to use the same scheme.") + return shard_configs[0] + else: + layer_quant_config = cast( + Dict[str, Any], self.quant_config.get("layer_quant_config")) + for name_pattern in layer_quant_config: + if fnmatch.fnmatch(layer_name, name_pattern): + return layer_quant_config[name_pattern] + + layer_type = cast(str, type(module)) + layer_type_quant_config = cast( + Dict[str, Any], + self.quant_config.get("layer_type_quant_config")) + if layer_type in layer_type_quant_config: + return layer_type_quant_config[layer_type] + + global_quant_config = cast( + Dict[str, Any], self.quant_config.get("global_quant_config")) + return global_quant_config + + def _get_scheme_from_config(self, config: Dict[str, Any]) -> "QuarkScheme": + if config.get("output_tensors") or config.get("bias"): + raise NotImplementedError( + "Currently, Quark models with output_tensors " + "and bias quantized are not supported") + weight_config = cast(Dict[str, Any], config.get("weight")) + input_config = cast(Dict[str, Any], config.get("input_tensors")) + + if self._is_fp8_w8a8(weight_config, input_config): + is_fp8_w8a8_supported = self._check_scheme_supported( + QuarkW8A8Fp8.get_min_capability(), error=False) + if is_fp8_w8a8_supported: + weight_qscheme = cast(str, weight_config.get("qscheme")) + input_static = (input_config is not None and + not cast(bool, input_config.get("is_dynamic"))) + return QuarkW8A8Fp8(qscheme=weight_qscheme, + is_static_input_scheme=input_static) + elif self._is_static_tensor_w8a8(weight_config, input_config): + weight_qscheme = cast(str, weight_config.get("qscheme")) + return QuarkW8A8Int8(qscheme=weight_qscheme, + is_static_input_scheme=True, + input_symmetric=input_config.get("symmetric")) + + raise NotImplementedError("No quark compatible scheme was found. " + f"Weight config: {weight_config}, " + f"Input config: {input_config}") + + def get_scheme(self, layer: torch.nn.Module, + layer_name: str) -> "QuarkScheme": + + layer_quant_config = self._find_matched_config(layer_name, layer) + + # Find the quant_scheme + scheme = self._get_scheme_from_config(layer_quant_config) + # Raise error if device does not support the scheme + # (e.g. fp8 needs ada lovelace) + self._check_scheme_supported(scheme.get_min_capability()) + + return scheme + + def get_cache_scale(self, name: str) -> Optional[str]: + """ + Check whether the param name matches the format for k/v cache scales + in quark. If this is the case, return its equivalent param name + expected by vLLM + + :param name: param name + :return: matching param name for KV cache scale in vLLM + """ + if self.kv_cache_group is None or len(self.kv_cache_group) == 0: + return None + + kv_proj_names = [ + re.split(r"[*.]", kv_cache)[-1] for kv_cache in self.kv_cache_group + ] + if name.endswith(".output_scale"): + if len(kv_proj_names) == 1 and kv_proj_names[0] in name: + kv_output_scale_name = "." + kv_proj_names[0] + ".output_scale" + return name.replace(kv_output_scale_name, ".attn.k_scale") + + elif len(kv_proj_names) == 2: + for kv_proj_name in kv_proj_names: + if kv_proj_name in name and kv_proj_name == "k_proj": + return name.replace(".k_proj.output_scale", + ".attn.k_scale") + elif kv_proj_name in name and kv_proj_name == "v_proj": + return name.replace(".v_proj.output_scale", + ".attn.v_scale") + + # If no matches, return None + return None + + +class QuarkLinearMethod(LinearMethodBase): + + def __init__(self, quantization_config: QuarkConfig): + self.quantization_config = quantization_config + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.scheme.process_weights_after_loading(layer) + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + """ + Use the CompressedTensorsScheme associated with each layer to create + the necessary parameters for the layer. See LinearMethodBase for param + details + """ + weight_loader = extra_weight_attrs.get("weight_loader") + layer.scheme.create_weights( + layer=layer, + input_size=input_size, + input_size_per_partition=input_size_per_partition, + output_partition_sizes=output_partition_sizes, + output_size=output_size, + params_dtype=params_dtype, + weight_loader=weight_loader) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None): + """ + Use the output of create_weights and the CompressedTensorsScheme + associated with the layer to apply the forward pass with the + layer input. See LinearMethodBase for param details + + """ + scheme = layer.scheme + if scheme is None: + raise ValueError("A scheme must be defined for each layer") + return scheme.apply_weights(layer, x, bias=bias) + + +class QuarkKVCacheMethod(BaseKVCacheMethod): + """ + Supports loading kv-cache scaling factors from quark checkpoints. + """ + + def __init__(self, quant_config: QuarkConfig): + self.validate_kv_cache_config(quant_config.kv_cache_config) + super().__init__(quant_config) + + @staticmethod + def validate_kv_cache_config(kv_cache_config: Optional[Dict[str, Any]]): + """ + Validator for the kv cache configuration. Useful for controlling the + kv cache quantization schemes, that are being supported in vLLM + :param kv_cache_config: the quark kv cache scheme + """ + if kv_cache_config is None: + return + + dtype = kv_cache_config.get("dtype") + if dtype != "fp8_e4m3": + raise NotImplementedError( + "Currently supported kv cache quantization is " + f"dtype=fp8_e4m3, however received {dtype}") + + qscheme = kv_cache_config.get("qscheme") + if qscheme != "per_tensor": + raise NotImplementedError( + "Only support per-tensor scaling factor " + "for quark KV cache. " + f"Expected qscheme: per_tensor, found qscheme: {qscheme}") diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py new file mode 100644 index 0000000..d1146c0 --- /dev/null +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -0,0 +1,236 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Dict, Optional + +import torch + +import vllm.model_executor.layers.fused_moe # noqa +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, + FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform + +logger = init_logger(__name__) + +__all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod"] + + +class QuarkMoEMethod(FusedMoEMethodBase): + + @staticmethod + def get_moe_method( + quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821 + module: torch.nn.Module, + layer_name: str) -> "QuarkMoEMethod": + layer_quant_config = quant_config._find_matched_config( + layer_name, module) + + if (layer_quant_config.get("output_tensors") + or layer_quant_config.get("bias")): + raise NotImplementedError("Currently, Quark models with " + "output_tensors and bias " + "quantized are not supported") + weight_config = layer_quant_config.get("weight") + input_config = layer_quant_config.get("input_tensors") + + if quant_config._is_fp8_w8a8(weight_config, input_config): + return QuarkW8A8Fp8MoEMethod(weight_config, input_config) + else: + raise RuntimeError("Unsupported FusedMoe scheme") + + +class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): + + def __init__(self, weight_config: Dict[str, Any], input_config: Dict[str, + Any]): + self.weight_quant = weight_config + self.input_quant = input_config + + weight_qscheme = self.weight_quant.get("qscheme") + input_qscheme = self.input_quant.get("qscheme") + if not (weight_qscheme == "per_tensor" + and input_qscheme == "per_tensor"): + raise ValueError( + "For FP8 Fused MoE layers, only per-tensor scales " + "for weights and activations are supported. Found " + f"{weight_qscheme}, {input_qscheme}") # noqa E501 + + self.static_input_scales = not self.input_quant.get("is_dynamic") + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + params_dtype = torch.float8_e4m3fn + + # WEIGHTS + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, + 2, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + + w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.static_input_scales: + w13_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Fp8 moe kernels require a single activation scale. + # We take the max of all the scales in case they differ. + if self.static_input_scales: + if (layer.w13_input_scale is None or layer.w2_input_scale is None): + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None.") + if (not all_close_1d(layer.w13_input_scale) + or not all_close_1d(layer.w2_input_scale)): + logger.warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer. ") + layer.w13_input_scale = torch.nn.Parameter( + layer.w13_input_scale.max(), requires_grad=False) + layer.w2_input_scale = torch.nn.Parameter( + layer.w2_input_scale.max(), requires_grad=False) + + if current_platform.is_fp8_fnuz(): + # Normalize the weights and scales + w13_weight, w13_weight_scale, w13_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w13_weight, layer.w13_weight_scale, + layer.w13_input_scale) + w2_weight, w2_weight_scale, w2_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w2_weight, layer.w2_weight_scale, + layer.w2_input_scale) + # Reset the parameter + layer.w13_weight = torch.nn.Parameter(w13_weight, + requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale, + requires_grad=False) + if w13_input_scale is not None: + layer.w13_input_scale = torch.nn.Parameter(w13_input_scale, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, + requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, + requires_grad=False) + if w2_input_scale is not None: + layer.w2_input_scale = torch.nn.Parameter(w2_input_scale, + requires_grad=False) + + # Fp8 moe kernel needs single weight scale for w13 per expert. + # We take the max then dequant and requant each expert. + assert layer.w13_weight_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + for expert_id in range(layer.local_num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start:start + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id]) + layer.w13_weight[expert_id][ + start:start + shard_size, :], _ = ops.scaled_fp8_quant( + dq_weight, max_w13_scales[expert_id]) + start += shard_size + + layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, + requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe import fused_experts + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_fp8_w8a8=True, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/__init__.py b/vllm/model_executor/layers/quantization/quark/schemes/__init__.py new file mode 100644 index 0000000..9069b5a --- /dev/null +++ b/vllm/model_executor/layers/quantization/quark/schemes/__init__.py @@ -0,0 +1,7 @@ +# SPDX-License-Identifier: Apache-2.0 + +from .quark_scheme import QuarkScheme +from .quark_w8a8_fp8 import QuarkW8A8Fp8 +from .quark_w8a8_int8 import QuarkW8A8Int8 + +__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8"] diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py new file mode 100644 index 0000000..40c8ea8 --- /dev/null +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from typing import Optional + +import torch + +__all__ = ["QuarkScheme"] + + +class QuarkScheme(ABC): + """ + Abstract class used to describe the weight creation and forward pass + of different quantization schemes supported by Quark. + """ + + @classmethod + @abstractmethod + def get_min_capability(cls) -> int: + """ + Get minimum device capability. + """ + raise NotImplementedError + + @abstractmethod + def create_weights(self, *args, **kwargs): + """ + Weight creation for the particular scheme. Inputs to this function + + """ + raise NotImplementedError + + @abstractmethod + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]): + """ + Run the forward pass for the particular scheme. This is where + scheme-specific dequant/quant steps/kernels should be applied. + + :param layer: torch.nn.Module with the registered weights and + other parameters relevant to the particular scheme. + :param x: input to the layer + :param bias: bias parameter + + """ + raise NotImplementedError + + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module): + """ + Called after weight loading is complete for any cleanup that + needs to occur. + """ + raise NotImplementedError diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py new file mode 100644 index 0000000..c161849 --- /dev/null +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -0,0 +1,140 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable, List, Optional + +import torch +from torch.nn import Parameter + +from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + Fp8LinearOp, normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale) +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) +from vllm.platforms import current_platform + +__all__ = ["QuarkW8A8Fp8"] + + +class QuarkW8A8Fp8(QuarkScheme): + + def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]): + self.qscheme = qscheme + self.is_static_input_scheme = is_static_input_scheme + self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True) + self.out_dtype = torch.get_default_dtype() + + @classmethod + def get_min_capability(cls) -> int: + # lovelace and up + return 89 + + def process_weights_after_loading(self, layer) -> None: + # If per tensor, when we have a fused module (e.g. QKV) with per + # tensor scales (thus N scales being passed to the kernel), + # requantize so we can always run per tensor + if self.qscheme == "per_tensor": + max_w_scale, weight = requantize_with_max_scale( + weight=layer.weight, + weight_scale=layer.weight_scale, + logical_widths=layer.logical_widths, + ) + + if current_platform.is_fp8_fnuz(): + weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=max_w_scale, + input_scale=layer.input_scale) + if input_scale is not None: + layer.input_scale = Parameter(input_scale, + requires_grad=False) + + layer.weight = Parameter(weight.t(), requires_grad=False) + layer.weight_scale = Parameter(max_w_scale, requires_grad=False) + + # If channelwise, scales are already lined up, so just transpose. + elif self.qscheme == "per_channel": + weight = layer.weight + + if current_platform.is_fp8_fnuz(): + weight, weight_scale, input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale) + if input_scale is not None: + layer.input_scale = Parameter(input_scale, + requires_grad=False) + else: + weight_scale = layer.weight_scale.data + + layer.weight = Parameter(weight.t(), requires_grad=False) + # required by torch.compile to be torch.nn.Parameter + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + + else: + raise ValueError(f"Unknown quantization scheme {self.qscheme}") + + # INPUT SCALE + if self.is_static_input_scheme: + layer.input_scale = Parameter(layer.input_scale.max(), + requires_grad=False) + else: + layer.input_scale = None + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + + # WEIGHT + weight = ModelWeightParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + # TODO: update create_xxx_parameter functions to return + # the newly added parameters + if self.qscheme == "per_channel": + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), + dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader) + else: + assert self.qscheme == "per_tensor" + weight_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + + # min requirement for fp8 kernels + weight_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", weight_scale) + + # INPUT SCALE + if self.is_static_input_scheme: + input_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + input_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("input_scale", input_scale) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + return self.fp8_linear.apply(input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + out_dtype=self.out_dtype, + input_scale=layer.input_scale, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py new file mode 100644 index 0000000..1bf34b0 --- /dev/null +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py @@ -0,0 +1,107 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable, List, Optional, Set + +import torch + +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + ScaledMMLinearLayerConfig, choose_scaled_mm_linear_kernel) +from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) + +logger = init_logger(__name__) + + +class QuarkW8A8Int8(QuarkScheme): + _kernel_backends_being_used: Set[str] = set() + + def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool], + input_symmetric: Optional[bool]): + self.qscheme = qscheme + self.is_static_input_scheme = is_static_input_scheme + self.input_symmetric = input_symmetric + + @classmethod + def get_min_capability(cls) -> int: + # turing and up + return 75 + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + self.logical_widths = output_partition_sizes + + scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig( + is_channelwise=(self.qscheme == "per_channel"), + is_static_input_scheme=(self.is_static_input_scheme is True), + input_symmetric=(self.input_symmetric is True)) + + kernel_type = choose_scaled_mm_linear_kernel( + scaled_mm_linear_kernel_config) + + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for QuarkW8A8Int8", kernel_type.__name__) + self._kernel_backends_being_used.add(kernel_type.__name__) + + # WEIGHT + weight = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=torch.int8), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + if self.qscheme == "per_channel": + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), + dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader) + else: + assert self.qscheme == "per_tensor" + weight_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("weight_scale", weight_scale) + + # INPUT SCALE + if self.is_static_input_scheme: + input_scale = BasevLLMParameter(data=torch.empty( + 1, dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("input_scale", input_scale) + + if not self.input_symmetric: + # Note: quark stores the zp using the same dtype + # as the weights + # AZP loaded as int8 but used as int32 + input_zero_point = BasevLLMParameter( + data=torch.empty(1, dtype=torch.int8), + weight_loader=weight_loader) + layer.register_parameter("input_zero_point", input_zero_point) + + self.kernel = kernel_type(c=scaled_mm_linear_kernel_config, + w_q_param_name="weight", + w_s_param_name="weight_scale", + i_s_param_name="input_scale", + i_zp_param_name="input_zero_point", + azp_adj_param_name="azp_adj") + + # Checkpoints are serialized in quark format, which is + # different from the format the kernel may want. Handle repacking here. + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/quark/utils.py b/vllm/model_executor/layers/quantization/quark/utils.py new file mode 100644 index 0000000..17e0df0 --- /dev/null +++ b/vllm/model_executor/layers/quantization/quark/utils.py @@ -0,0 +1,102 @@ +# SPDX-License-Identifier: Apache-2.0 + +import re +from types import MappingProxyType +from typing import Any, Iterable, List, Mapping, Optional + + +def deep_compare(dict1: Any, dict2: Any) -> bool: + if type(dict1) is not type(dict2): + return False + if isinstance(dict1, dict): + if dict1.keys() != dict2.keys(): + return False + return all(deep_compare(dict1[k], dict2[k]) for k in dict1) + elif isinstance(dict1, list): + return set(dict1) == set(dict2) + else: + return dict1 == dict2 + + +def should_ignore_layer( + layer_name: Optional[str], + ignore: Iterable[str], + fused_mapping: Mapping[str, List[str]] = MappingProxyType({}) +) -> bool: + if layer_name is None: + return False + + # layer_name = model.layers.0.self_attn.qkv_proj + # proj_name = qkv_proj + proj_name = layer_name.split(".")[-1] + + # Fused layers like gate_up_proj or qkv_proj will not be fused + # in the safetensors checkpoint. So, we convert the name + # from the fused version to unfused + check to make sure that + # each shard of the fused layer has the same scheme. + if proj_name in fused_mapping: + shard_proj_names = fused_mapping[proj_name] + + # Convert fused_name --> [shard_names] + shard_names = [ + layer_name.replace(proj_name, shard_proj_name) + for shard_proj_name in shard_proj_names + ] + + # Layer should be ignored if shards are ignored. + should_ignore_layer = None + for shard_name in shard_names: + should_ignore_shard = check_equal_or_regex_match( + layer_name=shard_name, targets=ignore) + + # If shard_idx=0, set layer ignore to match shard. + if should_ignore_layer is None: + should_ignore_layer = should_ignore_shard + + # If shard_idx=1+ confirm scheme matches prior shards. + elif should_ignore_shard != should_ignore_layer: + raise ValueError(f"Found a different quantization schemes for " + f"{shard_proj_names} in {layer_name}. vLLM " + "requires all to use the same scheme.") + + # Unfused layers like down_proj and o_proj will match + # the safetensors checkpoint already. + else: + should_ignore_layer = check_equal_or_regex_match(layer_name=layer_name, + targets=ignore) + + assert should_ignore_layer is not None + return should_ignore_layer + + +def check_equal_or_regex_match(layer_name: str, + targets: Iterable[str]) -> bool: + """ + Checks whether a layer_name is exactly equal or a regex match for + if target starts with 're:' to any target in list. + """ + for target in targets: + if _is_equal_or_regex_match(layer_name, target): + return True + return False + + +def _is_equal_or_regex_match(value: str, + target: str, + check_contains: bool = False) -> bool: + """ + Checks whether a value is exactly equal or a regex match for target + if target starts with 're:'. If check_contains is set to True, + additionally checks if the target string is contained within the value. + """ + + if target.startswith("re:"): + pattern = target[3:] + if re.match(pattern, value): + return True + elif check_contains: + if target.lower() in value.lower(): + return True + elif target == value: + return True + return False diff --git a/vllm/model_executor/layers/quantization/schema.py b/vllm/model_executor/layers/quantization/schema.py new file mode 100644 index 0000000..026881f --- /dev/null +++ b/vllm/model_executor/layers/quantization/schema.py @@ -0,0 +1,85 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This file contains the Pydantic schemas for various quantization-related +parameters. When a relevant quantization technique is specified, these +parameters are loaded in the form of a JSON alongside the model weights +and augment the model with additional information needed for use of that +technique. The format of this JSON should be specified by one or more +schemas contained here. + +For example, when the KV cache is quantized to FP8-E4M3 (currently only +possible on ROCm), the model can be optionally augmented with KV cache +scaling factors. +""" + +from typing import Dict, Optional + +from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator + + +class KVCacheQuantSchema(BaseModel): + dtype: str + # Each key is a TP rank. Each value is a dictionary mapping a TP rank's + # layer indices to their per-tensor KV cache scaling factor. + # TODO: Consider pulling this and its validation methods out into its + # own schema class (tricky as its members are variable) + scaling_factor: Dict[int, Dict[int, float]] + + @model_validator(mode="after") + def check_is_fp8(self) -> "KVCacheQuantSchema": + assert self.dtype == "float8_e4m3fn", ( + "Loaded scaling factors intended for KV cache dtype = " + f"{self.dtype} rather than float8_e4m3fn!") + return self + + @model_validator(mode="after") + def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema": + context = info.context + if context: + tp_size = context["tp_size"] + num_hidden_layers = context["num_hidden_layers"] + assert len(self.scaling_factor) == tp_size, ( + f"Loaded dictionary has TP size {len(self.scaling_factor)} " + f"but LLM engine is currently running with TP size {tp_size}.") + for tp_rank, layer_maps in self.scaling_factor.items(): + assert len(layer_maps) == num_hidden_layers, ( + f"KV cache scales map for TP rank {tp_rank} is malformed. " + f"Expected {num_hidden_layers} layers, got " + f"{len(layer_maps)}.") + for i in range(tp_size): + assert i in self.scaling_factor, ( + f"KV cache scales map for TP rank {i} not found.") + return self + + @model_validator(mode="after") + def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema": + context = info.context + if context: + tp_rank = context["tp_rank"] + num_hidden_layers = context["num_hidden_layers"] + layer_scales_map = self.scaling_factor[tp_rank] + for i in range(num_hidden_layers): + assert i in layer_scales_map, ( + f"Could not find KV cache scales for layer {i} in " + f"TP rank {tp_rank}.") + return self + + +class QuantParamSchema(BaseModel): + # TODO: Generalize and extend with more fields + # (e.g. weights/activations params) once functionality is enabled + model_config = ConfigDict(protected_namespaces=()) + model_type: Optional[str] + kv_cache: KVCacheQuantSchema + + @model_validator(mode="after") + def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema": + context = info.context + if context: + model_type = context.get("model_type", None) + if model_type is not None: + assert model_type == self.model_type, ( + f"Model type is {model_type} but loaded " + f"scaling factors belonging to different " + f"model type {self.model_type}!") + return self diff --git a/vllm/model_executor/layers/quantization/tpu_int8.py b/vllm/model_executor/layers/quantization/tpu_int8.py new file mode 100644 index 0000000..14e5bcf --- /dev/null +++ b/vllm/model_executor/layers/quantization/tpu_int8.py @@ -0,0 +1,119 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional, Tuple + +import torch +from torch.nn import Module +from torch.nn.parameter import Parameter + +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.parameter import ModelWeightParameter + +ACTIVATION_SCHEMES = ["none"] + + +class Int8TpuConfig(QuantizationConfig): + """Int8 Quantization Config class for TPU Backend.""" + + def __init__( + self, + activation_scheme: str = "none", + ) -> None: + super().__init__() + if activation_scheme not in ACTIVATION_SCHEMES: + raise ValueError( + f"Unsupported activation scheme {activation_scheme}") + self.activation_scheme = activation_scheme + + def get_name(self) -> str: + return "tpu_int8" + + def get_supported_act_dtypes(self) -> List[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + raise NotImplementedError( + "This function should not be called with TPU Backend") + + @staticmethod + def get_config_filenames() -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "Int8TpuConfig": + activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + return cls(activation_scheme=activation_scheme) + + def get_quant_method(self, layer: Module, + prefix: str) -> Optional["TPUInt8LinearMethod"]: + if isinstance(layer, LinearBase): + return TPUInt8LinearMethod(self) + return None + + +class TPUInt8LinearMethod(LinearMethodBase): + """Int8 Linear method for TPU Quant. """ + + def __init__(self, quant_config: Int8TpuConfig): + self.quant_config = quant_config + + def create_weights(self, layer: Module, input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + + weight_loader = extra_weight_attrs.get("weight_loader") + weight = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight", weight) + + def _quantize_weight( + self, weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + weight_dtype = weight.dtype + weight = weight.cpu().to(torch.float32) + n_bit = 8 + eps = 1e-5 + max_int = 2**(n_bit - 1) - 1 + min_int = -(2**(n_bit - 1)) + max_val = weight.abs().amax(dim=-1, keepdim=True) + max_val = max_val.clamp(min=eps) + qscale = max_val / max_int + qweight = torch.clamp(torch.round(weight * (1.0 / qscale)), min_int, + max_int).to(torch.int8) + qscale = qscale.squeeze().to(weight_dtype) + return qweight, qscale + + def process_weights_after_loading(self, layer: Module) -> None: + layer.weight = Parameter(layer.weight.data, requires_grad=False) + device = layer.weight.device + qweight, qscale = self._quantize_weight(layer.weight) + qweight = qweight.to(device) + qscale = qscale.to(device) + layer.weight = Parameter(qweight, requires_grad=False) + layer.scale = Parameter(qscale, requires_grad=False) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + try: + import torch_xla.experimental.xla_quantized_matmul # noqa: F401 + except ImportError as err: + raise ImportError( + "Please install torch_xla by following the instructions at " + "https://docs.vllm.ai/en/latest/getting_started/tpu-installation.html " # noqa: E501 + "to run vLLM on TPU.") from err + weight = layer.weight + scale = layer.scale + out = torch.ops.xla.quantized_matmul(x, weight, scale) + if bias is not None: + out = out + bias + return out diff --git a/vllm/model_executor/layers/quantization/utils/__init__.py b/vllm/model_executor/layers/quantization/utils/__init__.py new file mode 100644 index 0000000..f7ee472 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/__init__.py @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 + +from .layer_utils import replace_parameter, update_tensor_inplace + +__all__ = ['update_tensor_inplace', 'replace_parameter'] diff --git a/vllm/model_executor/layers/quantization/utils/allspark_utils.py b/vllm/model_executor/layers/quantization/utils/allspark_utils.py new file mode 100644 index 0000000..9786076 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/allspark_utils.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from vllm.platforms import current_platform +from vllm.scalar_type import ScalarType, scalar_types + +ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD = 1024 +ALLSPARK_SUPPORTED_QUANT_TYPES = [scalar_types.uint8b128] +ALLSPARK_AMPERE_N_ALIGN = 16 +ALLSPARK_AMPERE_K_ALIGN = 16 + + +def check_allspark_supported_dtype_shape(input_size_per_partition: int, + output_size_per_partition: int, + group_size: int, + weight_dtype: ScalarType, + act_dtype: torch.dtype): + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) + + # For Ampere GPU + if device_capability >= 80 and device_capability < 90: + if group_size != -1: + return False, \ + "For Ampere GPU, AllSpark does not support group_size "\ + f"= {group_size}. Only group_size = -1 are supported." + + if weight_dtype not in ALLSPARK_SUPPORTED_QUANT_TYPES: + return False, "For Ampere GPU, AllSpark does not support "\ + f"quant type ({weight_dtype}). Only quant type "\ + f"({ALLSPARK_SUPPORTED_QUANT_TYPES}) are supported." + + if input_size_per_partition % ALLSPARK_AMPERE_K_ALIGN != 0 \ + or output_size_per_partition % ALLSPARK_AMPERE_N_ALIGN != 0: + return False, \ + "AllSpark needs input_size_per_partition % "\ + f"{ALLSPARK_AMPERE_K_ALIGN} = 0 and "\ + f"output_size_per_partition % {ALLSPARK_AMPERE_N_ALIGN} = 0 "\ + "for Ampere GPU optimized kernels." + + if act_dtype != torch.float16 and act_dtype != torch.bfloat16: + return False, \ + "AllSpark only supports act_dtype = float16 or bfloat16,"\ + f"for Ampere GPU, but got act_dtype = {act_dtype}." + else: + return False, "AllSpark currently does not support "\ + f"device_capability = {device_capability}." + + return True, None diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..e9a50e1 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..119969d --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..119969d --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..3e8ebf3 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..2bb5b45 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..6496a38 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..6e2aeee --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..b0f9442 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..b3bf9ea --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..7e52ab6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..7e52ab6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..bee8d03 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..9da876d --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..3618053 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..0a1a252 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..46a982f --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..9696611 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..d6279a1 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..defaacb --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..ecc2fda --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..ecc2fda --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..3bc0036 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..310dff4 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..035ec02 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..206c8a2 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..8b49f27 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..edc2353 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..987c8f6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..108af31 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..108af31 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..43b5bdb --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..bffa749 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..851bc9f --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..f96f127 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..d1227c2 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..fe3e18c --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..b3ed43a --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..abd1915 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..abd1915 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..e4d5b2d --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..137b9dd --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..77ba0d7 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..1c61451 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..38cac46 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..8e6ebe2 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..63e661c --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..459062e --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..1225d84 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..03e8235 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..bb61d83 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..bb61d83 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..d44e384 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..c559a69 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..cf35403 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..8ec2005 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..65840aa --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..1a457b9 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..574cf49 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..574cf49 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..0a5d7bf --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..4e120d6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..eccb86a --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..125fe36 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..4415cc9 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..7bfaf93 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..7bfaf93 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..cb91a27 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..88af484 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..5c29874 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..dd06972 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..125fe36 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..7c039b4 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..c2bd478 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..c2bd478 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..4990268 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..18afdd9 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..7febe3d --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..56b939e --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..51d10bb --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..1480e09 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..63d9a0b --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..f5fdec3 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..6bd350c --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..5c604b9 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..75906ad --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..75906ad --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..7fa398c --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..f15d8f6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..b4d25ae --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..fdc6437 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..fdc6437 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..9d7658b --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..cd3e078 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..2b9f0d1 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..9d5a329 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..7f449db --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..634c1bf --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..7eaa7d1 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..7eaa7d1 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..03dba5a --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..96e1594 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..d979c6b --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..5ffd367 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..be93dfe --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..19452df --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..3382554 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..3382554 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..9a5ff48 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..6eb22de --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..eabc423 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..84ef35e --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..e6d9107 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..c9d18c9 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..c9d18c9 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..c746e70 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..0b4746c --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..386928d --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..51e237b --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..8ec2005 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..202acf2 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..6280219 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..983525f --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,18 @@ +{ + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..11a9bce --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..c298da8 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..56a766c --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..56a766c --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..386ee59 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..60df5e3 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..40c01c0 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..4f1747b --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..c6fd365 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..53bbaca --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..cb993c8 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..f250d3f --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..f250d3f --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..ffe67dc --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..2a17e16 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..160f12e --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..b259993 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..e5c4a1d --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..a71ab88 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..56d3e1f --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..bbd4df4 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..bbd4df4 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..eda96e7 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..bd0767b --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..2bf5eb2 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..29f7651 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..6db1385 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..9cdff13 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..7bb8e87 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..7bb8e87 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..1a47cae --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..8dd5ae5 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..9c908e8 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..0a1e14c --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..6d1a8b5 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..e77abaf --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..15b1c93 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..0cf6a47 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..01327b2 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..6f9bd75 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..f050b75 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..f050b75 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..12eea5f --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..9db9dae --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..f78e706 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..8ff12e6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..365f8d0 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..f080ea5 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..4532f93 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..0cf6a47 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..e9bf044 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..c7122d3 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..4a3ccc0 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..4a3ccc0 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..1d3ce5c --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..ca7f32b --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..c37aced --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..5acea24 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..d962889 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..3cea21b --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..24ef112 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..24ef112 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..3ab5796 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..58cdd93 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..d6bef7f --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..b72e037 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..b4b08ea --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..a8141f5 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..c911a8e --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..c911a8e --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..3cb7eaa --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..8df6e4b --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..293adce --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..9d7edc3 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..c9566d7 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..d86b349 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..d86b349 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..e471687 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..b4c3249 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..b4c3249 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py new file mode 100644 index 0000000..77f6730 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -0,0 +1,526 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from https://github.com/sgl-project/sglang/pull/2575 +import functools +import json +import os +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import triton +import triton.language as tl + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + scaled_dequantize) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + CUTLASS_BLOCK_FP8_SUPPORTED) +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op + +logger = init_logger(__name__) + + +def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool: + if isinstance(x, torch.Tensor): + x = x.dtype + try: + return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz + except: + return False + + +# TODO fix ROCm->Triton custom path: +# https://github.com/vllm-project/vllm/issues/14397 +def apply_w8a8_block_fp8_linear( + input: torch.Tensor, + weight: torch.Tensor, + block_size: List[int], + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, +) -> torch.Tensor: + assert input_scale is None + # View input as 2D matrix for fp8 methods + input_2d = input.view(-1, input.shape[-1]) + output_shape = [*input.shape[:-1], weight.shape[0]] + + shape_supported_by_cutlass = (weight.shape[0] % 128 == 0 + and weight.shape[1] % 128 == 0) + if current_platform.is_rocm(): + # TODO this is never used, as cutlass_block_fp8_supported is False + scale_a_shape = ((input_2d.shape[-1] // block_size[1], ) + + input_2d.shape[:-1])[::-1] + scale_b_shape = (weight_scale.view(-1, 1) + if weight_scale.dim() <= 1 else weight_scale.T).shape + ar, ac = scale_a_shape + br, bc = scale_b_shape + if (ac > 1 or bc > 1 or ar not in (1, input_2d.shape[0]) + or br not in (1, weight.shape[0])): + shape_supported_by_cutlass = False + if cutlass_block_fp8_supported and shape_supported_by_cutlass: + q_input, x_scale = per_token_group_quant_fp8(input_2d, + block_size[1], + column_major_scales=True) + output = ops.cutlass_scaled_mm(q_input, + weight.T, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale.T) + else: + q_input, x_scale = per_token_group_quant_fp8(input_2d, + block_size[1], + column_major_scales=False) + output = w8a8_block_fp8_matmul(q_input, + weight, + x_scale, + weight_scale, + block_size, + output_dtype=input.dtype) + if bias is not None: + output = output + bias + return output.to(dtype=input.dtype).view(*output_shape) + + +def apply_w8a8_block_fp8_linear_fake( + input: torch.Tensor, + weight: torch.Tensor, + block_size: List[int], + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + output_shape = [*input.shape[:-1], weight.shape[0]] + return torch.empty(output_shape, dtype=input.dtype, device=input.device) + + +direct_register_custom_op( + op_name="apply_w8a8_block_fp8_linear", + op_func=apply_w8a8_block_fp8_linear, + mutates_args=[], + fake_impl=apply_w8a8_block_fp8_linear_fake, +) + + +def input_to_float8( + x: torch.Tensor, + dtype: Optional[torch.dtype] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """This function quantizes input values to float8 values " + "with tensor-wise quantization.""" + dtype = current_platform.fp8_dtype() if dtype is None else dtype + finfo = torch.finfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() + + +def block_quant_to_tensor_quant( + x_q_block: torch.Tensor, + x_s: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """This function converts block-wise quantization to tensor-wise + quantization. The inputs are block-wise quantization tensor `x_q_block`, + block-wise quantization scale and the block size. + The outputs are tensor-wise quantization tensor and tensor-wise + quantization scale. Note only float8 is supported for now. + """ + x_dq_block = scaled_dequantize(x_q_block, x_s) + x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype) + return x_q_tensor, scale + + +@triton.jit +def _per_token_group_quant_fp8( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + y_row_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + groups_per_row = y_num_columns // group_size + + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + row = g_id // groups_per_row + row_g_id = g_id % groups_per_row + + y_ptr += (row * y_row_stride) + (row_g_id * group_size) + y_q_ptr += g_id * group_size + y_s_ptr += g_id + + cols = tl.arange(0, BLOCK) # N <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +@triton.jit +def _per_token_group_quant_fp8_colmajor( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + y_row_stride, + # Stride from one column to the next of y_s + y_s_col_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + groups_per_row = y_num_columns // group_size + + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + row = g_id // groups_per_row + row_g_id = g_id % groups_per_row + + y_ptr += (row * y_row_stride) + (row_g_id * group_size) + y_q_ptr += g_id * group_size + + # Convert g_id the flattened block coordinate to 2D so we can index + # into the output y_scales matrix + blocks_per_row = y_num_columns // group_size + scale_col = g_id % blocks_per_row + scale_row = g_id // blocks_per_row + y_s_ptr += scale_col * y_s_col_stride + scale_row + + cols = tl.arange(0, BLOCK) # group_size <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +def per_token_group_quant_fp8( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: Optional[torch.dtype] = None, + column_major_scales: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Function to perform per-token-group quantization on an input tensor `x`. + It converts the tensor values into signed float8 values and returns the + quantized tensor along with the scaling factor used for quantization. + Args: + x: The input tensor with ndim >= 2. + group_size: The group size used for quantization. + eps: The minimum to avoid dividing zero. + dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` + is supported for now. + Returns: + Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the + scaling factor for quantization. + """ + dtype = current_platform.fp8_dtype() if dtype is None else dtype + assert (x.shape[-1] % group_size == 0), ( + f"the last dimension of `x` {x.shape[-1]} must be divisible " + f"by `group_size` {group_size}") + assert x.stride(-1) == 1, "`x` groups must be contiguous" + + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + M = x.numel() // group_size + N = group_size + if column_major_scales: + shape = (x.shape[-1] // group_size, ) + x.shape[:-1] + x_s = torch.empty(shape, device=x.device, + dtype=torch.float32).permute(-1, -2) + else: + shape = x.shape[:-1] + (x.shape[-1] // group_size, ) + x_s = torch.empty(shape, device=x.device, dtype=torch.float32) + + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + if column_major_scales: + _per_token_group_quant_fp8_colmajor[(M, )]( + x, + x_q, + x_s, + group_size, + x.shape[1], + x.stride(0), + x_s.stride(1), + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + else: + _per_token_group_quant_fp8[(M, )]( + x, + x_q, + x_s, + group_size, + x.shape[1], + x.stride(0), + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + + return x_q, x_s + + +@triton.jit +def _w8a8_block_fp8_matmul( + # Pointers to inputs and output + A, + B, + C, + As, + Bs, + # Shape for matmul + M, + N, + K, + # Block size for block-wise quantization + group_n, + group_k, + # Stride for inputs and output + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_As_m, + stride_As_k, + stride_Bs_k, + stride_Bs_n, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """Triton-accelerated function used to perform linear operations (dot + product) on input tensors `A` and `B` with block-wise quantization, and + store the result in output tensor `C`. + """ + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + As_ptrs = As + offs_am * stride_As_m + offs_bsn = offs_bn // group_n + Bs_ptrs = Bs + offs_bsn * stride_Bs_n + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, + mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, + other=0.0) + b = tl.load(b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0) + + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_s = tl.load(As_ptrs + offs_ks * stride_As_k) + b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) + + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if C.dtype.element_ty == tl.bfloat16: + c = accumulator.to(tl.bfloat16) + elif C.dtype.element_ty == tl.float16: + c = accumulator.to(tl.float16) + else: + c = accumulator.to(tl.float32) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +@functools.lru_cache +def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int, + block_k: int) -> Optional[Dict[int, Any]]: + """ + Return optimized configurations for the w8a8 block fp8 kernel. + The return value will be a dictionary that maps an irregular grid of + batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the + kernel on a given batch size bs, the closest batch size in the grid should + be picked and the associated configuration chosen to invoke the kernel. + """ + + # First look up if an optimized configuration is available in the configs + # directory + device_name = current_platform.get_device_name().replace(" ", "_") + json_file_name = f"N={N},K={K},device_name={device_name},dtype=fp8_w8a8,block_shape=[{block_n},{block_k}].json" # noqa: E501 + + config_file_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) + if os.path.exists(config_file_path): + with open(config_file_path) as f: + logger.info( + "Using configuration from %s for W8A8 Block FP8 kernel.", + config_file_path, + ) + # If a configuration has been found, return it + return {int(key): val for key, val in json.load(f).items()} + + # If no optimized configuration is available, we will use the default + # configuration + logger.warning( + "Using default W8A8 Block FP8 kernel config. Performance might " + "be sub-optimal! Config file not found at %s", + config_file_path, + ) + return None + + +def w8a8_block_fp8_matmul( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: List[int], + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + """This function performs matrix multiplication with block-wise + quantization. + It takes two input tensors `A` and `B` with scales `As` and `Bs`. + The output is returned in the specified `output_dtype`. + Args: + A: The input tensor, e.g., activation. + B: The input tensor, e.g., weight. + As: The per-token-group quantization scale for `A`. + Bs: The per-block quantization scale for `B`. + block_size: The block size for per-block quantization. It should + be 2-dim, e.g., [128, 128]. + output_dytpe: The dtype of the returned tensor. + Returns: + torch.Tensor: The result of matmul. + """ + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + + assert A.shape[-1] == B.shape[-1] + assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() + assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] + M = A.numel() // A.shape[-1] + + assert B.ndim == 2 and Bs.ndim == 2 + N, K = B.shape + assert triton.cdiv(N, block_n) == Bs.shape[0] + assert triton.cdiv(K, block_k) == Bs.shape[1] + + C_shape = A.shape[:-1] + (N, ) + C = A.new_empty(C_shape, dtype=output_dtype) + + configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1]) + if configs: + # Get the optimal config if there is one + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] + else: + # Default config + # Block-wise quant: BLOCK_SIZE_N must be divisible by block_size[0] + # BLOCK_SIZE_K must be divisible by block_size[1] + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": block_size[0], + "BLOCK_SIZE_K": block_size[1], + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + } + + def grid(META): + return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * + triton.cdiv(N, META["BLOCK_SIZE_N"]), ) + + _w8a8_block_fp8_matmul[grid]( + A, + B, + C, + As, + Bs, + M, + N, + K, + block_n, + block_k, + A.stride(-2), + A.stride(-1), + B.stride(1), + B.stride(0), + C.stride(-2), + C.stride(-1), + As.stride(-2), + As.stride(-1), + Bs.stride(1), + Bs.stride(0), + **config, + ) + + return C diff --git a/vllm/model_executor/layers/quantization/utils/gguf_utils.py b/vllm/model_executor/layers/quantization/utils/gguf_utils.py new file mode 100644 index 0000000..79b34e2 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/gguf_utils.py @@ -0,0 +1,373 @@ +import torch +import numpy as np +from gguf.constants import GGMLQuantizationType + +def get_awq_format(w, group_size=128, w_bit=4): + org_w_shape = w.shape + ori_w_dtype = torch.get_default_dtype() + assert w_bit == 4 + assert w.shape[1] % group_size == 0 + + in_features = org_w_shape[1] + w = w.reshape(-1, group_size) + assert torch.isnan(w).sum() == 0 + + max_val = w.amax(dim=1, keepdim=True) + min_val = w.amin(dim=1, keepdim=True) + max_int = 2**w_bit - 1 + min_int = 0 + scales = (max_val - min_val).clamp(min=1e-5) / max_int + zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int) + w = ( + torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros + ) * scales + zeros = zeros.view(org_w_shape[0], -1) + scales = scales.view(org_w_shape[0], -1) + w = w.reshape(org_w_shape) + assert torch.isnan(scales).sum() == 0 + assert torch.isnan(w).sum() == 0 + + scales = scales.t().contiguous() # input // group, o + zeros = zeros.t().contiguous() # input // group, o + + # from auto awq + scale_zeros = zeros * scales + scales = scales.clone().to(ori_w_dtype) + + pack_num = 32 // w_bit + intweight = [] + for idx in range(in_features): + intweight.append( + torch.round( + (w[:, idx] + scale_zeros[idx // group_size]) + / scales[idx // group_size] + ).to(torch.int)[:, None] + ) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.to(dtype=torch.int32) + + qweight = torch.zeros( + (intweight.shape[0], intweight.shape[1] // 32 * w_bit), + dtype=torch.int32, + device=intweight.device, + ) + + for col in range(intweight.shape[1] // pack_num): + order_map = [0, 2, w_bit, 6, 1, 3, 5, 7] + for i in range(pack_num): + qweight_col = intweight[:, col * pack_num + order_map[i]] + qweight[:, col] |= qweight_col << (i * w_bit) + + zeros = zeros.to(dtype=torch.int32, device=qweight.device) + + qzeros = torch.zeros( + (zeros.shape[0], zeros.shape[1] // 32 * w_bit), + dtype=torch.int32, + device=zeros.device, + ) + + for col in range(zeros.shape[1] // pack_num): + order_map = [0, 2, w_bit, 6, 1, 3, 5, 7] + for i in range(pack_num): + qzero_col = zeros[:, col * pack_num + order_map[i]] + qzeros[:, col] |= qzero_col << (i * w_bit) + + return qweight, qzeros, scales + +GGML_BLOCK_SIZES = { + "F32": 4, + "F16": 2, + "Q4_0": 2 + 16, + "Q5_0": 2 + 4 + 16, + "Q8_0": 2 + 32, + "Q2_K": 256 // 16 + 256 // 4 + 2 + 2, + "Q3_K": 256 // 8 + 256 // 4 + 12 + 2, + "Q4_K": 2 + 2 + 12 + 256 // 2, + "Q5_K": 2 + 2 + 12 + 256 // 8 + 256 // 2, + "Q6_K": 256 // 2 + 256 // 4 + 256 // 16 + 2, + "IQ4_XS": 2 + 2 + 256 // 2 + 256 // 64, +} + +def dequantize_f32(data): + return np.frombuffer(data, dtype=np.float32) + +def dequantize_f16(data): + return np.frombuffer(data, dtype=np.float16) + +def dequantize_q4_0(data): + num_blocks = len(data) // GGML_BLOCK_SIZES["Q4_0"] + + scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 8)[:, :1].astype(np.float32) + qs = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 16)[:, 2:] + + return np.concatenate([ + scales * ((qs & 0xf).astype(np.int8) - 8), + scales * ((qs >> 4).astype(np.int8) - 8), + ], axis=1) + +def dequantize_q5_0(data): + num_blocks = len(data) // GGML_BLOCK_SIZES["Q5_0"] + + scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 2 + 8)[:, :1].astype(np.float32) + qh = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 4 + 16)[:, 2:2 + 4] + qs = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 4 + 16)[:, 2 + 4:] + + bits = np.unpackbits(qh, axis=-1, bitorder="little") + + x0 = ((qs & 0xf).astype(np.int8) | (bits[:, :16] << 4)) - 16 + x1 = ((qs >> 4).astype(np.int8) | (bits[:, 16:] << 4)) - 16 + + return np.concatenate([ + scales * x0, + scales * x1, + ], axis=1) + +def dequantize_q8_0(data): + num_blocks = len(data) // GGML_BLOCK_SIZES["Q8_0"] + + scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 16)[:, :1].astype(np.float32) + qs = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, 2 + 32)[:, 2:] + return scales * qs + +def dequantize_q2_k(data): + block_size = GGML_BLOCK_SIZES["Q2_K"] + num_blocks = len(data) // block_size + + data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2) + data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size) + + dmin = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32) + d = data_f16[:, -2].reshape(num_blocks, 1, 1).astype(np.float32) + scales = data_u8[:, :16].reshape(num_blocks, 16, 1) + qs = data_u8[:, 16:80].reshape(num_blocks, 64) + + tmp = np.stack([ + qs[:, 00:16] >> 0, + qs[:, 16:32] >> 0, + qs[:, 00:16] >> 2, + qs[:, 16:32] >> 2, + qs[:, 00:16] >> 4, + qs[:, 16:32] >> 4, + qs[:, 00:16] >> 6, + qs[:, 16:32] >> 6, + qs[:, 32:48] >> 0, + qs[:, 48:64] >> 0, + qs[:, 32:48] >> 2, + qs[:, 48:64] >> 2, + qs[:, 32:48] >> 4, + qs[:, 48:64] >> 4, + qs[:, 32:48] >> 6, + qs[:, 48:64] >> 6, + ], axis=1) + + return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4) + + +def dequantize_q3_k(data): + block_size = GGML_BLOCK_SIZES["Q3_K"] + num_blocks = len(data) // block_size + + data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2) + data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size) + + d = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32) + bits = np.unpackbits(data_u8[:, :32].reshape(num_blocks, 32, 1), axis=-1, bitorder="little") + bits = 4 ^ (bits << 2) + qs = data_u8[:, 32:32 + 64].astype(np.int16) + a, b, c = data_u8[:, 96: 96 + 12].reshape(num_blocks, 3, 4).transpose(1, 0, 2) + scales = np.zeros((num_blocks, 4, 4), dtype=np.uint8) + scales[:, 0] = (a & 15) | ((c & 3) << 4) + scales[:, 1] = (b & 15) | (((c >> 2) & 3) << 4) + scales[:, 2] = (a >> 4) | (((c >> 4) & 3) << 4) + scales[:, 3] = (b >> 4) | ((c >> 6) << 4) + scales = scales.reshape(num_blocks, 16, 1).astype(np.int16) + + return d * (scales - 32) * np.stack([ + (((qs[:, 00:16] >> 0) & 3) - bits[:, :16, 0]), + (((qs[:, 16:32] >> 0) & 3) - bits[:, 16:, 0]), + (((qs[:, 00:16] >> 2) & 3) - bits[:, :16, 1]), + (((qs[:, 16:32] >> 2) & 3) - bits[:, 16:, 1]), + (((qs[:, 00:16] >> 4) & 3) - bits[:, :16, 2]), + (((qs[:, 16:32] >> 4) & 3) - bits[:, 16:, 2]), + (((qs[:, 00:16] >> 6) & 3) - bits[:, :16, 3]), + (((qs[:, 16:32] >> 6) & 3) - bits[:, 16:, 3]), + (((qs[:, 32:48] >> 0) & 3) - bits[:, :16, 4]), + (((qs[:, 48:64] >> 0) & 3) - bits[:, 16:, 4]), + (((qs[:, 32:48] >> 2) & 3) - bits[:, :16, 5]), + (((qs[:, 48:64] >> 2) & 3) - bits[:, 16:, 5]), + (((qs[:, 32:48] >> 4) & 3) - bits[:, :16, 6]), + (((qs[:, 48:64] >> 4) & 3) - bits[:, 16:, 6]), + (((qs[:, 32:48] >> 6) & 3) - bits[:, :16, 7]), + (((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7]) + ], axis=1) + +def dequantize_q4_k(data, device=None): + block_size = GGML_BLOCK_SIZES["Q4_K"] + num_blocks = len(data) // block_size + data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2) + data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size) + # Casting to float32 because float16 is very slow on CPU + scale_factors = data_f16[:, 0].reshape(num_blocks, 1, 1).astype(np.float32) + scale_offsets = data_f16[:, 1].reshape(num_blocks, 1, 1).astype(np.float32) + qs1 = data_u8[:, 4:16].reshape(num_blocks, 12, 1) + qs2 = data_u8[:, 16:].reshape(num_blocks, 4, 32) + # Dequantize scales and offsets (6 bits and 4 + 2 bits) + factors = scale_factors * np.concatenate([qs1[:, 0:4] & 0b111111, (qs1[:, 8:] & 15) | ((qs1[:, 0:4] >> 6) << 4)], axis=1) + offsets = scale_offsets * np.concatenate([qs1[:, 4:8] & 0b111111, (qs1[:, 8:] >> 4) | ((qs1[:, 4:8] >> 6) << 4)], axis=1) + # Interleave low and high quantized bits + qs2 = np.stack([qs2 & 0xf, qs2 >> 4], axis=2).reshape(num_blocks, 8, 32) + # Dequantize final weights using scales and offsets + weight = factors * qs2 - offsets + if device is None: + return weight + return torch.from_numpy(weight).to(device=device) + +def dequantize_q5_k(data): + block_size = GGML_BLOCK_SIZES["Q5_K"] + num_blocks = len(data) // block_size + + data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2) + data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size) + + d = data_f16[:, 0].reshape(num_blocks, 1).astype(np.float32) + dmin = data_f16[:, 1].reshape(num_blocks, 1).astype(np.float32) + scales = data_u8[:, 4:16].reshape(num_blocks, 12, 1) + qh = data_u8[:, 16: 16 + 32].reshape(num_blocks, 32, 1) + qs = data_u8[:, 48: 48 + 128].reshape(num_blocks, 4, 32) + + bits = np.unpackbits(qh, axis=-1, bitorder="little") + + qs_hi_4 = qs >> 4 + qs_lo_4 = qs & 15 + + scales_lo_6 = scales[:, :8] & 63 + scales_hi_6 = scales[:, :8] >> 6 + scales_lo_4 = scales[:, 8:] & 15 + scales_hi_4 = scales[:, 8:] >> 4 + + m1 = dmin * scales_lo_6[:, 4] + m2 = dmin * scales_lo_6[:, 5] + m3 = dmin * scales_lo_6[:, 6] + m4 = dmin * scales_lo_6[:, 7] + m5 = dmin * (scales_hi_4[:, 0] | (scales_hi_6[:, 4] << 4)) + m6 = dmin * (scales_hi_4[:, 1] | (scales_hi_6[:, 5] << 4)) + m7 = dmin * (scales_hi_4[:, 2] | (scales_hi_6[:, 6] << 4)) + m8 = dmin * (scales_hi_4[:, 3] | (scales_hi_6[:, 7] << 4)) + + d1 = d * scales_lo_6[:, 0] + d2 = d * scales_lo_6[:, 1] + d3 = d * scales_lo_6[:, 2] + d4 = d * scales_lo_6[:, 3] + d5 = d * (scales_lo_4[:, 0] | (scales_hi_6[:, 0] << 4)) + d6 = d * (scales_lo_4[:, 1] | (scales_hi_6[:, 1] << 4)) + d7 = d * (scales_lo_4[:, 2] | (scales_hi_6[:, 2] << 4)) + d8 = d * (scales_lo_4[:, 3] | (scales_hi_6[:, 3] << 4)) + + return np.concatenate([ + d1 * (qs_lo_4[:, 0] + (bits[:, :, 0] << 4)) - m1, + d2 * (qs_hi_4[:, 0] + (bits[:, :, 1] << 4)) - m2, + d3 * (qs_lo_4[:, 1] + (bits[:, :, 2] << 4)) - m3, + d4 * (qs_hi_4[:, 1] + (bits[:, :, 3] << 4)) - m4, + d5 * (qs_lo_4[:, 2] + (bits[:, :, 4] << 4)) - m5, + d6 * (qs_hi_4[:, 2] + (bits[:, :, 5] << 4)) - m6, + d7 * (qs_lo_4[:, 3] + (bits[:, :, 6] << 4)) - m7, + d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8, + ], axis=1) + +def dequantize_q6_k(data, device = None): + block_size = GGML_BLOCK_SIZES["Q6_K"] + num_blocks = len(data) // block_size + + data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2) + data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size) + data_i8 = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, block_size) + + scales = data_f16[:, -1].reshape(num_blocks, 1).astype(np.float32) + # TODO use uint8 and cast later? + ql = data_u8[:, :128].astype(np.int16) + qh = data_u8[:, 128:192].astype(np.int16) + sc = data_i8[:, 192:208, np.newaxis].astype(np.float32) + + # Unpack bits, subtraction requires signed data type + q1 = (ql[:, :32 ] & 0xF) | (((qh[:, :32] >> 0) & 3) << 4) - 32 + q2 = (ql[:, 32:64 ] & 0xF) | (((qh[:, :32] >> 2) & 3) << 4) - 32 + q3 = (ql[:, :32 ] >> 4) | (((qh[:, :32] >> 4) & 3) << 4) - 32 + q4 = (ql[:, 32:64 ] >> 4) | (((qh[:, :32] >> 6) & 3) << 4) - 32 + q5 = (ql[:, 64:96 ] & 0xF) | (((qh[:, 32:] >> 0) & 3) << 4) - 32 + q6 = (ql[:, 96:128] & 0xF) | (((qh[:, 32:] >> 2) & 3) << 4) - 32 + q7 = (ql[:, 64:96 ] >> 4) | (((qh[:, 32:] >> 4) & 3) << 4) - 32 + q8 = (ql[:, 96:128] >> 4) | (((qh[:, 32:] >> 6) & 3) << 4) - 32 + + # Dequantize + weight = scales * np.concatenate([ + sc[:, 0] * q1[:, :16], + sc[:, 1] * q1[:, 16:], + sc[:, 2] * q2[:, :16], + sc[:, 3] * q2[:, 16:], + sc[:, 4] * q3[:, :16], + sc[:, 5] * q3[:, 16:], + sc[:, 6] * q4[:, :16], + sc[:, 7] * q4[:, 16:], + sc[:, 8] * q5[:, :16], + sc[:, 9] * q5[:, 16:], + sc[:, 10] * q6[:, :16], + sc[:, 11] * q6[:, 16:], + sc[:, 12] * q7[:, :16], + sc[:, 13] * q7[:, 16:], + sc[:, 14] * q8[:, :16], + sc[:, 15] * q8[:, 16:], + ], axis=1) + + if device is None: + return weight + return torch.from_numpy(weight).to(device=device) + +QK_K = 256 +kvalues_iq4nl = np.array([-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], dtype=np.int8) + +def dequantize_iq4_xs(data): + block_size = GGML_BLOCK_SIZES["IQ4_XS"] + num_blocks = len(data) // block_size + + d = np.frombuffer(data, dtype=np.float16)[0::block_size//2].astype(np.float32).reshape(num_blocks, 1) + scales_h = np.frombuffer(data, dtype=np.uint16)[1::block_size//2].reshape(num_blocks, 1) + data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)[:, 4:] + scales_l = data_u8[:, :4].reshape(num_blocks, 4) + qs = data_u8[:, 4:].reshape(num_blocks, block_size - 8) + + ls = np.zeros((num_blocks, QK_K // 32), dtype=np.int8) + for ib in range(QK_K // 32): + ls[:, ib] = ((scales_l[:, ib // 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h[:, 0] >> 2 * ib) & 3) << 4) + + dl = (d * (ls - 32)).reshape(num_blocks, -1, 1) + + qs_lo_4 = qs[:, :QK_K // 2].reshape(num_blocks, -1, 16) & 0xf + qs_hi_4 = qs[:, :QK_K // 2].reshape(num_blocks, -1, 16) >> 4 + + y = np.zeros((num_blocks, QK_K), dtype=np.float32) + for ib in range(QK_K // 32): + y[:, ib*32:(ib*32)+16] = dl[:, ib] * kvalues_iq4nl[qs_lo_4[:, ib]] + y[:, (ib*32)+16:(ib*32)+32] = dl[:, ib] * kvalues_iq4nl[qs_hi_4[:, ib]] + + return y.flatten() + +GGML_DEQUANTIZE = { + int(GGMLQuantizationType.F32): dequantize_f32, + int(GGMLQuantizationType.F16): dequantize_f16, + int(GGMLQuantizationType.Q4_0): dequantize_q4_0, + int(GGMLQuantizationType.Q5_0): dequantize_q5_0, + int(GGMLQuantizationType.Q8_0): dequantize_q8_0, + int(GGMLQuantizationType.Q2_K): dequantize_q2_k, + int(GGMLQuantizationType.Q3_K): dequantize_q3_k, + int(GGMLQuantizationType.Q4_K): dequantize_q4_k, + int(GGMLQuantizationType.Q5_K): dequantize_q5_k, + int(GGMLQuantizationType.Q6_K): dequantize_q6_k, + int(GGMLQuantizationType.IQ4_XS): dequantize_iq4_xs, +} + + +def dequant_gguf(data, type, shape): + values = GGML_DEQUANTIZE[type](data) + values = torch.from_numpy(values).view(shape) + return values \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/gptq_utils.py b/vllm/model_executor/layers/quantization/utils/gptq_utils.py new file mode 100644 index 0000000..5b0e629 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/gptq_utils.py @@ -0,0 +1,94 @@ +# SPDX-License-Identifier: Apache-2.0 +import re +from copy import deepcopy +from typing import Dict, Optional, Union + +import torch + +from vllm.config import QuantizationConfig +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, UnquantizedEmbeddingMethod) + + +# Match dynamic rules with module name (prefix) and override quantize +# config if module (prefix) matches a rule +def override_config(config: QuantizationConfig, prefix: str): + weight_bits = get_dynamic_override(config, prefix, "bits", + config.weight_bits) + if isinstance(weight_bits, int): + config.weight_bits = weight_bits + group_size = get_dynamic_override(config, prefix, "group_size", + config.group_size) + if isinstance(group_size, int): + config.group_size = group_size + desc_act = get_dynamic_override(config, prefix, "desc_act", + config.desc_act) + if isinstance(desc_act, bool): + config.desc_act = desc_act + + config.pack_factor = 32 // config.weight_bits # packed into int32 + if config.get_name() == "gptq_marlin": + is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym) + if isinstance(is_sym, bool): + config.is_sym = is_sym + + if (config.weight_bits, config.is_sym) not in config.TYPE_MAP: + raise ValueError("Unsupported quantization config: " + f"bits={config.weight_bits}, sym={config.is_sym}") + + config.quant_type = config.TYPE_MAP[(config.weight_bits, + config.is_sym)] + elif config.get_name() == "gptq": + if config.weight_bits not in [2, 3, 4, 8]: + raise ValueError( + "Currently, only 2/3/4/8-bit weight quantization is " + f"supported for GPTQ, but got {config.weight_bits} bits.") + + +def get_dynamic_override( + config: QuantizationConfig, + layer_name: str, + key: Optional[str] = None, + default_value: Union[int, bool, + None] = None) -> Union[Dict, int, bool, None]: + for pattern, pattern_dict in config.dynamic.items(): + # Negative match: matched modules are excluded from quantized init + if pattern.startswith("-:"): + if re.match(pattern.removeprefix("-:"), layer_name): + return False + # Positive match: matched modules have quant properties overrides + # base quant config + elif re.match(pattern.removeprefix("+:"), layer_name): + if key is None: + return pattern_dict + else: + return pattern_dict.get(key, default_value) + return default_value + + +def get_linear_quant_method( + config: QuantizationConfig, + layer: torch.nn.Module, + prefix: str, + linear_method_cls: type, +): + cloned_config = deepcopy(config) + parallel_lm_head_quantized = isinstance( + layer, ParallelLMHead) and cloned_config.lm_head_quantized + if isinstance(layer, LinearBase) or parallel_lm_head_quantized: + # False = skip module, None = no override, else = Positive match + if get_dynamic_override( # noqa: E712 + cloned_config, # noqa: E712 + layer_name=prefix) == False: # noqa: E712 + if parallel_lm_head_quantized: + return UnquantizedEmbeddingMethod() + return UnquantizedLinearMethod() + + if prefix: + # Dynamic per module/layer rules may override base config + override_config(cloned_config, prefix=prefix) + + return linear_method_cls(cloned_config) + return None diff --git a/vllm/model_executor/layers/quantization/utils/layer_utils.py b/vllm/model_executor/layers/quantization/utils/layer_utils.py new file mode 100644 index 0000000..5acae7c --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/layer_utils.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +import torch + + +def update_tensor_inplace(dst: torch.Tensor, src: torch.Tensor): + assert dst.dtype == src.dtype, "Tensors must have the same dtype" + + # update tensor shape and stride + dst.as_strided_(src.shape, src.stride()) + + # If not the same underlying storage move tensor data + if dst.data_ptr() != src.data_ptr(): + dst.copy_(src) + del src + + +# Newly generated tensors need to replace existing tensors that are +# already registered as parameters by vLLM (and won't be freed) +def replace_parameter(mod: torch.nn.Module, name: str, + new: Union[torch.Tensor, torch.nn.Parameter]) -> None: + + old = getattr(mod, name) + if type(old) is type(new) and old.dtype == new.dtype and \ + old.untyped_storage().nbytes() == new.untyped_storage().nbytes(): + # If we can just update in-place to avoid re-registering + # can be faster if the underlying storage is the same + update_tensor_inplace(old, new) + else: + # Fallback re-register parameter, convert to Parameter if necessary + # this not only ensures we don't register a tensor as a parameter, but + # also ensures that all parameter subclasses get re-registered as + # parameters for `torch.compile` compatibility + if not isinstance(new, torch.nn.Parameter): + new = torch.nn.Parameter(new, requires_grad=False) + mod.register_parameter(name, + torch.nn.Parameter(new, requires_grad=False)) diff --git a/vllm/model_executor/layers/quantization/utils/machete_utils.py b/vllm/model_executor/layers/quantization/utils/machete_utils.py new file mode 100644 index 0000000..cb7d49e --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/machete_utils.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional, Tuple + +import torch + +from vllm.scalar_type import ScalarType, scalar_types + +MACHETE_SUPPORTED_GROUP_SIZES = [-1, 128] +MACHETE_PREPACKED_BLOCK_SHAPE = [64, 128] + + +def query_machete_supported_quant_types(zero_points: bool) -> List[ScalarType]: + if zero_points: + return [scalar_types.uint4, scalar_types.uint8] + else: + return [scalar_types.uint4b8, scalar_types.uint8b128] + + +def query_machete_supported_act_types(zero_points: bool) -> List[ScalarType]: + return [torch.float16, torch.bfloat16] + + +def check_machete_supports_shape(in_features: int, out_featrues: int) \ + -> Tuple[bool, Optional[str]]: + if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0: + return False, "Input features size must be divisible by "\ + f"{MACHETE_PREPACKED_BLOCK_SHAPE[0]}" + if out_featrues % MACHETE_PREPACKED_BLOCK_SHAPE[1] != 0: + return False, "Output features size must be divisible by "\ + f"{MACHETE_PREPACKED_BLOCK_SHAPE[1]}" + return True, None diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py new file mode 100644 index 0000000..d1fb52a --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -0,0 +1,399 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional, Tuple + +import numpy +import torch + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.model_executor.layers.linear import LinearBase +from vllm.platforms import current_platform +from vllm.scalar_type import ScalarType, scalar_types + +from .quant_utils import pack_cols, unpack_cols + +GPTQ_MARLIN_TILE = 16 +GPTQ_MARLIN_MIN_THREAD_N = 64 +GPTQ_MARLIN_MIN_THREAD_K = 128 +GPTQ_MARLIN_MAX_PARALLEL = 16 + +MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +# In case there is a performance issue with Marlin, the variable below can be +# changed to False, which allows Marlin to perform global reductions in fp16 +# precision (instead of fp32), and therefore, save on some memory movements. +USE_FP32_REDUCE_DEFAULT = True + + +# For binary size and compile time, we don't support the same types for with and +# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. +# TODO: we may want to move this into the C++ so its closer to the actual impl +def query_marlin_supported_quant_types(has_zp: bool, + device_capability: Optional[int] = None + ): + if device_capability is None: + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) + + if device_capability < 80: + return [] + + if has_zp: + # AWQ style, unsigned + runtime zero-point + return [scalar_types.uint4, scalar_types.uint8] + else: + # GPTQ style, unsigned + symmetric bias + # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able + # to add `scalar_types.float8_e4m3fn` here + return [scalar_types.uint4b8, scalar_types.uint8b128] + + +def _check_marlin_supported( + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]: + + if device_capability is None: + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) + + supported_types = query_marlin_supported_quant_types( + has_zp, device_capability) + + if quant_type not in supported_types: + return (False, f"Marlin does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).") + if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES): + return (False, f"Marlin does not support group_size = {group_size}. " + f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.") + + return True, None + + +def check_marlin_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None) -> bool: + cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, + device_capability) + return cond + + +def verify_marlin_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False) -> None: + cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) + if not cond: + assert err_msg is not None + raise ValueError(err_msg) + + +def verify_marlin_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) -> None: + + # Validate output_size_per_partition + if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: + raise ValueError(f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + # Validate input_size_per_partition + if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: + raise ValueError(f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + if (group_size < input_size + and input_size_per_partition % group_size != 0): + raise ValueError( + f"Weight input_size_per_partition = {input_size_per_partition}" + f" is not divisible by group_size = {group_size}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + +def check_marlin_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) \ + -> Tuple[bool, Optional[str]]: + try: + verify_marlin_supports_shape(output_size_per_partition, + input_size_per_partition, input_size, + group_size) + except ValueError as e: + return False, e.__str__() + return True, None + + +def check_marlin_supports_layer(layer: LinearBase, group_size: int) \ + -> bool: + output_size_per_partition = getattr(layer, "output_size_per_partition", + None) or layer.output_size + input_size_per_partition = getattr(layer, "input_size_per_partition", + None) or layer.input_size + + return check_marlin_supports_shape( + output_size_per_partition=output_size_per_partition, + input_size_per_partition=input_size_per_partition, + input_size=layer.input_size, + group_size=group_size)[0] + + +def marlin_make_workspace(output_size_per_partition: int, + device: torch.device) -> torch.Tensor: + max_workspace_size = (output_size_per_partition // + GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL + + return torch.zeros(max_workspace_size, + dtype=torch.int, + device=device, + requires_grad=False) + + +def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: + return (not act_order) or (act_order and not is_row_parallel) + + +def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int, + is_row_parallel: bool) -> bool: + # Need to repeat scales on every rank if act_ordering or + # channelwise and RowParallelLinear + is_channelwise = group_size == -1 + return act_order or (is_channelwise and is_row_parallel) + + +def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) + + +def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) + + +def marlin_sort_g_idx( + g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) + return g_idx[g_idx_sort_indices], g_idx_sort_indices + + +def get_scale_perms(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, + group_size: int) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_moe_permute_scales( + s: torch.Tensor, + size_k: int, + size_n: int, + group_size: int, +): + num_experts = s.shape[0] + output = torch.empty( + (num_experts, s.shape[1], s.shape[2]), + device=s.device, + dtype=s.dtype, + ) + + for e in range(num_experts): + output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) + return output + + +def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + # Permute zero-points in a similar way to scales, but do not use the + # "single" permutation, since zero-points are applied on every MMA + scale_perm, _ = get_scale_perms() + zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() + zp = zp.reshape((-1, size_n)).contiguous() + zp = pack_cols(zp, num_bits, size_k, size_n) + + return zp + + +def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, + size_n: int, num_bits: int) -> torch.Tensor: + # AWQ zero-points are quantized and packed on the column dim. + # In addition, the values are permuted based on dequantizer. + # Here we undo both of these, and then apply marlin permutation + # and pack it back. + q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) + + # Undo interleaving (use argsort(..) to get inverse perm) + if num_bits == 4: + undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) + elif num_bits == 8: + undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() + q_zp = q_zp.reshape((-1, size_n)).contiguous() + + marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) + return marlin_zp + + +def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, + size_n: int, num_bits: int): + num_experts = q_zp_packed.shape[0] + output = torch.empty( + (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), + device=q_zp_packed.device, + dtype=q_zp_packed.dtype, + ) + for e in range(num_experts): + output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, + num_bits) + return output + + +def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device, + dtype: torch.dtype) -> bool: + # disable atomicAdd reduce by default, + # one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1 + if not envs.VLLM_MARLIN_USE_ATOMIC_ADD or device.type != "cuda": + return False + + # sm8x doesn't support atomicAdd + bfloat16 natively + device_capability = torch.cuda.get_device_capability(device) + if device_capability[0] < 9 and dtype == torch.bfloat16: + return False + + # the performance of atomicAdd is better than global reduce + # only when m*n is small and k is large + return max(m, 64) * n < 64 * 2048 and k >= 2048 + + +def apply_gptq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + wtype: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + is_k_full: bool, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(reshaped_x, + weight, + weight_scale, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + wtype, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=is_k_full, + has_zp=False, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def apply_awq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + quant_type: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(reshaped_x, + weight, + weight_scale, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + quant_type, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=True, + has_zp=True, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py new file mode 100644 index 0000000..6120a8e --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import torch + +import vllm._custom_ops as ops +from vllm.logger import init_logger +from vllm.platforms import current_platform + +from .marlin_utils import marlin_make_workspace, marlin_permute_scales + +logger = init_logger(__name__) + + +def is_fp8_marlin_supported(): + return current_platform.has_device_capability(80) + + +def apply_fp8_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + # For GPUs that lack FP8 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP8 quantization + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n, ) + + output = ops.fp8_marlin_gemm( + a=reshaped_x, + b_q_weight=weight, + b_scales=weight_scale, + workspace=workspace, + num_bits=8, + size_m=reshaped_x.shape[0], + size_n=size_n, + size_k=size_k, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def prepare_fp8_layer_for_marlin(layer: torch.nn.Module, + strategy: str = "tensor") -> None: + logger.warning_once( + "Your GPU does not have native support for FP8 computation but " + "FP8 quantization is being used. Weight-only FP8 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + + device = layer.weight.device + + # WORKSPACE + layer.workspace = marlin_make_workspace(part_size_n, device) + + # WEIGHT + # Repack weights to marlin format + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=pack_fp8_to_int32( + layer.weight), + perm=torch.empty(0, + dtype=torch.int, + device=device), + size_k=part_size_k, + size_n=part_size_n, + num_bits=8) + layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + scales = layer.weight_scale.to(layer.orig_dtype) + # Permute scales + marlin_scales = marlin_permute_scales(s=scales, + size_k=part_size_k, + size_n=part_size_n, + group_size=-1) + layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) + + +def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: + """ + Repack FP8 weights to gptq format (packed int32 elements) + """ + assert fp8_tensor.dtype == torch.float8_e4m3fn + assert fp8_tensor.shape[0] % 4 == 0 + + # Reshape to prepare for packing + reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) + + # Convert fp8 to uint8 (byte) representation + byte_tensor = reshaped.view(torch.uint8) + + # Pack 4 uint8 values into one int32 + packed = (byte_tensor[:, 0].to(torch.int32) | + (byte_tensor[:, 1].to(torch.int32) << 8) | + (byte_tensor[:, 2].to(torch.int32) << 16) | + (byte_tensor[:, 3].to(torch.int32) << 24)) + + return packed.view(fp8_tensor.shape[0] // 4, + *fp8_tensor.shape[1:]).contiguous() diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py new file mode 100644 index 0000000..fb557a3 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py @@ -0,0 +1,164 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Utility functions used for tests and benchmarks""" + +from typing import List, Optional + +import numpy as np +import torch + +from vllm.scalar_type import ScalarType + +from .marlin_utils import (GPTQ_MARLIN_TILE, marlin_permute_scales, + marlin_zero_points) +from .quant_utils import (get_pack_factor, gptq_quantize_weights, + quantize_weights, sort_weights) + + +class MarlinWorkspace: + + def __init__(self, out_features, min_thread_n, max_parallel): + assert (out_features % min_thread_n == 0), ( + "out_features = {} is undivisible by min_thread_n = {}".format( + out_features, min_thread_n)) + + max_workspace_size = ((out_features // min_thread_n) * max_parallel) + + self.scratch = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda") + + +def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + + return q_w + + +def marlin_weights(q_w, size_k, size_n, num_bits, perm): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(np.uint32) + + q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), + dtype=np.uint32) + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) + + return q_packed + + +def get_weight_perm(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = np.array(perm_list) + + if num_bits == 4: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = np.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_quantize(w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None): + size_k, size_n = w.shape + num_bits = quant_type.size_bits + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( + w, quant_type, group_size, act_order, test_perm) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + # Reformat to marlin + weight_perm = get_weight_perm(num_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + + +def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, + group_size: int): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Detect num groups + assert size_k % group_size == 0 + num_groups = size_k // group_size + + # Quantize with zp + w_ref, q_w, s, zp = quantize_weights(w, + quant_type, + group_size, + zero_points=True) + + # Reformat to marlin + weight_perm = get_weight_perm(quant_type.size_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, + weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + marlin_zp = marlin_zero_points(zp, num_groups, size_n, + quant_type.size_bits) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py new file mode 100644 index 0000000..3654268 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py @@ -0,0 +1,464 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Utility functions used for tests and benchmarks""" + +import random +from typing import List + +import numpy +import torch + +from vllm.scalar_type import ScalarType + +from .marlin_utils_test import marlin_weights +from .quant_utils import gptq_quantize_weights + + +# This is PyTorch implementation of main part of reorder_meta() +# function, from tools/util/include/cutlass/util/host_reorder.h file +# of CUTLASS source tree. Furthermore, CUTLASS template for sparse +# GEMM decides upon layout of this matrix, and at the moment for the +# sparse GEMM executed on tensor cores, this is layout described by +# ColumnMajorInterleaved<2> data structure, in +# include/cutlass/layout/matrix.h of CUTLASS source tree. The +# reordering of meta matrix into meta_reordered matrix calculated +# according to these segments of CUTLASS code is re-implemented here. +# Note that this calculation produces offsets for scattering metadata +# matrix elements into reordered metadata matrix elements (or, +# equivalently, for gathering reordered metadata matrix element back +# into metadata matrix elements). +def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, + device): + dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) + dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) + + # Reorder the rows, then swizzle the 2x2 blocks. + group_x = 64 + group_y = 32 if meta_dtype.itemsize == 2 else 16 + + dst_rows = (dst_rows // group_x * group_x + (dst_rows % 2) * 2 + + (dst_rows % 8) // 4 + ((dst_rows % group_y) % 4) // 2 * 32 + + ((dst_rows % group_x) // 8) * 4) + + topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) + bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) + dst_rows += topright - bottomleft + dst_cols -= topright - bottomleft + + # Assumed that meta tensor is to be stored in CUTLASS + # InterleavedColumnMajor layout, and reverse engineered + # corresponding code to store values into this tensor. + interleave = 2 + cols_maj = dst_cols // interleave + cols_min = dst_cols % interleave + return (cols_maj * m * interleave + dst_rows * interleave + + cols_min).view(-1) + + +# This function converts dense matrix into sparse semi-structured +# representation, producing "compressed" matrix, in the layout used by +# CUTLASS backend, and corresponding metadata matrix. +def sparse_semi_structured_from_dense_cutlass(dense): + if dense.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = dense.shape + device = dense.device + + meta_dtype = torch.int8 + if dense.dtype == torch.int8: + meta_dtype = torch.int32 + elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: + meta_dtype = torch.int16 + else: + raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + if quadbits_per_meta_elem not in (4, 8): + raise RuntimeError( + "Invalid number of elements per meta element calculated") + + if meta_dtype == torch.int32: + if m % 16 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 16") + else: + if m % 32 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 32") + if k % (4 * quadbits_per_meta_elem) != 0: + raise RuntimeError( + f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 + ) + + if dense.dtype != torch.float: + ksparse = 4 + dense_4 = dense.view(-1, k // ksparse, ksparse) + m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) + else: + ksparse = 2 + dense_2 = dense.view(-1, k // ksparse, ksparse) + m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) + meta_ncols = k // (ksparse * quadbits_per_meta_elem) + + # Encoding quadruples of True/False values as follows: + # [True, True, False, False] -> 0b0100 + # [True, False, True, False] -> 0b1000 + # [False, True, True, False] -> 0b1001 + # [True, False, False, True ] -> 0b1100 + # [False, True, False, True ] -> 0b1101 + # [False, False, True, True ] -> 0b1110 + # Thus, lower two bits in the encoding are index of the True value + # at the lowest index in the quadruple, and the higher two bits in + # the encoding are index of the other True value in the quadruple. + # In case there are less than two True values, than False value or + # values at some index or indices are considered True for the + # encoding. In case there are more than two True values, then the + # excess True value(s) at some indices are considered False for + # the encoding. The exact encodings used for these cases are as + # follows: + # [False, False, False, False] -> 0b1110 + # [False, False, False, True ] -> 0b1110 + # [False, False, True, False] -> 0b1110 + # [False, True, False, False] -> 0b1001 + # [False, True, True, True ] -> 0b1101 + # [True, False, False, False] -> 0b1000 + # [True, False, True, True ] -> 0b1100 + # [True, True, False, True ] -> 0b0100 + # [True, True, True, False] -> 0b0100 + # [True, True, True, True ] -> 0b0100 + # These particular encodings are chosen, with the help of Espresso + # logic minimizer software, for the purpose of minimization of + # corresponding Boolean functions, that translate non-zero flags + # into encoding bits. Note also possible choices for the first + # and last of these encodings were limited only to (0b0100, + # 0b1110), in order to produce valid encodings for 1:2 sparsity + # case. + + expr0 = m0 & m1 + expr1 = ~m0 & m1 + expr2 = ~m0 & ~m1 + bit0 = expr1 + bit1 = expr2 + bit2 = expr0 | expr2 | m3 + bit3 = expr1 | ~m1 + idxs0 = bit0 | (bit1.to(torch.int64) << 1) + idxs1 = bit2 | (bit3.to(torch.int64) << 1) + + if dense.dtype != torch.float: + sparse0 = dense_4.gather( + -1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined] + sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) + sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) + else: + sparse = dense_2.gather(-1, + idxs0.unsqueeze(-1) // 2).view( + m, + k // 2) # type: ignore[possibly-undefined] + + meta_4 = idxs0 | (idxs1 << 2) + meta_n = meta_4.view( + (-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) + + if quadbits_per_meta_elem == 4: + meta = (meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12)) + elif quadbits_per_meta_elem == 8: + meta = (meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + | (meta_n[:, :, 4] << 16) + | (meta_n[:, :, 5] << 20) + | (meta_n[:, :, 6] << 24) + | (meta_n[:, :, 7] << 28)) + + # Reorder meta tensor elements. + meta_reordered = meta.new_empty( + (m * meta_ncols, )) # type: ignore[possibly-undefined] + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device) + meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) + + return (sparse, meta_reordered.view(m, meta_ncols)) + + +# This function performs reverse of the function above - it +# reconstructs dense matrix from a pair of "compressed" matrix, given +# in the layout used by CUTLASS backend, and accompanying metadata +# matrix. +def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): + if sparse.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = sparse.shape + device = sparse.device + + if meta_reordered.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501 + ) + if meta_reordered.device != device: + raise RuntimeError( + f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501 + ) + + meta_dtype = meta_reordered.dtype + if meta_dtype not in (torch.int16, torch.int32): + raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + + ksparse = 4 if sparse.dtype != torch.float else 2 + + meta_nrows, meta_ncols = meta_reordered.shape + if meta_nrows != m: + raise RuntimeError( + f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501 + ) + if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: + raise RuntimeError( + f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 + "expected according to the number of columns of meta matrix") + + # Undo meta tensor elements reordering. + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device) + meta = torch.gather(meta_reordered.view(-1), 0, + meta_offsets).view(m, meta_ncols) + + # Unpack sparse tensor back to original dense tensor, using + # information provided by meta tensor. Note that torch.float + # datatype is handled pretty much the same as + # torch.half/torch.bfloat16, as metadata for a pair of torch.float + # value is encoded as if underlying 8 bytes contain four + # torch.half/torch.bfloat16 values, where either first two or last + # two are zeros. + meta_2 = torch.empty( + (m, meta_ncols, 2 * quadbits_per_meta_elem), + dtype=meta_dtype, + device=device, + ) + if quadbits_per_meta_elem == 4: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + elif quadbits_per_meta_elem == 8: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + meta_2[:, :, 8] = (meta >> 16) & 0b11 + meta_2[:, :, 9] = (meta >> 18) & 0b11 + meta_2[:, :, 10] = (meta >> 20) & 0b11 + meta_2[:, :, 11] = (meta >> 22) & 0b11 + meta_2[:, :, 12] = (meta >> 24) & 0b11 + meta_2[:, :, 13] = (meta >> 26) & 0b11 + meta_2[:, :, 14] = (meta >> 28) & 0b11 + meta_2[:, :, 15] = (meta >> 30) & 0b11 + + dense_offsets = meta_2.view(-1) + ( + torch.arange(0, 2 * m * k // ksparse, device=device) * 4).view( + -1, 1).repeat(1, 2).view(-1) + + dense = torch.zeros((m * 2 * k, ), dtype=sparse.dtype, device=device) + if sparse.dtype != torch.float: + # dense.scatter_(0, dense_offsets, sparse.view(-1)) + dense.scatter_(0, dense_offsets, sparse.reshape(-1)) + else: + dense.view(torch.half).scatter_(0, dense_offsets, + sparse.view(torch.half).view(-1)) + + return dense.view(m, 2 * k) + + +def mask_creator(tensor): + """ + Class for creating N:M sparsity masks. + Masks will be created using the N:M ratio, where for every block of + M weights, N will be pruned based on ranked weight value. Each mask + will correspond to the given tensor. + + :param N: The number of weights in a group to keep + :param M: The size of a weight group + """ + N = 2 + M = 4 + + mask = None + # for i, tensor in enumerate(tensors): + if tensor.numel() % M != 0: + raise ValueError( + f"Tensor of size {tensor.shape} can't be evenly divided into " + f"{M} groups") + + num_groups = tensor.numel() // M + + # N:M sparsity for linear layers + tensor_temp = tensor.detach().abs().reshape(num_groups, M) + index = torch.argsort(tensor_temp, dim=1)[:, :int(M - N)] + + w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) + mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) + + return mask + + +def inject_24(w, size_k, size_n): + assert w.shape == (size_k, size_n) + + mask = mask_creator(w.t()).t().cuda().bool() + + return (mask * w).contiguous(), mask.contiguous() + + +def check_24(w, num_rows_to_sample=50, _verbose=False): + BLOCK_SIZE = 4 + MAX_NON_ZEROS = 2 + + w = w.t().contiguous() + + print("check_24: w.shape = {}".format(w.shape)) + + num_rows, num_cols = w.shape + sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) + if _verbose: + print(f"Sampled row idxs = {sampled_row_idxs}") + + total_segments = 0 + non_24_segments = 0 + for i in sampled_row_idxs: + for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): + total_segments += 1 + block = w[i, j:j + BLOCK_SIZE] + num_nonzero = torch.count_nonzero(block) + if num_nonzero > MAX_NON_ZEROS: + print("i = {} j = {} block = {}".format(i, j, block)) + non_24_segments += 1 + + print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") + + +def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType): + assert q_24.shape == (size_k, size_n) + + # Remove bias to normalize over 0 + q_24_no_zp = q_24 - wtype.bias + + # Compress + q_24_no_zp = q_24_no_zp.t().contiguous() + q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass( + q_24_no_zp) + q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() + + # Restore bias + q_24_comp = q_24_no_zp_comp + wtype.bias + + # Resize meta to its actual shape (without moving any data) + meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) + + return q_24_comp, meta + + +def get_scale_perms_24(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) + scale_perm_single: List[int] = [] + for i in range(8): + scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) + return scale_perm, scale_perm_single + + +def get_weight_perm_24(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + col_o = col // 2 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + + 4 * block) + for j in range(4): + perm_list.extend([p + 1 * j for p in perm1]) + perm = numpy.array(perm_list) + + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_permute_scales_24(s: torch.Tensor, size_k: int, size_n: int, + group_size: int) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms_24() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_24_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Inject 2:4 sparsity + w_24, mask_24 = inject_24(w, size_k, size_n) + + # Quantize + w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights( + w_24, quant_type, group_size, act_order=False) + + # Compress quantized weight + q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, + quant_type) + size_k_comp = size_k // 2 + + # Reformat to marlin + weight_perm = get_weight_perm_24(quant_type.size_bits) + marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n, + quant_type.size_bits, weight_perm) + marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size) + + # Create result + res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py new file mode 100644 index 0000000..176b294 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List + +import numpy +import torch + +from .marlin_utils_test import marlin_permute_weights +from .quant_utils import get_pack_factor, qqq_quantize_weights + + +def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), + dtype=numpy.uint32) + if group_size == size_k: + for i in range(pack_factor): + q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i + else: + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) + + return q_packed + + +def get_qqq_scale_perms(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501 +def get_qqq_weight_perm(num_bits: int, quant_type: str): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 4 * (i % 4), + 4 * (i % 4) + 1, + 4 * (i % 4) + 2, + 4 * (i % 4) + 3, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = numpy.array(perm_list) + + assert quant_type in ["per-channel", + "per-group"], "not supported quantization type" + if num_bits == 4: + if quant_type == "per-channel": + interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3]) + else: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + else: + raise Exception("num_bits must be 4, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size): + scale_perm, scale_perm_single = get_qqq_scale_perms() + if group_size < size_k and group_size != -1: + s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm] + s_channel = s_channel.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s_group = s_group.reshape((-1, size_n)).contiguous() + else: + s_channel = s_channel.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s_channel = s_channel.reshape((-1, size_n)).contiguous() + + return s_group, s_channel + + +def marlin_qqq_quantize( + w: torch.Tensor, + num_bits: int, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + quant_type = "per-channel" if group_size == size_k else "per-group" + + # Quantize + w_ref, q_w, s_group, s_channel = qqq_quantize_weights( + w, num_bits, group_size) + + # Reformat to marlin_qqq + weight_perm = get_qqq_weight_perm(num_bits, quant_type) + marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits, + weight_perm, group_size) + marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales( + s_group, s_channel, size_k, size_n, group_size) + + # Create result + res_list = [ + w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel + ] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py new file mode 100644 index 0000000..45e8253 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -0,0 +1,568 @@ +# SPDX-License-Identifier: Apache-2.0 +"""This file is used for /tests and /benchmarks""" +from types import MappingProxyType +from typing import List, Mapping, Optional, Tuple + +import numpy +import torch + +from vllm.model_executor.layers.quantization.qqq import ( + MARLIN_QQQ_SUPPORTED_NUM_BITS) +from vllm.scalar_type import ScalarType, scalar_types + +SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + + +# Normalize the group_shape to the full extent for any dims that are -1 +def _normalize_quant_group_shape(x: torch.Tensor, group_shape: Tuple[int, + int]): + # -1 means full extent + return (group_shape[0] if group_shape[0] > 0 else x.shape[-2], + group_shape[1] if group_shape[1] > 0 else x.shape[-1]) + + +# Useful when treating N-dimensional group scaling as extended numpy-style +# broadcasting in numpy simply stretches dimensions with an extent of 1 to match +# the target shape by repeating the data along that dimension (broadcasting) +# , we extend these semantics to say if the extent of a dimension in the +# source shape is not 1 and does not match the target shape we repeat each +# element along that dimension src_shape[dim] // target_shape[dim] times +# example if we have: +# a = [[1, 2], and target_shape = (2, 4) +# [3, 4]] +# then we would expand a to: +# a = [[1, 1, 2, 2], +# [3, 3, 4, 4]] +# NOTE this function this function does not explicitly broadcast dimensions +# with an extent of 1, since this can be done implicitly by pytorch +def group_broadcast(t, shape): + for i, s in enumerate(shape): + if t.shape[i] != s and t.shape[i] != 1: + assert s % t.shape[i] == 0 + t = t.unsqueeze(i + 1)\ + .expand(*t.shape[:i+1], s // t.shape[i], *t.shape[i+1:])\ + .flatten(i, i + 1) + return t + + +# Quantize assuming once scale per group of elements with shape group_shape, +# example group shapes: +# * (-1, -1) for per-tensor quantization +# * (1, -1) for per-row quantization +# * (-1, 1) for per-column quantization +# * (128, 128) for 128x128 deepseek style block quantization +# * (1, 128) for deepseek style activation quantization +# (i.e. per-token-per-group) +def scaled_quantize( + x: torch.Tensor, + group_shape: Tuple[int, int], + quant_dtype: torch.dtype, +) -> Tuple[torch.Tensor, torch.Tensor]: + group_shape = _normalize_quant_group_shape(x, group_shape) + + finfo = torch.finfo(quant_dtype) if quant_dtype.is_floating_point else torch.iinfo(quant_dtype) + + # Reshape (M, N) into (BLK_M, BLOCK_SIZE_M, BLK_N, BLOCK_SIZE_N) + assert x.ndim == 2 + assert x.shape[0] % group_shape[0] == 0 and x.shape[1] % group_shape[1] == 0 + blk_m, blk_n = x.shape[0] // group_shape[0], x.shape[1] // group_shape[1] + x_blkd = x.reshape(blk_m, group_shape[0], blk_n, group_shape[1]) + + # Permute to (BLK_M, BLK_N, BLOCK_SIZE_M, BLOCK_SIZE_N) + x_blkd_permd = x_blkd.permute(0, 2, 1, 3) + # Flatten to (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) + x_blkd_permd = x_blkd_permd.flatten(start_dim=2) + + # Compute scales + min_val, max_val = x_blkd_permd.aminmax(dim=-1) + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax + + # Apply scale and convert form: + # (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) to (M, N) + x_scl_sat = (x_blkd_permd * scale.unsqueeze(-1))\ + .clamp(min=finfo.min, max=finfo.max)\ + .reshape(blk_m, blk_n, group_shape[0], group_shape[1])\ + .permute(0, 2, 1, 3)\ + .reshape(x.shape) + + return x_scl_sat.to(quant_dtype).contiguous(), scale.float().reciprocal() + + +# inverses `scaled_quantize` +def scaled_dequantize( + x_q: torch.Tensor, + x_s: torch.Tensor, + group_shape: Optional[Tuple[int, int]] = None, + out_dtype: torch.dtype = torch.float32, +) -> Tuple[torch.Tensor, torch.Tensor]: + if group_shape is not None: + group_shape = _normalize_quant_group_shape(x_q, group_shape) + + if x_s.ndim == 0: # scalar + x_s = x_s.unsqueeze(-1).unsqueeze(-1) # convert to (1, 1) tensor + if x_s.ndim == 1: + if group_shape is None: + raise AssertionError( + "if x_s is 1D tensor, group_shape must be provided otherwise " + "its ambiguous which dimension to broadcast x_s to") + # unsqueeze the scales for the dimension where we want to broadcast + # across the full extent + if group_shape[0] == x_q.shape[-2]: + x_s = x_s.unsqueeze(-2) + elif group_shape[1] == x_q.shape[-1]: + x_s = x_s.unsqueeze(-1) + else: + raise AssertionError( + "if x_s is a vector we should be broadcasting it to the full " + "extent of one of the dimensions") + + if group_shape is not None: + assert x_s.shape[-1] == x_q.shape[-1] // group_shape[1] + assert x_s.shape[-2] == x_q.shape[-2] // group_shape[0] + x_s = group_broadcast(x_s.to(torch.float32), x_q.shape) + return (x_q.to(torch.float32) * x_s).to(out_dtype) + + +def pack_quantized_values_into_int32(w_q: torch.Tensor, + wtype: ScalarType, + packed_dim: int = 0): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + assert w_q_perm.shape[-1] % pack_factor == 0 + new_shape_perm[-1] //= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i + + return res.permute(inv_perm) + + +def unpack_quantized_values_into_int32(w_q: torch.Tensor, + wtype: ScalarType, + packed_dim: int = 0): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + new_shape_perm[-1] *= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask + + return res.permute(inv_perm) + + +def is_layer_skipped( + prefix: str, + ignored_layers: List[str], + fused_mapping: Mapping[str, List[str]] = MappingProxyType({}) +) -> bool: + # prefix: model.layers.0.self_attn.q_proj + # proj_name: q_proj + proj_name = prefix.split(".")[-1] + + # Fused layers like gate_up_proj or qkv_proj will not be fused + # in the safetensors checkpoint. So, we convert the name + # from the fused version to unfused + check to make sure that + # each shard of the fused layer has the same scheme. + if proj_name in fused_mapping: + shard_prefixes = [ + prefix.replace(proj_name, shard_proj_name) + for shard_proj_name in fused_mapping[proj_name] + ] + + is_skipped = None + for shard_prefix in shard_prefixes: + is_shard_skipped = shard_prefix in ignored_layers + + if is_skipped is None: + is_skipped = is_shard_skipped + elif is_shard_skipped != is_skipped: + raise ValueError( + f"Detected some but not all shards of {prefix} " + "are quantized. All shards of fused layers " + "to have the same precision.") + else: + is_skipped = prefix in ignored_layers + + assert is_skipped is not None + return is_skipped + + +def get_pack_factor(num_bits): + assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + + +def permute_rows(q_w: torch.Tensor, + w_ref: torch.Tensor, + group_size: int, + test_perm: Optional[torch.Tensor] = None): + assert q_w.shape == w_ref.shape + + orig_device = q_w.device + k_size, _ = q_w.shape + + g_idx = torch.zeros((k_size, ), dtype=torch.int32) + for i in range(k_size): + g_idx[i] = i // group_size + + # Simulate act_order by doing a random permutation on K + rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) + + g_idx = g_idx[rand_perm].contiguous() + q_w = q_w[rand_perm, :].contiguous() + w_ref = w_ref[rand_perm, :].contiguous() + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + rand_perm.to(device=orig_device), + ) + + +def quantize_weights(w: torch.Tensor, + quant_type: ScalarType, + group_size: Optional[int], + zero_points: bool = False, + ref_zero_points_after_scales: bool = False): + assert quant_type.is_integer(), \ + "Floating point quantization may work but has not been tested" + assert not zero_points or group_size is not None, \ + "to have group zero points, group_size must be provided "\ + "(-1 group_size is channelwise)" + + orig_device = w.device + orig_type = w.dtype + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + + if group_size == -1: + group_size = size_k + + # Reshape to [groupsize, -1] + if group_size is not None and group_size < size_k: + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + max_val = torch.max(w, 0, keepdim=True).values + min_val = torch.min(w, 0, keepdim=True).values + + max_q_val = quant_type.max() + min_q_val = quant_type.min() + + w_s = torch.Tensor([1.0]).to(w.device) # unscaled case + maybe_w_zp = None + if group_size is not None: + if zero_points: + assert not quant_type.is_signed() and quant_type.max() > 0 + w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() + maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \ + .clamp(min_q_val, max_q_val).int() + else: + # If the bias is such that there are no possible negative/positive + # values, set the max value to inf to avoid divide by 0 + w_s = torch.max( + abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), + abs(min_val / (min_q_val if min_q_val != 0 else torch.inf))) + + # Quantize + w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) + w_q = torch.clamp(w_q, min_q_val, max_q_val) + + # Compute ref (dequantized) + # For some kernels (namely Machete) the zero-points are applied after the + # scales are applied, for this case computing the reference in similar way + # allows us to use tighter error tolerances in our unit tests. + if ref_zero_points_after_scales and maybe_w_zp is not None: + w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s + else: + w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s + + if quant_type.has_bias(): + w_q += quant_type.bias + + # Restore original shapes + if group_size is not None and group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + w_q = reshape_w(w_q) + w_ref = reshape_w(w_ref) + w_s = w_s.reshape((-1, size_n)).contiguous() + + if maybe_w_zp is not None: + maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() + maybe_w_zp = maybe_w_zp.to(device=orig_device) + + return ( + w_ref.to(device=orig_device), + w_q.to(device=orig_device), + w_s if group_size is not None else None, + maybe_w_zp, + ) + + +def gptq_quantize_weights(w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None): + size_k, _ = w.shape + + assert w.is_floating_point(), "w must be float" + assert quant_type in SUPPORTED_GPTQ_QUANT_TYPES, \ + f"Unsupported gptq type = {quant_type}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) + + # Apply act_order + g_idx = torch.empty(0, dtype=torch.int, device=w.device) + rand_perm = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + assert ( + group_size < size_k + ), "For act_order, groupsize = {} must be less than size_k = {}".format( + group_size, size_k) + + w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, + test_perm) + + return w_ref, w_q, w_s, g_idx, rand_perm + + +# QQQ employs different quant schemes for per-group and +# per-channel quantization. +def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): + orig_device = w.device + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + assert num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS, \ + f"Unsupported num_bits = {num_bits}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + if group_size < size_k: + # Reshape to [groupsize, -1] + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + max_q_val = 2**num_bits - 1 + half_q_val = (max_q_val + 1) // 2 + + # Compute scale for each group + s_group = torch.max(torch.abs(w), 0, keepdim=True)[0] + s_group *= 2 / max_q_val # 2 => symmetric + + # Quantize + q_w = torch.round(w / s_group).int() + q_w += half_q_val + q_w = torch.clamp(q_w, 0, max_q_val) + # Compute ref (dequantized) + w_ref = (q_w - half_q_val).half() * s_group + + # Restore original shapes + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + q_w = reshape_w(q_w) + w_ref = reshape_w(w_ref) + + # Compute int8 quantization scale for each channel + s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0] + s_channel /= 127.0 + t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8) + w_ref = t_int8.half() * s_channel + s_channel = s_channel.reshape(1, -1).to(dtype=torch.float) + + # Fuse scales + s_group = (s_group.reshape(-1, size_n).contiguous() / + s_channel).to(dtype=torch.half) + else: + max_q_val = 2**(num_bits - 1) - 1 + + # Compute scale for each channel + s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0] + s_channel /= max_q_val + + # Quantize + q_w = torch.round(w / s_channel).int() + q_w = torch.clamp(q_w, -max_q_val, max_q_val) + # Compute ref (dequantized) + w_ref = q_w.half() * s_channel + + s_group = torch.tensor([], dtype=torch.half) + # div 2 ** (8 - self.bits)) to offset right shift in unpacking + s_channel /= (2**(8 - num_bits)) + s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float) + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + s_group.to(device=orig_device), + s_channel.to(device=orig_device), + ) + + +def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): + orig_device = q_w.device + + sort_indices = torch.argsort(g_idx).to( + dtype=torch.int32) # Sort based on g_idx + + g_idx = g_idx[sort_indices].contiguous() + q_w = q_w[sort_indices, :].contiguous() + + return ( + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + sort_indices.to(device=orig_device), + ) + + +def pack_rows( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_k % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[i::pack_factor, :] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + return q_res + + +def pack_cols( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[:, i::pack_factor] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def unpack_cols( + packed_q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + assert packed_q_w.shape == ( + size_k, size_n // pack_factor + ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( + packed_q_w.shape, size_k, size_n, pack_factor) + + orig_device = packed_q_w.device + + packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) + q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) + + mask = (1 << num_bits) - 1 + for i in range(pack_factor): + vals = packed_q_w_cpu & mask + packed_q_w_cpu >>= num_bits + q_res[:, i::pack_factor] = vals + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def gptq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + return pack_rows(q_w, num_bits, size_k, size_n) + + +def awq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() + q_w = q_w.reshape((-1, size_n)).contiguous() + + return pack_cols(q_w, num_bits, size_k, size_n) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py new file mode 100644 index 0000000..b8e6384 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -0,0 +1,322 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional, Tuple, Union + +import torch + +from vllm import _custom_ops as ops +from vllm.config import CompilationLevel, get_current_vllm_config +from vllm.platforms import current_platform + +# Input scaling factors are no longer optional in _scaled_mm starting +# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale +TORCH_DEVICE_IDENTITY = None + +# The condition to determine if it is on a platform that supports +# torch._scaled_mm rowwise feature. +# The condition is determined once as the operations +# are time consuming. +USE_ROWWISE_TORCH_SCALED_MM = (current_platform.is_rocm() + and current_platform.has_device_capability(94)) + + +def sparse_cutlass_supported() -> bool: + if not current_platform.is_cuda(): + return False + + capability_tuple = current_platform.get_device_capability() + capability = -1 if capability_tuple is None else capability_tuple.to_int() + + return ops.cutlass_sparse_scaled_mm_supported(capability) + + +def cutlass_fp8_supported() -> bool: + if not current_platform.is_cuda(): + return False + + capability_tuple = current_platform.get_device_capability() + capability = -1 if capability_tuple is None else capability_tuple.to_int() + + return ops.cutlass_scaled_mm_supports_fp8(capability) + + +def cutlass_block_fp8_supported() -> bool: + if not current_platform.is_cuda(): + return False + + capability_tuple = current_platform.get_device_capability() + capability = -1 if capability_tuple is None else capability_tuple.to_int() + + return ops.cutlass_scaled_mm_supports_block_fp8(capability) + + +def cutlass_group_gemm_supported() -> bool: + if not current_platform.is_cuda(): + return False + + capability_tuple = current_platform.get_device_capability() + capability = -1 if capability_tuple is None else capability_tuple.to_int() + + return ops.cutlass_group_gemm_supported(capability) + + +CUTLASS_FP8_SUPPORTED = cutlass_fp8_supported() +CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported() + + +def per_tensor_dequantize( + tensor: torch.Tensor, inv_scale: Union[float, + torch.Tensor]) -> torch.Tensor: + fake_qweight = tensor.to(torch.float16) + dq_weight = fake_qweight * inv_scale + return dq_weight + + +def all_close_1d(x: torch.Tensor) -> bool: + assert len(x.shape) == 1 + return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) + + +def convert_to_channelwise( + weight_scale: torch.Tensor, + logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: + # Create channelwise buffer + weight_scale_channel = torch.empty((sum(logical_widths), 1), + dtype=torch.float32, + device=weight_scale.device) + + # Expand each scale to match the size of each logical matrix. + start = 0 + for idx, logical_width in enumerate(logical_widths): + end = start + logical_width + weight_scale_channel[start:end, :] = weight_scale[idx] + start = end + + return weight_scale_channel + + +def requantize_with_max_scale( + weight: torch.Tensor, weight_scale: torch.Tensor, + logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: + # Max scale to be used for requanitzation. + max_w_scale = weight_scale.max() + + # QKV / MLP is fused in the on disk checkpoint if any of the + # weight scales are still set to the default since we initialize + # N weight scales for N shards but we only load 1 weight scale + # from disk in this case. Skip requantization in this case (since) + # we already are quantized with the single scale. + # * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8 + unfused_module_in_checkpoint = (weight_scale[-1] + > torch.finfo(torch.float8_e4m3fn).min) + + # If unfused checkpoint, need requanize with the single scale. + if unfused_module_in_checkpoint: + start = 0 + for idx, logical_width in enumerate(logical_widths): + end = start + logical_width + weight_dq = per_tensor_dequantize(weight[start:end, :], + weight_scale[idx]) + weight[start:end, :], _ = ops.scaled_fp8_quant( + weight_dq, max_w_scale) + start = end + + return max_w_scale, weight + + +def maybe_create_device_identity(): + # Allocate dummy ones tensor for torch._scaled_mm + global TORCH_DEVICE_IDENTITY + if TORCH_DEVICE_IDENTITY is None: + TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) + + +# TODO(luka): follow similar pattern for marlin and block-fp8-linear +# https://github.com/vllm-project/vllm/issues/14397 +class Fp8LinearOp: + """ + This class executes a FP8 linear layer using cutlass if supported and + torch.scaled_mm otherwise. + It needs to be a class instead of a method so that config can be read + in the __init__ method, as reading config is not allowed inside forward. + """ + + def __init__(self, + cutlass_fp8_supported: bool = cutlass_fp8_supported(), + use_per_token_if_dynamic: bool = False, + pad_output: Optional[bool] = None): + self.cutlass_fp8_supported = cutlass_fp8_supported + self.use_per_token_if_dynamic = use_per_token_if_dynamic + + # Note: we pad the input because torch._scaled_mm is more performant + # for matrices with batch dimension > 16. + # This could change in the future. + # We also don't pad when using torch.compile, + # as it breaks with dynamic shapes. + if pad_output is None: + config = get_current_vllm_config().compilation_config + pad_output = config.level < CompilationLevel.PIECEWISE + self.output_padding = 17 if pad_output else None + + def apply( + self, + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + out_dtype: Optional[torch.dtype] = None, + input_scale: Optional[torch.Tensor] = None, + input_scale_ub: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + # TODO(luka) remove this parameter in favor of __init__ + use_per_token_if_dynamic: Optional[bool] = None + ) -> torch.Tensor: + # ops.scaled_fp8_quant supports both dynamic and static quant. + # If dynamic, layer.input_scale is None and x_scale computed from x. + # If static, layer.input_scale is scalar and x_scale is input_scale. + + # View input as 2D matrix for fp8 methods + input_2d = input.view(-1, input.shape[-1]) + output_shape = [*input.shape[:-1], weight.shape[1]] + + # TODO(luka) this is here because currently MLA only decides this + # during the forward method instead of in __init__. + if use_per_token_if_dynamic is None: + use_per_token_if_dynamic = self.use_per_token_if_dynamic + + if out_dtype is None: + out_dtype = input.dtype + + # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A + if self.cutlass_fp8_supported: + assert input.dtype != current_platform.fp8_dtype( + ), "FP8 input to cutlass is not currently implemented" + qinput, x_scale = ops.scaled_fp8_quant( + input_2d, + input_scale, + scale_ub=input_scale_ub, + use_per_token_if_dynamic=use_per_token_if_dynamic) + + # Fused GEMM_DQ + output = ops.cutlass_scaled_mm(qinput, + weight, + out_dtype=out_dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias) + return output.view(*output_shape) + + # torch.scaled_mm supports per tensor weights + activations only + # so fallback to naive if per channel or per token + else: + if input.dtype != current_platform.fp8_dtype(): + # Maybe apply padding to output, see comment in __init__ + qinput, x_scale = ops.scaled_fp8_quant( + input_2d, + input_scale, + num_token_padding=self.output_padding, + use_per_token_if_dynamic=use_per_token_if_dynamic) + else: + qinput, x_scale = input_2d, input_scale + + per_tensor_weights = (weight_scale.numel() == 1) + per_tensor_activations = (x_scale.numel() == 1) + + if per_tensor_weights and per_tensor_activations: + # Fused GEMM_DQ + output = torch._scaled_mm(qinput, + weight, + out_dtype=out_dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] + + return torch.narrow(output, 0, 0, + input_2d.shape[0]).view(*output_shape) + + elif (use_per_token_if_dynamic and not per_tensor_weights + and not per_tensor_activations + and USE_ROWWISE_TORCH_SCALED_MM): + # For now validated on ROCm platform + # fp8 rowwise scaling in torch._scaled_mm is introduced in + # https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt + # and ROCm 6.3, which only exists in torch 2.7 and above. + # For CUDA platform please validate if the + # torch._scaled_mm support rowwise scaled GEMM + # Fused GEMM_DQ Rowwise GEMM + output = torch._scaled_mm(qinput, + weight, + out_dtype=out_dtype, + scale_a=x_scale, + scale_b=weight_scale.t(), + bias=bias) + + output = torch.narrow(output, 0, 0, input_2d.shape[0]) + output = output.view(*output_shape) + return output + + else: + # Fallback for channelwise case, where we use unfused DQ + # due to limitations with scaled_mm + + # Symmetric quantized GEMM by definition computes the following: + # C = (s_x * X) (s_w * W) + bias + # This is equivalent to dequantizing the weights and activations + # before applying a GEMM. + # + # In order to compute quantized operands, a quantized kernel + # will rewrite the above like so: + # C = s_w * s_x * (X * W) + bias + # + # For the scaled_mm fallback case, we break this down, since it + # does not support s_w being a vector. + + # GEMM + # This computes C = (X * W). + # Output in fp32 to allow subsequent ops to happen in-place + output = torch._scaled_mm(qinput, + weight, + scale_a=TORCH_DEVICE_IDENTITY, + scale_b=TORCH_DEVICE_IDENTITY, + out_dtype=torch.float32) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] + # Unpad (undo num_token_padding) + output = torch.narrow(output, 0, 0, input_2d.shape[0]) + x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0]) + + # DQ + # C = sw * sx * (X * W) + bias + output = output * x_scale * weight_scale.t() + if bias is not None: + output = output + bias + return output.to(dtype=input.dtype).view(*output_shape) + + +def normalize_e4m3fn_to_e4m3fnuz( + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + assert weight.dtype == torch.float8_e4m3fn + # The bits pattern 10000000(-128) represents zero in e4m3fn + # but NaN in e4m3fnuz. So here we set it to 0. + # https://onnx.ai/onnx/technical/float8.html + weight_as_int8 = weight.view(torch.int8) + ROCM_FP8_NAN_AS_INT = -128 + weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0 + weight = weight_as_int8.view(torch.float8_e4m3fnuz) + + # For the same bits representation, e4m3fnuz value is half of + # the e4m3fn value, so we should double the scaling factor to + # get the same dequantized value. + # https://onnx.ai/onnx/technical/float8.html + weight_scale = weight_scale * 2.0 + if input_scale is not None: + input_scale = input_scale * 2.0 + return weight, weight_scale, input_scale diff --git a/vllm/model_executor/layers/quantization/w8a16.py b/vllm/model_executor/layers/quantization/w8a16.py new file mode 100644 index 0000000..6c42ce7 --- /dev/null +++ b/vllm/model_executor/layers/quantization/w8a16.py @@ -0,0 +1,114 @@ +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.parameter import (GroupQuantScaleParameter, + PackedvLLMParameter) +from vllm.model_executor.utils import set_weight_attrs + + +class W8a16Config(QuantizationConfig): + """Config class for W8a16. + + """ + + def __init__( + self, + ) -> None: + pass + + def __repr__(self) -> str: + return ("W8a16Config") + + def get_name(self) -> str: + return "w8a16" + + def get_supported_act_dtypes(self) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + def get_min_capability(self) -> int: + return 75 + + @staticmethod + def get_config_filenames(): + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "W8a16Config": + return cls() + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["W8a16LinearMethod"]: + if isinstance(layer, LinearBase): + return W8a16LinearMethod(self) + return None + + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class W8a16LinearMethod(LinearMethodBase): + """Linear method for w8a16. + + """ + + def __init__(self, quant_config: W8a16Config): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + output_size_per_partition = sum(output_partition_sizes) + weight = Parameter( + torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.int8, + ), + requires_grad=False, + ) + set_weight_attrs( + weight, { + "input_dim": 1, + "output_dim": 0, + }) + + scales = Parameter( + torch.empty( + 1, + output_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(scales, { + "input_dim": None, + "output_dim": 1, + }) + + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + layer.register_parameter("scales", scales) + set_weight_attrs(scales, extra_weight_attrs) + + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + qweight = layer.weight + scales = layer.scales + out_shape = (x.shape[:-1] + (qweight.shape[-2],)) + reshaped_x = x.reshape(-1, x.shape[-1]) + out = ops.linear_w8a16(reshaped_x, qweight, scales, format="TN") + if bias is not None: + out = out + bias + return out.reshape(out_shape) \ No newline at end of file diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py new file mode 100644 index 0000000..617f2ad --- /dev/null +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -0,0 +1,400 @@ +# SPDX-License-Identifier: Apache-2.0 + +from functools import cached_property +from importlib.util import find_spec +from typing import Dict, Optional, Tuple + +import torch +import torch.jit + +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.model_executor.layers.spec_decode_base_sampler import ( + SpecDecodeStochasticBaseSampler) +from vllm.platforms import current_platform + +logger = init_logger(__name__) + +# if find_spec("flashinfer"): +# """ +# Consider utilizing the FlashInfer rejection sampling kernel initially, +# as it employs a dedicated kernel rather than relying on +# Torch tensor operations. This design choice helps to fuse operations, +# reduce memory I/O, and consequently enhances performance. +# """ +# from flashinfer.sampling import chain_speculative_sampling +# else: +# chain_speculative_sampling = None +chain_speculative_sampling = None + +class RejectionSampler(SpecDecodeStochasticBaseSampler): + """Apply modified rejection sampling as described in "Accelerating Large + Language Model Decoding with Speculative Sampling" + https://arxiv.org/pdf/2302.01318.pdf. + """ + + def __init__(self, + strict_mode: bool = False, + use_flashinfer: Optional[bool] = None): + """Create a rejection sampler. + + Args: + strict_mode: Whether or not to perform shape/device/dtype checks + during sampling. This catches correctness issues but adds + nontrivial latency. + use_flashinfer: We will use this parameter to determine whether + to use the FlashInfer rejection sampling kernel or not. If it's + None, we will use the default value from the environment variable. + This parameter is only used for testing purposes. + """ + super().__init__(strict_mode=strict_mode) + if use_flashinfer is None: + self.use_flashinfer = envs.VLLM_USE_FLASHINFER_SAMPLER and ( + chain_speculative_sampling is not None) + else: + self.use_flashinfer = use_flashinfer + + if self.use_flashinfer: + logger.info("Use flashinfer for rejection sampling.") + else: + logger.info("Use pytorch for rejection sampling.") + + def forward( + self, + target_with_bonus_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + seeded_seqs: Optional[Dict[int, torch.Generator]] = None, + ) -> torch.Tensor: + """Sample token ids using rejection sampling. This accepts or rejects + tokens proposed by the draft model using the probability of each token + according to the draft and target models. + + In the worst case where all draft tokens are rejected, it is guaranteed + one correct token will be emitted. + + In the case where all draft tokens are accepted, a bonus token will be + accepted as its cheap to have the target model score this speculative + sequence. + + Args: + target_with_bonus_probs: The probability distribution + over token ids given context according to the target model. + shape = [batch_size, num_speculative_tokens + 1, vocab_size] + + bonus_token_ids: The "bonus" token ids that are accepted iff all + speculative tokens in a sequence are accepted. + shape = [batch_size, num_bonus_tokens] + + draft_probs: The probability distribution over token ids given + context according to the draft model. + shape = [batch_size, num_speculative_tokens, vocab_size] + + draft_token_ids: The token ids that were sampled from the draft + probabilities. + shape = [batch_size, num_speculative_tokens] + + seeded_seqs: Dict of batch row index to torch generator, for + sequences using seeded generation. + + Returns: + output_token_ids: The token ids sampled via rejection sampling, + or -1 if unable to sample a token because the previous token + was rejected. + shape = [batch_size, num_speculative_tokens + num_bonus_tokens] + """ + # Only perform shape/dtype/device checking in strict mode, as it adds + # overhead. + if self._strict_mode: + self._raise_if_incorrect_input(target_with_bonus_probs, + draft_token_ids, bonus_token_ids, + draft_probs) + + batch_size, k, _ = draft_probs.shape + + # batch_size = 0 when all requests in the batch are + # non_spec requests. In this case, output_token_ids is + # just an empty tensor. + if batch_size == 0: + return torch.empty(0, k + 1, device=draft_probs.device, dtype=int) + + # If use Flashinfer chain_speculative_sampling kernel + # for rejection sampling + if self.use_flashinfer and chain_speculative_sampling is not None: + batch_size, k, _ = draft_probs.shape + uniform_samples = self._create_uniform_samples( + seeded_seqs, batch_size, k, draft_probs.device) + output_token_ids, accepted_token_num, emitted_token_num \ + = chain_speculative_sampling( + draft_probs, draft_token_ids, uniform_samples, + target_with_bonus_probs) + + # num_emitted_tokens returned by flashinfer + # does not include the bonus token + # Flashinfer stops at the first token that violates + # the condition p >= q and does not include recovery/bonus token. + # Therefore, we need to add batch_size here. + self.num_accepted_tokens += accepted_token_num.sum() + self.num_emitted_tokens += emitted_token_num.sum() + batch_size + self.num_draft_tokens += batch_size * k + else: + accepted, recovered_token_ids = ( + self._batch_modified_rejection_sampling( + target_with_bonus_probs[:, :-1], + draft_probs, + draft_token_ids, + seeded_seqs, + )) + + output_token_ids = self._create_output( + accepted, + recovered_token_ids, + draft_token_ids, + bonus_token_ids, + ) + + return output_token_ids + + def _batch_modified_rejection_sampling( + self, + target_probs: torch.Tensor, # [batch_size, k, vocab_size] + draft_probs: torch.Tensor, # [batch_size, k, vocab_size] + draft_token_ids: torch.Tensor, # [batch_size, k] + seeded_seqs: Optional[Dict[int, torch.Generator]], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Perform modified rejection sampling on each sequence. + + Returns: + A tuple of two tensors: + 0: A bool tensor of which tokens in each sequence is accepted. + shape = [batch_size, k] + 1: Token ids sampled from a recovered distribution, to be used + when a token is rejected. + shape = [batch_size, k] + """ + + batch_size, k, vocab_size = draft_probs.shape + + # shape [batch_size, k] + accepted = self._get_accepted(target_probs, draft_probs, + draft_token_ids, seeded_seqs) + + recovered_probs = self._get_recovered_probs( + target_probs, draft_probs).reshape(batch_size * k, vocab_size) + + # NOTE: the recovered_probs are overwritten by this method. + recovered_token_ids = _multinomial( + recovered_probs, + num_samples=1, + k=k, + seeded_seqs=seeded_seqs or {}, + ).reshape(batch_size, k) + + return accepted, recovered_token_ids + + def _create_uniform_samples(self, + seeded_seqs: Optional[Dict[int, + torch.Generator]], + batch_size: int, k: int, + device: torch.device) -> torch.Tensor: + """ + Generates a batch of uniform random samples, with optional seeding + for specific sequences. + + This method creates a tensor of shape `(batch_size, k + 1)` filled + with uniform random values in the range [0, 1). If `seeded_seqs` + is provided, the sequences corresponding to specific indices + will be generated using the provided `torch.Generator` for + reproducibility. The other sequences will be generated without + a seed. + + Args: + seeded_seqs : Optional[Dict[int, torch.Generator]] + A dictionary mapping indices in the batch to + `torch.Generator` objects. If `None`, all samples are + generated without a seed. + batch_size : int + The number of sequences to generate. + k : int + The number of random samples per sequence. + device : torch.device + The device on which to allocate the tensor. + + Returns: + uniform_rand : torch.Tensor + A tensor of shape `(batch_size, k + 1)` containing uniform + random values in the range [0, 1). + """ + if not seeded_seqs: + return torch.rand(batch_size, k + 1, device=device) + + uniform_rand = torch.empty(batch_size, k + 1, device=device) + + non_seeded_indices = [] + for idx in range(batch_size): + generator = seeded_seqs.get(idx) + if generator is None: + non_seeded_indices.append(idx) + else: + uniform_rand[idx, :] = torch.rand(1, + k + 1, + dtype=self.probs_dtype, + device=device, + generator=generator) + if non_seeded_indices: + uniform_rand[non_seeded_indices, :] = torch.rand( + len(non_seeded_indices), + k + 1, + dtype=self.probs_dtype, + device=device) + return uniform_rand + + def _get_accepted( + self, + target_probs: torch.Tensor, # [batch_size, k, vocab_size] + draft_probs: torch.Tensor, # [batch_size, k, vocab_size] + draft_token_ids: torch.Tensor, # [batch_size, k] + seeded_seqs: Optional[Dict[int, torch.Generator]], + ) -> torch.Tensor: + r"""Create bool matrix over the proposed draft tokens. If + True, then a token can be accepted, else it should be + rejected. + + Given :math:`q(\hat{x}_{n+1}|x_1, \dots, x_n)`, the probability of + :math:`\hat{x}_{n+1}` given context :math:`x_1, \dots, x_n` according + to the target model, and :math:`p(\hat{x}_{n+1}|x_1, \dots, x_n)`, the + same conditional probability according to the draft model, the token + is accepted with probability: + + .. math:: + \min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)} + {p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right) + + This implementation does not apply causality. When using the output, + if a token is rejected, subsequent tokens should not be used. + + Returns a bool tensor of shape [batch_size, k] specifying which tokens + are accepted. + """ + batch_size, k, _ = draft_probs.shape + batch_indices = torch.arange(batch_size, + device=target_probs.device)[:, None] + probs_indicies = torch.arange(k, device=target_probs.device) + + # shape [batch_size, k] + selected_draft_probs = draft_probs[batch_indices, probs_indicies, + draft_token_ids] + + # shape [batch_size, k] + selected_target_probs = target_probs[batch_indices, probs_indicies, + draft_token_ids] + + uniform_rand = self._create_uniform_samples(seeded_seqs, batch_size, + k - 1, target_probs.device) + + capped_ratio = torch.minimum( + selected_target_probs / selected_draft_probs, + torch.full((1, ), 1, device=target_probs.device)) + accepted = uniform_rand < capped_ratio + + return accepted + + def _get_recovered_probs( + self, + target_probs: torch.Tensor, # [k, vocab_size] + draft_probs: torch.Tensor, # [k, vocab_size] + ) -> torch.Tensor: + r"""Create a probability distribution for each proposed token which can + be sampled if the proposed token is rejected. + + When this routine is applied sequentially, the true distribution of the + target model is recovered (within hardware numerics). + + The probability distribution used in this rejection case is constructed + as follows. Given :math:`q(x|x_1, \dots, x_n)`, the probability of + :math:`x` given context :math:`x_1, \dots, x_n` according to the target + model and :math:`p(x|x_1, \dots, x_n)`, the same conditional probability + according to the draft model: + + .. math:: + x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+ + + where :math:`(f(x))_+` is defined as: + + .. math:: + (f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))} + + See https://github.com/vllm-project/vllm/pull/2336 for a visualization + of the draft, target, and recovered probability distributions. + + Returns a tensor of shape [batch_size, k, vocab_size]. + + Note: This batches operations on GPU and thus constructs the recovered + distribution for all tokens, even if they are accepted. This causes + division-by-zero errors, so we use self._smallest_positive_value to + avoid that. This introduces some drift to the distribution. + """ + _, k, _ = draft_probs.shape + + # shape [batch_size, k, vocab_size] + difference = target_probs - draft_probs + + # TODO(cade): Can we use logprobs instead of probs, and avoid the + # division-by-zero errors without introducing distribution drift? + + # shape [batch_size, k, vocab_size] + f = torch.clamp(difference, min=self._smallest_positive_value) + + # shape [batch_size, k, vocab_size] + recovered_probs = f / torch.sum(f, dim=-1).reshape(-1, k, 1) + + return recovered_probs + + @cached_property + def _smallest_positive_value(self) -> float: + """Return the smallest positive value representable by the probs dtype. + This value is used when constructing a distribution from which to sample + recovered tokens in the first rejection case. + + See _get_recovered_probs for more details + + Note that this isn't actually the smallest positive value representable + by float32, but the smallest positive normal value. + See https://en.wikipedia.org/wiki/Subnormal_number for more information. + """ + return torch.finfo(self.probs_dtype).tiny + + +# torch.multinomial forces a GPU<->CPU sync. +# Therefore, we use an optimized implementation instead that skips the sync. +# Note that we always sample with replacement. +# probs will be modified in place, but this is fine, as we pass +# in a copy already. +# @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) +def _multinomial( + probs: torch.Tensor, + num_samples: int, + k: int, + seeded_seqs: Dict[int, torch.Generator], +) -> torch.Tensor: + + if num_samples > 1: + # This is equivalent to torch.repeat_interleaved (which also + # forces a GPU<->CPU sync). + probs = probs[:, None, :].expand(probs.shape[0], num_samples, + probs.shape[1]).contiguous().view( + -1, probs.shape[1]) + q = torch.empty_like(probs) + if not seeded_seqs: + q.exponential_(1.0) + else: + start = 0 + for idx in range(len(q) // k): + end = start + k + generator = seeded_seqs.get(idx) + # Note: generator might be None for non seeded + q[start:end].exponential_(1.0, generator=generator) + start = end + + return probs.div_(q).argmax(dim=1).view(-1, num_samples) diff --git a/vllm/model_executor/layers/resampler.py b/vllm/model_executor/layers/resampler.py new file mode 100644 index 0000000..4c98600 --- /dev/null +++ b/vllm/model_executor/layers/resampler.py @@ -0,0 +1,269 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py +# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 +# +# Copyright 2023 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Shared resampler perceiver network used in multimodal models and +related helpers for sincos positional embeddings. + +Example models: Qwen (Qwen-VL), MiniCPM-V 2.0 +""" +import math +from functools import partial +from typing import Callable, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.quantization import QuantizationConfig + +DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) + + +def get_abs_pos(abs_pos: torch.Tensor, tgt_size: Union[torch.Tensor, + int]) -> torch.Tensor: + # abs_pos: L, C + # tgt_size: (H, W) + # return: M, C + src_size = int(math.sqrt(abs_pos.size(0))) + dtype = abs_pos.dtype + if isinstance(tgt_size, int): + tgt_size = (tgt_size, tgt_size) + if (src_size == tgt_size[0] and src_size == tgt_size[1]): + return abs_pos + return (F.interpolate( + abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), + size=(tgt_size[0], tgt_size[1]), + mode="bicubic", + align_corners=False, + ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)) + + +# sin/cos positional embedding helpers are adapted from: +# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 +def get_1d_sincos_pos_embed_from_grid( + embed_dim: int, pos: np.ndarray, + version: Tuple[int, int] = (2, 0)) -> torch.Tensor: + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) / (H, W) + out: (M, D) / (H, W, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + if version == (2, 0): + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + else: + out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product + emb_sin = np.sin(out) # (H, W, D/2) + emb_cos = np.cos(out) # (H, W, D/2) + emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D) + return emb + + +def get_2d_sincos_pos_embed_from_grid( + embed_dim: int, grid: np.ndarray, + version: Tuple[int, int] = (2, 0)) -> torch.Tensor: + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid( + embed_dim // 2, grid[0], version) # (H*W, D/2) or (H, W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid( + embed_dim // 2, grid[1], version) # (H*W, D/2) or (H, W, D/2) + + if version == (2, 0): + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + else: + emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D) + return emb + + +def get_2d_sincos_pos_embed( + embed_dim: int, + grid_size: Union[int, Tuple[int, int]], + cls_token: bool = False, + version: Tuple[int, int] = (2, 0), +) -> torch.Tensor: + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if isinstance(grid_size, int): + grid_h_size, grid_w_size = grid_size, grid_size + else: + grid_h_size, grid_w_size = grid_size[0], grid_size[1] + + grid_h = np.arange(grid_h_size, dtype=np.float32) + grid_w = np.arange(grid_w_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + assert isinstance(grid, np.ndarray) and \ + grid.shape == (2, grid_h_size, grid_w_size) + + if version == (2, 0): + grid = grid.reshape([2, 1, grid_h_size, grid_w_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], + axis=0) + else: + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) + return pos_embed + + +class BaseResampler(nn.Module): + """ + A 2D perceiver-resampler network with one cross attention layers by + (grid_size**2) learnable queries and 2d sincos pos_emb. + Outputs: + A tensor with the shape of (grid_size**2, embed_dim) + """ + + def __init__(self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + do_post_projection: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + + self.num_queries = num_queries + self.embed_dim = embed_dim + self.num_heads = num_heads + + self.query = nn.Parameter(torch.empty(self.num_queries, embed_dim)) + + if kv_dim is not None and kv_dim != embed_dim: + self.kv_proj = ReplicatedLinear(kv_dim, + embed_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_proj") + else: + # Maintain the same return value with ReplicatedLinear.forward + self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa + nn.Identity()(*args, **kwargs), + None, + ) + self.attn = nn.MultiheadAttention(embed_dim, num_heads) + self.ln_q = norm_layer(embed_dim) + self.ln_kv = norm_layer(embed_dim) + self.do_post_projection = do_post_projection + self.ln_post = norm_layer(embed_dim) if do_post_projection else None + self.proj = nn.Parameter( + (embed_dim**-0.5) * + torch.empty(embed_dim, embed_dim)) if do_post_projection else None + + def _repeat(self, query, N: int): + return query.unsqueeze(1).repeat(1, N, 1) + + +class Resampler2(BaseResampler): + """Resampler-perceiver network to be used for a variety of model types, + e.g., Qwen-vl / Minicpmv 2.0. The main difference is the addition of the + do_post_projection arg, which indicates whether or not there should be + a post layer normalization and projector after the attention. This is + present in minicpmv2.0, but not qwen-vl. + """ + + def __init__(self, + grid_size: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + adaptive: bool = False, + do_post_projection: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__(grid_size**2, + embed_dim, + num_heads, + kv_dim, + norm_layer, + do_post_projection=do_post_projection, + quant_config=quant_config, + prefix=prefix) + + self.adaptive = adaptive + pos_embed_arr = get_2d_sincos_pos_embed(embed_dim, + grid_size, + version=(2, 0)) + + self.pos_embed = nn.Parameter( + torch.from_numpy(pos_embed_arr).requires_grad_(False)) + + def forward( + self, + x: torch.Tensor, + tgt_sizes: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if tgt_sizes is None: + tgt_sizes = int(math.sqrt(x.size(1))) + if self.adaptive: + pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim, + tgt_sizes, + version=(2, 0)) + pos_embed = torch.from_numpy(pos_embed_arr).to(device=x.device, + dtype=x.dtype) + else: + pos_embed = get_abs_pos(self.pos_embed, + tgt_sizes).to(device=x.device, + dtype=x.dtype) + + x, _ = self.kv_proj(x) + x = self.ln_kv(x).permute(1, 0, 2) + + N = x.shape[1] + q = self.ln_q(self.query) + out = self.attn( + self._repeat(q, N) + self.pos_embed.unsqueeze(1), + x + pos_embed.unsqueeze(1), + x, + attn_mask=attn_mask, + )[0] + x = out.permute(1, 0, 2) + if self.do_post_projection: + x = self.ln_post(x) + x = x @ self.proj + return x diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py new file mode 100644 index 0000000..fa2e0d0 --- /dev/null +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -0,0 +1,1228 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Rotary Positional Embeddings.""" +import math +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.model_executor.custom_op import CustomOp +from vllm.platforms import current_platform + +import ixformer.inference.functions as ixops + +def _rotate_neox(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., ::2] + x2 = x[..., 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) + + +def _apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, +) -> torch.Tensor: + """ + Args: + x: [num_tokens, num_heads, head_size] + cos: [num_tokens, head_size // 2] + sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. + """ + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + + +@CustomOp.register("rotary_embedding") +class RotaryEmbedding(CustomOp): + """Original rotary positional embedding.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + + cache = self._compute_cos_sin_cache() + cache = cache.to(dtype) + self.cos_sin_cache: torch.Tensor + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / (base**(torch.arange( + 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """A PyTorch-native implementation of forward().""" + if offsets is not None: + positions = positions + offsets + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + def forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + from vllm import _custom_ops as ops + + # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`) + # is expensive, so avoid calling it if possible + if self.cos_sin_cache.device != query.device or \ + self.cos_sin_cache.dtype != query.dtype: + self.cos_sin_cache = self.cos_sin_cache.to(query.device, + dtype=query.dtype) + + # ops.rotary_embedding()/batched_rotary_embedding() + # are in-place operations that update the query and key tensors. + if offsets is not None: + ops.batched_rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, + self.is_neox_style, self.rotary_dim, + offsets) + else: + ops.rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, self.is_neox_style) + return query, key + + def forward_xpu( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + from vllm._ipex_ops import ipex_ops as ops + + self.cos_sin_cache = self.cos_sin_cache.to(positions.device, + dtype=query.dtype) + # ops.rotary_embedding()/batched_rotary_embedding() + # are in-place operations that update the query and key tensors. + if offsets is not None: + ops.batched_rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, + self.is_neox_style, self.rotary_dim, + offsets) + else: + ops.rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, self.is_neox_style) + return query, key + + def forward_hpu( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + from habana_frameworks.torch.hpex.kernels import ( + RotaryPosEmbeddingMode, apply_rotary_pos_emb) + if offsets is not None: + offsets = offsets.view(positions.shape[0], -1) + positions = positions + offsets + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions).view( + num_tokens, 1, -1) + cos, sin = cos_sin.chunk(2, dim=-1) + # HPU RoPE kernel requires hidden dimension for cos and sin to be equal + # to query hidden dimension, so the original tensors need to be + # expanded + # GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE + # and expansion of cos/sin tensors via concatenation + # GPT-J kernel requires position_ids = None, offset = 0, mode = PAIRWISE + # and expansion of cos/sin tensors via repeat_interleave + rope_mode: RotaryPosEmbeddingMode + if self.is_neox_style: + rope_mode = RotaryPosEmbeddingMode.BLOCKWISE + cos = torch.cat((cos, cos), dim=-1) + sin = torch.cat((sin, sin), dim=-1) + else: + rope_mode = RotaryPosEmbeddingMode.PAIRWISE + sin = torch.repeat_interleave(sin, + 2, + dim=-1, + output_size=cos_sin.shape[-1]) + cos = torch.repeat_interleave(cos, + 2, + dim=-1, + output_size=cos_sin.shape[-1]) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, + rope_mode) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + def forward_neuron( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + def _apply_rotary_emb_neuron( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, + ) -> torch.Tensor: + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + # x1 = x[..., ::2] + + # x2 = x[..., 1::2] + d = x.shape[-1] // 2 + x_reshaped = x.view(-1, x.shape[-1]) + x1 = x_reshaped[:, ::2].view(*x.shape[:-1], d) + x2 = x_reshaped[:, 1::2].view(*x.shape[:-1], d) + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + + if offsets is not None: + positions = positions + offsets + + self.cos_sin_cache = self.cos_sin_cache.to(query.device, + dtype=query.dtype) + + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + + if self.rotary_dim == self.head_size: + query = _apply_rotary_emb(query, cos, sin, self.is_neox_style) + query = query.reshape(query_shape) + key = _apply_rotary_emb(key, cos, sin, self.is_neox_style) + key = key.reshape(key_shape) + else: + head_size = query.shape[-1] + query_reshaped = query.view(-1, head_size) + query_pass = query_reshaped[:, self.rotary_dim:].view( + *query.shape[:-1], head_size - self.rotary_dim) + query_rot = query_reshaped[:, :self.rotary_dim].view( + *query.shape[:-1], self.rotary_dim) + query_rot = _apply_rotary_emb_neuron(query_rot, cos, sin, + self.is_neox_style) + query = torch.cat((query_rot, query_pass), + dim=-1).reshape(query_shape) + + key_reshaped = key.view(-1, head_size) + key_pass = key_reshaped[:, self.rotary_dim:].view( + *key.shape[:-1], head_size - self.rotary_dim) + key_rot = key_reshaped[:, :self.rotary_dim].view( + *key.shape[:-1], self.rotary_dim) + key_rot = _apply_rotary_emb_neuron(key_rot, cos, sin, + self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + return s + + +class LinearScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with linear scaling. + + It supports multiple scaling factors. Since multiple LoRA adapters may have + different scaling factors, we need multiple cos/sin caches. In this way, + instead of running rotary embedding kernel per lora, we can run multiple + lora in a batched way. + + In addition to that, we also keep the cos/sin cache for the scaling factor + of 1 (default) at all times. + + Exemplary for two scaling factors x=1, y and z with embeddings + [[x11, x12, ... x1m], ..., [xn1, xn2, ..., xnm]] and + [[y11, y12, ... y1o], ..., [yn1, yn2, ..., yno]], and + [[z11, z12, ... z1p], ..., [zn1, zn2, ..., znp]], + + we construct the cos/sin cache as follows: + [[x11, x12, ... x1m, y11, y12, ... y1o, z11, z12, ... z1p], + ... + [xn1, xn2, ... xnm, yn1, yn2, ... yno, zn1, zn2, ... znp]] + + We then use offsets to index into the cos/sin cache for + the respective scaling factors. + + The offset to cache can be accessed via `scaling_factor_to_offset` API. + + Credits to the Reddit user /u/kaiokendev + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factors: Union[List[float], float], + dtype: torch.dtype, + ) -> None: + if isinstance(scaling_factors, float): + scaling_factors = [scaling_factors] + self.scaling_factors: List[float] = scaling_factors # noqa + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + # Lazy initialized. + self._scaling_factor_to_offset: Dict[float, int] + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.base) + cache_list: List[torch.Tensor] = [] + # offsets to the next cache in a tensor. + # Each offset corresponds to the same index in scaling_factors. + offsets: List[int] = [] + for scaling_factor in self.scaling_factors: + # NOTE(woosuk): self.max_position_embeddings is the original + # maximum length before applying the rope scaling. + # Thus, the maximum length after applying the rope scaling is + # self.max_position_embeddings * self.scaling_factor. + max_len = self.max_position_embeddings * scaling_factor + t = torch.arange(max_len, dtype=torch.float) + t = t / scaling_factor + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + if not cache_list: + offset = 0 + else: + last_offset = offsets[-1] + next_max_len = cache_list[-1].shape[0] + offset = last_offset + next_max_len + offsets.append(offset) + cache_list.append(cache) + self._scaling_factor_to_offset = { + float(scaling_factor): offsets[i] + for i, scaling_factor in enumerate(self.scaling_factors) + } + assert len(self.scaling_factors) == len(offsets) + return torch.cat(cache_list, dim=0) + + @property + def scaling_factor_to_offset(self) -> Dict[float, int]: + return self._scaling_factor_to_offset + + +class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with Dynamic NTK scaling. + + Credits to the Reddit users /u/bloc97 and /u/emozilla + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + ) -> None: + self.scaling_factor = scaling_factor + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_cos_sin_cache(self) -> torch.Tensor: + # NOTE(woosuk): self.max_position_embeddings is the original + # maximum length before applying the rope scaling. + # Thus, the maximum length after applying the rope scaling is + # self.max_position_embeddings * self.scaling_factor. + max_len = self.max_position_embeddings * self.scaling_factor + base = self.base * ( + (self.scaling_factor * max_len / self.max_position_embeddings) - + (self.scaling_factor - 1))**(self.rotary_dim / + (self.rotary_dim - 2)) + inv_freq = self._compute_inv_freq(base) + t = torch.arange(max_len, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + +# Inverse dim formula to find dim based on number of rotations +def _yarn_find_correction_dim(num_rotations: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048) -> float: + return (dim * math.log(max_position_embeddings / + (num_rotations * 2 * math.pi))) / (2 * + math.log(base)) + + +# Find dim range bounds based on rotations +def _yarn_find_correction_range( + low_rot: int, + high_rot: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048) -> Tuple[int, int]: + low = math.floor( + _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil( + _yarn_find_correction_dim(high_rot, dim, base, + max_position_embeddings)) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def _yarn_linear_ramp_mask(low: float, high: float, dim: int, + dtype: torch.dtype) -> torch.Tensor: + if low == high: + high += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +def _yarn_get_mscale(scale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * math.log(scale) + 1.0 + + +class YaRNScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with YaRN method. + + Credits to Peng et al. github.com/jquesnelle/yarn + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation + self.mscale = float( + _yarn_get_mscale(self.scaling_factor) * attn_factor) + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base**( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / + self.rotary_dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow, + self.rotary_dim, self.base, + self.max_position_embeddings) + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = (1 - _yarn_linear_ramp_mask( + low, high, self.rotary_dim // 2, + dtype=torch.float)) * self.extrapolation_factor + inv_freq = inv_freq_interpolation * ( + 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange(self.max_position_embeddings * self.scaling_factor, + dtype=torch.float32) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = (freqs.cos() * self.mscale) + sin = (freqs.sin() * self.mscale) + cache = torch.cat((cos, sin), dim=-1) + return cache + + +class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): + """Phi3 family of models scaled rotary embedding. + + Based on the original RotaryEmbedding implementation. + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + original_max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + short_factor: List[float], + long_factor: List[float], + short_mscale: Optional[float] = None, + long_mscale: Optional[float] = None, + ): + super().__init__() + + if is_neox_style is False: + raise ValueError( + "`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style." + ) + + self.rotary_dim = rotary_dim + self.head_size = head_size + self.max_position_embeddings = max_position_embeddings + self.original_max_position_embeddings = original_max_position_embeddings + self.base = base + self.short_factor = short_factor + self.long_factor = long_factor + + scale = self.max_position_embeddings / \ + self.original_max_position_embeddings + if scale <= 1.0: + scaling_factor = 1.0 + else: + scaling_factor = math.sqrt( + 1 + math.log(scale) / + math.log(self.original_max_position_embeddings)) + if short_mscale is None: + short_mscale = scaling_factor + if long_mscale is None: + long_mscale = scaling_factor + + self.short_mscale = short_mscale + self.long_mscale = long_mscale + + short_cache = self._compute_cos_sin_cache( + original_max_position_embeddings, short_factor, short_mscale) + short_cache = short_cache.to(dtype) + + long_cache = self._compute_cos_sin_cache(max_position_embeddings, + long_factor, long_mscale) + long_cache = long_cache.to(dtype) + + long_short_cache = torch.cat([short_cache, long_cache], dim=0) + self.register_buffer("long_short_cos_sin_cache", + long_short_cache, + persistent=False) + + def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor: + rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32) + inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange( + 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))) + return inv_freq + + def _compute_cos_sin_cache( + self, + max_position_embeddings: int, + rescale_factors: List[float], + mscale: float, + ) -> torch.Tensor: + inv_freq = self._compute_inv_freq(rescale_factors) + t = torch.arange(max_position_embeddings, dtype=torch.float) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() * mscale + sin = freqs.sin() * mscale + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + k = self.original_max_position_embeddings + long_prompt_offset = torch.any(positions > k) + + ixops.vllm_rotary_embedding_phi( + positions, + query, + key, + self.head_size, + self.long_short_cos_sin_cache, + long_prompt_offset, + k, + offsets + ) + + return query, key + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +class DeepseekScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with YaRN method. + + Credits to Peng et al. github.com/jquesnelle/yarn + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + mscale: float = 1, + mscale_all_dim: float = 0, + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation. + self.mscale = float( + yarn_get_mscale(self.scaling_factor, float(mscale)) / + yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * + attn_factor) + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base**( + torch.arange(0, + self.rotary_dim, + 2, + dtype=torch.float, + device=current_platform.device_type) / + self.rotary_dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow, + self.rotary_dim, self.base, + self.max_position_embeddings) + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = (1 - _yarn_linear_ramp_mask( + low, high, self.rotary_dim // 2, + dtype=torch.float)) * self.extrapolation_factor + inv_freq = inv_freq_interpolation * ( + 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange(self.max_position_embeddings * self.scaling_factor, + device=current_platform.device_type, + dtype=torch.float32) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = (freqs.cos() * self.mscale) + sin = (freqs.sin() * self.mscale) + cache = torch.cat((cos, sin), dim=-1) + return cache + + +class Llama3RotaryEmbedding(RotaryEmbedding): + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + scaling_factor: float, + low_freq_factor: float, + high_freq_factor: float, + orig_max_position: int, + ) -> None: + self.scaling_factor = scaling_factor + self.low_freq_factor = low_freq_factor + self.high_freq_factor = high_freq_factor + self.orig_max_position = orig_max_position + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + inv_freqs = super()._compute_inv_freq(base) + low_freq_wavelen = self.orig_max_position / self.low_freq_factor + high_freq_wavelen = self.orig_max_position / self.high_freq_factor + + wave_len = 2 * math.pi / inv_freqs + if self.low_freq_factor != self.high_freq_factor: + smooth = (self.orig_max_position / wave_len - self.low_freq_factor + ) / (self.high_freq_factor - self.low_freq_factor) + else: + smooth = 0 + new_freqs = torch.where( + wave_len < high_freq_wavelen, + inv_freqs, + torch.where( + wave_len > low_freq_wavelen, + inv_freqs / self.scaling_factor, + (1 - smooth) * inv_freqs / self.scaling_factor + + smooth * inv_freqs, + ), + ) + return new_freqs + + +class Llama4VisionRotaryEmbedding(RotaryEmbedding): + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + ): + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + inv_freqs = super()._compute_inv_freq(base) + inv_freqs = inv_freqs[:(self.rotary_dim // 2)] + return inv_freqs + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.base) + + # self.max_position_embeddings here is number of image patches + # i.e. (image_size // patch_size) ** 2 + num_patches = self.max_position_embeddings + img_idx = torch.arange(num_patches, + dtype=torch.int32) \ + .reshape(num_patches, 1) + img_idx = torch.cat([img_idx, img_idx[:1]], dim=0) + img_idx[-1, -1] = -2 # set to ID_CLS_TOKEN + num_patches_single_dim = int(math.sqrt(num_patches)) + frequencies_x = img_idx % num_patches_single_dim + frequencies_y = img_idx // num_patches_single_dim + freqs_x = ((frequencies_x + 1)[..., None] * + inv_freq[None, None, :]).repeat_interleave(2, dim=-1) + freqs_y = ((frequencies_y + 1)[..., None] * + inv_freq[None, None, :]).repeat_interleave(2, dim=-1) + freqs = torch.cat([freqs_x, freqs_y], + dim=-1).float().contiguous()[..., ::2] + freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0) + cache = torch.view_as_complex( + torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)) + return cache + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device) + query_ = torch.view_as_complex(query.float().reshape( + *query.shape[:-1], -1, 2)) + key_ = torch.view_as_complex(key.float().reshape( + *key.shape[:-1], -1, 2)) + broadcast_shape = [ + d if i == 1 or i == (query_.ndim - 1) else 1 + for i, d in enumerate(query_.shape) + ] + freqs_ci = self.cos_sin_cache.view(*broadcast_shape) + query_out = torch.view_as_real(query_ * freqs_ci).flatten(3) + key_out = torch.view_as_real(key_ * freqs_ci).flatten(3) + return query_out.type_as(query), key_out.type_as(key) + + +class MRotaryEmbedding(RotaryEmbedding): + """Rotary Embedding with Multimodal Sections.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + mrope_section: Optional[List[int]] = None, + ) -> None: + # In Qwen2.5-VL, the maximum index value is related to the duration of + # the input video. We enlarge max_position_embeddings to 4 times to get + # a larger the cos and sin cache. + self.cache_max_position_num = max_position_embeddings * 4 + super().__init__(head_size, rotary_dim, self.cache_max_position_num, + base, is_neox_style, dtype) + + self.mrope_section = mrope_section + if self.mrope_section: + assert sum(self.mrope_section) == rotary_dim // 2 + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """PyTorch-native implementation equivalent to forward(). + + Args: + positions: + [num_tokens,] (text only) or + [3, num_tokens] (T/H/W positions with multimodal inputs) + query: [num_tokens, num_heads * head_size] + key: [num_tokens, num_kv_heads * head_size] + """ + assert positions.ndim == 1 or positions.ndim == 2 + + num_tokens = positions.shape[-1] + cos_sin = self.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + if positions.ndim == 2: + assert self.mrope_section + + cos = torch.cat([ + m[i] + for i, m in enumerate(cos.split(self.mrope_section, dim=-1)) + ], + dim=-1) + sin = torch.cat([ + m[i] + for i, m in enumerate(sin.split(self.mrope_section, dim=-1)) + ], + dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + @staticmethod + def get_input_positions( + input_tokens: List[int], + hf_config: PretrainedConfig, + image_grid_thw: Optional[Union[List[List[int]], torch.Tensor]], + video_grid_thw: Optional[Union[List[List[int]], torch.Tensor]], + second_per_grid_ts: Optional[List[float]], + context_len: int = 0, + seq_len: Optional[int] = None, + ) -> Tuple[List[List[int]], int]: + """Get mrope input positions and delta value.""" + + image_grid_thw = [] if image_grid_thw is None else image_grid_thw + video_grid_thw = [] if video_grid_thw is None else video_grid_thw + second_per_grid_ts = [] if second_per_grid_ts is None else \ + second_per_grid_ts + + llm_positions, mrope_position_delta = \ + MRotaryEmbedding.get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + context_len=context_len, + seq_len=seq_len, + ) + + return llm_positions.tolist(), mrope_position_delta + + @staticmethod + def get_input_positions_tensor( + input_tokens: List[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[List[List[int]], torch.Tensor], + video_grid_thw: Union[List[List[int]], torch.Tensor], + second_per_grid_ts: List[float], + context_len: int = 0, + seq_len: Optional[int] = None, + ) -> Tuple[torch.Tensor, int]: + """Get mrope input positions and delta value.""" + + image_token_id = hf_config.image_token_id + video_token_id = hf_config.video_token_id + vision_start_token_id = hf_config.vision_start_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + tokens_per_second = getattr(hf_config.vision_config, + "tokens_per_second", 1.0) + + if isinstance(image_grid_thw, torch.Tensor): + image_grid_thw = image_grid_thw.tolist() + if isinstance(video_grid_thw, torch.Tensor): + video_grid_thw = video_grid_thw.tolist() + + input_tokens_tensor = torch.tensor(input_tokens) + vision_start_indices = torch.argwhere( + input_tokens_tensor == vision_start_token_id).squeeze(1) + vision_tokens = input_tokens_tensor[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + llm_pos_ids_list: list = [] + + st = 0 + remain_images, remain_videos = image_nums, video_nums + + image_index, video_index = 0, 0 + for _ in range(image_nums + video_nums): + video_second_per_grid_t = 0.0 + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_second_per_grid_t = 1.0 + if second_per_grid_ts: + video_second_per_grid_t = second_per_grid_ts[video_index] + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = \ + t, h // spatial_merge_size, w // spatial_merge_size + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + t_index = (torch.arange(llm_grid_t).view(-1, 1).expand( + -1, llm_grid_h * llm_grid_w) * video_second_per_grid_t * + tokens_per_second).long().flatten() + + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( + llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( + llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + mrope_position_delta = (llm_positions.max() + 1 - + len(input_tokens)).item() + llm_positions = llm_positions[:, context_len:seq_len] + + return llm_positions, mrope_position_delta + + @staticmethod + def get_next_input_positions( + mrope_position_delta: int, + context_len: int, + seq_len: int, + ) -> List[List[int]]: + return [ + list( + range(context_len + mrope_position_delta, + seq_len + mrope_position_delta)) for _ in range(3) + ] + + @staticmethod + def get_next_input_positions_tensor( + mrope_position_delta: int, + context_len: int, + seq_len: int, + ) -> torch.Tensor: + return torch.arange( + mrope_position_delta + context_len, + mrope_position_delta + seq_len, + ).expand(3, -1) + + +_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} + + +def get_rope( + head_size: int, + rotary_dim: int, + max_position: int, + base: int, + is_neox_style: bool = True, + rope_scaling: Optional[Dict[str, Any]] = None, + dtype: Optional[torch.dtype] = None, + partial_rotary_factor: float = 1.0, +) -> RotaryEmbedding: + if dtype is None: + dtype = torch.get_default_dtype() + if rope_scaling is not None: + # Transforms every value that is a list into a tuple for caching calls + rope_scaling_tuple = { + k: tuple(v) if isinstance(v, list) else v + for k, v in rope_scaling.items() + } + rope_scaling_args = tuple(rope_scaling_tuple.items()) + else: + rope_scaling_args = None + if partial_rotary_factor < 1.0: + rotary_dim = int(rotary_dim * partial_rotary_factor) + key = (head_size, rotary_dim, max_position, base, is_neox_style, + rope_scaling_args, dtype) + if key in _ROPE_DICT: + return _ROPE_DICT[key] + + if rope_scaling is None: + rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, + is_neox_style, dtype) + else: + scaling_type = rope_scaling["rope_type"] + + if scaling_type == "llama3": + scaling_factor = rope_scaling["factor"] + low_freq_factor = rope_scaling["low_freq_factor"] + high_freq_factor = rope_scaling["high_freq_factor"] + original_max_position = rope_scaling[ + "original_max_position_embeddings"] + rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, dtype, + scaling_factor, low_freq_factor, + high_freq_factor, + original_max_position) + elif scaling_type == "mllama4": + rotary_emb = Llama4VisionRotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, dtype) + elif scaling_type == "default": + if "mrope_section" in rope_scaling: + rotary_emb = MRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + mrope_section=rope_scaling["mrope_section"], + ) + else: + rotary_emb = RotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + ) + elif scaling_type == "linear": + scaling_factor = rope_scaling["factor"] + rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, + scaling_factor, dtype) + elif scaling_type == "dynamic": + scaling_factor = rope_scaling["factor"] + rotary_emb = DynamicNTKScalingRotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, + scaling_factor, dtype) + elif scaling_type == "yarn": + scaling_factor = rope_scaling["factor"] + original_max_position = rope_scaling[ + "original_max_position_embeddings"] + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k in ("extrapolation_factor", "attn_factor", "beta_fast", + "beta_slow") + } + rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim, + original_max_position, + base, is_neox_style, + scaling_factor, dtype, + **extra_kwargs) + elif scaling_type == "deepseek_yarn": + scaling_factor = rope_scaling["factor"] + original_max_position = rope_scaling[ + "original_max_position_embeddings"] + # assert max_position == original_max_position * scaling_factor + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k in ("extrapolation_factor", "attn_factor", "beta_fast", + "beta_slow", "mscale", "mscale_all_dim") + } + rotary_emb = DeepseekScalingRotaryEmbedding( + head_size, rotary_dim, original_max_position, base, + is_neox_style, scaling_factor, dtype, **extra_kwargs) + elif scaling_type == "longrope": + short_factor = rope_scaling["short_factor"] + long_factor = rope_scaling["long_factor"] + original_max_position = rope_scaling[ + "original_max_position_embeddings"] + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k in ("short_mscale", "long_mscale") + } + rotary_emb = Phi3LongRoPEScaledRotaryEmbedding( + head_size, rotary_dim, max_position, original_max_position, + base, is_neox_style, dtype, short_factor, long_factor, + **extra_kwargs) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + _ROPE_DICT[key] = rotary_emb + return rotary_emb diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py new file mode 100644 index 0000000..1ee1332 --- /dev/null +++ b/vllm/model_executor/layers/sampler.py @@ -0,0 +1,1221 @@ +# SPDX-License-Identifier: Apache-2.0 +"""A layer that samples the next tokens from the model's outputs.""" +import itertools +import warnings +from dataclasses import dataclass +from importlib.util import find_spec +from math import inf +from typing import Dict, Iterator, List, Optional, Tuple, Union + +import msgspec +import torch +import torch.nn as nn + +import vllm.envs as envs +from vllm.model_executor.layers.utils import apply_penalties +from vllm.model_executor.sampling_metadata import (SamplingMetadata, + SamplingTensors, + SequenceGroupToSample) +from vllm.sampling_params import SamplingType +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, + CompletionSequenceGroupOutput, Logprob, + PromptLogprobs, SampleLogprobs, SequenceOutput) +from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics + +if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): + import flashinfer.sampling + # yapf: disable + from flashinfer.sampling import ( + top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling) + + # yapf: enable +else: + flashinfer_top_k_top_p_sampling = None + + +def get_sampler() -> torch.nn.Module: + if envs.VLLM_USE_V1: + # Lazy import: the v1 package isn't distributed + from vllm.v1.sample.sampler import Sampler as V1Sampler + return V1Sampler() + return Sampler() + + +# (num_token_ids, num_parent_ids) per sequence group. +SampleResultType = List[Tuple[List[int], List[int]]] + +# Types of temporary data structures used for +# computing sample_result +SampleMetadataType = Dict[SamplingType, Tuple[List[int], + List[SequenceGroupToSample]]] +MultinomialSamplesType = Dict[SamplingType, torch.Tensor] +SampleResultsDictType = Dict[int, Tuple[List[int], List[int]]] + + +# Encapsulates temporary data structures for computing +# sample_result. +# +# * For multi-step scheduling: must be returned +# by `Sampler.forward()` and used later to compute the pythonized +# sample_result +# +# * For single-step scheduling: consumed immediately +# inside `Sampler.forward()` to compute pythonized sample_result. +@dataclass +class SampleResultArgsType: + sample_metadata: SampleMetadataType + multinomial_samples: MultinomialSamplesType + sample_results_dict: SampleResultsDictType + sampling_metadata: SamplingMetadata + greedy_samples: Optional[torch.Tensor] + + +# Union of non-deferred (single-step scheduling) +# vs deferred (multi-step scheduling) +# sample result types +MaybeDeferredSampleResultType = Union[SampleResultType, SampleResultArgsType] + +# Abbreviation of the _sample() return type +SampleReturnType = Tuple[MaybeDeferredSampleResultType, Optional[torch.Tensor]] + + +class SamplerOutput( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] + """For each sequence group, we generate a list of SequenceOutput object, + each of which contains one possible candidate for the next token. + + This data structure implements methods, so it can be used like a list, but + also has optional fields for device tensors. + """ + + outputs: List[CompletionSequenceGroupOutput] + + # On-device tensor containing probabilities of each token. + sampled_token_probs: Optional[torch.Tensor] = None + + # On-device tensor containing the logprobs of each token. + logprobs: Optional["torch.Tensor"] = None + + # Holds either (1) the pythonized sampler result (single-step scheduling) + # or (2) what will be arguments for later deferred pythonization of the + # sampler result (muliti-step scheduling) + deferred_sample_results_args: Optional[SampleResultArgsType] = None + + # On-device tensor containing the sampled token ids. + sampled_token_ids: Optional[torch.Tensor] = None + # CPU tensor containing the sampled token ids. Used during multi-step to + # return the sampled token ids from last rank to AsyncLLMEngine to be + # 'broadcasted' to all other PP ranks for next step. + sampled_token_ids_cpu: Optional[torch.Tensor] = None + + # Spec decode metrics populated by workers. + spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None + + # Optional last hidden states from the model. + hidden_states: Optional[torch.Tensor] = None + + # Optional prefill hidden states from the model + # (used for models like EAGLE). + prefill_hidden_states: Optional[torch.Tensor] = None + + # Time taken in the forward pass for this across all workers + model_forward_time: Optional[float] = None + + # Time taken in the model execute function. This will include model forward, + # block/sync across workers, cpu-gpu sync time and sampling time. + model_execute_time: Optional[float] = None + + def __getitem__(self, idx: int) -> CompletionSequenceGroupOutput: + return self.outputs[idx] + + def __setitem__(self, idx: int, value): + self.outputs[idx] = value + + def __iter__(self) -> Iterator[CompletionSequenceGroupOutput]: + return iter(self.outputs) + + def __len__(self): + return len(self.outputs) + + def __eq__(self, other: object): + return isinstance(other, + self.__class__) and self.outputs == other.outputs + + def __repr__(self) -> str: + """Show the shape of a tensor instead of its values to reduce noise. + """ + sampled_token_probs_repr = ("None" if self.sampled_token_probs is None + else self.sampled_token_probs.shape) + sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else + self.sampled_token_ids.shape) + return ( + f"SamplerOutput(outputs={self.outputs}, " + f"sampled_token_probs={sampled_token_probs_repr}, " + f"sampled_token_ids={sampled_token_ids_repr}, " + f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})") + + +class Sampler(nn.Module): + """Samples the next tokens from the model's outputs. + + This layer does the following: + 1. Discard the hidden states that are not used for sampling (i.e., all + tokens except the final one in each prompt). + 2. Compute the logits for the next tokens. + 3. Apply presence, frequency and repetition penalties. + 4. Apply temperature scaling. + 5. Apply top-p and top-k truncation. + 6. Sample the next tokens. + Here, each sequence group within the batch can have different sampling + parameters (e.g., sampling method, temperature, top-p, top-k, etc.). + + The structure of the logits tensor is coupled with the seq_groups in + sampling_metadata. Typically, each sequence in each seq_group has one row in + logits for the next token to be sampled; however, for a seq_group with a + prompt request with the prompt_logprobs sampling parameter, there are rows + in logits for each token in the input prompt. + """ + + def __init__(self): + super().__init__() + + # Whether or not the SamplerOutput should have on-device tensors + # containing the sampled token ids and probabilities. This is used by + # speculative decoding. + self.include_gpu_probs_tensor = False + self.should_modify_greedy_probs_inplace = False + + def _init_sampling_tensors( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ): + """The goal here is to reuse sampling tensors between similar decode + runs. This is possible because sampling logic does not change between + decodes of the same sequences. + """ + _, vocab_size = logits.shape + + # First free any existing stored sampling tensors. + # This is necessary because some sampling tensors may + # have pinned memory. + self._sampling_tensors = None + + # Initialize new sampling tensors + (sampling_tensors, do_penalties, do_top_p_top_k, + do_min_p) = SamplingTensors.from_sampling_metadata( + sampling_metadata, vocab_size, logits.device, logits.dtype) + + self._sampling_tensors = sampling_tensors + self._do_penalties = do_penalties + self._do_top_p_top_k = do_top_p_top_k + self._do_min_p = do_min_p + + def forward( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + """ + Single-step scheduling: + * Perform GPU-side sampling computation & compute + GPU-side logprobs tensor + * Pythonize sampling result & logprobs tensor + + Multi-step scheduling: + * Perform GPU-side sampling computation & compute + GPU-side logprobs tensor + * Defer Pythonization of sampling result & logprobs + tensor + * Encapsulate arguments required for deferred Pythonization + in the :class:`SamplerOutput` structure + + Args: + logits: (num_tokens, vocab_size). + sampling_metadata: Metadata for sampling. + """ + assert logits is not None + _, vocab_size = logits.shape + + # Prepare sampling tensors with pinned memory to avoid blocking. + if not sampling_metadata.reuse_sampling_tensors: + self._init_sampling_tensors(logits, sampling_metadata) + elif self._do_penalties: + # In this case, the sampling tensors logic depends on + # "output_tokens" of a sequence. As a result, we cannot + # reuse sampling tensors, since "output_tokens" changes + # between decode runs. + self._init_sampling_tensors(logits, sampling_metadata) + + assert self._sampling_tensors is not None + sampling_tensors = self._sampling_tensors + do_penalties = self._do_penalties + do_top_p_top_k = self._do_top_p_top_k + do_min_p = self._do_min_p + + logits = _apply_min_tokens_penalty(logits, sampling_metadata) + + # Apply presence and frequency penalties. + if do_penalties: + logits = apply_penalties(logits, sampling_tensors.prompt_tokens, + sampling_tensors.output_tokens, + sampling_tensors.presence_penalties, + sampling_tensors.frequency_penalties, + sampling_tensors.repetition_penalties) + + # Use float32 to apply temperature scaling. + # Use in-place division to avoid creating a new tensor. + logits = logits.to(torch.float) + logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1)) + + if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None: + logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, + sampling_tensors.top_ks) + + if do_min_p: + logits = _apply_min_p(logits, sampling_tensors.min_ps) + + # We use float32 for probabilities and log probabilities. + # Compute the probabilities. + probs = torch.softmax(logits, dim=-1, dtype=torch.float) + # Compute the log probabilities. + logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) + + # Sample the next tokens. + maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample( + probs, + logprobs, + sampling_metadata, + sampling_tensors, + include_gpu_probs_tensor=self.include_gpu_probs_tensor, + modify_greedy_probs=self._should_modify_greedy_probs_inplace, + ) + + if self.include_gpu_probs_tensor: + # Since we will defer sampler result Pythonization, + # preserve GPU-side tensors in support of later + # deferred pythonization of logprobs + assert maybe_sampled_tokens_tensor is not None + on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor) + else: + # Since Pythonization has already happened, don't preserve + # GPU-side tensors. + on_device_tensors = None + + # Get the logprobs query results. + prompt_logprobs = None + sample_logprobs = None + if not sampling_metadata.skip_sampler_cpu_output: + # Pythonize logprobs now (GPU -> CPU); do not defer. + assert not isinstance(maybe_deferred_sample_results, + SampleResultArgsType) + prompt_logprobs, sample_logprobs = get_logprobs( + logprobs, sampling_metadata, maybe_deferred_sample_results) + + return _build_sampler_output( + maybe_deferred_sample_results, + sampling_metadata, + prompt_logprobs, + sample_logprobs, + on_device_tensors=on_device_tensors, + skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output) + + @property + def _should_modify_greedy_probs_inplace(self) -> bool: + """Whether or not the sampler should modify the probability distribution + of greedily-sampled tokens such that multinomial sampling would sample + the greedily-sampled token. + + In other words, if True then we set the probability of the greedily- + sampled token to 1. + + This is used by speculative decoding, which requires that the sampling + method be encoded into the probability distribution. + """ + return self.should_modify_greedy_probs_inplace + + +def _apply_min_tokens_penalty( + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> torch.Tensor: + """Apply min_tokens penalty which sets stop tokens to -inf if min_tokens + have not been generated yet + """ + # list of indices in logits that will be set to -inf + logits_to_penalize: List[Tuple[int, int]] = [] + logits_applied = 0 + for seq_group in sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + sampling_params = seq_group.sampling_params + + sample_indices = seq_group.sample_indices + logits_applied += len(sample_indices) + len( + seq_group.prompt_logprob_indices) + if not seq_group.do_sample: + continue + + start_idx = sample_indices[0] + min_tokens = sampling_params.min_tokens + token_ids_to_penalize = sampling_params.all_stop_token_ids + if min_tokens > 0 and token_ids_to_penalize: + seqs_to_penalize: List[int] = [] + for j, seq_id in enumerate(seq_ids): + seq_data = seq_group.seq_data[seq_id] + if len(seq_data.output_token_ids_array) < min_tokens: + seqs_to_penalize.append(j) + + if seqs_to_penalize: + # convert to the index into logits + seqs_to_penalize = [start_idx + j for j in seqs_to_penalize] + # itertools.product pairs each seq index with every token id + logits_to_penalize.extend( + itertools.product(seqs_to_penalize, token_ids_to_penalize)) + + if logits_to_penalize: + # use zip and * to group indices along each dimension + # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) ) + logits[tuple(zip(*logits_to_penalize))] = -float("inf") + + # verifies that no rows in logits were missed unexpectedly + assert logits_applied == logits.shape[0] + return logits + + +def _apply_top_k_top_p( + logits: torch.Tensor, + p: torch.Tensor, + k: torch.Tensor, +) -> torch.Tensor: + logits_sort, logits_idx = logits.sort(dim=-1, descending=False) + + # Apply top-k. + top_k_mask = logits_sort.size(1) - k.to(torch.long) + # Get all the top_k values. + top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) + top_k_mask = logits_sort < top_k_mask + logits_sort.masked_fill_(top_k_mask, -float("inf")) + + # Apply top-p. + probs_sort = logits_sort.softmax(dim=-1) + probs_sum = probs_sort.cumsum(dim=-1) + top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) + # at least one + top_p_mask[:, -1] = False + logits_sort.masked_fill_(top_p_mask, -float("inf")) + + # Re-sort the probabilities. + logits = torch.empty_like(logits_sort).scatter_(dim=-1, + index=logits_idx, + src=logits_sort) + return logits + + +def _apply_min_p( + logits: torch.Tensor, + min_p: torch.Tensor, +) -> torch.Tensor: + """ + Adapted from + https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17 + """ + probs = torch.softmax(logits, dim=-1) + top_probs, _ = probs.max(dim=-1, keepdim=True) + scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs + tokens_to_remove = probs < scaled_min_p + logits = logits.masked_fill_(tokens_to_remove, -float("inf")) + + return logits + + +def _greedy_sample( + selected_seq_groups: List[SequenceGroupToSample], + samples: torch.Tensor, +) -> SampleResultType: + """Run greedy sampling on a given samples. + + Args: + selected_seq_groups: A list of sequence groups batched. + samples: (num_selected_samples,) A tensor of samples. The length of + samples could be smaller than selected_seq_groups if + seq_group.do_sample is False. + Returns: + Tuple of (next_token_ids, parent_ids). The length of returned list is + same as the length of selected_seq_groups. If the corresponding + seq_group has do_sample=False, tuple contains ([], []) + """ + samples_lst = samples.tolist() + sample_idx = 0 + results: SampleResultType = [] + for seq_group in selected_seq_groups: + if not seq_group.do_sample: + results.append(([], [])) + continue + + seq_ids = seq_group.seq_ids + num_parent_seqs = len(seq_ids) + assert num_parent_seqs == 1, ( + "Greedy sampling should have only one seq.") + parent_ids = list(range(num_parent_seqs)) + next_token_ids = [samples_lst[sample_idx]] + results.append((next_token_ids, parent_ids)) + sample_idx += num_parent_seqs + return results + + +def _random_sample( + selected_seq_groups: List[SequenceGroupToSample], + random_samples: torch.Tensor, +) -> SampleResultType: + """Run random sampling on a given samples. + + Args: + selected_seq_groups: A list of sequence groups batched. + random_samples: (num_selected_samples,) A tensor of samples. The + length of samples could be smaller than selected_seq_groups if + seq_group.do_sample is False. + Returns: + Tuple of (next_token_ids, parent_ids). The length of returned list is + same as the length of selected_seq_groups. If the corresponding + seq_group has do_sample=False, tuple contains ([], []) + """ + # Find the maximum n value of the prompt phase requests. + random_samples = random_samples.cpu() + sample_idx = 0 + results: SampleResultType = [] + for seq_group in selected_seq_groups: + if not seq_group.do_sample: + results.append(([], [])) + continue + + seq_ids = seq_group.seq_ids + sampling_params = seq_group.sampling_params + is_prompt = seq_group.is_prompt + num_parent_seqs = len(seq_ids) + if is_prompt: + # Prompt phase. + parent_ids = [0] * sampling_params.n + next_token_ids = random_samples[ + sample_idx, :sampling_params.n].tolist() + else: + # Generation phase. + parent_ids = list(range(num_parent_seqs)) + next_token_ids = random_samples[sample_idx:sample_idx + + num_parent_seqs, 0].tolist() + results.append((next_token_ids, parent_ids)) + sample_idx += num_parent_seqs + return results + + +# torch.multinomial forces a GPU<->CPU sync. +# Therefore, we use an optimized implementation instead. +# Note that we always sample with replacement. +# probs will be modified in place, but this is fine, as we pass +# in a copy already. +def _multinomial( + probs: torch.Tensor, + num_samples: int, + seq_groups: Optional[List[SequenceGroupToSample]] = None, +) -> torch.Tensor: + if num_samples > 1: + probs = probs.repeat_interleave(num_samples, dim=0) + q = torch.empty_like(probs) + if seq_groups is None: + q.exponential_() + else: + sample_idx = 0 + for seq_group in seq_groups: + seq_ids = seq_group.seq_ids + stride = len(seq_ids) * num_samples + assert seq_group.generator is not None + q[sample_idx:sample_idx + + stride].exponential_(generator=seq_group.generator) + sample_idx += stride + return probs.div_(q).argmax(dim=1).view(-1, num_samples) + + +def _top_k_top_p_multinomial_with_flashinfer( + probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor, + num_samples: int, seq_groups: Optional[List[SequenceGroupToSample]]): + max_top_k_round = 32 + if num_samples > 1: + probs = probs.repeat_interleave(num_samples, dim=0) + top_ks = top_ks.repeat_interleave(num_samples) + top_ps = top_ps.repeat_interleave(num_samples) + batch_size = probs.shape[0] + uniform_samples = torch.empty((max_top_k_round, batch_size), + device=probs.device) + if seq_groups is None: + uniform_samples.uniform_() + else: + sample_idx = 0 + for seq_group in seq_groups: + seq_ids = seq_group.seq_ids + stride = len(seq_ids) * num_samples + assert seq_group.generator is not None + uniform_samples[:, sample_idx:sample_idx + + stride].uniform_(generator=seq_group.generator) + sample_idx += stride + batch_next_token_ids, success = flashinfer_top_k_top_p_sampling( + probs, + uniform_samples, + top_ks, + top_ps, + ) + if not success.all(): + warnings.warn("FlashInfer rejection sampling failed, fallback.", + stacklevel=1) + probs = flashinfer.sampling.top_k_renorm_prob(probs, top_ks) + probs = flashinfer.sampling.top_p_renorm_prob(probs, top_ps) + batch_next_token_ids = flashinfer.sampling.sampling_from_probs( + probs, uniform_samples[0]) + return batch_next_token_ids.view(-1, num_samples) + + +def get_pythonized_sample_results( + sample_result_args: SampleResultArgsType) -> SampleResultType: + '''This function consumes GPU-side sampler results and computes + Pythonized CPU-side sampler results (GPU -> CPU sync.) + + Single-step scheduling: this function is invoked at sampling-time + for immediate Pythonization. + + Multi-step scheduling: Pythonization is deferred until after multiple + GPU-side steps have been completed. + + Args: + sample_result_args: GPU-side inputs to the Pythonization process + + Returns: + Pythonized sampler results + ''' + + ( + sample_metadata, + sampling_metadata, + greedy_samples, + multinomial_samples, + sample_results_dict, + ) = ( + sample_result_args.sample_metadata, + sample_result_args.sampling_metadata, + sample_result_args.greedy_samples, + sample_result_args.multinomial_samples, + sample_result_args.sample_results_dict, + ) + + for sampling_type in SamplingType: + if sampling_type not in sample_metadata: + continue + (seq_group_id, seq_groups) = sample_metadata[sampling_type] + if sampling_type == SamplingType.GREEDY: + sample_results = _greedy_sample(seq_groups, greedy_samples) + elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): + sample_results = _random_sample(seq_groups, + multinomial_samples[sampling_type]) + sample_results_dict.update(zip(seq_group_id, sample_results)) + + return [ + sample_results_dict.get(i, ([], [])) + for i in range(len(sampling_metadata.seq_groups)) + ] + + +def _sample_with_torch( + probs: torch.Tensor, + logprobs: torch.Tensor, + sampling_metadata: SamplingMetadata, + sampling_tensors: SamplingTensors, + include_gpu_probs_tensor: bool, + modify_greedy_probs: bool, +) -> SampleReturnType: + '''Torch-oriented _sample() implementation. + + Single-step scheduling: + * Perform GPU-side sampling computation + * Immediately Pythonize sampling result + + Multi-step scheduling: + * Perform GPU-side sampling computation + * Defer Pythonization & preserve GPU-side + tensors required for Pythonization + ''' + + categorized_seq_group_ids: Dict[SamplingType, List[int]] = { + t: [] + for t in SamplingType + } + categorized_sample_indices = sampling_metadata.categorized_sample_indices + for i, seq_group in enumerate(sampling_metadata.seq_groups): + sampling_params = seq_group.sampling_params + sampling_type = sampling_params.sampling_type + categorized_seq_group_ids[sampling_type].append(i) + + sample_results_dict: SampleResultsDictType = {} + sample_metadata: SampleMetadataType = {} + multinomial_samples: MultinomialSamplesType = {} + greedy_samples: Optional[torch.Tensor] = None + + # Create output tensor for sampled token ids. + if include_gpu_probs_tensor: + sampled_token_ids_tensor = torch.full((logprobs.shape[0], 1), + VLLM_INVALID_TOKEN_ID, + dtype=torch.long, + device=logprobs.device) + else: + sampled_token_ids_tensor = None + + # Counterintiutively, having two loops here is actually faster. + # The first loop can run without waiting on GPU<->CPU sync. + for sampling_type in SamplingType: + sample_indices = categorized_sample_indices[sampling_type] + num_tokens = len(sample_indices) + if num_tokens == 0: + continue + + seq_group_id = categorized_seq_group_ids[sampling_type] + seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id] + sample_metadata[sampling_type] = (seq_group_id, seq_groups) + long_sample_indices = sample_indices.long() + if sampling_type == SamplingType.GREEDY: + greedy_samples = torch.argmax(logprobs[long_sample_indices], + dim=-1) + + if sampled_token_ids_tensor is not None: + # Store sampled tokens in output tensor. + sampled_token_ids_tensor[ + long_sample_indices] = greedy_samples.unsqueeze(-1) + + if modify_greedy_probs: + # If required, modify the probabilities such that sampling from + # the modified distribution would always sample the argmax + # token id. + _modify_greedy_probs_inplace(logprobs, probs, + long_sample_indices, + greedy_samples) + + elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): + max_n_in_batch = 1 + for seq_group in seq_groups: + if seq_group.is_prompt: + sampling_params = seq_group.sampling_params + max_n_in_batch = max(max_n_in_batch, sampling_params.n) + seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else + seq_groups) + + if flashinfer_top_k_top_p_sampling is not None: + multinomial_samples[ + sampling_type] = _top_k_top_p_multinomial_with_flashinfer( + probs[long_sample_indices], + sampling_tensors.top_ks[long_sample_indices], + sampling_tensors.top_ps[long_sample_indices], + max_n_in_batch, + seq_groups_arg, + ) + else: + multinomial_samples[sampling_type] = _multinomial( + probs[long_sample_indices], + max_n_in_batch, + seq_groups=seq_groups_arg) + + if sampled_token_ids_tensor is not None: + # Store sampled tokens in output tensor. + sampled_token_ids_tensor[long_sample_indices] = \ + multinomial_samples[sampling_type].to(torch.long) + + else: + raise ValueError(f"Unsupported sampling type: {sampling_type}") + + # Encapsulate arguments for computing Pythonized sampler + # results, whether deferred or otherwise. + maybe_deferred_args = SampleResultArgsType( + sampling_metadata=sampling_metadata, + sample_metadata=sample_metadata, + multinomial_samples=multinomial_samples, + greedy_samples=greedy_samples, + sample_results_dict=sample_results_dict) + + if not sampling_metadata.skip_sampler_cpu_output: + # GPU<->CPU sync happens here. + # This also converts the sampler output to a Python object. + # Return Pythonized sampler result & sampled token ids + return get_pythonized_sample_results( + maybe_deferred_args), sampled_token_ids_tensor + else: + # Defer sampler result Pythonization; return deferred + # Pythonization args & sampled token ids + return ( + maybe_deferred_args, + sampled_token_ids_tensor, + ) + + +def _sample( + probs: torch.Tensor, + logprobs: torch.Tensor, + sampling_metadata: SamplingMetadata, + sampling_tensors: SamplingTensors, + include_gpu_probs_tensor: bool, + modify_greedy_probs: bool, +) -> SampleReturnType: + """ + Args: + probs: (num_query_tokens_in_batch, num_vocab) + logprobs: (num_query_tokens_in_batch, num_vocab) + sampling_metadata: The metadata for a batch for sampling. + sampling_tensors: Tensors that include sampling related metadata. + + Returns: + (next_token_ids, parent_seq_ids) for each seq group in a batch. + If sampling is skipped, it returns ([], []) + sampled_token_ids_tensor: A tensor of sampled token ids. + """ + return _sample_with_torch( + probs, + logprobs, + sampling_metadata, + sampling_tensors, + include_gpu_probs_tensor=include_gpu_probs_tensor, + modify_greedy_probs=modify_greedy_probs, + ) + + +def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: + """ + This function calculates the ranks of the chosen tokens in a logprob tensor. + + Args: + x (torch.Tensor): 2D logprob tensor of shape (N, M) + where N is the no. of tokens and M is the vocab dim. + indices (torch.Tensor): List of chosen token indices. + + Returns: + torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens. + Each element in the returned tensor represents the rank + of the chosen token in the input logprob tensor. + """ + vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype), + indices] + result = (x > vals[:, None]) + del vals + return result.sum(1).add_(1) + + +def get_logprobs( + logprobs: torch.Tensor, + sampling_metadata: SamplingMetadata, + sample_results: SampleResultType, +) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]: + """Return sample logprobs and prompt logprobs. + + The logic consists of 3 parts. + - Select indices to compute logprob from, ranks of token ids, and + the top k token ids from logprobs. + - Compute prompt logprobs if required. + - Compute sample logprobs if required. + + Args: + logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's + logprob per vocab. Sequence groups' query tokens are batched in a + single flattened tensor. For example, assuming there are N + seq groups, it is sorted by prefill tokens for seq_group_1 (if + prompt logprob is enabled), decode tokens for seq_group_1 (if + sampling is required), prefill tokens for seq_group_2, ... + sampling_metadata: The sampling metadata. + sample_results: (num_seq_groups) The tuple of (next_token_ids, + parent_ids) for each sequence group. When beam search is enabled, + sample_results can contain different number of seq_ids from + sampling_metadata.seq_groups. It is because beam search creates + 2 * BEAM_WIDTH number of samples (whereas there are only up to + BEAM_WIDTH number of seq_ids). + + Returns: + A tuple of prompt and sample logprobs per sequence group in a batch. + """ + # The index of query token to calculate logprobs. It includes both + # prompt and sample logprob indices. + query_indices: List[int] = [] + # The next token ids to get the logprob value from. + next_token_ids: List[int] = [] + # The largest requested number of logprobs. We find logprobs as many as the + # largest num logprobs in this API. If every logprobs is None, it will be + # set to -1. + largest_num_logprobs = -1 + + # Select indices to compute logprob from, ranks of token ids, and the top + # k token ids from logprobs. + for (seq_group, sample_result) in zip(sampling_metadata.seq_groups, + sample_results): + sampling_params = seq_group.sampling_params + + # Update indices and tokens for prompt logprobs. + if (seq_group.is_prompt + and sampling_params.prompt_logprobs is not None): + largest_num_logprobs = max(largest_num_logprobs, + sampling_params.prompt_logprobs) + next_prompt_tokens = _get_next_prompt_tokens(seq_group) + query_indices.extend(seq_group.prompt_logprob_indices) + next_token_ids.extend(next_prompt_tokens) + + # Update indices and next tokenes for sample logprob. + if seq_group.do_sample: + token_ids, parent_seq_ids = sample_result + # NOTE: We cannot directly use sample_indices because + # sample_indices only contain parent seq_ids of a previous step. + # The current step may have different number of seq_ids, and + # we can obtain it from `sample_result[1]`. + query_idx = seq_group.sample_indices[0] + query_indices.extend( + [query_idx + parent_id for parent_id in parent_seq_ids]) + next_token_ids.extend(token_ids) + + if sampling_params.logprobs is not None: + largest_num_logprobs = max(largest_num_logprobs, + sampling_params.logprobs) + + assert len(next_token_ids) == len(query_indices) + + if len(query_indices) == 0: + empty_sampled_logprob: SampleLogprobs = [] + empty_prompt_logprob: Optional[PromptLogprobs] = None + num_seq_groups = len(sampling_metadata.seq_groups) + return [empty_prompt_logprob + ] * num_seq_groups, [empty_sampled_logprob] * num_seq_groups + + selected_logprobs, ranks = None, None + top_logprobs, top_token_ids = None, None + + # If largest_num_logprobs == -1, i.e. no logprobs are requested, we can + # skip the whole logprob calculation. + if largest_num_logprobs >= 0: + query_indices_gpu = torch.tensor(query_indices, device=logprobs.device) + next_token_ids_gpu = torch.tensor(next_token_ids, + device=logprobs.device) + + # (num_selected_query_tokens, num_logprobs). Note that query_indices can + # contain duplicates if beam search is enabled. + selected_logprobs = logprobs[[ + query_indices_gpu, + next_token_ids_gpu, + ]] + ranks = _get_ranks( + logprobs[query_indices_gpu], + next_token_ids_gpu, + ) + assert selected_logprobs.shape[0] == ranks.shape[0] + + # We need to compute top k only if there exists logprobs > 0. + if largest_num_logprobs > 0: + # Logprobs of topk tokens for a batch of sequence groups. + # (num_query_tokens_across_batch). + top_logprobs, top_token_ids = torch.topk(logprobs, + largest_num_logprobs, + dim=-1) + top_logprobs = top_logprobs.to('cpu') + top_token_ids = top_token_ids.to('cpu') + + selected_logprobs = selected_logprobs.to('cpu') + ranks = ranks.to('cpu') + + # Find prompt/sample logprobs. + prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = [] + sample_logprobs_per_seq_group: List[SampleLogprobs] = [] + top_logprob_idx = 0 + selected_logprobs_idx = 0 + + for seq_group, sample_result in zip(sampling_metadata.seq_groups, + sample_results): + (prompt_logprobs, top_logprob_idx, + selected_logprobs_idx) = _get_prompt_logprob_if_needed( + seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs, + selected_logprobs_idx, top_logprob_idx) + prompt_logprobs_per_seq_group.append(prompt_logprobs) + + (sampled_logprobs, top_logprob_idx, + selected_logprobs_idx) = _get_sampled_logprob_if_needed( + seq_group, sample_result, selected_logprobs, ranks, top_token_ids, + top_logprobs, selected_logprobs_idx, top_logprob_idx) + sample_logprobs_per_seq_group.append(sampled_logprobs) + + return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group + + +def _get_prompt_logprob_if_needed( + seq_group: SequenceGroupToSample, + selected_logprobs: torch.Tensor, + ranks: torch.Tensor, + top_token_ids: torch.Tensor, + top_logprobs: torch.Tensor, + selected_logprobs_idx: int, + top_logprob_idx: int, +): + """Compute the prompt logprob from a sequence group if needed.""" + sampling_params = seq_group.sampling_params + is_prompt = seq_group.is_prompt + + # Find prompt logprobs + prompt_logprobs: Optional[PromptLogprobs] = None + if is_prompt and sampling_params.prompt_logprobs is not None: + prompt_logprobs = [] + num_logprobs = sampling_params.prompt_logprobs + next_prompt_tokens = _get_next_prompt_tokens(seq_group) + # Pre-select indexes and create a list. It is faster than calling .item + # repetitively. + selected_logprob_items = selected_logprobs[ + selected_logprobs_idx:selected_logprobs_idx + + len(next_prompt_tokens)].tolist() + rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx + + len(next_prompt_tokens)].tolist() + + for idx, token_id in enumerate(next_prompt_tokens): + # Calculate the prompt logprob of the real prompt tokens. + # {token_id: (logprob, rank_from_vocab)} + prompt_logprobs_dict: Dict[int, Tuple[float, int]] = { + token_id: (selected_logprob_items[idx], rank_items[idx]) + } + + # Add top K prompt logprobs along with its rank. + if num_logprobs > 0: + top_ids = top_token_ids[ + top_logprob_idx, :num_logprobs].tolist() + top_probs = top_logprobs[ + top_logprob_idx, :num_logprobs].tolist() + # Top K is already sorted by rank, so we can use 1 ~ + # num_logprobs + 1 for rank. + top_ranks = range(1, num_logprobs + 1) + prompt_logprobs_dict.update({ + top_id: (top_prob, rank) + for top_id, top_prob, rank in zip(top_ids, top_probs, + top_ranks) + }) + prompt_logprobs.append({ + token_id: Logprob(*logprob_and_rank) + for token_id, logprob_and_rank in prompt_logprobs_dict.items() + }) + # + 1 to go to the next prompt token. + top_logprob_idx += 1 + + # + len(next_prompt_tokens) to go to the next prompt. + selected_logprobs_idx += len(next_prompt_tokens) + return prompt_logprobs, top_logprob_idx, selected_logprobs_idx + + +def _get_sampled_logprob_if_needed( + seq_group: SequenceGroupToSample, + sample_result: Tuple[List[int], List[int]], + selected_logprobs: torch.Tensor, + ranks: torch.Tensor, + top_token_ids: torch.Tensor, + top_logprobs: torch.Tensor, + selected_logprobs_idx: int, + top_logprob_idx: int, +): + """Compute the sample logprob if needed.""" + seq_ids = seq_group.seq_ids + num_logprobs = seq_group.sampling_params.logprobs + sampled_logprobs: SampleLogprobs = [] + next_token_ids, parent_seq_ids = sample_result + + if seq_group.do_sample: + assert len(next_token_ids) > 0 + if num_logprobs is None: + for next_token_id in next_token_ids: + # Use a dummy logprob + sampled_logprobs.append({next_token_id: Logprob(inf)}) + else: + # Pre-select items from tensor. tolist() is faster than repetitive + # `.item()` calls. + selected_logprob_items = selected_logprobs[ + selected_logprobs_idx:selected_logprobs_idx + + len(next_token_ids)].tolist() + rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx + + len(next_token_ids)].tolist() + for idx, (next_token_id, parent_id) in enumerate( + zip(next_token_ids, parent_seq_ids)): + # Get the logprob of a sampled token. + sampled_logprobs_dict = { + next_token_id: + (selected_logprob_items[idx], rank_items[idx]) + } + if num_logprobs is not None and num_logprobs > 0: + # Get top K logprobs. + top_ids = top_token_ids[top_logprob_idx + + parent_id, :num_logprobs].tolist() + top_probs = top_logprobs[ + top_logprob_idx + parent_id, :num_logprobs].tolist() + # Top K is already sorted by rank, so we can use 1 ~ + # num_logprobs + 1 for rank. + top_ranks = range(1, num_logprobs + 1) + sampled_logprobs_dict.update({ + top_id: (top_prob, rank) + for top_id, top_prob, rank in zip( + top_ids, top_probs, top_ranks) + }) + + sampled_logprobs.append({ + token_id: Logprob(*logprob_and_rank) + for token_id, logprob_and_rank in + sampled_logprobs_dict.items() + }) + + # NOTE: This part of code is not intuitive. `selected_logprobs` include + # logprobs for the current step, which has len(next_token_ids) tokens + # per sequence group. `logprobs` includes logprobs from the previous + # steps, which has len(seq_ids) tokens per sequence group. + + # Iterate to the next sequence group in a batch. + selected_logprobs_idx += len(next_token_ids) + # Iterate to the next sequence group in a batch. + top_logprob_idx += len(seq_ids) + return sampled_logprobs, top_logprob_idx, selected_logprobs_idx + + +def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, + sample_indices: torch.Tensor, + greedy_samples: torch.Tensor) -> None: + """Modify the probability distributions of the greedily-sampled tokens such + that each sampled token has a "probability" of 1.0. This is required by + speculative decoding, which depends on the sampling method being encoded + within the probability distribution for correctness. + + # Why do we only need to do this for greedy sampling? + + vLLM's sampler performs the following steps for greedy or multinomial + (random) sampling: + 1. Get logits from model. + 2. Modify logits according to per-sequence sampling parameters. + - Multiply by temperature, top-k and top-p masking, penalize tokens + according to their frequency, etc. + 3. Sample a token. + - Random sampling simply samples from the modified probability + distribution. + - Greedy sampling performs `argmax` to obtain the token with the + highest likelihood. + + Ignoring greedy sampling for a moment, we find that the computed probability + distribution has the following property: we can sample from it independently + and find that the token sampled by the Sampler has a frequency corresponding + to how often we see it in our sampling. In other words, for tokens sampled + with vLLM's random SamplingType, the computed probability distribution + encodes the sampling methodology completely. + + Greedy sampling does not normally have this property. vLLM modifies logits + according to sampling params, then performs `argmax`, then returns the + sampled token and the computed probability distribution. If we sample from + the distribution, we'll find the likelihood of the greedily-sampled token + is not always 1.0. + + Since lossless speculative decoding requires that the sampling methodology + be encoded within the probability distribution, we are motivated to modify + the probability distribution such that the sampled token has probability 1 + when speculative decoding is used. + + NOTE: Alternatively, we could use an extremely low temperature to achieve + greedy sampling using multinomial computation and unite the codepaths. This + has implications on the overall design of the sampler, e.g. how to record + accurate logprobs for the user, so this improvement is deferred to later. + """ + # NOTE: logprobs are not modified so they can be returned to the user. + probs[sample_indices, :] = 0 + probs[sample_indices, greedy_samples] = 1.0 + + +def _build_sampler_output( + maybe_deferred_sample_results: MaybeDeferredSampleResultType, + sampling_metadata: SamplingMetadata, + prompt_logprobs: Optional[List[Optional[PromptLogprobs]]], + sample_logprobs: Optional[List[SampleLogprobs]], + on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor, + torch.Tensor]], + skip_sampler_cpu_output: bool = False, +) -> SamplerOutput: + """Construct Python objects with the output of sampling. + + Args: + on_device_tensors: Tuple containing on-device tensors with the + probabilities used in sampling and the sampled token ids. This + allows post-processing without copies to CPU/serialization, e.g. in + speculative decoding rejection sampling. + """ + sampler_output: List[CompletionSequenceGroupOutput] = [] + + if skip_sampler_cpu_output: + assert isinstance(maybe_deferred_sample_results, SampleResultArgsType) + deferred_sample_results_args = maybe_deferred_sample_results + else: + assert prompt_logprobs is not None + assert sample_logprobs is not None + assert not isinstance(maybe_deferred_sample_results, + SampleResultArgsType) + assert len(sampling_metadata.seq_groups) \ + == len(maybe_deferred_sample_results) \ + == len(prompt_logprobs) \ + == len(sample_logprobs) + deferred_sample_results_args = None + + for (seq_group, sample_result, group_prompt_logprobs, + group_sample_logprobs) in zip(sampling_metadata.seq_groups, + maybe_deferred_sample_results, + prompt_logprobs, sample_logprobs): + seq_ids = seq_group.seq_ids + next_token_ids, parent_ids = sample_result + seq_outputs: List[SequenceOutput] = [] + for parent_id, next_token_id, logprobs in zip( + parent_ids, next_token_ids, group_sample_logprobs): + seq_outputs.append( + SequenceOutput(seq_ids[parent_id], next_token_id, + logprobs)) + sampler_output.append( + CompletionSequenceGroupOutput(seq_outputs, + group_prompt_logprobs)) + + # If not specified, store None values in SamplerOutput. + if on_device_tensors is not None: + (sampled_token_probs, logprobs_tensor, + sampled_token_ids) = on_device_tensors + else: + sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None, + None) + + return SamplerOutput( + outputs=sampler_output, + sampled_token_probs=sampled_token_probs, + sampled_token_ids=sampled_token_ids, + logprobs=logprobs_tensor, + deferred_sample_results_args=deferred_sample_results_args) + + +def _get_next_prompt_tokens( + seq_group: SequenceGroupToSample) -> tuple[int, ...]: + """Get a list of next prompt tokens to compute logprob from a + given sequence group. + + It is used to compute prompt logprob. Imagine you have logprob for each + query token. Query token needs to know the next prompt token id to compute + prompt logprob. This is a helper to obtain next prompt token ids. + + This API has to be used only when the caller knows seq_group is in prefill + stage. + + Returns: + A list of next prompt tokens to compute logprob. + """ + assert seq_group.is_prompt, ( + "Caller should ensure the sequence group is in a prefill stage.") + seq_ids = seq_group.seq_ids + query_len = seq_group.query_len + assert query_len is not None + # prompt has only 1 seq id. + assert len(seq_ids) == 1 + seq_data = seq_group.seq_data[seq_ids[0]] + computed_len = seq_data.get_num_computed_tokens() + prompt_tokens = seq_data.prompt_token_ids + # +1 because we are looking for a next prompt token. + next_token_index_start = computed_len + 1 + next_token_index_end = min(computed_len + query_len + 1, + len(prompt_tokens)) + next_prompt_tokens = prompt_tokens[ + next_token_index_start:next_token_index_end] + return next_prompt_tokens diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py new file mode 100644 index 0000000..54fd43f --- /dev/null +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -0,0 +1,258 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import abstractmethod +from typing import Dict, Optional, Union + +import torch +import torch.jit +import torch.nn as nn + +from vllm.platforms import current_platform + + +class SpecDecodeBaseSampler(nn.Module): + """Base class for samplers used for Speculative Decoding verification + step. + """ + + def __init__(self, strict_mode: bool = False): + """Base class constructor. + Args: + strict_mode: Whether or not to perform shape/device/dtype checks + during sampling. This catches correctness issues but adds + nontrivial latency. + """ + super().__init__() + self._strict_mode = strict_mode + + # NOTE: A "bonus token" is accepted iff all proposal tokens are + # accepted. There is always only one possible bonus token. We store this + # value in a variable for readability. + self._num_bonus_tokens = 1 + + self.num_accepted_tokens: Optional[torch.Tensor] = None + self.num_emitted_tokens: Optional[torch.Tensor] = None + self.num_draft_tokens: int = 0 + + def init_gpu_tensors(self, device: Union[int, str]) -> None: + assert self.num_accepted_tokens is None + if isinstance(device, int): + device = f"{current_platform.device_type}:{device}" + elif not isinstance(device, str): + raise ValueError(f"Device must be int or str, get {type(device)}") + self.num_accepted_tokens = torch.tensor(0, + dtype=torch.long, + device=device) + self.num_emitted_tokens = torch.tensor(0, + dtype=torch.long, + device=device) + + def init_tensors(self, + device: Union[int, str], + device_type: Union[torch.device, str] = 'cuda') -> None: + assert self.num_accepted_tokens is None + if isinstance(device_type, torch.device): + device_type = device_type.type + if isinstance(device, int): + device = f"{device_type}:{device}" + self.num_accepted_tokens = torch.tensor(0, + dtype=torch.long, + device=device) + self.num_emitted_tokens = torch.tensor(0, + dtype=torch.long, + device=device) + + @property + def probs_dtype(self): + return torch.float32 + + @property + def token_id_dtype(self): + return torch.int64 + + def _create_output( + self, + accepted: torch.Tensor, # [batch_size, k] + substitute_token_ids: torch.Tensor, # [batch_size, k] + draft_token_ids: torch.Tensor, # [batch_size, k] + bonus_token_ids: torch.Tensor, # [batch_size] + ) -> torch.Tensor: + """Format output. Returns a matrix of token ids. When + a token is rejected via sampling, all subsequent token ids are + set to -1 for the sequence. + + Args: + accepted: A boolean tensor indicating if the corresponding + draft token in draft_token_ids should be accepted or not. + substitute_token_ids: A tensor of token_ids that can be used + as substitutes for the draft token ids if the proposed token + is rejected. + draft_token_ids: A tensor of token ids speculated by the + draft model. + bonus_token_ids: Token ids to use as the bonus token if + all the draft tokens are accepted. + Returns: + A tensor containing the accepted token ids. The shape of the + tensor is [batch_size, k + num_bonus_tokens] + """ + batch_size, k = substitute_token_ids.shape + bonus_token_ids = bonus_token_ids.squeeze(-1) + # Determine the index of the first False value for each row. + limits = (accepted == 0).max(1).indices + limits[~(accepted == 0).any(1)] = k + + # Create masks using the indices. + indices = torch.arange(k, device=accepted.device).unsqueeze(0) + accepted_mask = indices < limits.unsqueeze(1) + after_false_mask = indices == limits.unsqueeze(1) + + # Create an extended output tensor + output_with_bonus_tokens = -torch.ones( + (batch_size, k + self._num_bonus_tokens), + dtype=self.token_id_dtype, + device=accepted.device) + output = output_with_bonus_tokens[:, :k] + + # Fill in the first k columns of the output tensor using masks and data + # tensors. + output[:, :k] = torch.where(accepted_mask, draft_token_ids, + -torch.ones_like(draft_token_ids)) + + # Fill the last column. + # We check output directly as accepted may have True values inconsistent + # with causal acceptance. + output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1, + bonus_token_ids, -1) + + # Fill the recovered token ids. + output.mul_(~after_false_mask).add_( + substitute_token_ids.mul(after_false_mask)) + + self.num_accepted_tokens += accepted.sum() + self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum() + self.num_draft_tokens += batch_size * k + + return output_with_bonus_tokens + + def _raise_if_incorrect_input( + self, + target_with_bonus_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: Optional[torch.Tensor] = None, + ) -> None: + self._raise_if_incorrect_shape(target_with_bonus_probs, + draft_token_ids, bonus_token_ids, + draft_probs) + self._raise_if_incorrect_dtype(target_with_bonus_probs, + draft_token_ids, bonus_token_ids, + draft_probs) + self._raise_if_inconsistent_device(target_with_bonus_probs, + draft_token_ids, bonus_token_ids, + draft_probs) + self._raise_if_out_of_bounds_vocab(target_with_bonus_probs.shape[-1], + draft_token_ids, bonus_token_ids) + + def _raise_if_incorrect_shape( + self, + target_with_bonus_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: Optional[torch.Tensor] = None, + ) -> None: + (target_batch_size, num_target_probs, + target_vocab_size) = target_with_bonus_probs.shape + + # Does not count the extra token + num_target_probs -= 1 + + # validate the shape of draft token ids. + draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape + assert draft_token_ids_batch_size == target_batch_size + assert num_draft_token_ids == num_target_probs + + # validate the shape of bonus token ids + bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape + assert bonus_batch_size == target_batch_size + assert num_bonus_tokens == self._num_bonus_tokens + + # validate the shape of draft probs if it is set + if draft_probs is not None: + (draft_batch_size, num_draft_probs, + draft_vocab_size) = draft_probs.shape + assert draft_batch_size == target_batch_size + assert num_draft_probs == num_target_probs + assert (draft_vocab_size == target_vocab_size + ), f"{draft_vocab_size=} {target_vocab_size=}" + + def _raise_if_incorrect_dtype( + self, + target_with_bonus_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: Optional[torch.Tensor] = None, + ) -> None: + assert target_with_bonus_probs.dtype == self.probs_dtype + assert draft_token_ids.dtype == self.token_id_dtype + assert bonus_token_ids.dtype == self.token_id_dtype + if draft_probs is not None: + assert draft_probs.dtype == self.probs_dtype + + def _raise_if_inconsistent_device( + self, + target_with_bonus_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: Optional[torch.Tensor] = None, + ) -> None: + devices = [ + t.device for t in [ + target_with_bonus_probs, bonus_token_ids, draft_probs, + draft_token_ids + ] if t is not None + ] + assert all([devices[0] == device for device in devices]) + + def _raise_if_out_of_bounds_vocab( + self, + vocab_size: int, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + ) -> None: + assert torch.all(bonus_token_ids < vocab_size) + assert torch.all(bonus_token_ids >= 0) + assert torch.all(draft_token_ids < vocab_size) + assert torch.all(draft_token_ids >= 0) + + +class SpecDecodeDeterministicBaseSampler(SpecDecodeBaseSampler): + """Base class for samplers used for Speculative Decoding verification + step which are deterministic. + """ + + @abstractmethod + def forward( + self, + target_with_bonus_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError + + +class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler): + """Base class for samplers used for Speculative Decoding verification + step which are stochastic + """ + + @abstractmethod + def forward( + self, + target_with_bonus_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + seeded_seqs: Optional[Dict[int, torch.Generator]] = None, + ) -> torch.Tensor: + raise NotImplementedError diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py new file mode 100644 index 0000000..95362c2 --- /dev/null +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -0,0 +1,172 @@ +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.jit + +from vllm.model_executor.layers.spec_decode_base_sampler import ( + SpecDecodeDeterministicBaseSampler) + + +class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): + """Apply typical acceptance sampling as described in section 3.3.1 in + "MEDUSA: Simple LLM Inference Acceleration Framework with + Multiple Decoding Heads" + https://arxiv.org/pdf/2401.10774 + """ + + def __init__( + self, + posterior_threshold: float, + posterior_alpha: float, + strict_mode: bool = False, + ): + """Create a Typical Acceptance Sampler. + + Args: + strict_mode: Whether or not to perform shape/device/dtype checks + during sampling. This catches correctness issues but adds + nontrivial latency. + posterior_threshold : A threshold value that sets a lower bound + on the posterior probability of a token in target model for it + to be accepted. + posterior_alpha : A scaling factor for the entropy-based + threshold in typical acceptance sampling. + """ + self._posterior_threshold = posterior_threshold + self._posterior_alpha = posterior_alpha + super().__init__(strict_mode=strict_mode) + + def forward( + self, + target_with_bonus_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + ) -> torch.Tensor: + """Sample token ids using typical acceptance sampling. This accepts + or rejects tokens proposed by the draft model using the probability + of each token according to the draft and target models. + + In the worst case where all draft tokens are rejected, it is guaranteed + one token will be emitted. + + In the case where all draft tokens are accepted, the bonus token will be + accepted. + + Args: + target_probs: The probability distribution over token ids given + context according to the target model. + shape = [batch_size, num_speculative_tokens, vocab_size] + + bonus_token_ids: The "bonus" token ids that are accepted iff all + speculative tokens in a sequence are accepted. + shape = [batch_size, num_bonus_tokens] + + draft_probs: This parameter is unused by the acceptance sampler. + + draft_token_ids: The token ids that were sampled from the draft + probabilities. + shape = [batch_size, num_speculative_tokens] + + Returns: + output_token_ids: The token ids sampled via rejection sampling, + or -1 if unable to sample a token because the previous token + was rejected. + shape = [batch_size, num_speculative_tokens + num_bonus_tokens] + """ + # Only perform shape/dtype/device checking in strict mode, as it adds + # overhead. + if self._strict_mode: + self._raise_if_incorrect_input(target_with_bonus_probs, + draft_token_ids, bonus_token_ids) + target_probs = target_with_bonus_probs[:, :-1] + accepted = self._evaluate_accepted_tokens(target_probs, + draft_token_ids) + recovered_token_ids = self._get_recovered_token_ids(target_probs) + output_token_ids = self._create_output(accepted, recovered_token_ids, + draft_token_ids, + bonus_token_ids) + return output_token_ids + + def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): + r""" + Evaluates and returns a mask of accepted tokens based on the + posterior probabilities. + + Parameters: + ---------- + target_probs : torch.Tensor + A tensor of shape (batch_size, k, vocab_size) representing + the probabilities of each token in the vocabulary for each + position in the proposed sequence. This is the distribution + generated by the target model. + draft_token_ids : torch.Tensor + A tensor of shape (batch_size, k) representing the proposed + token ids. + + A draft token_id x_{n+k} is accepted if it satisfies the + following condition + + .. math:: + p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) > + \min \left( \epsilon, \delta * \exp \left( + -H(p_{\text{original}}( + \cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right) + + where :math:`p_{\text{original}}` corresponds to target_probs + and :math:`\epsilon` and :math:`\delta` correspond to hyperparameters + specified using self._posterior_threshold and self._posterior_alpha + + This method computes the posterior probabilities for the given + draft token ids based on the provided target probabilities. It + calculates the entropy of the posterior distribution and determines + a dynamic threshold for each token position using the provided + posterior_threshold and posterior_alpha values. The method then + returns a boolean mask indicating which tokens can be accepted. + + Returns: + ------- + torch.Tensor + A boolean tensor of shape (batch_size, k) where each element + indicates whether the corresponding draft token has been accepted + or rejected. True indicates acceptance and false indicates + rejection. + + """ + device = target_probs.device + candidates_prob = torch.gather( + target_probs, dim=-1, + index=draft_token_ids.unsqueeze(-1)).squeeze(-1) + # A small constant added to prevent computing the logarithm of zero, + # which can lead to undefined values. + epsilon = 1e-5 + posterior_entropy = -torch.sum( + target_probs * torch.log(target_probs + epsilon), dim=-1) + threshold = torch.minimum( + torch.ones_like(posterior_entropy, device=device) * + self._posterior_threshold, + torch.exp(-posterior_entropy) * self._posterior_alpha, + ) + accepted_mask = candidates_prob > threshold + return accepted_mask + + def _get_recovered_token_ids(self, target_probs): + """ + The recovered token ids will fill the first unmatched token + by the target token. + + Parameters + ---------- + target_probs : torch.Tensor + A tensor of shape (batch_size, k, vocab_size) containing + the target probability distribution + + Returns + ------- + torch.Tensor + A tensor of shape (batch_size, k) with the recovered token + ids which are selected from target probs. + """ + max_indices = torch.argmax(target_probs, dim=-1) + + return max_indices diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py new file mode 100644 index 0000000..a9ef973 --- /dev/null +++ b/vllm/model_executor/layers/utils.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Utility methods for model layers.""" +from typing import Tuple + +import torch + + +def get_token_bin_counts_and_mask( + tokens: torch.Tensor, + vocab_size: int, + num_seqs: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + # Compute the bin counts for the tokens. + # vocab_size + 1 for padding. + bin_counts = torch.zeros((num_seqs, vocab_size + 1), + dtype=torch.long, + device=tokens.device) + bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens)) + bin_counts = bin_counts[:, :vocab_size] + mask = bin_counts > 0 + + return bin_counts, mask + + +def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, + output_tokens_tensor: torch.Tensor, + presence_penalties: torch.Tensor, + frequency_penalties: torch.Tensor, + repetition_penalties: torch.Tensor) -> torch.Tensor: + """ + Applies penalties in place to the logits tensor + logits : The input logits tensor of shape [num_seqs, vocab_size] + prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts + are padded to the maximum prompt length within the batch using + `vocab_size` as the padding value. The value `vocab_size` is used + for padding because it does not correspond to any valid token ID + in the vocabulary. + output_tokens_tensor: The output tokens tensor. + presence_penalties: The presence penalties of shape (num_seqs, ) + frequency_penalties: The frequency penalties of shape (num_seqs, ) + repetition_penalties: The repetition penalties of shape (num_seqs, ) + """ + num_seqs, vocab_size = logits.shape + _, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor, + vocab_size, num_seqs) + output_bin_counts, output_mask = get_token_bin_counts_and_mask( + output_tokens_tensor, vocab_size, num_seqs) + repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat( + 1, vocab_size) + logits[logits > 0] /= torch.where(prompt_mask | output_mask, + repetition_penalties, 1.0)[logits > 0] + logits[logits <= 0] *= torch.where(prompt_mask | output_mask, + repetition_penalties, 1.0)[logits <= 0] + # We follow the definition in OpenAI API. + # Refer to https://platform.openai.com/docs/api-reference/parameter-details + logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts + logits -= presence_penalties.unsqueeze(dim=1) * output_mask + return logits diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py new file mode 100644 index 0000000..b90792f --- /dev/null +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -0,0 +1,489 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import List, Optional, Sequence, Tuple + +import torch +import torch.nn.functional as F +import ixformer.inference.functions as IXF +from torch.nn.parameter import Parameter, UninitializedParameter + +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding) +from vllm.model_executor.parameter import BasevLLMParameter +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform + +DEFAULT_VOCAB_PADDING_SIZE = 64 + + +class UnquantizedEmbeddingMethod(QuantizeMethodBase): + """Unquantized method for embeddings.""" + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + """Create weights for embedding layer.""" + weight = Parameter(torch.empty(sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype), + requires_grad=False) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return IXF.linear(x, layer.weight, bias) + + def embedding(self, layer: torch.nn.Module, + input_: torch.Tensor) -> torch.Tensor: + return F.embedding(input_, layer.weight) + + +def pad_vocab_size(vocab_size: int, + pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int: + """Pad the vocab size to the given value.""" + return ((vocab_size + pad_to - 1) // pad_to) * pad_to + + +def vocab_range_from_per_partition_vocab_size( + per_partition_vocab_size: int, + rank: int, + offset: int = 0) -> Sequence[int]: + index_f = rank * per_partition_vocab_size + index_l = index_f + per_partition_vocab_size + return index_f + offset, index_l + offset + + +def vocab_range_from_global_vocab_size(global_vocab_size: int, + rank: int, + world_size: int, + offset: int = 0) -> Sequence[int]: + per_partition_vocab_size = divide(global_vocab_size, world_size) + return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, + rank, + offset=offset) + + +@dataclass +class VocabParallelEmbeddingShardIndices: + """Indices for a shard of a vocab parallel embedding.""" + padded_org_vocab_start_index: int + padded_org_vocab_end_index: int + padded_added_vocab_start_index: int + padded_added_vocab_end_index: int + + org_vocab_start_index: int + org_vocab_end_index: int + added_vocab_start_index: int + added_vocab_end_index: int + + @property + def num_org_elements(self) -> int: + return self.org_vocab_end_index - self.org_vocab_start_index + + @property + def num_added_elements(self) -> int: + return self.added_vocab_end_index - self.added_vocab_start_index + + @property + def num_org_elements_padded(self) -> int: + return (self.padded_org_vocab_end_index - + self.padded_org_vocab_start_index) + + @property + def num_added_elements_padded(self) -> int: + return (self.padded_added_vocab_end_index - + self.padded_added_vocab_start_index) + + @property + def num_org_vocab_padding(self) -> int: + return self.num_org_elements_padded - self.num_org_elements + + @property + def num_added_vocab_padding(self) -> int: + return self.num_added_elements_padded - self.num_added_elements + + @property + def num_elements_padded(self) -> int: + return self.num_org_elements_padded + self.num_added_elements_padded + + def __post_init__(self): + # sanity checks + assert (self.padded_org_vocab_start_index + <= self.padded_org_vocab_end_index) + assert (self.padded_added_vocab_start_index + <= self.padded_added_vocab_end_index) + + assert self.org_vocab_start_index <= self.org_vocab_end_index + assert self.added_vocab_start_index <= self.added_vocab_end_index + + assert self.org_vocab_start_index <= self.padded_org_vocab_start_index + assert (self.added_vocab_start_index + <= self.padded_added_vocab_start_index) + assert self.org_vocab_end_index <= self.padded_org_vocab_end_index + assert self.added_vocab_end_index <= self.padded_added_vocab_end_index + + assert self.num_org_elements <= self.num_org_elements_padded + assert self.num_added_elements <= self.num_added_elements_padded + + +# @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) +def get_masked_input_and_mask( + input_: torch.Tensor, org_vocab_start_index: int, + org_vocab_end_index: int, num_org_vocab_padding: int, + added_vocab_start_index: int, + added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]: + # torch.compile will fuse all of the pointwise ops below + # into a single kernel, making it very fast + org_vocab_mask = (input_ >= org_vocab_start_index) & ( + input_ < org_vocab_end_index) + added_vocab_mask = (input_ >= added_vocab_start_index) & ( + input_ < added_vocab_end_index) + added_offset = added_vocab_start_index - ( + org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding + # valid_offset = (org_vocab_start_index * + # org_vocab_mask) + (added_offset * added_vocab_mask) + # torch.complie with upper code causes wrong results, the following is working. + # I guess our torch.compile processes "int * bool_tensor" still get a 0/1 results(should have a int results) + valid_offset = (org_vocab_start_index * + org_vocab_mask.int()) + (added_offset * added_vocab_mask.int()) + vocab_mask = org_vocab_mask | added_vocab_mask + input_ = vocab_mask * (input_ - valid_offset) + return input_, ~vocab_mask + + +class VocabParallelEmbedding(torch.nn.Module): + """Embedding parallelized in the vocabulary dimension. + + Adapted from torch.nn.Embedding, note that we pad the vocabulary size to + make sure it is divisible by the number of model parallel GPUs. + + In order to support various loading methods, we ensure that LoRA-added + embeddings are always at the end of TP-sharded tensors. In other words, + we shard base embeddings and LoRA embeddings separately (both padded), + and place them in the same tensor. + In this example, we will have the original vocab size = 1010, + added vocab size = 16 and padding to 64. Therefore, the total + vocab size with padding will be 1088 (because we first pad 1010 to + 1024, add 16, and then pad to 1088). + Therefore, the tensor format looks like the following: + TP1, rank 0 (no sharding): + |< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >| + corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1015 | -1 | ... | -1 | + index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 | + + TP2, rank 0: + |< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >| + corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1000 | ... | 1015 | -1 | ... | -1 | + index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 520 | ... | 543 | + TP2, rank 1: + |< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >| + corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 | + index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 | + + Args: + num_embeddings: vocabulary size. + embedding_dim: size of hidden state. + params_dtype: type of the parameters. + org_num_embeddings: original vocabulary size (without LoRA). + padding_size: padding size for the vocabulary. + quant_config: quant config for the layer + prefix: full name of the layer in the state dict + """ # noqa: E501 + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + + # Keep the input dimensions. + tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + self.num_embeddings = num_embeddings + self.padding_size = padding_size + self.org_vocab_size = org_num_embeddings or num_embeddings + num_added_embeddings = num_embeddings - self.org_vocab_size + self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size, + self.padding_size) + self.num_embeddings_padded = pad_vocab_size( + self.org_vocab_size_padded + num_added_embeddings, + self.padding_size) + assert self.org_vocab_size_padded <= self.num_embeddings_padded + + self.shard_indices = self._get_indices(self.num_embeddings_padded, + self.org_vocab_size_padded, + self.num_embeddings, + self.org_vocab_size, tp_rank, + self.tp_size) + self.embedding_dim = embedding_dim + + quant_method = None + if quant_config is not None: + quant_method = quant_config.get_quant_method(self, prefix=prefix) + if quant_method is None: + quant_method = UnquantizedEmbeddingMethod() + + # If we are making an embedding layer, then our quantization linear + # method must implement the embedding operation. If we are another + # layer type like ParallelLMHead, this is not important. + is_embedding_layer = type(self) is VocabParallelEmbedding + quant_method_implements_embedding = method_has_implemented_embedding( + type(quant_method)) + if is_embedding_layer and not quant_method_implements_embedding: + raise NotImplementedError( + f"The class {type(quant_method).__name__} must implement " + "the 'embedding' method, see UnquantizedEmbeddingMethod.") + + self.quant_method: QuantizeMethodBase = quant_method + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + # Divide the weight matrix along the vocaburaly dimension. + self.num_added_embeddings = self.num_embeddings - self.org_vocab_size + self.num_embeddings_per_partition = divide(self.num_embeddings_padded, + self.tp_size) + assert (self.shard_indices.num_elements_padded == + self.num_embeddings_per_partition) + self.num_org_embeddings_per_partition = ( + self.shard_indices.org_vocab_end_index - + self.shard_indices.org_vocab_start_index) + self.num_added_embeddings_per_partition = ( + self.shard_indices.added_vocab_end_index - + self.shard_indices.added_vocab_start_index) + + self.quant_method.create_weights(self, + self.embedding_dim, + [self.num_embeddings_per_partition], + self.embedding_dim, + self.num_embeddings_padded, + params_dtype=params_dtype, + weight_loader=self.weight_loader) + + @classmethod + def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int, + vocab_size: int, org_vocab_size: int, tp_rank: int, + tp_size: int) -> VocabParallelEmbeddingShardIndices: + """Get start and end indices for vocab parallel embedding, following the + layout outlined in the class docstring, based on the given tp_rank and + tp_size.""" + num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded + padded_org_vocab_start_index, padded_org_vocab_end_index = ( + vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank, + tp_size)) + padded_added_vocab_start_index, padded_added_vocab_end_index = ( + vocab_range_from_global_vocab_size(num_added_embeddings_padded, + tp_rank, + tp_size, + offset=org_vocab_size)) + # remove padding + org_vocab_start_index = min(padded_org_vocab_start_index, + org_vocab_size) + org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size) + added_vocab_start_index = min(padded_added_vocab_start_index, + vocab_size) + added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size) + return VocabParallelEmbeddingShardIndices( + padded_org_vocab_start_index, padded_org_vocab_end_index, + padded_added_vocab_start_index, padded_added_vocab_end_index, + org_vocab_start_index, org_vocab_end_index, + added_vocab_start_index, added_vocab_end_index) + + def get_sharded_to_full_mapping(self) -> Optional[List[int]]: + """Get a mapping that can be used to reindex the gathered + logits for sampling. + + During sampling, we gather logits from all ranks. The relationship + of index->token_id will follow the same format as outlined in the class + docstring. However, after the gather, we want to reindex the final + logits tensor to map index->token_id one-to-one (the index is always + equal the token_id it corresponds to). The indices returned by this + method allow us to do that. + """ + if self.tp_size < 2: + return None + + base_embeddings: List[int] = [] + added_embeddings: List[int] = [] + padding: List[int] = [] + for tp_rank in range(self.tp_size): + shard_indices = self._get_indices(self.num_embeddings_padded, + self.org_vocab_size_padded, + self.num_embeddings, + self.org_vocab_size, tp_rank, + self.tp_size) + range_start = self.num_embeddings_per_partition * tp_rank + range_end = self.num_embeddings_per_partition * (tp_rank + 1) + base_embeddings.extend( + range(range_start, + range_start + shard_indices.num_org_elements)) + padding.extend( + range(range_start + shard_indices.num_org_elements, + range_start + shard_indices.num_org_elements_padded)) + added_embeddings.extend( + range( + range_start + shard_indices.num_org_elements_padded, + range_start + shard_indices.num_org_elements_padded + + shard_indices.num_added_elements)) + padding.extend( + range( + range_start + shard_indices.num_org_elements_padded + + shard_indices.num_added_elements, + range_start + shard_indices.num_org_elements_padded + + shard_indices.num_added_elements_padded)) + assert (range_start + shard_indices.num_org_elements_padded + + shard_indices.num_added_elements_padded == range_end) + ret = base_embeddings + added_embeddings + padding + assert len(ret) == self.num_embeddings_padded + return ret + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + output_dim = getattr(param, "output_dim", None) + packed_dim = getattr(param, "packed_dim", None) + + # If the parameter is a gguf weight, then load it directly. + if getattr(param, "is_gguf_weight_type", None): + param.data.copy_(loaded_weight) + param.weight_type = loaded_weight.item() + return + elif isinstance(param, UninitializedParameter): + shape = list(loaded_weight.shape) + if output_dim is not None: + shape[output_dim] = self.num_embeddings_per_partition + param.materialize(tuple(shape), dtype=loaded_weight.dtype) + + # If parameter does not have output dim, then it should + # be copied onto all gpus (e.g. g_idx for act_order gptq). + if output_dim is None: + assert param.data.shape == loaded_weight.shape + param.data.copy_(loaded_weight) + return + + # Shard indexes for loading the weight + start_idx = self.shard_indices.org_vocab_start_index + shard_size = self.shard_indices.org_vocab_end_index - start_idx + + # If param packed on the same dim we are sharding on, then + # need to adjust offsets of loaded weight by pack_factor. + if packed_dim is not None and packed_dim == output_dim: + packed_factor = param.packed_factor if isinstance( + param, BasevLLMParameter) else param.pack_factor + assert loaded_weight.shape[output_dim] == (self.org_vocab_size // + param.packed_factor) + start_idx = start_idx // packed_factor + shard_size = shard_size // packed_factor + else: + assert loaded_weight.shape[output_dim] == self.org_vocab_size + + # Copy the data. Select chunk corresponding to current shard. + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + + if current_platform.is_hpu(): + # FIXME(kzawora): Weight copy with slicing bugs out on Gaudi here, + # so we're using a workaround. Remove this when fixed in + # HPU PT bridge. + padded_weight = torch.cat([ + loaded_weight, + torch.zeros(param.shape[0] - loaded_weight.shape[0], + *loaded_weight.shape[1:]) + ]) + param.data.copy_(padded_weight) + else: + param[:loaded_weight.shape[0]].data.copy_(loaded_weight) + param[loaded_weight.shape[0]:].data.fill_(0) + + def forward(self, input_): + if self.tp_size > 1: + # Build the mask. + masked_input, input_mask = get_masked_input_and_mask( + input_, self.shard_indices.org_vocab_start_index, + self.shard_indices.org_vocab_end_index, + self.shard_indices.num_org_vocab_padding, + self.shard_indices.added_vocab_start_index, + self.shard_indices.added_vocab_end_index) + else: + masked_input = input_ + # Get the embeddings. + output_parallel = self.quant_method.embedding(self, + masked_input.long()) + # Mask the output embedding. + if self.tp_size > 1: + output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) + # Reduce across all the model parallel GPUs. + output = tensor_model_parallel_all_reduce(output_parallel) + return output + + def extra_repr(self) -> str: + s = f"num_embeddings={self.num_embeddings_per_partition}" + s += f", embedding_dim={self.embedding_dim}" + s += f", org_vocab_size={self.org_vocab_size}" + s += f', num_embeddings_padded={self.num_embeddings_padded}' + s += f', tp_size={self.tp_size}' + return s + + +class ParallelLMHead(VocabParallelEmbedding): + """Parallelized LM head. + + Output logits weight matrices used in the Sampler. The weight and bias + tensors are padded to make sure they are divisible by the number of + model parallel GPUs. + + Args: + num_embeddings: vocabulary size. + embedding_dim: size of hidden state. + bias: whether to use bias. + params_dtype: type of the parameters. + org_num_embeddings: original vocabulary size (without LoRA). + padding_size: padding size for the vocabulary. + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + bias: bool = False, + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__(num_embeddings, embedding_dim, params_dtype, + org_num_embeddings, padding_size, quant_config, + prefix) + self.quant_config = quant_config + if bias: + self.bias = Parameter( + torch.empty(self.num_embeddings_per_partition, + dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + + def tie_weights(self, embed_tokens: VocabParallelEmbedding): + """Tie the weights with word embeddings.""" + # GGUF quantized embed_tokens. + if self.quant_config and self.quant_config.get_name() == "gguf": + return embed_tokens + else: + self.weight = embed_tokens.weight + return self + + def forward(self, input_): + del input_ + raise RuntimeError("LMHead's weights should be used in the sampler.") diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py new file mode 100644 index 0000000..9048c70 --- /dev/null +++ b/vllm/model_executor/model_loader/__init__.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 + +from torch import nn + +from vllm.config import VllmConfig +from vllm.model_executor.model_loader.loader import (BaseModelLoader, + get_model_loader) +from vllm.model_executor.model_loader.utils import ( + get_architecture_class_name, get_model_architecture) + + +def get_model(*, vllm_config: VllmConfig) -> nn.Module: + loader = get_model_loader(vllm_config.load_config) + return loader.load_model(vllm_config=vllm_config) + + +__all__ = [ + "get_model", "get_model_loader", "BaseModelLoader", + "get_architecture_class_name", "get_model_architecture" +] diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py new file mode 100644 index 0000000..4314d1a --- /dev/null +++ b/vllm/model_executor/model_loader/loader.py @@ -0,0 +1,1564 @@ +# SPDX-License-Identifier: Apache-2.0 + +# ruff: noqa: SIM117 +import collections +import copy +import dataclasses +import fnmatch +import glob +import inspect +import itertools +import math +import os +import time +import warnings +from abc import ABC, abstractmethod +from contextlib import contextmanager +from typing import (Any, Callable, Dict, Generator, Iterable, List, Optional, + Tuple, cast) + +import gguf +import huggingface_hub +import numpy as np +import torch +from huggingface_hub import HfApi +from torch import nn +from transformers import AutoModelForCausalLM +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME + +from vllm.attention import Attention +from vllm.config import (LoadConfig, LoadFormat, ModelConfig, ParallelConfig, + VllmConfig, set_current_vllm_config) +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.envs import VLLM_USE_MODELSCOPE +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizeMethodBase) +from vllm.model_executor.model_loader.tensorizer import ( + TensorizerConfig, is_vllm_tensorized, load_with_tensorizer, + serialize_vllm_model, tensorizer_weights_iterator) +from vllm.model_executor.model_loader.utils import (ParamMapping, + configure_quant_config, + get_model_architecture, + set_default_torch_dtype) +from vllm.model_executor.model_loader.weight_utils import ( + download_safetensors_index_file_from_hf, download_weights_from_hf, + fastsafetensors_weights_iterator, filter_duplicate_safetensors_files, + filter_files_not_needed_for_inference, get_gguf_extra_tensor_names, + get_lock, gguf_quant_weights_iterator, initialize_dummy_weights, + np_cache_weights_iterator, pt_weights_iterator, + runai_safetensors_weights_iterator, safetensors_weights_iterator) +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.transformers_utils.s3_utils import glob as s3_glob +from vllm.transformers_utils.utils import is_s3 +from vllm.utils import is_pin_memory_available +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding + + +@contextmanager +def device_loading_context(module: torch.nn.Module, + target_device: torch.device): + if target_device.type == "cpu": + # If target is CPU, no need to move anything + yield module + return + + original_device_states: Dict[str, torch.device] = {} + + # Store original device states and move parameters to GPU if they're on CPU + for name, p in module.named_parameters(): + if p.device.type == "cpu": + original_device_states[name] = p.device + p.data = p.data.to(target_device) + # Parameters already on target device are not touched + + try: + yield module + + finally: + # Restore parameters to their original devices, ignoring new parameters + pin_memory = is_pin_memory_available() + for name, p in module.named_parameters(): + if name in original_device_states: + original_device: torch.device = original_device_states[name] + if original_device.type == "cpu": + # `torch.empty_like` does not support `pin_memory` argument + cpu_data = torch.empty_strided( + size=p.data.size(), + stride=p.data.stride(), + dtype=p.data.dtype, + layout=p.data.layout, + device="cpu", + pin_memory=pin_memory, + ) + cpu_data.copy_(p.data) + p.data = cpu_data + else: + p.data = p.data.to(original_device) + # New parameters or parameters already on target device are untouched + + +logger = init_logger(__name__) + + +def _initialize_model( + vllm_config: VllmConfig, + *, + prefix: str = "", +) -> nn.Module: + """Initialize a model with the given configurations.""" + model_config = vllm_config.model_config + model_class, _ = get_model_architecture(model_config) + + if vllm_config.quant_config is not None: + configure_quant_config(vllm_config.quant_config, model_class) + + signatures = inspect.signature(model_class.__init__) + all_params = [param.name for param in signatures.parameters.values()] + if "vllm_config" in all_params and "prefix" in all_params: + # new-style model class + with set_current_vllm_config(vllm_config, check_compile=True): + return model_class(vllm_config=vllm_config, prefix=prefix) + + msg = ("vLLM model class should accept `vllm_config` and `prefix` as " + "input arguments. Possibly you have an old-style model class" + " registered from out of tree and it is used for new vLLM version. " + "Check https://docs.vllm.ai/en/latest/design/arch_overview.html " + "for the design and update the model class accordingly.") + warnings.warn(msg, DeprecationWarning, stacklevel=2) + + logger.warning( + "Trying to guess the arguments for old-style model class %s", + model_class, + ) + # try to be compatible with old-style model class + kwargs = {} + if "prefix" in all_params: + kwargs["prefix"] = prefix + if "config" in all_params: + kwargs["config"] = model_config.hf_config + if "cache_config" in all_params: + kwargs["cache_config"] = vllm_config.cache_config + if "quant_config" in all_params: + kwargs["quant_config"] = vllm_config.quant_config + if "lora_config" in all_params: + kwargs["lora_config"] = vllm_config.lora_config + if "scheduler_config" in all_params: + kwargs["scheduler_config"] = vllm_config.scheduler_config + with set_current_vllm_config(vllm_config, check_compile=True): + return model_class(**kwargs) + + +def _process_weights_after_loading(model: nn.Module, model_config: ModelConfig, + target_device: torch.device) -> None: + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if isinstance(quant_method, QuantizeMethodBase): + # When quant methods need to process weights after loading + # (for repacking, quantizing, etc), they expect parameters + # to be on the global target device. This scope is for the + # case where cpu offloading is used, where we will move the + # parameters onto device for processing and back off after. + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) + + # Currently only used by MLA. + # NOTE: This intentionally happens after other modules so we can easily + # decompress the weights for MLA. + for _, module in model.named_modules(): + if isinstance(module, Attention) and \ + hasattr(module, "process_weights_after_loading"): + # TODO(lucas): see if there is a way to unify the signatures + # of process_weights_after_loading + module.process_weights_after_loading(model_config.dtype) + + +class BaseModelLoader(ABC): + """Base class for model loaders.""" + + def __init__(self, load_config: LoadConfig): + self.load_config = load_config + + @abstractmethod + def download_model(self, model_config: ModelConfig) -> None: + """Download a model so that it can be immediately loaded.""" + raise NotImplementedError + + @abstractmethod + def load_model(self, *, vllm_config: VllmConfig) -> nn.Module: + """Load a model with the given configurations.""" + raise NotImplementedError + + +class DefaultModelLoader(BaseModelLoader): + """Model loader that can load different file types from disk.""" + + @dataclasses.dataclass + class Source: + """A source for weights.""" + + model_or_path: str + """The model ID or path.""" + + revision: Optional[str] + """The optional model revision.""" + + prefix: str = "" + """A prefix to prepend to all weights.""" + + fall_back_to_pt: bool = True + """Whether .pt weights can be used.""" + + allow_patterns_overrides: Optional[list[str]] = None + """If defined, weights will load exclusively using these patterns.""" + + counter_before_loading_weights: float = 0.0 + counter_after_loading_weights: float = 0.0 + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def _maybe_download_from_modelscope( + self, model: str, revision: Optional[str]) -> Optional[str]: + """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True. + + Returns the path to the downloaded model, or None if the model is not + downloaded from ModelScope.""" + if VLLM_USE_MODELSCOPE: + # download model from ModelScope hub, + # lazy import so that modelscope is not required for normal use. + # pylint: disable=C. + from modelscope.hub.snapshot_download import snapshot_download + + if not os.path.exists(model): + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model, self.load_config.download_dir): + model_path = snapshot_download( + model_id=model, + cache_dir=self.load_config.download_dir, + local_files_only=huggingface_hub.constants. + HF_HUB_OFFLINE, + revision=revision, + ignore_file_pattern=self.load_config.ignore_patterns, + ) + else: + model_path = model + return model_path + return None + + def _prepare_weights( + self, + model_name_or_path: str, + revision: Optional[str], + fall_back_to_pt: bool, + allow_patterns_overrides: Optional[list[str]], + ) -> Tuple[str, List[str], bool]: + """Prepare weights for the model. + + If the model is not local, it will be downloaded.""" + model_name_or_path = (self._maybe_download_from_modelscope( + model_name_or_path, revision) or model_name_or_path) + + is_local = os.path.isdir(model_name_or_path) + load_format = self.load_config.load_format + use_safetensors = False + index_file = SAFE_WEIGHTS_INDEX_NAME + # Some quantized models use .pt files for storing the weights. + if load_format == LoadFormat.AUTO: + allow_patterns = ["*.safetensors", "*.bin"] + elif (load_format == LoadFormat.SAFETENSORS + or load_format == LoadFormat.FASTSAFETENSORS): + use_safetensors = True + allow_patterns = ["*.safetensors"] + elif load_format == LoadFormat.MISTRAL: + use_safetensors = True + allow_patterns = ["consolidated*.safetensors"] + index_file = "consolidated.safetensors.index.json" + elif load_format == LoadFormat.PT: + allow_patterns = ["*.pt"] + elif load_format == LoadFormat.NPCACHE: + allow_patterns = ["*.bin"] + else: + raise ValueError(f"Unknown load_format: {load_format}") + + if fall_back_to_pt: + allow_patterns += ["*.pt"] + + if allow_patterns_overrides is not None: + allow_patterns = allow_patterns_overrides + + if not is_local: + hf_folder = download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + allow_patterns, + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + else: + hf_folder = model_name_or_path + + hf_weights_files: List[str] = [] + for pattern in allow_patterns: + hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) + if len(hf_weights_files) > 0: + if pattern == "*.safetensors": + use_safetensors = True + break + + if use_safetensors: + # For models like Mistral-7B-Instruct-v0.3 + # there are both sharded safetensors files and a consolidated + # safetensors file. Using both breaks. + # Here, we download the `model.safetensors.index.json` and filter + # any files not found in the index. + if not is_local: + download_safetensors_index_file_from_hf( + model_name_or_path, + index_file, + self.load_config.download_dir, + revision, + ) + hf_weights_files = filter_duplicate_safetensors_files( + hf_weights_files, hf_folder, index_file) + else: + hf_weights_files = filter_files_not_needed_for_inference( + hf_weights_files) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`") + + return hf_folder, hf_weights_files, use_safetensors + + def _get_weights_iterator( + self, source: "Source" + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Get an iterator for the model weights based on the load format.""" + hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( + source.model_or_path, source.revision, source.fall_back_to_pt, + source.allow_patterns_overrides) + if self.load_config.load_format == LoadFormat.NPCACHE: + # Currently np_cache only support *.bin checkpoints + assert use_safetensors is False + weights_iterator = np_cache_weights_iterator( + source.model_or_path, + self.load_config.download_dir, + hf_folder, + hf_weights_files, + self.load_config.use_tqdm_on_load, + ) + elif use_safetensors: + if self.load_config.load_format == LoadFormat.FASTSAFETENSORS: + weights_iterator = fastsafetensors_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + ) + else: + weights_iterator = safetensors_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + ) + else: + weights_iterator = pt_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + ) + + if current_platform.is_tpu(): + # In PyTorch XLA, we should call `xm.mark_step` frequently so that + # not too many ops are accumulated in the XLA program. + import torch_xla.core.xla_model as xm + + def _xla_weights_iterator(iterator: Generator): + for weights in iterator: + yield weights + xm.mark_step() + + weights_iterator = _xla_weights_iterator(weights_iterator) + + elif current_platform.is_hpu(): + import habana_frameworks.torch.core as htcore + + def _hpu_weights_iterator(iterator: Generator): + for weights in iterator: + yield weights + htcore.mark_step() + + weights_iterator = _hpu_weights_iterator(weights_iterator) + + if self.counter_before_loading_weights == 0.0: + self.counter_before_loading_weights = time.perf_counter() + # Apply the prefix. + return ((source.prefix + name, tensor) + for (name, tensor) in weights_iterator) + + def _get_all_weights( + self, + model_config: ModelConfig, + model: nn.Module, + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + primary_weights = DefaultModelLoader.Source( + model_config.model, + model_config.revision, + prefix="", + fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", + True), + allow_patterns_overrides=getattr(model, "allow_patterns_overrides", + None), + ) + yield from self._get_weights_iterator(primary_weights) + + secondary_weights = cast( + Iterable[DefaultModelLoader.Source], + getattr(model, "secondary_weights", ()), + ) + for source in secondary_weights: + yield from self._get_weights_iterator(source) + + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model, + model_config.revision, + fall_back_to_pt=True, + allow_patterns_overrides=None) + + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + device_config = vllm_config.device_config + model_config = vllm_config.model_config + target_device = torch.device(device_config.device) + with set_default_torch_dtype(model_config.dtype): + with target_device: + model = _initialize_model(vllm_config=vllm_config) + + weights_to_load = {name for name, _ in model.named_parameters()} + loaded_weights = model.load_weights( + self._get_all_weights(model_config, model)) + self.counter_after_loading_weights = time.perf_counter() + logger.info( + "Loading weights took %.2f seconds", + self.counter_after_loading_weights - + self.counter_before_loading_weights) + # We only enable strict check for non-quantized models + # that have loaded weights tracking currently. + if model_config.quantization is None and loaded_weights is not None: + weights_not_loaded = weights_to_load - loaded_weights + if weights_not_loaded: + raise ValueError( + "Following weights were not initialized from " + f"checkpoint: {weights_not_loaded}") + + _process_weights_after_loading(model, model_config, target_device) + + return model.eval() + + +class DummyModelLoader(BaseModelLoader): + """Model loader that will set model weights to random values.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def download_model(self, model_config: ModelConfig) -> None: + pass # Nothing to download + + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + device_config = vllm_config.device_config + model_config = vllm_config.model_config + target_device = torch.device(device_config.device) + with set_default_torch_dtype(model_config.dtype): + with target_device: + model = _initialize_model(vllm_config=vllm_config) + # NOTE(woosuk): For accurate performance evaluation, we assign + # random values to the weights. + initialize_dummy_weights(model) + + _process_weights_after_loading(model, model_config, target_device) + return model.eval() + + +class TensorizerLoader(BaseModelLoader): + """Model loader using CoreWeave's tensorizer library.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if isinstance(load_config.model_loader_extra_config, TensorizerConfig): + self.tensorizer_config = load_config.model_loader_extra_config + else: + self.tensorizer_config = TensorizerConfig( + **load_config.model_loader_extra_config) + + def _verify_config(self, model_config: ModelConfig, + parallel_config: ParallelConfig): + self.tensorizer_config.verify_with_model_config(model_config) + self.tensorizer_config.verify_with_parallel_config(parallel_config) + + def _get_weights_iterator( + self, ) -> Generator[Tuple[str, torch.Tensor], None, None]: + tensorizer_args = self.tensorizer_config._construct_tensorizer_args() + return tensorizer_weights_iterator(tensorizer_args) + + def _load_model_serialized_cpu( + self, + vllm_config: VllmConfig, + ) -> nn.Module: + """Load a serialized model with tensorizer to the CPU. + + This is only necessary when the model isn't vLLM-tensorized (see + examples/other/tensorize_vllm_model.py) This should still + be faster than default HuggingFace loading, but will be slower than + loading a vLLM-tensorized model. + """ + device_config = vllm_config.device_config + model_config = vllm_config.model_config + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(vllm_config=vllm_config) + + model.load_weights(self._get_weights_iterator()) + return model.eval() + + def _load_model_serialized( + self, + vllm_config: VllmConfig, + ) -> nn.Module: + """Load a serialized model with tensorizer. + + Expects a vLLM-tensorized model. See the + examples/other/tensorize_vllm_model.py example script + for serializing vLLM models.""" + + device_config = vllm_config.device_config + model_config = vllm_config.model_config + + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model_class = get_model_architecture(model_config)[0] + + tensorizer_config = copy.copy(self.tensorizer_config) + tensorizer_config.model_class = model_class + tensorizer_config.hf_config = model_config.hf_config + tensorizer_config.dtype = model_config.dtype + + model = load_with_tensorizer(tensorizer_config, + vllm_config=vllm_config) + return model.eval() + + def download_model(self, model_config: ModelConfig) -> None: + self.tensorizer_config.verify_with_model_config(model_config) + + with self.tensorizer_config.open_stream(): + pass + + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config + self._verify_config(model_config, parallel_config) + + if parallel_config.tensor_parallel_size > 1: + from vllm.distributed import get_tensor_model_parallel_rank + + self.tensorizer_config.tensorizer_uri = ( + self.tensorizer_config.tensorizer_uri % + get_tensor_model_parallel_rank()) + + if is_vllm_tensorized(self.tensorizer_config): + return self._load_model_serialized(vllm_config=vllm_config) + return self._load_model_serialized_cpu(vllm_config=vllm_config) + + @staticmethod + def save_model( + model: torch.nn.Module, + tensorizer_config: TensorizerConfig, + ) -> None: + serialize_vllm_model( + model=model, + tensorizer_config=tensorizer_config, + ) + + +class ShardedStateLoader(BaseModelLoader): + """ + Model loader that directly loads each worker's model state dict, which + enables a fast load path for large tensor-parallel models where each worker + only needs to read its own shard rather than the entire checkpoint. See + `examples/offline_inference/save_sharded_state.py` for creating a sharded + checkpoint. + """ + + DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + extra_config = ({} if load_config.model_loader_extra_config is None + else load_config.model_loader_extra_config.copy()) + self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN) + if extra_config: + raise ValueError(f"Unexpected extra config keys for load format " + f"{load_config.load_format}: " + f"{load_config.model_loader_extra_config.keys()}") + + @staticmethod + def _filter_subtensors( + tensors: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]: + """ + Filter out all tensors that share the same memory or a subset of the + memory of another tensor. + """ + same_storage_groups: Dict[Any, List[Tuple[str, torch.Tensor]]] = ( + collections.defaultdict(list)) + for key, tensor in tensors.items(): + if tensor.numel(): + ptr = tensor.untyped_storage().data_ptr() + same_storage_groups[tensor.device, ptr].append((key, tensor)) + + def get_end_ptr(tensor: torch.Tensor) -> int: + return tensor.view(-1)[-1].data_ptr() + tensor.element_size() + + result: Dict[str, torch.Tensor] = {} + for group in same_storage_groups.values(): + for k, t in group: + a, b = t.data_ptr(), get_end_ptr(t) + for k2, t2 in group: + if not t2.is_contiguous(): + continue + a2, b2 = t2.data_ptr(), get_end_ptr(t2) + if a < a2 or b2 < b: + continue + if a2 < a or b < b2 or not t.is_contiguous(): + break # t2 covers strictly more memory than t. + if k2 < k: + # Same tensors, keep the one with the smaller key. + break + else: + result[k] = t + return result + + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str]): + if os.path.isdir(model_name_or_path): + return model_name_or_path + else: + allow_patterns = ["*.safetensors"] + return download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + allow_patterns, + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model, model_config.revision) + + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + device_config = vllm_config.device_config + model_config = vllm_config.model_config + target_device = torch.device(device_config.device) + from safetensors.torch import safe_open + + from vllm.distributed import get_tensor_model_parallel_rank + + local_model_path = self._prepare_weights(model_config.model, + model_config.revision) + + with set_default_torch_dtype(model_config.dtype): + with target_device: + model = _initialize_model(vllm_config=vllm_config) + _process_weights_after_loading(model, model_config, + target_device) + rank = get_tensor_model_parallel_rank() + pattern = os.path.join( + local_model_path, + self.pattern.format(rank=rank, part="*"), + ) + filepaths = glob.glob(pattern) + if not filepaths: + # TODO: support un-sharded checkpoints too + raise ValueError( + f"Could not find checkpoint files '{pattern}', only " + f"pre-sharded checkpoints are currently supported!") + state_dict = self._filter_subtensors(model.state_dict()) + for path in filepaths: + with safe_open(path, framework="pt") as f: + for key in f.keys(): # noqa: SIM118 + tensor = f.get_tensor(key) + # If loading with LoRA enabled, additional padding may + # be added to certain parameters. We only load into a + # narrowed view of the parameter data. + param_data = state_dict[key].data + param_shape = state_dict[key].shape + for dim, size in enumerate(tensor.shape): + if size < param_shape[dim]: + param_data = param_data.narrow(dim, 0, size) + if tensor.shape != param_shape: + logger.warning( + "loading tensor of shape %s into " + "parameter '%s' of shape %s", + tensor.shape, + key, + param_shape, + ) + param_data.copy_(tensor) + state_dict.pop(key) + if state_dict: + raise ValueError( + f"Missing keys {tuple(state_dict)} in loaded state!") + return model.eval() + + @staticmethod + def save_model( + model: torch.nn.Module, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + from safetensors.torch import save_file + + from vllm.distributed import get_tensor_model_parallel_rank + + if pattern is None: + pattern = ShardedStateLoader.DEFAULT_PATTERN + rank = get_tensor_model_parallel_rank() + part_idx = 0 + total_size = 0 + state_dict = ShardedStateLoader._filter_subtensors(model.state_dict()) + state_dict_part: Dict[str, torch.Tensor] = {} + for key, tensor in state_dict.items(): + param_size = tensor.nelement() * tensor.element_size() + if max_size is not None and total_size + param_size > max_size: + filename = pattern.format(rank=rank, part=part_idx) + save_file( + state_dict_part, + os.path.join(path, filename), + ) + part_idx += 1 + total_size = 0 + state_dict_part = {} + state_dict_part[key] = tensor + total_size += param_size + if len(state_dict_part) > 0: + filename = pattern.format(rank=rank, part=part_idx) + save_file( + state_dict_part, + os.path.join(path, filename), + ) + + +class BitsAndBytesModelLoader(BaseModelLoader): + """Model loader to load model weights with BitAndBytes quantization.""" + + possible_config_file_names = ["adapter_config.json"] + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + + # Save the module names without sharding. + self.unsharded_weights_modules: List[str] = [] + # Save the module names that are sharded by column. + self.column_sharded_weights_modules: List[str] = [] + # Store all module names (from transformers) that support + # BNB quantization. + self.target_modules: List[str] = [] + # mapping weight names from transformers to vllm. + self.weight_mapper: Callable = lambda name: name + + def _get_weight_files( + self, + model_name_or_path: str, + allowed_patterns: List[str], + revision: Optional[str] = None, + ) -> Tuple[str, List[str], str]: + """Retrieve weight files. Download the files if necessary. + + Return the weight files and the file pattern.""" + is_local = os.path.isdir(model_name_or_path) + + if is_local: + for pattern in allowed_patterns: + weight_files = glob.glob( + os.path.join(model_name_or_path, pattern)) + if weight_files: + return model_name_or_path, weight_files, pattern + else: + hf_api = HfApi() + repo_files = hf_api.list_repo_files(repo_id=model_name_or_path) + for pattern in allowed_patterns: + matching_files = fnmatch.filter(repo_files, pattern) + if matching_files: + hf_folder = download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + [pattern], + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + return hf_folder, glob.glob( + os.path.join(hf_folder, pattern)), pattern + + raise RuntimeError( + f"No model weights found in: `{model_name_or_path}`") + + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str]) -> Tuple[List[str], bool]: + """Prepare weight files for the model.""" + + allowed_patterns = ["*.safetensors", "*.bin", "*.pt"] + + hf_folder, hf_weights_files, matched_pattern = self._get_weight_files( + model_name_or_path, allowed_patterns, revision) + + use_safetensors = matched_pattern == "*.safetensors" + is_local = os.path.isdir(model_name_or_path) + index_file = SAFE_WEIGHTS_INDEX_NAME + if use_safetensors: + # For models like Mistral-7B-Instruct-v0.3 + # there are both sharded safetensors files and a consolidated + # safetensors file. Using both breaks. + # Here, we download the `model.safetensors.index.json` and filter + # any files not found in the index. + if not is_local: + download_safetensors_index_file_from_hf( + model_name_or_path, + index_file, + self.load_config.download_dir, + revision, + ) + hf_weights_files = filter_duplicate_safetensors_files( + hf_weights_files, hf_folder, index_file) + else: + hf_weights_files = filter_files_not_needed_for_inference( + hf_weights_files) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`") + + return hf_weights_files, use_safetensors + + def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): + if use_safetensors: + iterator = safetensors_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + ) + else: + iterator = pt_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + ) + for org_name, param in iterator: + # mapping weight names from transformers to vllm while preserving + # original names. + mapped_name = self.weight_mapper(org_name) + yield org_name, mapped_name, param + + def _get_quantized_weights_iterator( + self, + model_name_or_path: str, + revision: Optional[str], + pre_quant: bool, + load_8bit: bool, + ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str, + Any]]: + """Get an iterator to the model weights with bitsandbytes quantization, + as well as the quantization state dictionary.""" + + # only load the bitsandbytes module when needed + try: + import bitsandbytes + + if bitsandbytes.__version__ < "0.45.3": + raise ImportError("bitsandbytes version is wrong. Please " + "install bitsandbytes>=0.45.3.") + except ImportError as err: + raise ImportError("Please install bitsandbytes>=0.45.3 via " + "`pip install bitsandbytes>=0.45.3` to use " + "bitsandbytes quantizer.") from err + + hf_weights_files, use_safetensors = self._prepare_weights( + model_name_or_path, revision) + + quant_state_dict: Dict[str, Any] = {} + + if pre_quant: + if load_8bit: + return self._quantized_8bit_generator( + hf_weights_files, use_safetensors, + quant_state_dict), quant_state_dict + else: + return self._quantized_4bit_generator( + hf_weights_files, use_safetensors, + quant_state_dict), quant_state_dict + + return self._unquantized_generator(hf_weights_files, use_safetensors, + quant_state_dict), quant_state_dict + + def _is_8bit_weight_name(self, weight_name: str): + quantized_suffix = {".scb", ".weight_format"} + return any(weight_name.lower().endswith(suffix) + for suffix in quantized_suffix) + + def _is_4bit_weight_name(self, weight_name: str): + quantized_suffix = { + "absmax", + "quant_map", + "nested_absmax", + "nested_quant_map", + "bitsandbytes", + } + suffix = weight_name.split(".")[-1] + return any(q_suffix in suffix for q_suffix in quantized_suffix) + + def _quantized_8bit_generator(self, hf_weights_files, use_safetensors, + quant_state_dict) -> Generator: + for ( + org_weight_name, + mapped_weight_name, + weight_tensor, + ) in self._hf_weight_iter(hf_weights_files, use_safetensors): + if not mapped_weight_name.lower().endswith(".scb"): + continue + + weight_key = mapped_weight_name.lower().replace(".scb", ".weight") + quant_state_dict[weight_key] = weight_tensor + + for ( + org_weight_name, + mapped_weight_name, + weight_tensor, + ) in self._hf_weight_iter(hf_weights_files, use_safetensors): + if self._is_8bit_weight_name(mapped_weight_name): + continue + + if mapped_weight_name in quant_state_dict: + set_weight_attrs(weight_tensor, {"load_in_8bit": True}) + yield org_weight_name, weight_tensor + else: + yield org_weight_name, weight_tensor + + def _quantized_4bit_generator(self, hf_weights_files, use_safetensors, + quant_state_dict) -> Generator: + from bitsandbytes.functional import QuantState + + # First iterate over all quant state weights + weight_iterator = self._hf_weight_iter(hf_weights_files, + use_safetensors) + temp_state_dict = {} + for ( + org_weight_name, + mapped_weight_name, + weight_tensor, + ) in weight_iterator: + if not self._is_4bit_weight_name(mapped_weight_name): + continue + # bitsandbytes library requires + # weight.quant_state.bitsandbytes__* in CPU + if "quant_state.bitsandbytes" in mapped_weight_name: + temp_state_dict[mapped_weight_name] = weight_tensor.cpu().data + else: + temp_state_dict[mapped_weight_name] = weight_tensor + + # Closure to parse quant_state for each prequant weight + def _parse_quant_state(param_name: str, + temp_state_dict: Dict) -> QuantState: + quant_state = {} + for k in temp_state_dict: + if param_name + "." in k: + quant_state[k] = temp_state_dict[k] + + return QuantState.from_dict(quant_state, + device=current_platform.device_type) + + # Second iterate over all prequant and normal weights + # pre quantized weights would have a quant_state + for ( + org_weight_name, + mapped_weight_name, + weight_tensor, + ) in self._hf_weight_iter(hf_weights_files, use_safetensors): + if self._is_4bit_weight_name(mapped_weight_name): + continue + + if (f"{mapped_weight_name}.quant_state.bitsandbytes__nf4" + in temp_state_dict) or ( + f"{mapped_weight_name}.quant_state.bitsandbytes__fp4" + in temp_state_dict): + quant_state = _parse_quant_state(mapped_weight_name, + temp_state_dict) + quant_state_dict[mapped_weight_name] = quant_state + yield org_weight_name, weight_tensor + else: + yield org_weight_name, weight_tensor + + def _unquantized_generator(self, hf_weights_files, use_safetensors, + quant_state_dict) -> Generator: + from bitsandbytes.functional import quantize_4bit + + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + for ( + org_weight_name, + mapped_weight_name, + weight_tensor, + ) in self._hf_weight_iter(hf_weights_files, use_safetensors): + if any(target_module in mapped_weight_name + for target_module in self.target_modules + ) and mapped_weight_name.endswith(".weight"): + # Without sharding + if any( + mapped_weight_name.startswith(module) + for module in self.unsharded_weights_modules): + weight_sub_tensor = weight_tensor + # Shard by column + elif any( + mapped_weight_name.startswith(module) + for module in self.column_sharded_weights_modules): + total_size = weight_tensor.size(-1) + start_index = total_size // tp_size * tp_rank + end_index = total_size // tp_size * (tp_rank + 1) + weight_sub_tensor = weight_tensor[..., + start_index:end_index] + # Weights have fused on disk. In this case, we assume that the + # weight and module use same name. + elif any( + mapped_weight_name.startswith(module) + for module in self.maybe_fused_weights_modules): + # special case for fused weights + # get the size of each shard weight tensor + total_shard_sizes = next( + (sizes for module, sizes in + self.maybe_fused_weights_modules.items() + if mapped_weight_name.startswith(module))) + total_size = weight_tensor.size(0) + assert total_size == sum(total_shard_sizes) + # get the start/end index of each shard weight tensor + total_start_index = list( + itertools.accumulate([0] + total_shard_sizes))[:-1] + shard_weights_index = [( + idx + size // tp_size * tp_rank, + idx + size // tp_size * (tp_rank + 1), + ) for idx, size in zip(total_start_index, + total_shard_sizes)] + # slice and reorder the weight tensor + weight_tensor = [ + weight_tensor[start_index:end_index, ...] + for start_index, end_index in shard_weights_index + ] + weight_sub_tensor = torch.cat(weight_tensor, dim=0) + # Shard by row + else: + total_size = weight_tensor.size(0) + start_index = total_size // tp_size * tp_rank + end_index = total_size // tp_size * (tp_rank + 1) + weight_sub_tensor = weight_tensor[start_index:end_index, + ...] + + # bitsandbytes requires data in GPU + if weight_sub_tensor.is_cuda: + loaded_weight = weight_sub_tensor + else: + loaded_weight = weight_sub_tensor.cuda() + + # remove the following after the issue is fixed: + # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342 + if loaded_weight.is_contiguous() is False: + loaded_weight = loaded_weight.contiguous() + + with set_default_torch_dtype(torch.float32): + processed_weight, quant_state = quantize_4bit( + loaded_weight, + compress_statistics=True, + quant_type="nf4", + ) + + quant_state_dict[mapped_weight_name] = quant_state + else: + processed_weight = weight_tensor + yield org_weight_name, processed_weight + + def _get_bnb_target_modules(self, model: nn.Module) -> None: + + for name, module in model.named_modules(): + if isinstance(module, (LinearBase, )): + if modules_info := self.modules_mapping.get_sub_modules(name): + # Map vllm's names to transformers's names. + rep_name, sub_modules = modules_info + for sub_name in sub_modules: + self.target_modules.append( + name.replace(rep_name, sub_name)) + # Add original module name even if the module has stacked map, + # in case model has a mixture of disk-merged and disk-splitted + # weights with same last name. + self.target_modules.append(name) + + assert (self.target_modules + ), "vllm currently does not support BNB quantization for" + f" {type(model).__name__}" + + def _load_weights(self, model_config: ModelConfig, + model: nn.Module) -> None: + if not hasattr(model, "load_weights"): + raise AttributeError( + "The required method 'load_weights' is not defined in class" + f" {type(model).__name__}.") + + if not hasattr(model, "packed_modules_mapping"): + raise AttributeError( + f"Model {type(model).__name__} does not support BitsAndBytes " + "quantization yet. No 'packed_modules_mapping' found.") + + self.modules_mapping = ParamMapping( + copy.deepcopy(model.packed_modules_mapping)) + + # For some models like Molmo, we need to use hf_to_vllm_mapper + # to ensure correct loading of weights. + if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None): + self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name) + + # Modules whose weights might have fused on disk + # we need their output_sizes to make shard in flight correctly with TP + self.maybe_fused_weights_modules: Dict[str, List[int]] = {} + self._get_bnb_target_modules(model) + for name, module in model.named_modules(): + # Some modules like `ReplicatedLinear` should not have their weights + # sharded. The reason for implementing it this way is to avoid new + # static variable in the model implementation. + if isinstance(module, (ReplicatedLinear, )): + self.unsharded_weights_modules.append(name) + # `QKVParallelLinear` and `MergedColumnParallelLinear` might have + # fused weights on disk. We need to use the output sizes of these + # modules to shard the weights correctly. + elif isinstance(module, + (QKVParallelLinear, MergedColumnParallelLinear)): + self.maybe_fused_weights_modules[name] = module.output_sizes + # In TP, these weights are partitioned along the column + # dimension (dim=-1) + elif isinstance(module, (RowParallelLinear, )): + self.column_sharded_weights_modules.append(name) + + self.model_type = type(model).__name__ + + logger.info("Loading weights with BitsAndBytes quantization. " + "May take a while ...") + + quant_config = getattr(model_config.hf_config, "quantization_config", + None) + + pre_quant = False + if quant_config is not None: + quant_method = quant_config.get("quant_method") + if quant_method == "bitsandbytes": + pre_quant = True + else: + raise ValueError( + f"BitsAndBytes loader does not support {quant_method} " + "quantization") + + # The quant_states in pre_quantized models cannot work with a split + # weight tensor. So TP does not work with pre_quantized bnb models. + if pre_quant and get_tensor_model_parallel_world_size() > 1: + raise ValueError( + "Prequant BitsAndBytes models with tensor parallelism is not " + "supported. Please try with pipeline parallelism.") + + load_8bit = False + if pre_quant: + load_8bit = quant_config.get("load_in_8bit", False) + + qweight_iterator, quant_state_dict = ( + self._get_quantized_weights_iterator(model_config.model, + model_config.revision, + pre_quant, load_8bit)) + + weights_to_load = {name for name, _ in model.named_parameters()} + loaded_weights = model.load_weights(qweight_iterator) + # Some models may have weights loading tracker unimplemented. + if loaded_weights is not None: + weights_not_loaded = weights_to_load - loaded_weights + if weights_not_loaded: + raise ValueError("Following weights were not initialized from " + f"checkpoint: {weights_not_loaded}") + + torch.cuda.empty_cache() + + param_dict = dict(model.named_parameters()) + stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {} + # TODO: Change this lazy import to normal import + # after the checks are updated to run on a new version + from vllm.model_executor.models.utils import is_pp_missing_parameter + + for quant_param_name in quant_state_dict: + if is_pp_missing_parameter(quant_param_name, model): + continue + + non_stacked_param_name = quant_param_name + + shard_index = 0 + for shard_name, ( + weight_name, + index, + ) in self.modules_mapping.inverse_packed_mapping.items(): + # Some models, such as MiniCPM V2.5/2.6, contain both + # module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj' + # from being incorrectly identified as being present in + # 'vpm.encoder.layers.0.self_attn.qkv_proj.weight + shard_pos = quant_param_name.find(shard_name) + can_correct_rename = (shard_pos + > 0) and (quant_param_name[shard_pos - 1] + == ".") + # If the quant_param_name is packed, it won't occur in the + # param_dict before renaming. + new_quant_param_name = quant_param_name.replace( + shard_name, weight_name) + need_rename = (quant_param_name not in param_dict) \ + and (new_quant_param_name in param_dict) + if can_correct_rename and need_rename: + shard_index = index + quant_param_name = new_quant_param_name + break + + # Models like Clip/Siglip may skip some layers in initialization, + # causing unused quant_param_name in state_dict. + if quant_param_name not in param_dict: + continue + + if quant_param_name not in stacked_quant_state_dict: + stacked_quant_state_dict[quant_param_name] = {} + + stacked_quant_state_dict[quant_param_name][shard_index] = ( + quant_state_dict[non_stacked_param_name]) + + # save quant_states and offsets as the attributes of the parameters + for param_name, param in param_dict.items(): + if param_name in stacked_quant_state_dict: + quant_states = stacked_quant_state_dict[param_name] + set_weight_attrs(param, {"bnb_quant_state": quant_states}) + + pack_ratio = getattr(param, "pack_factor", -1) + if pack_ratio == -1: + raise ValueError( + f"pack_factor not set for parameter {param_name}.") + + num_elements = [0] * len(quant_states) + for seq, quant_state in quant_states.items(): + num_elements[seq] = (math.prod(quant_state.shape) // + pack_ratio) + + offsets = np.concatenate(([0], np.cumsum(num_elements))) + # Make torch infer_schema happy + offsets = torch.tensor(offsets).cpu() + set_weight_attrs(param, {"bnb_shard_offsets": offsets}) + + if load_8bit: + set_weight_attrs( + param, {"matmul_state": [None] * len(quant_states)}) + + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model, model_config.revision) + + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + device_config = vllm_config.device_config + model_config = vllm_config.model_config + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(vllm_config=vllm_config) + + self._load_weights(model_config, model) + + return model.eval() + + +class GGUFModelLoader(BaseModelLoader): + """ + Model loader that can load GGUF files. This is useful for loading models + that are quantized with GGUF and saved in the GGUF format. This loader + supports loading both full models and sharded models. + """ + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def _prepare_weights(self, model_name_or_path: str): + if os.path.isfile(model_name_or_path): + return model_name_or_path + else: + guff_files = glob.glob(os.path.join(model_name_or_path, "*.gguf")) + if not guff_files: + raise ValueError( + f"Could not find checkpoint files *.gguf', only gguf files are currently supported!") + return guff_files + + def _get_gguf_weights_map(self, model_config: ModelConfig): + """ + GGUF uses this naming convention for their tensors from HF checkpoint: + `blk.N.BB.weight` and `blk.N.BB.bias` + where N signifies the block number of a layer, and BB signifies the + attention/mlp layer components. + See "Standardized tensor names" in + https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details. + """ + config = model_config.hf_config + model_type = config.model_type + gguf_to_hf_name_map = {} + # hack: ggufs have a different name than transformers + if model_type == "cohere": + model_type = "command-r" + if model_type in ("deepseek_v3", "deepseek_v2"): + model_type = "deepseek2" + # GGUF layer map assumes that we will have a merged expert weights + # so we need to map them manually + for idx in range(config.num_hidden_layers): + gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = \ + f"model.layers.{idx}.mlp.gate.e_score_correction_bias" + gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = \ + f"model.layers.{idx}.mlp.experts.0.down_proj.weight" + gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = \ + f"model.layers.{idx}.mlp.experts.0.gate_proj.weight" + gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = \ + f"model.layers.{idx}.mlp.experts.0.up_proj.weight" + + arch = None + for key, value in gguf.MODEL_ARCH_NAMES.items(): + if value == model_type: + arch = key + break + if arch is None: + raise RuntimeError(f"Unknown gguf model_type: {model_type}") + num_layers = config.num_hidden_layers + name_map = gguf.get_tensor_name_map(arch, num_layers) + with torch.device("meta"): + dummy_model = AutoModelForCausalLM.from_config( + config, trust_remote_code=model_config.trust_remote_code) + state_dict = dummy_model.state_dict() + + for hf_name in state_dict: + name, suffix = hf_name.rsplit(".", 1) + gguf_name = name_map.get_name(name) + gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name + return gguf_to_hf_name_map + + def _get_weights_iterator( + self, model_name_or_path: str, gguf_to_hf_name_map: Dict[str, str] + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + return gguf_quant_weights_iterator(model_name_or_path, + gguf_to_hf_name_map) + + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model) + + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + device_config = vllm_config.device_config + model_config = vllm_config.model_config + local_model_path = self._prepare_weights(model_config.model) + gguf_weights_map = self._get_gguf_weights_map(model_config) + gguf_extra_tensor_names = [] + local_model_path = [os.path.join(model_config.model, x) for x in local_model_path] + for file in local_model_path: + gguf_extra_tensor_names.extend(get_gguf_extra_tensor_names(file, gguf_weights_map)) + # we can only know if tie word embeddings after mapping weights + if "lm_head.weight" in gguf_extra_tensor_names: + model_config.hf_config.update({"tie_word_embeddings": True}) + + target_device = torch.device(device_config.device) + with set_default_torch_dtype(model_config.dtype): + with target_device: + model = _initialize_model(vllm_config=vllm_config) + + all_tensors = [] + all_reader = [] + for gguf_file in local_model_path: + reader = gguf.GGUFReader(gguf_file) + all_tensors.extend(reader.tensors) + all_reader.append(reader) + + def all_gguf_weights(): + for tensor in all_tensors: + if tensor.name in gguf_weights_map: + weight_type = tensor.tensor_type + name = gguf_weights_map[tensor.name] + + if weight_type.name != "F32": + weight_type_name = name.replace("weight", "qweight_type") + weight_type = torch.tensor(weight_type) + yield weight_type_name, weight_type + + for tensor in all_tensors: + if tensor.name in gguf_weights_map: + weight = tensor.data + weight_type = tensor.tensor_type + name = gguf_weights_map[tensor.name] + + if weight_type.name != "F32": + name = name.replace("weight", "qweight") + param = torch.tensor(weight) + + yield name, param + + model.load_weights(all_gguf_weights()) + + setattr(model, "reader_for_ref", all_reader) + from tqdm import tqdm + for _, module in tqdm(model.named_modules()): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None and hasattr(quant_method, "process_weights_after_loading"): + if module.__class__.__name__ == "FusedMoE": + quant_method.process_weights_after_loading(module, all_tensors, gguf_weights_map) + else: + quant_method.process_weights_after_loading(module) + + if isinstance(module, Attention) and \ + hasattr(module, "process_weights_after_loading"): + # When attention modules need to process weights after + # currently only used by MLA + # TODO(lucas): see if there is a way to unify the signatures + # of process_weights_after_loading + module.process_weights_after_loading(model_config.dtype) + + torch.cuda.empty_cache() + + return model + + +class RunaiModelStreamerLoader(BaseModelLoader): + """ + Model loader that can load safetensors + files from local FS or S3 bucket. + """ + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + extra_config = load_config.model_loader_extra_config + + if ("concurrency" in extra_config + and isinstance(extra_config.get("concurrency"), int)): + os.environ["RUNAI_STREAMER_CONCURRENCY"] = str( + extra_config.get("concurrency")) + + if ("memory_limit" in extra_config + and isinstance(extra_config.get("memory_limit"), int)): + os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str( + extra_config.get("memory_limit")) + + runai_streamer_s3_endpoint = os.getenv( + 'RUNAI_STREAMER_S3_ENDPOINT') + aws_endpoint_url = os.getenv('AWS_ENDPOINT_URL') + if (runai_streamer_s3_endpoint is None + and aws_endpoint_url is not None): + os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url + + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str]) -> List[str]: + """Prepare weights for the model. + + If the model is not local, it will be downloaded.""" + + is_s3_path = is_s3(model_name_or_path) + is_local = os.path.isdir(model_name_or_path) + safetensors_pattern = "*.safetensors" + index_file = SAFE_WEIGHTS_INDEX_NAME + + hf_folder = (model_name_or_path if + (is_local or is_s3_path) else download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + [safetensors_pattern], + revision, + ignore_patterns=self.load_config.ignore_patterns, + )) + if is_s3_path: + hf_weights_files = s3_glob(path=hf_folder, + allow_pattern=[safetensors_pattern]) + else: + hf_weights_files = glob.glob( + os.path.join(hf_folder, safetensors_pattern)) + + if not is_local and not is_s3_path: + download_safetensors_index_file_from_hf( + model_name_or_path, index_file, self.load_config.download_dir, + revision) + + if not hf_weights_files: + raise RuntimeError( + f"Cannot find any safetensors model weights with " + f"`{model_name_or_path}`") + + return hf_weights_files + + def _get_weights_iterator( + self, model_or_path: str, + revision: str) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Get an iterator for the model weights based on the load format.""" + hf_weights_files = self._prepare_weights(model_or_path, revision) + return runai_safetensors_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + ) + + def download_model(self, model_config: ModelConfig) -> None: + """Download model if necessary""" + self._prepare_weights(model_config.model, model_config.revision) + + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + """Perform streaming of the model to destination""" + device_config = vllm_config.device_config + model_config = vllm_config.model_config + + target_device = torch.device(device_config.device) + with set_default_torch_dtype(model_config.dtype): + with target_device: + model = _initialize_model(vllm_config=vllm_config) + + model_weights = model_config.model + if hasattr(model_config, "model_weights"): + model_weights = model_config.model_weights + model.load_weights( + self._get_weights_iterator(model_weights, + model_config.revision)) + + _process_weights_after_loading(model, model_config, target_device) + return model.eval() + + +def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: + """Get a model loader based on the load format.""" + if isinstance(load_config.load_format, type): + return load_config.load_format(load_config) + + if load_config.load_format == LoadFormat.DUMMY: + return DummyModelLoader(load_config) + + if load_config.load_format == LoadFormat.TENSORIZER: + return TensorizerLoader(load_config) + + if load_config.load_format == LoadFormat.SHARDED_STATE: + return ShardedStateLoader(load_config) + + if load_config.load_format == LoadFormat.BITSANDBYTES: + return BitsAndBytesModelLoader(load_config) + + if load_config.load_format == LoadFormat.GGUF: + return GGUFModelLoader(load_config) + + if load_config.load_format == LoadFormat.RUNAI_STREAMER: + return RunaiModelStreamerLoader(load_config) + + return DefaultModelLoader(load_config) diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py new file mode 100644 index 0000000..d900fb3 --- /dev/null +++ b/vllm/model_executor/model_loader/neuron.py @@ -0,0 +1,212 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Utilities for selecting and loading neuron models.""" +import copy +import importlib +import os +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import get_quantization_config +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, + SequenceOutput) + +TORCH_DTYPE_TO_NEURON_AMP = { + "auto": "f32", + "half": "f16", + "float16": "f16", + "bfloat16": "bf16", + "float": "f32", + "float32": "f32", + torch.float16: "f16", + torch.bfloat16: "bf16", + torch.float32: "f32", +} + +# Models supported by Neuron. +_NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = { + "LlamaForCausalLM": ("transformers_neuronx.llama.model", + "LlamaForSampling", "LlamaForCausalLM"), + "MistralForCausalLM": ("transformers_neuronx.mistral.model", + "MistralForSampling", "MistralForCausalLM") +} + + +class NeuronCausalLM(nn.Module): + + def __init__(self, + config: PretrainedConfig, + on_device_sampling_disabled: bool = False) -> None: + super().__init__() + self.config = config + self.logits_processor = LogitsProcessor(config.vocab_size, + logits_as_input=True) + + self.on_device_sampling_disabled = on_device_sampling_disabled + if self.on_device_sampling_disabled: + # Use default sampler + self.sampler = Sampler() + + # Lazy initialized + self.model: nn.Module + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_block_ids: torch.Tensor, + ) -> torch.Tensor: + logits = self.model(input_ids, + cache_ids=positions, + start_ids=input_block_ids) + return logits + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(None, hidden_states, sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + + if self.on_device_sampling_disabled: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + # On-device sampling outputs the token ids directly. + sampled_token_ids = logits.flatten() + next_tokens = [] + sample_idx = 0 + for seq_group in sampling_metadata.seq_groups: + samples = [] + for seq_id in seq_group.seq_ids: + token_id = sampled_token_ids[sample_idx].item() + samples.append( + SequenceOutput(parent_seq_id=seq_id, + output_token=token_id, + logprobs={token_id: Logprob(token_id)})) + sample_idx += 1 + next_tokens.append( + CompletionSequenceGroupOutput(samples=samples, + prompt_logprobs=None)) + + return SamplerOutput(outputs=next_tokens) + + def load_weights(self, model_name_or_path: str, **kwargs): + arch = _get_model_architecture(self.config) + neuronx_module_path, neuronx_model_cls_name, hf_model_cls_name = ( + _NEURON_SUPPORTED_MODELS[arch]) + neuronx_module = importlib.import_module(neuronx_module_path) + neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) + + self.model = neuronx_model_cls.from_pretrained(model_name_or_path, + **kwargs) + self.model.to_neuron() + + +def _get_model_architecture(config: PretrainedConfig) -> str: + architectures = getattr(config, "architectures", []) + for arch in architectures: + if arch in _NEURON_SUPPORTED_MODELS: + return arch + raise ValueError( + f"Model architectures {architectures} are not supported on Neuron " + f"for now. Supported architectures: " + f"{list(_NEURON_SUPPORTED_MODELS.keys())}") + + +def _get_buckets(env: str, default_value: List[int]) -> List[int]: + env_value = os.getenv(env) + if env_value is None: + return default_value + buckets_remove_empty = filter( + lambda x: x is not None and len(x.strip()) > 0, env_value.split(",")) + buckets_int = map(int, buckets_remove_empty) + buckets_list = list(buckets_int) + return buckets_list + + +def _get_default_neuron_config(model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig): + from transformers_neuronx.config import ContinuousBatchingConfig + from transformers_neuronx.constants import LAYOUT_BSH + + continuous_batching_config = ContinuousBatchingConfig( + batch_size_for_shared_caches=scheduler_config.max_num_seqs) + quant_config = dict( + dequant_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], + quantize_method="vector_dynamic") + neuron_quantization_config_builder = lambda quant: get_quantization_config( + quant).from_config(quant_config).get_quant_method(None, "") + # TODO: Add Paged attention config to the default neuron arguments. + default_neuron_args = dict( + collectives_layout=LAYOUT_BSH, + attention_layout=LAYOUT_BSH, + fuse_qkv=True, + quant=neuron_quantization_config_builder(model_config.quantization) + if model_config.quantization else None, + continuous_batching=continuous_batching_config, + weight_tiling=bool(model_config.quantization), + on_device_generation=_get_neuron_on_device_generation_config( + model_config)) + return default_neuron_args + + +def _get_neuron_on_device_generation_config(model_config: ModelConfig): + if not _is_neuron_on_device_sampling_disabled(model_config): + return copy.deepcopy(model_config.neuron_sampling_params) + return None + + +def _is_neuron_on_device_sampling_disabled(model_config: ModelConfig) -> bool: + return not getattr(model_config, "neuron_sampling_params", None) + + +def _get_neuron_config_after_override(default_neuron_config, + overridden_neuron_config): + from transformers_neuronx.config import NeuronConfig + overridden_neuron_config = overridden_neuron_config or {} + default_neuron_config.update(overridden_neuron_config) + return NeuronConfig(**default_neuron_config) + + +def get_neuron_model(model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig) -> nn.Module: + + # Create a model instance. + model = NeuronCausalLM( + model_config.hf_config, + _is_neuron_on_device_sampling_disabled(model_config)) + + default_neuron_config_args = _get_default_neuron_config( + model_config, parallel_config, scheduler_config) + + neuron_config = _get_neuron_config_after_override( + default_neuron_config_args, model_config.override_neuron_config) + + context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS", + [scheduler_config.max_model_len]) + n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS", + [scheduler_config.max_model_len]) + + # Load the weights from the cached or downloaded files. + model.load_weights(model_config.model, + tp_degree=parallel_config.tensor_parallel_size, + amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], + neuron_config=neuron_config, + context_length_estimate=context_length_estimates, + n_positions=n_positions, + batch_size=scheduler_config.max_num_seqs) + + return model.eval() diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py new file mode 100644 index 0000000..117251c --- /dev/null +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -0,0 +1,468 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import dataclasses +import io +import os +import re +import time +from dataclasses import dataclass +from functools import partial +from typing import BinaryIO, Generator, Optional, Tuple, Type, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +import vllm.envs as envs +from vllm.config import ModelConfig, ParallelConfig, set_current_vllm_config +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.llm_engine import LLMEngine +from vllm.logger import init_logger +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.utils import FlexibleArgumentParser, PlaceholderModule + +try: + from tensorizer import (DecryptionParams, EncryptionParams, + TensorDeserializer, TensorSerializer) + from tensorizer.stream_io import open_stream + from tensorizer.utils import (convert_bytes, get_mem_usage, + no_init_or_tensor) + + _read_stream, _write_stream = (partial( + open_stream, + mode=mode, + ) for mode in ("rb", "wb+")) +except ImportError: + tensorizer = PlaceholderModule("tensorizer") + DecryptionParams = tensorizer.placeholder_attr("DecryptionParams") + EncryptionParams = tensorizer.placeholder_attr("EncryptionParams") + TensorDeserializer = tensorizer.placeholder_attr("TensorDeserializer") + TensorSerializer = tensorizer.placeholder_attr("TensorSerializer") + open_stream = tensorizer.placeholder_attr("stream_io.open_stream") + convert_bytes = tensorizer.placeholder_attr("utils.convert_bytes") + get_mem_usage = tensorizer.placeholder_attr("utils.get_mem_usage") + no_init_or_tensor = tensorizer.placeholder_attr("utils.no_init_or_tensor") + + _read_stream = tensorizer.placeholder_attr("_read_stream") + _write_stream = tensorizer.placeholder_attr("_write_stream") + +__all__ = [ + 'EncryptionParams', 'DecryptionParams', 'TensorDeserializer', + 'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage', + 'no_init_or_tensor', 'TensorizerConfig' +] + +logger = init_logger(__name__) + + +@dataclass +class TensorizerConfig: + tensorizer_uri: str + vllm_tensorized: Optional[bool] = False + verify_hash: Optional[bool] = False + num_readers: Optional[int] = None + encryption_keyfile: Optional[str] = None + s3_access_key_id: Optional[str] = None + s3_secret_access_key: Optional[str] = None + s3_endpoint: Optional[str] = None + model_class: Optional[Type[torch.nn.Module]] = None + hf_config: Optional[PretrainedConfig] = None + dtype: Optional[Union[str, torch.dtype]] = None + _is_sharded: bool = False + + def __post_init__(self): + # check if the configuration is for a sharded vLLM model + self._is_sharded = isinstance(self.tensorizer_uri, str) \ + and re.search(r'%0\dd', self.tensorizer_uri) is not None + + def _construct_tensorizer_args(self) -> "TensorizerArgs": + tensorizer_args = { + "tensorizer_uri": self.tensorizer_uri, + "vllm_tensorized": self.vllm_tensorized, + "verify_hash": self.verify_hash, + "num_readers": self.num_readers, + "encryption_keyfile": self.encryption_keyfile, + "s3_access_key_id": self.s3_access_key_id, + "s3_secret_access_key": self.s3_secret_access_key, + "s3_endpoint": self.s3_endpoint, + } + return TensorizerArgs(**tensorizer_args) # type: ignore + + def verify_with_parallel_config( + self, + parallel_config: "ParallelConfig", + ) -> None: + if parallel_config.tensor_parallel_size > 1 \ + and not self._is_sharded: + raise ValueError( + "For a sharded model, tensorizer_uri should include a" + " string format template like '%04d' to be formatted" + " with the rank of the shard") + + def verify_with_model_config(self, model_config: "ModelConfig") -> None: + if (model_config.quantization is not None + and self.tensorizer_uri is not None): + logger.warning( + "Loading a model using Tensorizer with quantization on vLLM" + " is unstable and may lead to errors.") + + def open_stream(self, tensorizer_args: Optional["TensorizerArgs"] = None): + if tensorizer_args is None: + tensorizer_args = self._construct_tensorizer_args() + + return open_stream(self.tensorizer_uri, + **tensorizer_args.stream_params) + + +def load_with_tensorizer(tensorizer_config: TensorizerConfig, + **extra_kwargs) -> nn.Module: + tensorizer = TensorizerAgent(tensorizer_config, **extra_kwargs) + return tensorizer.deserialize() + + +@dataclass +class TensorizerArgs: + tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str, + bytes, os.PathLike, int] + vllm_tensorized: Optional[bool] = False + verify_hash: Optional[bool] = False + num_readers: Optional[int] = None + encryption_keyfile: Optional[str] = None + s3_access_key_id: Optional[str] = None + s3_secret_access_key: Optional[str] = None + s3_endpoint: Optional[str] = None + """ + Args for the TensorizerAgent class. These are used to configure the behavior + of the TensorDeserializer when loading tensors from a serialized model. + + Args: + tensorizer_uri: Path to serialized model tensors. Can be a local file + path or a S3 URI. + vllm_tensorized: If True, indicates that the serialized model is a + vLLM model. This is used to determine the behavior of the + TensorDeserializer when loading tensors from a serialized model. + It is far faster to deserialize a vLLM model as it utilizes + tensorizer's optimized GPU loading. Note that this is now + deprecated, as serialized vLLM models are now automatically + inferred as vLLM models. + verify_hash: If True, the hashes of each tensor will be verified against + the hashes stored in the metadata. A `HashMismatchError` will be + raised if any of the hashes do not match. + num_readers: Controls how many threads are allowed to read concurrently + from the source file. Default is `None`, which will dynamically set + the number of readers based on the number of available + resources and model size. This greatly increases performance. + encryption_keyfile: File path to a binary file containing a + binary key to use for decryption. `None` (the default) means + no decryption. See the example script in + examples/other/tensorize_vllm_model.py. + s3_access_key_id: The access key for the S3 bucket. Can also be set via + the S3_ACCESS_KEY_ID environment variable. + s3_secret_access_key: The secret access key for the S3 bucket. Can also + be set via the S3_SECRET_ACCESS_KEY environment variable. + s3_endpoint: The endpoint for the S3 bucket. Can also be set via the + S3_ENDPOINT_URL environment variable. + """ + + def __post_init__(self): + self.file_obj = self.tensorizer_uri + self.s3_access_key_id = self.s3_access_key_id or envs.S3_ACCESS_KEY_ID + self.s3_secret_access_key = (self.s3_secret_access_key + or envs.S3_SECRET_ACCESS_KEY) + self.s3_endpoint = self.s3_endpoint or envs.S3_ENDPOINT_URL + self.stream_params = { + "s3_access_key_id": self.s3_access_key_id, + "s3_secret_access_key": self.s3_secret_access_key, + "s3_endpoint": self.s3_endpoint, + } + + self.deserializer_params = { + "verify_hash": self.verify_hash, + "encryption": self.encryption_keyfile, + "num_readers": self.num_readers + } + + if self.encryption_keyfile: + with open_stream( + self.encryption_keyfile, + **self.stream_params, + ) as stream: + key = stream.read() + decryption_params = DecryptionParams.from_key(key) + self.deserializer_params['encryption'] = decryption_params + + @staticmethod + def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + """Tensorizer CLI arguments""" + + # Tensorizer options arg group + group = parser.add_argument_group( + 'tensorizer options', + description=('Options for configuring the behavior of the' + ' tensorizer deserializer when ' + 'load_format=tensorizer is specified when ' + 'initializing an LLMEngine, either via the CLI ' + 'when running the vLLM OpenAI inference server ' + 'with a JSON string passed to ' + '--model-loader-extra-config or as arguments given ' + 'to TensorizerConfig when passed to ' + 'model_loader_extra_config in the constructor ' + 'for LLMEngine.')) + + group.add_argument( + "--tensorizer-uri", + help="Path to serialized model tensors. Can be a local file path," + " or an HTTP(S) or S3 URI.", + ) + group.add_argument( + "--verify-hash", + action="store_true", + help="If enabled, the hashes of each tensor will be verified" + " against the hashes stored in the file metadata. An exception" + " will be raised if any of the hashes do not match.", + ) + group.add_argument( + "--encryption-keyfile", + default=None, + help="The file path to a binary file containing a binary key to " + "use for decryption. Can be a file path or S3 network URI.") + group.add_argument( + "--num-readers", + default=None, + type=int, + help="Controls how many threads are allowed to read concurrently " + "from the source file. Default is `None`, which will dynamically " + "set the number of readers based on the available resources " + "and model size. This greatly increases performance.") + group.add_argument( + "--s3-access-key-id", + default=None, + help="The access key for the S3 bucket. Can also be set via the " + "S3_ACCESS_KEY_ID environment variable.", + ) + group.add_argument( + "--s3-secret-access-key", + default=None, + help="The secret access key for the S3 bucket. Can also be set via " + "the S3_SECRET_ACCESS_KEY environment variable.", + ) + group.add_argument( + "--s3-endpoint", + default=None, + help="The endpoint for the S3 bucket. Can also be set via the " + "S3_ENDPOINT_URL environment variable.", + ) + + return parser + + @classmethod + def from_cli_args(cls, args: argparse.Namespace) -> "TensorizerArgs": + attrs = [attr.name for attr in dataclasses.fields(cls)] + tensorizer_args = cls(**{ + attr: getattr(args, attr) + for attr in attrs if hasattr(args, attr) + }) + return tensorizer_args + + +class TensorizerAgent: + """ + A class for performing tensorizer deserializations specifically for + vLLM models using plaid_mode. Uses TensorizerArgs to configure the + behavior of the TensorDeserializer when loading tensors from a serialized + model. For deserializations of HuggingFace models, TensorDeserializer is + instead used as an iterator directly in the func hf_model_weights_iterator + in vllm/model_executor/model_loader/weight_utils.py + """ + + def __init__(self, tensorizer_config: TensorizerConfig, vllm_config): + self.tensorizer_config = tensorizer_config + self.tensorizer_args = ( + self.tensorizer_config._construct_tensorizer_args()) + self.vllm_config = vllm_config + self.model = self._init_model() + + def _init_model(self): + assert self.tensorizer_config.hf_config is not None + model_args = self.tensorizer_config.hf_config + model_args.torch_dtype = self.tensorizer_config.dtype + assert self.tensorizer_config.model_class is not None + # TODO: Do we need to consider old-style model class? + with no_init_or_tensor(), set_current_vllm_config(self.vllm_config, + check_compile=True): + return self.tensorizer_config.model_class( + vllm_config=self.vllm_config, ) + + def _resize_lora_embeddings(self): + """Modify LoRA embedding layers to use bigger tensors + to allow for adapter added tokens.""" + for child in self.model.modules(): + if (isinstance(child, VocabParallelEmbedding) + and child.weight.shape[0] + < child.num_embeddings_per_partition): + new_weight = torch.empty(child.num_embeddings_per_partition, + child.embedding_dim, + dtype=child.weight.dtype, + device=child.weight.device) + new_weight[:child.weight.shape[0]].copy_(child.weight.data) + new_weight[child.weight.shape[0]:].fill_(0) + child.weight.data = new_weight + + def _check_tensors_on_meta_device(self): + for tensor in self.model.state_dict().values(): + if tensor.device.type == 'meta': + raise ValueError( + "The serialized model contains tensors on the meta device," + " indicating that some tensors were not loaded properly." + " Please check that the parameters of the model being" + " specified match that of the serialized model, such as" + " its quantization.") + + def deserialize(self): + """ + Deserialize the model using the TensorDeserializer. This method is + specifically for vLLM models using tensorizer's plaid_mode. + + The deserializer makes use of tensorizer_args.stream_params + to configure the behavior of the stream when loading tensors from a + serialized model. The deserializer_params are used to configure the + behavior of the TensorDeserializer when loading tensors themselves. + Documentation on these params can be found in TensorizerArgs + + Returns: + nn.Module: The deserialized model. + """ + before_mem = get_mem_usage() + start = time.perf_counter() + with _read_stream( + self.tensorizer_config.tensorizer_uri, + **self.tensorizer_args.stream_params + ) as stream, TensorDeserializer( + stream, + dtype=self.tensorizer_config.dtype, + device=f'cuda:{torch.cuda.current_device()}', + **self.tensorizer_args.deserializer_params) as deserializer: + deserializer.load_into_module(self.model) + end = time.perf_counter() + + total_bytes_str = convert_bytes(deserializer.total_tensor_bytes) + duration = end - start + per_second = convert_bytes(deserializer.total_tensor_bytes / duration) + after_mem = get_mem_usage() + deserializer.close() + logger.info("Deserialized %s in %0.2fs, %s/s", total_bytes_str, + end - start, per_second) + logger.info("Memory usage before: %s", before_mem) + logger.info("Memory usage after: %s", after_mem) + + self._check_tensors_on_meta_device() + self._resize_lora_embeddings() + del self.model.vllm_tensorized_marker + return self.model.eval() + + +def tensorizer_weights_iterator( + tensorizer_args: "TensorizerArgs" +) -> Generator[Tuple[str, torch.Tensor], None, None]: + logger.warning("Deserializing HuggingFace models is not optimized for " + "loading on vLLM, as tensorizer is forced to load to CPU. " + "Consider deserializing a vLLM model instead for faster " + "load times. See the " + "examples/other/tensorize_vllm_model.py example script " + "for serializing vLLM models.") + + deserializer_args = tensorizer_args.deserializer_params + stream_params = tensorizer_args.stream_params + stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params) + with TensorDeserializer(stream, **deserializer_args, + device="cpu") as state: + yield from state.items() + del state + + +def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool: + """ + Infer if the model is a vLLM model by checking the weights for + a vLLM tensorized marker. + + Args: + tensorizer_config: The TensorizerConfig object containing the + tensorizer_uri to the serialized model. + + Returns: + bool: True if the model is a vLLM model, False otherwise. + """ + tensorizer_args = tensorizer_config._construct_tensorizer_args() + deserializer = TensorDeserializer(open_stream( + tensorizer_args.tensorizer_uri, **tensorizer_args.stream_params), + **tensorizer_args.deserializer_params, + lazy_load=True) + if tensorizer_config.vllm_tensorized: + logger.warning( + "Please note that newly serialized vLLM models are automatically " + "inferred as vLLM models, so setting vllm_tensorized=True is " + "only necessary for models serialized prior to this change.") + return True + return ".vllm_tensorized_marker" in deserializer + + +def serialize_vllm_model( + model: nn.Module, + tensorizer_config: TensorizerConfig, +) -> nn.Module: + model.register_parameter( + "vllm_tensorized_marker", + nn.Parameter(torch.tensor((1, ), device="meta"), requires_grad=False)) + tensorizer_args = tensorizer_config._construct_tensorizer_args() + + encryption_params = None + if (keyfile := tensorizer_config.encryption_keyfile) is not None: + with open(keyfile, "rb") as f: + key = f.read() + encryption_params = EncryptionParams(key=key) + + output_file = tensorizer_args.tensorizer_uri + if tensorizer_config._is_sharded: + from vllm.distributed import get_tensor_model_parallel_rank + output_file = output_file % get_tensor_model_parallel_rank() + + with _write_stream(output_file, **tensorizer_args.stream_params) as stream: + serializer = TensorSerializer(stream, encryption=encryption_params) + serializer.write_module(model) + serializer.close() + logger.info("Successfully serialized model to %s", str(output_file)) + return model + + +def tensorize_vllm_model(engine_args: EngineArgs, + tensorizer_config: TensorizerConfig, + generate_keyfile: bool = True): + """Utility to load a model and then serialize it with Tensorizer + + Intended to be used separately from running a vLLM server since it + creates its own Engine instance. + """ + engine_config = engine_args.create_engine_config() + tensorizer_config.verify_with_model_config(engine_config.model_config) + tensorizer_config.verify_with_parallel_config( + engine_config.parallel_config) + + # generate the encryption key before creating the engine to support sharding + if generate_keyfile and (keyfile := + tensorizer_config.encryption_keyfile) is not None: + encryption_params = EncryptionParams.random() + with _write_stream( + keyfile, + s3_access_key_id=tensorizer_config.s3_access_key_id, + s3_secret_access_key=tensorizer_config.s3_secret_access_key, + s3_endpoint=tensorizer_config.s3_endpoint, + ) as stream: + stream.write(encryption_params.key) + + engine = LLMEngine.from_engine_args(engine_args) + engine.model_executor.collective_rpc( + "save_tensorized_model", + kwargs=dict(tensorizer_config=tensorizer_config), + ) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py new file mode 100644 index 0000000..15f37aa --- /dev/null +++ b/vllm/model_executor/model_loader/utils.py @@ -0,0 +1,168 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Utilities for selecting and loading models.""" +import contextlib +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple, Type + +import torch +import transformers +from torch import nn +from transformers.dynamic_module_utils import get_class_from_dynamic_module + +from vllm.config import ModelConfig, ModelImpl +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.models.adapters import (as_classification_model, + as_embedding_model, + as_reward_model) + +logger = init_logger(__name__) + + +@contextlib.contextmanager +def set_default_torch_dtype(dtype: torch.dtype): + """Sets the default torch dtype to the given dtype.""" + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(old_dtype) + + +def is_transformers_impl_compatible( + arch: str, + module: Optional["transformers.PreTrainedModel"] = None) -> bool: + mod = module or getattr(transformers, arch, None) + if mod is None: + return False + return mod.is_backend_compatible() + + +def resolve_transformers_arch(model_config: ModelConfig, + architectures: list[str]): + for i, arch in enumerate(architectures): + if arch == "TransformersForCausalLM": + continue + auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map", + None) or dict() + # Make sure that config class is always initialized before model class, + # otherwise the model class won't be able to access the config class, + # the expected auto_map should have correct order like: + # "auto_map": { + # "AutoConfig": "--", + # "AutoModel": "--", + # "AutoModelFor": "--", + # }, + auto_modules = { + name: get_class_from_dynamic_module(module, model_config.model) + for name, module in sorted(auto_map.items(), key=lambda x: x[0]) + } + custom_model_module = auto_modules.get("AutoModel") + # TODO(Isotr0py): Further clean up these raises. + # perhaps handled them in _ModelRegistry._raise_for_unsupported? + if model_config.model_impl == ModelImpl.TRANSFORMERS: + if not is_transformers_impl_compatible(arch, custom_model_module): + raise ValueError( + f"The Transformers implementation of {arch} is not " + "compatible with vLLM.") + architectures[i] = "TransformersForCausalLM" + if model_config.model_impl == ModelImpl.AUTO: + if not is_transformers_impl_compatible(arch, custom_model_module): + raise ValueError( + f"{arch} has no vLLM implementation and the Transformers " + "implementation is not compatible with vLLM. Try setting " + "VLLM_USE_V1=0.") + logger.warning( + "%s has no vLLM implementation, falling back to Transformers " + "implementation. Some features may not be supported and " + "performance may not be optimal.", arch) + architectures[i] = "TransformersForCausalLM" + return architectures + + +def get_model_architecture( + model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: + architectures = getattr(model_config.hf_config, "architectures", []) + + # Special handling for quantized Mixtral. + # FIXME(woosuk): This is a temporary hack. + mixtral_supported = [ + "fp8", "compressed-tensors", "gptq_marlin", "awq_marlin" + ] + + if (model_config.quantization is not None + and model_config.quantization not in mixtral_supported + and "MixtralForCausalLM" in architectures): + architectures = ["QuantMixtralForCausalLM"] + + vllm_supported_archs = ModelRegistry.get_supported_archs() + is_vllm_supported = any(arch in vllm_supported_archs + for arch in architectures) + if (not is_vllm_supported + or model_config.model_impl == ModelImpl.TRANSFORMERS): + architectures = resolve_transformers_arch(model_config, architectures) + + model_cls, arch = ModelRegistry.resolve_model_cls(architectures) + if model_config.task == "embed": + model_cls = as_embedding_model(model_cls) + elif model_config.task == "classify": + model_cls = as_classification_model(model_cls) + elif model_config.task == "reward": + model_cls = as_reward_model(model_cls) + + return model_cls, arch + + +def get_architecture_class_name(model_config: ModelConfig) -> str: + return get_model_architecture(model_config)[1] + + +@dataclass +class ParamMapping: + """ + A class to handle parameter mapping for model weight loading. + It creates a bidirectional mapping between packed parameters and their + constituent parts. + """ + packed_mapping: Dict[str, List[str]] + inverse_packed_mapping: Dict[str, Tuple[str, + int]] = field(default_factory=dict) + + def __post_init__(self): + for packed_name, sub_params in self.packed_mapping.items(): + # Skip self-contained cases (e.g., {"W_pack": ["W_pack"]}) + if len(sub_params) == 1 and sub_params[0] == packed_name: + continue + for index, param_name in enumerate(sub_params): + self.inverse_packed_mapping[param_name] = ( + packed_name, + index, + ) + + def get_sub_modules(self, + module_name: str) -> Optional[Tuple[str, List[str]]]: + for key, value in self.packed_mapping.items(): + if module_name.endswith(key): + return key, value + return None + + +def configure_quant_config(quant_config: QuantizationConfig, + model_class: Type[nn.Module]): + """ + Pass packed_modules_mapping by reference to quant_config so that + quant_config can properly match fused modules + + Note that model attributes are passed by reference to quant_config, + enabling them to be updated by model_class.__new__ (ex. chatglm, qwen) + """ + packed_mapping = getattr(model_class, "packed_modules_mapping", None) + if packed_mapping is not None: + # pass packed_modules_mapping by reference to quant_config + quant_config.packed_modules_mapping = packed_mapping + else: + logger.warning( + "The model class %s has not defined `packed_modules_mapping`, " + "this may lead to incorrect mapping of quantized or ignored " + "modules", model_class.__name__) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py new file mode 100644 index 0000000..a747594 --- /dev/null +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -0,0 +1,736 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Utilities for downloading and initializing model weights.""" +import fnmatch +import glob +import hashlib +import json +import os +import tempfile +import time +from collections import defaultdict +from pathlib import Path +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union + +import filelock +import gguf +import huggingface_hub.constants +import numpy as np +import torch +from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download +from safetensors.torch import load_file, safe_open, save_file +from tqdm.auto import tqdm + +from vllm.config import LoadConfig, ModelConfig +from vllm.distributed import get_tensor_model_parallel_rank +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import (QuantizationConfig, + get_quantization_config) +from vllm.platforms import current_platform +from vllm.utils import PlaceholderModule + +try: + from runai_model_streamer import SafetensorsStreamer +except (ImportError, OSError): + # see https://github.com/run-ai/runai-model-streamer/issues/26 + # OSError will be raised on arm64 platform + runai_model_streamer = PlaceholderModule( + "runai_model_streamer") # type: ignore[assignment] + SafetensorsStreamer = runai_model_streamer.placeholder_attr( + "SafetensorsStreamer") + +try: + from fastsafetensors import SafeTensorsFileLoader, SingleGroup +except ImportError: + fastsafetensors = PlaceholderModule("fastsafetensors") + SafeTensorsFileLoader = fastsafetensors.placeholder_attr( + "SafeTensorsFileLoader") + SingleGroup = fastsafetensors.placeholder_attr("SingleGroup") + +logger = init_logger(__name__) + +# use system-level temp directory for file locks, so that multiple users +# can share the same lock without error. +# lock files in the temp directory will be automatically deleted when the +# system reboots, so users will not complain about annoying lock files +temp_dir = tempfile.gettempdir() + + +def enable_hf_transfer(): + """automatically activates hf_transfer + """ + if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ: + try: + # enable hf hub transfer if available + import hf_transfer # type: ignore # noqa + huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True + except ImportError: + pass + + +enable_hf_transfer() + + +class DisabledTqdm(tqdm): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, disable=True) + + +def get_lock(model_name_or_path: Union[str, Path], + cache_dir: Optional[str] = None): + lock_dir = cache_dir or temp_dir + model_name_or_path = str(model_name_or_path) + os.makedirs(os.path.dirname(lock_dir), exist_ok=True) + model_name = model_name_or_path.replace("/", "-") + hash_name = hashlib.sha256(model_name.encode()).hexdigest() + # add hash to avoid conflict with old users' lock files + lock_file_name = hash_name + model_name + ".lock" + # mode 0o666 is required for the filelock to be shared across users + lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), + mode=0o666) + return lock + + +def _shared_pointers(tensors): + ptrs = defaultdict(list) + for k, v in tensors.items(): + ptrs[v.data_ptr()].append(k) + failing = [] + for _, names in ptrs.items(): + if len(names) > 1: + failing.append(names) + return failing + + +def convert_bin_to_safetensor_file( + pt_filename: str, + sf_filename: str, +) -> None: + loaded = torch.load(pt_filename, map_location="cpu", weights_only=True) + if "state_dict" in loaded: + loaded = loaded["state_dict"] + shared = _shared_pointers(loaded) + for shared_weights in shared: + for name in shared_weights[1:]: + loaded.pop(name) + + # For tensors to be contiguous + loaded = {k: v.contiguous() for k, v in loaded.items()} + + dirname = os.path.dirname(sf_filename) + os.makedirs(dirname, exist_ok=True) + save_file(loaded, sf_filename, metadata={"format": "pt"}) + + # check file size + sf_size = os.stat(sf_filename).st_size + pt_size = os.stat(pt_filename).st_size + if (sf_size - pt_size) / pt_size > 0.01: + raise RuntimeError(f"""The file size different is more than 1%: + - {sf_filename}: {sf_size} + - {pt_filename}: {pt_size} + """) + + # check if the tensors are the same + reloaded = load_file(sf_filename) + for k in loaded: + pt_tensor = loaded[k] + sf_tensor = reloaded[k] + if not torch.equal(pt_tensor, sf_tensor): + raise RuntimeError(f"The output tensors do not match for key {k}") + + +# TODO(woosuk): Move this to other place. +def get_quant_config(model_config: ModelConfig, + load_config: LoadConfig) -> QuantizationConfig: + + quant_cls = get_quantization_config(model_config.quantization) + + # GGUF doesn't have config file + if model_config.quantization == "gguf": + return quant_cls.from_config({}) + + # Read the quantization config from the HF model config, if available. + hf_quant_config = getattr(model_config.hf_config, "quantization_config", + None) + # some vision model may keep quantization_config in their text_config + hf_text_config = getattr(model_config.hf_config, "text_config", None) + if hf_quant_config is None and hf_text_config is not None: + hf_quant_config = getattr(hf_text_config, "quantization_config", None) + if hf_quant_config is None: + # compressed-tensors uses a compressions_config + hf_quant_config = getattr(model_config.hf_config, "compression_config", + None) + if hf_quant_config is not None: + return quant_cls.from_config(hf_quant_config) + # In case of bitsandbytes/QLoRA, get quant config from the adapter model. + if model_config.quantization == "bitsandbytes": + if (not load_config.model_loader_extra_config + or "qlora_adapter_name_or_path" + not in load_config.model_loader_extra_config): + return quant_cls.from_config({"adapter_name_or_path": ""}) + model_name_or_path = load_config.model_loader_extra_config[ + "qlora_adapter_name_or_path"] + + else: + model_name_or_path = model_config.model + is_local = os.path.isdir(model_name_or_path) + if not is_local: + # Download the config files. + with get_lock(model_name_or_path, load_config.download_dir): + hf_folder = snapshot_download( + model_name_or_path, + revision=model_config.revision, + allow_patterns="*.json", + cache_dir=load_config.download_dir, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + tqdm_class=DisabledTqdm, + ) + else: + hf_folder = model_name_or_path + + possible_config_filenames = quant_cls.get_config_filenames() + + # If the quantization config is not found, use the default config. + if not possible_config_filenames: + return quant_cls() + + config_files = glob.glob(os.path.join(hf_folder, "*.json")) + + quant_config_files = [ + f for f in config_files if any( + f.endswith(x) for x in possible_config_filenames) + ] + if len(quant_config_files) == 0: + raise ValueError( + f"Cannot find the config file for {model_config.quantization}") + if len(quant_config_files) > 1: + raise ValueError( + f"Found multiple config files for {model_config.quantization}: " + f"{quant_config_files}") + + quant_config_file = quant_config_files[0] + with open(quant_config_file) as f: + config = json.load(f) + + if model_config.quantization == "bitsandbytes": + config["adapter_name_or_path"] = model_name_or_path + elif model_config.quantization == "modelopt": + if config["producer"]["name"] == "modelopt": + return quant_cls.from_config(config) + else: + raise ValueError( + f"Unsupported quantization config" + f" found for {model_config.quantization} in {f}.") + + return quant_cls.from_config(config) + + +def download_weights_from_hf( + model_name_or_path: str, + cache_dir: Optional[str], + allow_patterns: List[str], + revision: Optional[str] = None, + ignore_patterns: Optional[Union[str, List[str]]] = None, +) -> str: + """Download model weights from Hugging Face Hub. + + Args: + model_name_or_path (str): The model name or path. + cache_dir (Optional[str]): The cache directory to store the model + weights. If None, will use HF defaults. + allow_patterns (List[str]): The allowed patterns for the + weight files. Files matched by any of the patterns will be + downloaded. + revision (Optional[str]): The revision of the model. + ignore_patterns (Optional[Union[str, List[str]]]): The patterns to + filter out the weight files. Files matched by any of the patterns + will be ignored. + + Returns: + str: The path to the downloaded model weights. + """ + local_only = huggingface_hub.constants.HF_HUB_OFFLINE + if not local_only: + # Before we download we look at that is available: + fs = HfFileSystem() + file_list = fs.ls(model_name_or_path, detail=False, revision=revision) + + # depending on what is available we download different things + for pattern in allow_patterns: + matching = fnmatch.filter(file_list, pattern) + if len(matching) > 0: + allow_patterns = [pattern] + break + + logger.info("Using model weights format %s", allow_patterns) + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model_name_or_path, cache_dir): + start_time = time.perf_counter() + hf_folder = snapshot_download( + model_name_or_path, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + cache_dir=cache_dir, + tqdm_class=DisabledTqdm, + revision=revision, + local_files_only=local_only, + ) + time_taken = time.perf_counter() - start_time + if time_taken > 0.5: + logger.info("Time spent downloading weights for %s: %.6f seconds", + model_name_or_path, time_taken) + return hf_folder + + +def download_safetensors_index_file_from_hf( + model_name_or_path: str, + index_file: str, + cache_dir: Optional[str], + revision: Optional[str] = None, +) -> None: + """Download hf safetensors index file from Hugging Face Hub. + + Args: + model_name_or_path (str): The model name or path. + cache_dir (Optional[str]): The cache directory to store the model + weights. If None, will use HF defaults. + revision (Optional[str]): The revision of the model. + """ + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model_name_or_path, cache_dir): + try: + # Download the safetensors index file. + hf_hub_download( + repo_id=model_name_or_path, + filename=index_file, + cache_dir=cache_dir, + revision=revision, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ) + # If file not found on remote or locally, we should not fail since + # only some models will have index_file. + except huggingface_hub.utils.EntryNotFoundError: + logger.info("No %s found in remote.", index_file) + except huggingface_hub.utils.LocalEntryNotFoundError: + logger.info("No %s found in local cache.", index_file) + + +# For models like Mistral-7B-v0.3, there are both sharded +# safetensors files and a consolidated safetensors file. +# Passing both of these to the weight loader functionality breaks. +# So, we use the index_file to +# look up which safetensors files should be used. +def filter_duplicate_safetensors_files(hf_weights_files: List[str], + hf_folder: str, + index_file: str) -> List[str]: + # model.safetensors.index.json is a mapping from keys in the + # torch state_dict to safetensors file holding that weight. + index_file_name = os.path.join(hf_folder, index_file) + if not os.path.isfile(index_file_name): + return hf_weights_files + + # Iterate through the weight_map (weight_name: safetensors files) + # to identify weights that we should use. + with open(index_file_name) as f: + weight_map = json.load(f)["weight_map"] + weight_files_in_index = set() + for weight_name in weight_map: + weight_files_in_index.add( + os.path.join(hf_folder, weight_map[weight_name])) + # Filter out any fields that are not found in the index file. + hf_weights_files = [ + f for f in hf_weights_files if f in weight_files_in_index + ] + return hf_weights_files + + +def filter_files_not_needed_for_inference( + hf_weights_files: List[str]) -> List[str]: + """ + Exclude files that are not needed for inference. + + See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233 + """ + blacklist = [ + "training_args.bin", + "optimizer.bin", + "optimizer.pt", + "scheduler.pt", + "scaler.pt", + ] + hf_weights_files = [ + f for f in hf_weights_files + if not any(f.endswith(x) for x in blacklist) + ] + return hf_weights_files + + +# explicitly use pure text format, with a newline at the end +# this makes it impossible to see the animation in the progress bar +# but will avoid messing up with ray or multiprocessing, which wraps +# each line of output with some prefix. +_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501 + + +def enable_tqdm(use_tqdm_on_load: bool): + return use_tqdm_on_load and (not torch.distributed.is_initialized() + or torch.distributed.get_rank() == 0) + + +def np_cache_weights_iterator( + model_name_or_path: str, + cache_dir: Optional[str], + hf_folder: str, + hf_weights_files: List[str], + use_tqdm_on_load: bool, +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model np files. + + Will dump the model weights to numpy files if they are not already dumped. + """ + # Convert the model weights from torch tensors to numpy arrays for + # faster loading. + np_folder = os.path.join(hf_folder, "np") + os.makedirs(np_folder, exist_ok=True) + weight_names_file = os.path.join(np_folder, "weight_names.json") + # Use file lock to prevent multiple processes from + # dumping the same model weights to numpy at the same time. + with get_lock(model_name_or_path, cache_dir): + if not os.path.exists(weight_names_file): + weight_names: List[str] = [] + for bin_file in tqdm( + hf_weights_files, + desc="Loading np_cache checkpoint shards", + disable=not enable_tqdm(use_tqdm_on_load), + bar_format=_BAR_FORMAT, + ): + state = torch.load(bin_file, + map_location="cpu", + weights_only=True) + for name, param in state.items(): + param_path = os.path.join(np_folder, name) + with open(param_path, "wb") as f: + np.save(f, param.cpu().detach().numpy()) + weight_names.append(name) + with open(weight_names_file, "w") as f: + json.dump(weight_names, f) + + with open(weight_names_file) as f: + weight_names = json.load(f) + + for name in weight_names: + param_path = os.path.join(np_folder, name) + with open(param_path, "rb") as f: + param = np.load(f) + yield name, torch.from_numpy(param) + + +def safetensors_weights_iterator( + hf_weights_files: List[str], + use_tqdm_on_load: bool, +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model safetensor files.""" + for st_file in tqdm( + hf_weights_files, + desc="Loading safetensors checkpoint shards", + disable=not enable_tqdm(use_tqdm_on_load), + bar_format=_BAR_FORMAT, + ): + with safe_open(st_file, framework="pt") as f: + for name in f.keys(): # noqa: SIM118 + param = f.get_tensor(name) + yield name, param + + +def runai_safetensors_weights_iterator( + hf_weights_files: List[str], + use_tqdm_on_load: bool, +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model safetensor files.""" + with SafetensorsStreamer() as streamer: + for st_file in tqdm( + hf_weights_files, + desc="Loading safetensors using Runai Model Streamer", + disable=not enable_tqdm(use_tqdm_on_load), + bar_format=_BAR_FORMAT, + ): + streamer.stream_file(st_file) + yield from streamer.get_tensors() + + +def fastsafetensors_weights_iterator( + hf_weights_files: List[str], + use_tqdm_on_load: bool, +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model safetensor files + using fastsafetensor library.""" + if torch.distributed.is_initialized(): + pg = torch.distributed.group.WORLD + else: + pg = SingleGroup() + + device = torch.device(f'cuda:{pg.rank()}') + weight_files_sub_lists = [ + hf_weights_files[i:i + pg.size()] + for i in range(0, len(hf_weights_files), pg.size()) + ] + + for f_list in tqdm( + weight_files_sub_lists, + desc="Loading safetensors using Fastsafetensor loader", + disable=not enable_tqdm(use_tqdm_on_load), + bar_format=_BAR_FORMAT, + ): + loader = SafeTensorsFileLoader(pg, device) + rank_file_map = {i: [f] for i, f in enumerate(f_list)} + loader.add_filenames(rank_file_map) + try: + fb = loader.copy_files_to_device() + try: + keys = list(fb.key_to_rank_lidx.keys()) + for k in keys: + t = fb.get_tensor(k) + yield k, t + finally: + fb.close() + finally: + loader.close() + + +def pt_weights_iterator( + hf_weights_files: List[str], + use_tqdm_on_load: bool, +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model bin/pt files.""" + for bin_file in tqdm( + hf_weights_files, + desc="Loading pt checkpoint shards", + disable=not enable_tqdm(use_tqdm_on_load), + bar_format=_BAR_FORMAT, + ): + state = torch.load(bin_file, map_location="cpu", weights_only=True) + yield from state.items() + del state + + +def get_gguf_extra_tensor_names( + gguf_file: str, gguf_to_hf_name_map: Dict[str, str]) -> List[str]: + reader = gguf.GGUFReader(gguf_file) + expected_gguf_keys = set(gguf_to_hf_name_map.keys()) + exact_gguf_keys = set([tensor.name for tensor in reader.tensors]) + extra_keys = expected_gguf_keys - exact_gguf_keys + return [gguf_to_hf_name_map[key] for key in extra_keys] + + +def gguf_quant_weights_iterator( + gguf_file: str, gguf_to_hf_name_map: Dict[str, str] +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """ + Iterate over the quant weights in the model gguf files and convert + them to torch tensors + """ + + reader = gguf.GGUFReader(gguf_file) + + for tensor in reader.tensors: + if tensor.name in gguf_to_hf_name_map: + weight_type = tensor.tensor_type + name = gguf_to_hf_name_map[tensor.name] + + if weight_type.name != "F32": + weight_type_name = name.replace("weight", "qweight_type") + weight_type = torch.tensor(weight_type) + yield weight_type_name, weight_type + + for tensor in reader.tensors: + if tensor.name in gguf_to_hf_name_map: + weight = tensor.data + weight_type = tensor.tensor_type + name = gguf_to_hf_name_map[tensor.name] + if weight_type.name != "F32": + name = name.replace("weight", "qweight") + param = torch.tensor(weight) + yield name, param + + +def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: + """convert PySafeSlice object from safetensors to torch.Tensor + + PySafeSlice object supports indexing, which is done before loading the + actual tensor and can reduce the amount of memory being read into the + memory. However, it does not support more advanced functionalities + like `.view()` or `.t()`. Therefore, if we need to modify the loaded + tensor with these more complicated operators, we need to convert to + tensor first. + """ + if not isinstance(x, torch.Tensor): + x = x[:] + return x + + +def default_weight_loader(param: torch.Tensor, + loaded_weight: torch.Tensor) -> None: + """Default weight loader.""" + try: + if param.numel() == 1 and loaded_weight.numel() == 1: + # Sometimes scalar values aren't considered tensors with shapes + # so if both param and loaded_weight are a scalar, + # "broadcast" instead of copy + param.data.fill_(loaded_weight.item()) + else: + assert param.size() == loaded_weight.size(), ( + f"Attempted to load weight ({loaded_weight.size()}) " + f"into parameter ({param.size()})") + + param.data.copy_(loaded_weight) + except Exception: + # NOTE: This exception is added for the purpose of setting breakpoint to + # debug weight loading issues. + raise + + +def row_parallel_weight_loader(param: torch.Tensor, + loaded_weight: torch.Tensor) -> None: + """Load weights that are row-parallelized.""" + tp_rank = get_tensor_model_parallel_rank() + shard_dim = 0 if param.dim() != 1 else None + + if shard_dim is not None: + shard_size = param.data.shape[shard_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(shard_dim, start_idx, shard_size) + + return default_weight_loader(param, loaded_weight) + + +LoaderFunction = Callable[[torch.Tensor, torch.Tensor], torch.Tensor] + + +def sharded_weight_loader(shard_axis: int) -> LoaderFunction: + """Create a weight loader that shards the weights along the given axis""" + + def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + tp_rank = get_tensor_model_parallel_rank() + + shard_size = param.data.shape[shard_axis] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(shard_axis, start_idx, shard_size) + + return default_weight_loader(param, loaded_weight) + + return loader + + +def composed_weight_loader( + loader: LoaderFunction, fn: Callable[[torch.Tensor], + torch.Tensor]) -> LoaderFunction: + """Create a weight loader that post-processes the weights after loading""" + + def composed_loader(param: torch.Tensor, + loaded_weight: torch.Tensor) -> None: + loader(param, loaded_weight) + param.data.copy_(fn(param)) + return + + return composed_loader + + +def initialize_dummy_weights( + model: torch.nn.Module, + low: float = -1e-3, + high: float = 1e-3, + seed: int = 1234, +) -> None: + """Initialize model weights with random values. + + The model weights must be randomly initialized for accurate performance + measurements. Additionally, the model weights should not cause NaNs in the + forward pass. We empirically found that initializing the weights with + values between -1e-3 and 1e-3 works well for most models. + + We use per-parameter random seed, so that dummy weights are consistent, + even if the model is partitioned across multiple devices. When the seed + is fixed, the random values generated by this function only depends on + the parameter's number of elements and its data type. + """ + for param in model.state_dict().values(): + if torch.is_floating_point(param): + if current_platform.is_tpu(): + # XLA device does not support torch.Generator() + param.uniform_(low, high) + continue + + generator = torch.Generator(device=param.data.device) + generator.manual_seed(seed) + if torch.finfo(param.data.dtype).bits < 16: + # uniform_ doesn't support < 16-bit datatypes (FP8) + dtype = param.data.dtype + tmp_param = param.data.to(torch.float16) + tmp_param = tmp_param.uniform_(low, high, + generator=generator).to(dtype) + param.data.copy_(tmp_param) + else: + param.uniform_(low, high, generator=generator) + + +def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: + """Remap the name of FP8 k/v_scale parameters. + + This function handles the remapping of FP8 k/v_scale parameter names. + It detects if the given name ends with a suffix and attempts to remap + it to the expected name format in the model. If the remapped name is not + found in the params_dict, a warning is printed and None is returned. + + Args: + name (str): The original loaded checkpoint parameter name. + params_dict (dict): Dictionary containing the model's named parameters. + + Returns: + str: The remapped parameter name if successful, or the original name + if no remapping is needed. + None: If the remapped name is not found in params_dict. + """ + if name.endswith(".kv_scale"): + logger.warning_once( + "DEPRECATED. Found kv_scale in the checkpoint. " + "This format is deprecated in favor of separate k_scale and " + "v_scale tensors and will be removed in a future release. " + "Functionally, we will remap kv_scale to k_scale and duplicate " + "k_scale to v_scale") + # NOTE: we remap the deprecated kv_scale to k_scale + remapped_name = name.replace(".kv_scale", ".attn.k_scale") + if remapped_name not in params_dict: + logger.warning_once( + f"Found kv_scale in the checkpoint (e.g. {name}), " + "but not found the expected name in the model " + f"(e.g. {remapped_name}). kv_scale is " + "not loaded.") + return None + return remapped_name + + possible_scale_names = [".k_scale", ".v_scale"] + modelopt_scale_names = [ + ".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale" + ] + for scale_name in possible_scale_names: + if name.endswith(scale_name): + if any(mo_scale_name in name + for mo_scale_name in modelopt_scale_names): + remapped_name = name.replace( + f".self_attn.{scale_name[1]}_proj{scale_name}", + f".self_attn.attn{scale_name}") + else: + remapped_name = name.replace(scale_name, f".attn{scale_name}") + if remapped_name not in params_dict: + logger.warning_once( + f"Found {scale_name} in the checkpoint (e.g. {name}), " + "but not found the expected name in the model " + f"(e.g. {remapped_name}). {scale_name} is " + "not loaded.") + return None + return remapped_name + + # If there were no matches, return the untouched param name + return name diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py new file mode 100644 index 0000000..3580c4f --- /dev/null +++ b/vllm/model_executor/models/__init__.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: Apache-2.0 + +from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal, + SupportsPP, SupportsV0Only, has_inner_state, + supports_lora, supports_multimodal, supports_pp, + supports_v0_only) +from .interfaces_base import (VllmModelForPooling, VllmModelForTextGeneration, + is_pooling_model, is_text_generation_model) +from .registry import ModelRegistry + +__all__ = [ + "ModelRegistry", + "VllmModelForPooling", + "is_pooling_model", + "VllmModelForTextGeneration", + "is_text_generation_model", + "HasInnerState", + "has_inner_state", + "SupportsLoRA", + "supports_lora", + "SupportsMultiModal", + "supports_multimodal", + "SupportsPP", + "supports_pp", + "SupportsV0Only", + "supports_v0_only", +] diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py new file mode 100644 index 0000000..6ab03c4 --- /dev/null +++ b/vllm/model_executor/models/adapters.py @@ -0,0 +1,247 @@ +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any, Optional, TypeVar + +import torch +import torch.nn as nn + +from .interfaces_base import VllmModelForPooling, is_pooling_model + +if TYPE_CHECKING: + from vllm.model_executor.layers.pooler import PoolingType + +_T = TypeVar("_T", bound=type[nn.Module]) + +_GENERATE_SUFFIXES = [ + "ForCausalLM", + "ForConditionalGeneration", + "ChatModel", + "LMHeadModel", +] + + +def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str: + model_name = orig_model_name + + for generate_suffix in _GENERATE_SUFFIXES: + model_name = model_name.removesuffix(generate_suffix) + + return model_name + pooling_suffix + + +def _create_pooling_model_cls( + orig_cls: _T, + *, + default_pooling_type: "PoolingType", + default_normalize: bool, + default_softmax: bool, +) -> _T: + # Lazy import + from vllm.config import VllmConfig + from vllm.model_executor.layers.pooler import Pooler, PoolerOutput + from vllm.model_executor.pooling_metadata import PoolingMetadata + + from .utils import AutoWeightsLoader, WeightsMapper + + class ModelForPooling(orig_cls, VllmModelForPooling): + + def __init__( + self, + *, + vllm_config: "VllmConfig", + prefix: str = "", + **kwargs: Any, + ) -> None: + super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) + + # These are not used in pooling models + for attr in ("lm_head", "logits_processor"): + if hasattr(self, attr): + delattr(self, attr) + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + # If the model already defines a pooler instance, don't overwrite it + if not getattr(self, "_pooler", None): + self._pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=default_pooling_type, + normalize=default_normalize, + softmax=default_softmax, + ) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + # TODO: Support uninitialized params tracking + + # We have deleted this attribute, so don't load it + weights = ((name, data) for name, data in weights + if not name.startswith("lm_head.")) + + # If `*ForCausalLM` defines `load_weights` on the inner model + # and there are no other inner modules with parameters, + # we support loading from both `*Model` and `*ForCausalLM` + if hasattr(self, "model") and hasattr(self.model, "load_weights"): + # Whether only `self.model` contains parameters + model_is_only_param = all( + name == "model" or next(child.parameters(), None) is None + for name, child in self.named_children()) + + if model_is_only_param: + mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) + weights = mapper.apply(weights) + + loaded_params = self.model.load_weights(weights) + loaded_params = {f"model.{name}" for name in loaded_params} + return loaded_params + + # For most other models + if hasattr(orig_cls, "load_weights"): + return orig_cls.load_weights(self, weights) # type: ignore + # Fallback + else: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) + + return ModelForPooling # type: ignore + + +def as_embedding_model(cls: _T) -> _T: + """ + Subclass an existing vLLM model to support embeddings. + + By default, the embeddings of the whole prompt are extracted from the + normalized hidden state corresponding to the last token. + + Note: + We assume that no extra layers are added to the original model; + please implement your own model if this is not the case. + """ + # Avoid modifying existing embedding models + if is_pooling_model(cls): + return cls + + # Lazy import + from vllm.model_executor.layers.pooler import PoolingType + + ModelForEmbedding = _create_pooling_model_cls( + cls, + default_pooling_type=PoolingType.LAST, + default_normalize=True, + default_softmax=False, + ) + ModelForEmbedding.__name__ = \ + _get_pooling_model_name(cls.__name__, "ForEmbedding") + + return ModelForEmbedding # type: ignore + + +def as_classification_model(cls: _T) -> _T: + """ + Subclass an existing vLLM model to support classification. + + By default, the class probabilities are extracted from the softmaxed + hidden state corresponding to the last token. + + Note: + We assume that the classification head is a single linear layer + stored as the attribute `score` of the top-level model; + please implement your own model if this is not the case. + """ + # Avoid modifying existing classification models + if is_pooling_model(cls): + return cls + + # Lazy import + from vllm.config import VllmConfig + from vllm.model_executor.layers.linear import RowParallelLinear + from vllm.model_executor.layers.pooler import PoolingType + from vllm.sequence import IntermediateTensors + + from .utils import maybe_prefix + + ModelForPooling = _create_pooling_model_cls( + cls, + default_pooling_type=PoolingType.LAST, + default_normalize=False, + default_softmax=True, + ) + + class ModelForClassification(ModelForPooling): + + def __init__( + self, + *, + vllm_config: "VllmConfig", + prefix: str = "", + **kwargs: Any, + ) -> None: + super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.score = RowParallelLinear(config.hidden_size, + config.num_labels, + quant_config=quant_config, + input_is_parallel=False, + bias=False, + prefix=maybe_prefix( + prefix, "score")) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = super().forward(input_ids, positions, + intermediate_tensors, + inputs_embeds) + logits, _ = self.score(hidden_states) + return logits + + + ModelForClassification.__name__ = \ + _get_pooling_model_name(cls.__name__, "ForClassification") + + return ModelForClassification # type: ignore + + +def as_reward_model(cls: _T) -> _T: + """ + Subclass an existing vLLM model to support reward modeling. + + By default, we return the hidden states of each token directly. + + Note: + We assume that no extra layers are added to the original model; + please implement your own model if this is not the case. + """ + # Avoid modifying existing reward models + if is_pooling_model(cls): + return cls + + # Lazy import + from vllm.model_executor.layers.pooler import PoolingType + + ModelForReward = _create_pooling_model_cls( + cls, + default_pooling_type=PoolingType.ALL, + default_normalize=False, + default_softmax=False, + ) + + ModelForReward.__name__ = \ + _get_pooling_model_name(cls.__name__, "ForReward") + + return ModelForReward # type: ignore diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py new file mode 100644 index 0000000..065715c --- /dev/null +++ b/vllm/model_executor/models/arctic.py @@ -0,0 +1,569 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Inference-only Snowflake Arctic model.""" +from typing import Iterable, List, Optional, Set, Tuple, Union + +import torch +from torch import nn + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.deepspeedfp import ( + DeepSpeedFPConfig, DeepSpeedFPParameter) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.arctic import ArcticConfig + +from .interfaces import SupportsPP, SupportsQuant +from .utils import (extract_layer_index, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + + +class ArcticMLP(nn.Module): + + def __init__(self, + config: ArcticConfig, + expert_id: int = -1, + is_residual_mlp: bool = False, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = ""): + super().__init__() + self.hidden_size = config.hidden_size + self.expert_id = expert_id + + self.ffn_dim = config.intermediate_size if not is_residual_mlp \ + else self.hidden_size + + self.w13 = MergedColumnParallelLinear(self.hidden_size, + [self.ffn_dim] * 2, + bias=False, + quant_config=quant_config) + self.w2 = RowParallelLinear(self.ffn_dim, + self.hidden_size, + bias=False, + reduce_results=reduce_results, + quant_config=quant_config) + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, hidden_states): + gate_up, _ = self.w13(hidden_states) + hidden_states = self.act_fn(gate_up) + hidden_states, _ = self.w2(hidden_states) + return hidden_states + + +class ArcticMoE(nn.Module): + """ + Model-parallel implementation of Arctic MoE Layer. + """ + + def __init__(self, + config: ArcticConfig, + tp_size: Optional[int] = None, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = ""): + super().__init__() + + layer_id = extract_layer_index(prefix) + self.tp_size = tp_size or get_tensor_model_parallel_world_size() + self.hidden_size = config.hidden_size + self.num_experts = config.num_local_experts + self.layer_id = layer_id + self.top_k = config.num_experts_per_tok + self.intermediate_size = config.intermediate_size // self.tp_size + + self.is_moe_layer = (layer_id + 1) % config.moe_layer_frequency == 0 + self.is_quant = isinstance(quant_config, DeepSpeedFPConfig) + self.reduce_results = reduce_results + # Some other parameters + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + if not self.is_moe_layer: + self.mlp = ArcticMLP(config, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.mlp") + else: + self.gate = ReplicatedLinear(self.hidden_size, + self.num_experts, + bias=False, + params_dtype=self.params_dtype, + quant_config=quant_config, + prefix=f"{prefix}.gate") + if self.is_quant: + self.ws = DeepSpeedFPParameter( + torch.Size((self.num_experts, 2 * self.intermediate_size, + self.hidden_size)), + params_dtype=params_dtype, + quant_config=quant_config, + ) + self.w2s = DeepSpeedFPParameter( + torch.Size((self.num_experts, self.hidden_size, + self.intermediate_size)), + params_dtype=params_dtype, + quant_config=quant_config, + ) + else: + self.ws = nn.Parameter( + torch.empty(self.num_experts, + 2 * self.intermediate_size, + self.hidden_size, + device=current_platform.device_type, + dtype=self.params_dtype)) + self.w2s = nn.Parameter( + torch.empty(self.num_experts, + self.hidden_size, + self.intermediate_size, + device=current_platform.device_type, + dtype=self.params_dtype)) + set_weight_attrs(self.ws, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.w2s, { + "weight_loader": self.weight_loader, + }) + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, + weight_name: str, expert_id: int): + tp_rank = get_tensor_model_parallel_rank() + param_data = param.ds_dequantize() if self.is_quant else param.data + shard_size = self.intermediate_size + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + if weight_name.endswith("w1.weight"): + param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w3.weight"): + param_data[expert_id, + shard_size:2 * shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w2.weight"): + param_data[expert_id, :, :] = loaded_weight[:, shard] + if self.is_quant: + param.ds_quantize_(param_data) + + def local_moe_fused(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + do_normalize = self.top_k > 1 + topk_weights, topk_ids = fused_topk(hidden_states, + router_logits, + self.top_k, + renormalize=do_normalize) + # topk_ids: (num_tokens, k) + if self.is_quant: + if 2 * num_tokens <= self.num_experts: + # If much fewer tokens than experts, use selective dequantize. + ws_dequantized = self.ws.ds_selective_dequantize( + topk_ids.flatten()) + w2s_dequantized = self.w2s.ds_selective_dequantize( + topk_ids.flatten()) + # We gathered the experts to the tokens so update the mapping. + topk_ids = torch.arange( + 0, + topk_ids.numel(), + device=topk_ids.device, + ).reshape(topk_ids.shape) + else: + ws_dequantized = self.ws.ds_dequantize() + w2s_dequantized = self.w2s.ds_dequantize() + + final_hidden_states = fused_experts( + hidden_states, + ws_dequantized if self.is_quant else self.ws, + w2s_dequantized if self.is_quant else self.w2s, + topk_weights, + topk_ids, + inplace=True) + if self.reduce_results and self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + return final_hidden_states.view(num_tokens, hidden_size) + + def forward(self, hidden_states: torch.Tensor): + if self.is_moe_layer: + final_hidden_states = self.local_moe_fused(hidden_states) + else: + final_hidden_states = self.mlp(hidden_states) + return final_hidden_states + + +class ArcticAttention(nn.Module): + + def __init__( + self, + config: ArcticConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = self.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear(self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=False, + reduce_results=True, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=int(self.rope_theta), + is_neox_style=True, + ) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class ArcticDecoderLayer(nn.Module): + + def __init__( + self, + config: ArcticConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + layer_idx = extract_layer_index(prefix) + is_moe_layer = (layer_idx + 1) % config.moe_layer_frequency == 0 + self.use_residual = config.use_residual and is_moe_layer + self.self_attn = ArcticAttention(config, + cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn") + self.block_sparse_moe = ArcticMoE( + config, + quant_config=quant_config, + reduce_results=(not self.use_residual), + prefix=f"{prefix}.block_sparse_moe", + ) + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + if self.use_residual: + self.residual_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.residual_mlp = ArcticMLP(config, + is_residual_mlp=True, + reduce_results=False, + prefix=f"{prefix}.residual_mlp") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + residual_input = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + hidden_states = residual_input + hidden_states + + residual_attn = hidden_states + if self.use_residual: + hidden_states = self.residual_layernorm(hidden_states) + hidden_states = self.residual_mlp(hidden_states) + residual_mlp = hidden_states + hidden_states = self.post_attention_layernorm(residual_input) + hidden_states = self.block_sparse_moe(hidden_states) + hidden_states = residual_mlp + hidden_states + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + hidden_states = residual_attn + hidden_states + else: + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.block_sparse_moe(hidden_states) + hidden_states = residual_attn + hidden_states + return hidden_states + + +@support_torch_compile +class ArcticModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=self.vocab_size) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: ArcticDecoderLayer( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers") + self._attn_implementation = config._attn_implementation + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(positions, hidden_states) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant): + packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.model = ArcticModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + self.vocab_size, + config.hidden_size, + quant_config=quant_config, + ) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + self.unpadded_vocab_size = config.vocab_size + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + mlp_params_mapping: List[Tuple[str, str, int]] = [] + expert_params_mapping: List[Tuple[str, str, int]] = [] + num_layers = self.config.num_hidden_layers + + for layer in range(num_layers): + mlp_params_mapping.append( + (f"layers.{layer}.residual_mlp.w13.weight", + f"layers.{layer}.residual_mlp.w1.weight", 0)) + mlp_params_mapping.append( + (f"layers.{layer}.residual_mlp.w13.weight", + f"layers.{layer}.residual_mlp.w3.weight", 1)) + if layer % 2 == 0: + # MLP layers + mlp_params_mapping.append( + (f"layers.{layer}.block_sparse_moe.mlp.w13.weight", + f"layers.{layer}.block_sparse_moe.mlp.w1.weight", 0)) + mlp_params_mapping.append( + (f"layers.{layer}.block_sparse_moe.mlp.w13.weight", + f"layers.{layer}.block_sparse_moe.mlp.w3.weight", 1)) + else: + # MoE layers + for expert_id in range(self.config.num_local_experts): + expert_params_mapping.append( + ("ws", f"experts.{expert_id}.w1.weight", expert_id)) + expert_params_mapping.append( + ("w2s", f"experts.{expert_id}.w2.weight", expert_id)) + expert_params_mapping.append( + ("ws", f"experts.{expert_id}.w3.weight", expert_id)) + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + + logger.info( + "It will take ~10 minutes loading from the 16-bit weights. " + "Alternatively, use the prequantized 8-bit weights of arctic " + "and set load-format to `sharded_state` will accelerate loading.") + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, shard_id in mlp_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, shard_id \ + in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + weight_name, + expert_id=shard_id) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py new file mode 100644 index 0000000..8cd3be9 --- /dev/null +++ b/vllm/model_executor/models/aria.py @@ -0,0 +1,669 @@ +# SPDX-License-Identifier: Apache-2.0 +from collections.abc import Iterable, Mapping, Sequence +from typing import List, Optional, Set, Tuple, TypedDict, Union + +import torch +import torch.nn as nn +from transformers import AriaConfig, AriaTextConfig, BatchFeature +from transformers.models.aria.modeling_aria import AriaCrossAttention +from transformers.models.aria.processing_aria import AriaProcessor + +from vllm.config import CacheConfig, QuantizationConfig, VllmConfig +from vllm.distributed import get_tensor_model_parallel_rank +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import (SamplerOutput, + SamplingMetadata, get_sampler) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors + +# yapf: disable +from .idefics2_vision_model import Idefics2VisionConfig +from .idefics2_vision_model import ( + Idefics2VisionTransformer as Idefics3VisionTransformer) +# yapf: enable +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant +from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, + is_pp_missing_parameter, maybe_prefix, + merge_multimodal_embeddings) + + +class AriaImagePixelInputs(TypedDict): + pixel_values: torch.Tensor + pixel_mask: Optional[torch.Tensor] + """ + Shape: + pixel_values: `(batch_size * num_images, num_channels, height, width)` + pixel_mask: `(batch_size * num_images, height, width)` + """ + + +class AriaVisionTransformer(Idefics3VisionTransformer, SupportsQuant): + packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} + + def __init__( + self, + config: Idefics2VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config, quant_config=quant_config, prefix=prefix) + # Unlike Idefics3VisionTransformer which uses LayerNorm after the + # final layer, Aria omits this normalization, so we replace it with an + # Identity layer + self.post_layernorm = nn.Identity() + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + + # NOTE: post_layernorm is not used in Aria + if "post_layernorm" in name: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class AriaProjectorMLP(nn.Module): + + def __init__( + self, + in_features: int, + hidden_features: int, + output_dim: int, + ) -> None: + super().__init__() + + self.linear_in = ColumnParallelLinear(in_features, + hidden_features, + bias=False) + self.linear_out = RowParallelLinear(hidden_features, + output_dim, + bias=False) + self.act = get_act_fn("gelu_new") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.linear_in(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.linear_out(hidden_states) + return hidden_states + + +class AriaProjector(nn.Module): + """ + A projection module with one cross attention layer and one FFN layer, which + projects ViT's outputs into MoE's inputs. + + Args: + patch_to_query_dict (dict): Maps patch numbers to their corresponding + query numbers, + e.g., {1225: 128, 4900: 256}. This allows for different query sizes + based on image resolution. + embed_dim (int): Embedding dimension. + num_heads (int): Number of attention heads. + kv_dim (int): Dimension of key and value. + ff_dim (int): Hidden dimension of the feed-forward network. + output_dim (int): Output dimension. + norm_layer (nn.Module): Normalization layer. Default is nn.LayerNorm. + + Outputs: + A tensor with the shape of (batch_size, query_number, output_dim) + """ + + def __init__(self, config: AriaConfig) -> None: + super().__init__() + + self.patch_to_query_dict = config.projector_patch_to_query_dict + self.in_features = config.vision_config.hidden_size + self.num_heads = config.vision_config.num_attention_heads + self.kv_dim = config.vision_config.hidden_size + self.hidden_features = config.text_config.hidden_size + self.output_dim = config.text_config.hidden_size + + self.query = nn.Parameter( + torch.empty(config.max_value_projector_patch_to_query_dict, + self.in_features)) + + self.cross_attn = AriaCrossAttention(config) + + self.layer_norm = nn.LayerNorm(self.in_features) + self.feed_forward = AriaProjectorMLP(self.in_features, + self.hidden_features, + self.output_dim) + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, num_patches = x.shape[0], x.shape[1] + + if num_patches not in self.patch_to_query_dict: + raise KeyError(f"Number of patches {num_patches} not found in " + "patch_to_query_dict amongst possible values " + f"{self.patch_to_query_dict.keys()}.") + + query_num = self.patch_to_query_dict[num_patches] + + queries = self.query[:query_num].unsqueeze(0).repeat(batch_size, 1, 1) + + if attn_mask is not None: + attn_mask = attn_mask.repeat_interleave(self.num_heads, 0) + attn_mask = attn_mask.unsqueeze(1).expand(-1, queries.size(1), -1) + + attention_out = self.cross_attn(x, queries, attn_mask=attn_mask) + + out = self.feed_forward(self.layer_norm(attention_out)) + + return out + + +class AriaFusedMoE(FusedMoE): + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, + shard_id: str) -> None: + # Override the weight_loader to handle the expert weights in the Aria + # model, which are already packed with experts, and merge the gate and + # up weights for each expert. + # Note: Loading expert weights with quantization is not supported + tp_rank = get_tensor_model_parallel_rank() + if shard_id == 'w13': + # the shape of loaded_weight is + # (num_experts, hidden_size, 2 * moe_intermediate_size) + if self.tp_size > 1: + up, gate = loaded_weight.chunk(2, dim=-1) + up_current_rank = up.chunk(self.tp_size, dim=-1)[tp_rank] + gate_current_rank = gate.chunk(self.tp_size, dim=-1)[tp_rank] + up_and_gate = torch.cat([up_current_rank, gate_current_rank], + dim=-1).transpose(1, 2) + param.data.copy_(up_and_gate) + else: + param.data.copy_(loaded_weight.transpose(1, 2)) + elif shard_id == 'w2': + # the shape of loaded_weight is + # (num_experts, moe_intermediate_size, hidden_size) + if self.tp_size > 1: + down_current_rank = loaded_weight.chunk(self.tp_size, + dim=1)[tp_rank] + param.data.copy_(down_current_rank.transpose(1, 2)) + else: + param.data.copy_(loaded_weight.transpose(1, 2)) + + +class AriaTextMoELayer(nn.Module): + """ + Mixture of Experts (MoE) Layer for the AriaMoE model. + + This layer implements the MoE mechanism, which routes input tokens to + different experts based on a routing algorithm, processes them through the + experts, and then combines the outputs. + """ + + def __init__( + self, + config: AriaTextConfig, + quant_config: Optional[QuantizationConfig], + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + + self.router_weight = nn.Parameter( + torch.empty( + (self.config.moe_num_experts, self.config.hidden_size))) + + self.experts = AriaFusedMoE( + num_experts=config.moe_num_experts, + top_k=config.moe_topk, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + reduce_results=True, + prefix=f"{prefix}.experts", + ) + self.shared_experts = LlamaMLP( + config.hidden_size, + config.intermediate_size * config.moe_num_shared_experts, + "silu", + quant_config=quant_config, + bias=config.mlp_bias, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the MoE Layer. + + Args: + hidden_states (torch.Tensor): Input tensor of shape (batch_size, + sequence_length, hidden_size). + + Returns: + torch.Tensor: Output tensor after passing through the MoE layer. + """ + + router_output = torch.nn.functional.linear(hidden_states, + self.router_weight) + + hidden_states_copy = hidden_states.clone() + # NOTE: hidden_states will be modified inplace by `FusedMoE` + sparse_expert_output = self.experts(hidden_states, router_output) + shared_expert_output = self.shared_experts(hidden_states_copy) + + return sparse_expert_output + shared_expert_output + + +class AriaTextDecoderLayer(LlamaDecoderLayer): + """ + Custom Decoder Layer for the AriaMoE model which modifies the standard + `LlamaDecoderLayer` by replacing the traditional MLP with a Mixture of + Experts (MoE) Layer. + """ + + def __init__( + self, + config: AriaTextConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config, cache_config, quant_config, prefix) + self.mlp = AriaTextMoELayer(config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + +class AriaTextModel(LlamaModel, SupportsQuant): + """ + Custom LlamaModel for the AriaMoE model which modifies the standard + LlamaModel by replacing the `LlamaDecoderLayer` with `MoEDecoderLayer`. + """ + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + "experts.w13_weight": ["experts.fc1.weight"], + "experts.w2_weight": ["experts.fc2.weight"], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, + prefix=prefix, + layer_type=AriaTextDecoderLayer) + + # Adapted from LlamaModel.load_weights with the modification of adding + # the expert weights mapping to `stacked_params_mapping` + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ("experts.w13_weight", "experts.fc1.weight", 'w13'), + ("experts.w2_weight", "experts.fc2.weight", 'w2'), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class AriaProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(AriaConfig) + + def get_vision_config(self): + return self.get_hf_config().vision_config + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(AriaProcessor, **kwargs) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"image": self.get_num_image_tokens()} + + def get_num_image_tokens(self) -> int: + hf_config = self.get_hf_config() + return max(hf_config.projector_patch_to_query_dict.values()) + + +class AriaDummyInputsBuilder(BaseDummyInputsBuilder[AriaProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + vision_config = self.info.get_vision_config() + + max_image_size = vision_config.image_size + num_images = mm_counts.get("image", 0) + + mm_data = { + "image": + self._get_dummy_images(width=max_image_size, + height=max_image_size, + num_images=num_images) + } + + hf_processor = self.info.get_hf_processor() + image_token: str = hf_processor.tokenizer.image_token # type: ignore + + return ProcessorInputs( + prompt_text=image_token * num_images, + mm_data=mm_data, + ) + + +class AriaMultiModalProcessor(BaseMultiModalProcessor[AriaProcessingInfo]): + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + pixel_mask=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_config = self.info.get_hf_config() + image_token_id = hf_config.image_token_index + + num_image_tokens = self.info.get_num_image_tokens() + + return [ + PromptReplacement( + modality="image", + target=[image_token_id], + replacement=[image_token_id] * num_image_tokens, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor, + info=AriaProcessingInfo, + dummy_inputs=AriaDummyInputsBuilder) +class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): + """ + Aria model for conditional generation tasks. + + This model combines a vision tower, a multi-modal projector, and a language + model to perform tasks that involve both image and text inputs. + """ + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "language_model.model": "language_model", + "language_model.lm_head": "lm_head", + }, + orig_to_new_suffix={ + "router.weight": "router_weight", + }, + ) + + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.config = config + self.vision_tower = AriaVisionTransformer( + config.vision_config, + quant_config=quant_config, + prefix=f"{prefix}.vision_tower", + ) + self.multi_modal_projector = AriaProjector(config) + self.vocab_size = config.text_config.vocab_size + self.language_model = AriaTextModel( + vllm_config=vllm_config.with_hf_config(config.text_config), + prefix=maybe_prefix(prefix, "language_model.model"), + ) + self.pad_token_id = (self.config.pad_token_id + if self.config.pad_token_id is not None else -1) + self.unpadded_vocab_size = config.text_config.vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.text_config.hidden_size, + org_num_embeddings=self.language_model.org_vocab_size, + quant_config=quant_config, + ) + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + self.vocab_size, logit_scale) + self.sampler = get_sampler() + + def _validate_image_sizes( + self, images: List[torch.Tensor]) -> List[torch.Tensor]: + if not all(img.shape == images[0].shape for img in images): + raise ValueError("All images must be the same size") + return images + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[AriaImagePixelInputs]: + pixel_values = kwargs.pop("pixel_values", None) + pixel_mask = kwargs.pop("pixel_mask", None) + + if pixel_values is None: + return None + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + pixel_values = self._validate_image_sizes(pixel_values) + pixel_values = flatten_bn(pixel_values, concat=True) + + if pixel_mask is not None: + if not isinstance(pixel_mask, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel mask. " + f"Got type: {type(pixel_mask)}") + + pixel_mask = flatten_bn(pixel_mask, concat=True) + + return AriaImagePixelInputs( + pixel_values=pixel_values, + pixel_mask=pixel_mask, + ) + + def _create_patch_attention_mask( + self, pixel_mask: Optional[torch.Tensor]) -> torch.Tensor: + if pixel_mask is None: + return None + + patches_subgrid = pixel_mask.unfold( + dimension=1, + size=self.vision_tower.config.patch_size, + step=self.vision_tower.config.patch_size, + ).unfold( + dimension=2, + size=self.vision_tower.config.patch_size, + step=self.vision_tower.config.patch_size, + ) + return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + def _process_image_input( + self, image_input: AriaImagePixelInputs + ) -> Tuple[torch.Tensor, torch.Tensor]: + assert self.vision_tower is not None + + pixel_values = image_input['pixel_values'] + pixel_mask = image_input['pixel_mask'] + + patch_attention_mask = self._create_patch_attention_mask(pixel_mask) + + image_outputs = self.vision_tower( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + image_attn_mask = None + if patch_attention_mask is not None: + flattened_mask = patch_attention_mask.flatten(1) + image_attn_mask = torch.logical_not(flattened_mask) + + return self.multi_modal_projector(image_outputs, image_attn_mask) + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + multimodal_embeddings = self._process_image_input(image_input) + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.config.image_token_index) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if inputs_embeds is None: + multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) + # always pass the input via `inputs_embeds` + # to make sure the computation graph is consistent + inputs_embeds = self.get_input_embeddings(input_ids, + multimodal_embeddings) + input_ids = None + + hidden_states = self.language_model( + input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self) + loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py new file mode 100644 index 0000000..b4bf1d8 --- /dev/null +++ b/vllm/model_executor/models/aya_vision.py @@ -0,0 +1,527 @@ +# SPDX-License-Identifier: Apache-2.0 Adapted from +# https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision +from functools import cached_property +from typing import (Iterable, Literal, Mapping, Optional, Sequence, Set, Tuple, + TypedDict, Union, cast) + +import torch +from torch import nn +from transformers import BatchFeature, GotOcr2ImageProcessor +from transformers.activations import ACT2FN +from transformers.image_processing_utils import get_size_dict +from transformers.models.aya_vision import AyaVisionConfig +from transformers.models.aya_vision.processing_aya_vision import ( + AyaVisionProcessor) +from transformers.models.got_ocr2.image_processing_got_ocr2 import ( + get_optimal_tiled_canvas) + +from vllm.config import VllmConfig +from vllm.jsontree import json_map_leaves +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalKwargs +from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, + MultiModalDataItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, + MultiModalFieldConfig, + PromptReplacement, PromptUpdate, + encode_tokens) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .siglip import SiglipVisionModel +from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, + maybe_prefix, merge_multimodal_embeddings) +from .vision import scatter_patch_features, select_patch_features + + +class AyaVisionImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + pixel_values: torch.Tensor + """ + Shape: `(num_patches_total, num_channels, height, width)` + + `num_patches_total` is the total number of patches over each image over each + prompt in the batch. + """ + + num_patches: torch.Tensor + """Shape: `(batch_size * num_images)`""" + + embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] + """ + A boolean mask indicating which image embeddings correspond to patch tokens. + + Shape: `(batch_size * num_images, num_embeds)` + """ + + +class AyaVisionMultiModalProjector(nn.Module): + + def __init__(self, config: AyaVisionConfig): + super().__init__() + self.config = config + self.downsample_factor = config.downsample_factor + self.alignment_intermediate_size = getattr( + config, "alignment_intermediate_size", + config.text_config.hidden_size) + self.layernorm = nn.LayerNorm(config.vision_config.hidden_size * + (config.downsample_factor**2), + eps=config.adapter_layer_norm_eps) + + self.linear_1 = nn.Linear( + config.vision_config.hidden_size * (config.downsample_factor**2), + self.alignment_intermediate_size, + bias=True, + ) + + self.act = ACT2FN["silu"] # SwiGLU uses SiLU activation + # For SwiGLU, project down to half size since we split intermediate dim + self.linear_2 = nn.Linear(self.alignment_intermediate_size // 2, + config.text_config.hidden_size, + bias=True) + + def forward(self, image_features: torch.Tensor) -> torch.Tensor: + image_features = self.pixel_shuffle(image_features) + image_features = self.layernorm(image_features) + hidden_states = self.linear_1(image_features) + + # Split along last dimension and apply SwiGLU + x, gate = hidden_states.chunk(2, dim=-1) + hidden_states = self.act(gate) * x + + hidden_states = self.linear_2(hidden_states) + return hidden_states + + def pixel_shuffle(self, + image_features: torch.Tensor) -> torch.Tensor: # B, S, D + batch_size, seq_length, _ = image_features.shape + height = width = int(seq_length**0.5) + image_features = image_features.reshape(image_features.shape[0], width, + height, -1) + channels = image_features.shape[-1] + image_features = image_features.reshape( + batch_size, width, int(height / self.downsample_factor), + int(channels * self.downsample_factor)) + image_features = image_features.permute(0, 2, 1, 3) + image_features = image_features.reshape( + batch_size, int(height / self.downsample_factor), + int(width / self.downsample_factor), -1) + image_features = image_features.permute(0, 2, 1, 3) + return image_features + + +class AyaVisionProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self) -> AyaVisionConfig: + return self.ctx.get_hf_config(AyaVisionConfig) + + def get_hf_processor(self, **kwargs: object) -> AyaVisionProcessor: + return self.ctx.get_hf_processor(AyaVisionProcessor, **kwargs) + + def get_image_processor(self) -> GotOcr2ImageProcessor: + return self.get_hf_processor().image_processor + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"image": self.get_max_image_tokens()} + + def get_max_image_tokens(self) -> int: + hf_processor = self.get_hf_processor() + image_processor = hf_processor.image_processor + image_size = self.get_image_size_with_most_features() + tokenizer = hf_processor.tokenizer + num_patches = self.get_num_patches( + image_width=image_size.width, + image_height=image_size.height, + size=image_processor.size, + min_patches=image_processor.min_patches, + max_patches=image_processor.max_patches) + image_string = hf_processor._prompt_split_image(num_patches) + x = encode_tokens( + tokenizer, + image_string, + add_special_tokens=False, + ) + return len(x) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_image_size_with_most_features(self) -> ImageSize: + image_processor = self.get_image_processor() + height = image_processor.size['height'] + width = image_processor.size['width'] + max_patches = image_processor.max_patches + return ImageSize(height=height * max_patches, + width=width * max_patches) + + def get_num_patches(self, *, image_width: int, image_height: int, + size: dict, min_patches: int, max_patches: int) -> int: + """ + Calculate the number of patches needed for a given image based on size + constraints. This method replicates and adjusts the logic from: + transformers/models/got_ocr2/image_processing_got_ocr2 + """ + size = get_size_dict(size, default_to_square=False) + num_columns, num_rows = get_optimal_tiled_canvas( + (image_height, image_width), (size["height"], size["width"]), + min_patches, max_patches) + num_blocks = num_columns * num_rows + return num_blocks if num_blocks == 1 else num_blocks + 1 + + +class AyaVisionDummyInputsBuilder( + BaseDummyInputsBuilder[AyaVisionProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + processor = self.info.get_hf_processor() + image_token = processor.image_token + + num_images = mm_counts.get("image", 0) + image_size = \ + self.info.get_image_size_with_most_features() + + mm_data = { + "image": + self._get_dummy_images(width=image_size.width, + height=image_size.height, + num_images=num_images) + } + return ProcessorInputs( + prompt_text=image_token * num_images, + mm_data=mm_data, + ) + + +class AyaVisionMultiModalProcessor( + BaseMultiModalProcessor[AyaVisionProcessingInfo]): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + processed_outputs = super()._call_hf_processor( + prompt, + mm_data, + mm_kwargs, + ) + hf_processor = self.info.get_hf_processor(**mm_kwargs) + image_processor = hf_processor.image_processor + + hf_config = self.info.get_hf_config() + # HF processor pops the `num_patches` kwarg, which is needed by vLLM + if (images := + mm_data.get("images")) is not None and '' in prompt: + assert isinstance(images, list) + parsed_images = (self._get_data_parser().parse_mm_data({ + "image": + images + }).get_items("image", ImageProcessorItems)) + image_sizes = [ + parsed_images.get_image_size(i) + for i in range(len(parsed_images)) + ] + num_patches = [ + self.info.get_num_patches( + image_width=image_size.width, + image_height=image_size.height, + size=image_processor.size, + min_patches=image_processor.min_patches, + max_patches=image_processor.max_patches) + for image_size in image_sizes + ] + image_tokens_list = [ + hf_processor._prompt_split_image(num_patch) + for num_patch in num_patches + ] + tokenizer = self.info.get_tokenizer() + image_token_ids = [ + tokenizer.encode(image_tokens, add_special_tokens=False) + for image_tokens in image_tokens_list + ] + embed_is_patch = [ + torch.tensor(image_repl_tokens) == hf_config.image_token_index + for image_repl_tokens in image_token_ids + ] + processed_outputs["embed_is_patch"] = embed_is_patch + processed_outputs["num_patches"] = torch.tensor(num_patches) + + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + num_patches = hf_inputs.get("num_patches", torch.empty(0)) + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", num_patches), + num_patches=MultiModalFieldConfig.batched("image"), + embed_is_patch=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_token = hf_processor.image_token + image_processor = hf_processor.image_processor + + def get_replacement(item_idx: int): + images: ImageProcessorItems = mm_items.get("image", + ImageProcessorItems) + image_size: ImageSize = images.get_image_size(item_idx) + num_patches = self.info.get_num_patches( + image_width=image_size.width, + image_height=image_size.height, + size=image_processor.size, + min_patches=image_processor.min_patches, + max_patches=image_processor.max_patches) + return hf_processor._prompt_split_image(num_patches=num_patches) + + return [ + PromptReplacement( + modality="image", + target=image_token, + replacement=get_replacement, + ) + ] + + +def _get_num_hidden_layers(hf_config: AyaVisionConfig) -> int: + feature_layers = hf_config.vision_feature_layer + num_hidden_layers = hf_config.vision_config.num_hidden_layers + # If we have one feature layer, initialize up to that layer + if isinstance(feature_layers, int): + return _get_layer_index(feature_layers, num_hidden_layers) + # If we have multiple feature layers, initialize up to the deepest m + elif isinstance(feature_layers, (list, tuple)): + return max( + _get_layer_index(idx, num_hidden_layers) for idx in feature_layers) + raise TypeError(f"vision_layer_feature type: {type(feature_layers)}" + " is not supported") + + +def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int: + if feature_layer_index < 0: + return num_hidden_layers + feature_layer_index + 1 + return feature_layer_index + + +@MULTIMODAL_REGISTRY.register_processor( + AyaVisionMultiModalProcessor, + info=AyaVisionProcessingInfo, + dummy_inputs=AyaVisionDummyInputsBuilder) +class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config: AyaVisionConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + num_hidden_layers = _get_num_hidden_layers(config) + self.config = config + self.quant_config = quant_config + self.multimodal_config = multimodal_config + + self.vision_tower = SiglipVisionModel( + config.vision_config, + quant_config, + num_hidden_layers_override=num_hidden_layers, + prefix=maybe_prefix(prefix, "vision_model")) + self.vocab_size = config.text_config.vocab_size + self.multi_modal_projector = AyaVisionMultiModalProjector(config) + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "model"), + # Cohere2ForCausalLM and CohereForCausalLM are the same on vllm + architectures=["Cohere2ForCausalLM"]) + + @property + def dtype(self): + return next(self.parameters()).dtype + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) + + def _image_pixels_to_features(self, vision_tower: SiglipVisionModel, + pixel_values: torch.Tensor, + **kwargs) -> torch.Tensor: + target_dtype = vision_tower.get_input_embeddings().weight.dtype + image_features = vision_tower(pixel_values.to(dtype=target_dtype), + **kwargs) + + def select_features(leaf: torch.Tensor): + return self._select_image_features( + leaf, + strategy=self.config.vision_feature_select_strategy, + ) + + return cast( + Union[torch.Tensor, tuple[torch.Tensor, ...]], + json_map_leaves(select_features, image_features), + ) + + def _select_image_features(self, image_features: torch.Tensor, *, + strategy: str) -> torch.Tensor: + if strategy == "default": + return image_features[:, 1:] + elif strategy == "full": + return image_features + + raise ValueError(f"Unexpected select feature strategy: {strategy}") + + def _process_image_input(self, image_input: AyaVisionImagePixelInputs, + **kwargs) -> list[torch.Tensor]: + assert self.vision_tower is not None + pixel_values = image_input["pixel_values"] + num_patches = image_input["num_patches"] + image_features = self._image_pixels_to_features( + self.vision_tower, pixel_values=pixel_values) + image_embeds = self.multi_modal_projector(image_features) + return [ + e.flatten(0, 2) for e in image_embeds.split(num_patches.tolist()) + ] + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + if d.shape != expected_dims: + raise ValueError( + "The expected shape of pixel values per image per batch " + f"is {expected_dims}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[AyaVisionImagePixelInputs]: + pixel_values = kwargs.pop("pixel_values", None) + num_patches = kwargs.pop("num_patches", None) + embed_is_patch = kwargs.pop("embed_is_patch", None) + image_embeds = kwargs.pop("image_embeds", None) + assert image_embeds is None, "Aya Vision does not support image_embeds." + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + if num_patches is not None and not isinstance(num_patches, + (torch.Tensor, list)): + raise ValueError("Incorrect type of num_patches. " + f"Got type: {type(num_patches)}") + + if not isinstance(embed_is_patch, (torch.Tensor, list)): + raise ValueError("Incorrect type of embed_is_patch. " + f"Got type: {type(embed_is_patch)}") + + pixel_values = flatten_bn(pixel_values, concat=True) + num_patches = flatten_bn(num_patches, concat=True) + embed_is_patch = flatten_bn(embed_is_patch) + return AyaVisionImagePixelInputs( + type="pixel_values", + pixel_values=self._validate_pixel_values(pixel_values), + num_patches=num_patches, + embed_is_patch=embed_is_patch, + ) + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + image_features = self._process_image_input(image_input, **kwargs) + return scatter_patch_features( + image_features, + image_input["embed_is_patch"], + ) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + multimodal_embeddings=select_patch_features( + multimodal_embeddings), + placeholder_token_id=self.config.image_token_index) + + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.language_model.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return get_sampler() + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py new file mode 100644 index 0000000..6a3112b --- /dev/null +++ b/vllm/model_executor/models/baichuan.py @@ -0,0 +1,479 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only BaiChuan model compatible with HuggingFace weights.""" +import math +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: + closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) + base = torch.tensor( + 2**(-(2**-(math.log2(closest_power_of_2) - 3))), + dtype=torch.float32, + ) + powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != total_num_heads: + extra_base = torch.tensor( + 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + dtype=torch.float32, + ) + num_remaining_heads = min(closest_power_of_2, + total_num_heads - closest_power_of_2) + extra_powers = torch.arange(start=1, + end=1 + 2 * num_remaining_heads, + step=2, + dtype=torch.int32) + slopes = torch.cat( + [slopes, torch.pow(extra_base, extra_powers)], dim=0) + return slopes + + +class BaiChuanMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class BaiChuanAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + hidden_size: int, + num_heads: int, + position_embedding: str, + rope_theta: float = 10000, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = hidden_size + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( + ) + self.total_num_heads = num_heads + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = (self.total_num_heads // + tensor_model_parallel_world_size) + self.head_dim = hidden_size // self.total_num_heads + self.postion_embedding = position_embedding + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + # pylint: disable=invalid-name + self.W_pack = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + # Create the alibi slopes and slice them. + if self.postion_embedding == "ALIBI": + tp_rank = get_tensor_model_parallel_rank() + head_start = tp_rank * self.num_heads + head_end = (tp_rank + 1) * self.num_heads + alibi_slopes = _get_alibi_slopes(self.total_num_heads) + alibi_slopes = alibi_slopes[head_start:head_end].tolist() + + scaling = self.head_dim**-0.5 + self.attn = Attention(self.num_heads, + self.head_dim, + scaling, + alibi_slopes=alibi_slopes, + quant_config=quant_config, + prefix=f"{prefix}.attn") + else: + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + ) + self.scaling = self.head_dim**-0.5 + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.W_pack(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + if self.postion_embedding != "ALIBI": + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class BaiChuanDecoderLayer(nn.Module): + + def __init__(self, + config: PretrainedConfig, + position_embedding: str, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.self_attn = BaiChuanAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + position_embedding=position_embedding, + rope_theta=rope_theta, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = BaiChuanMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class BaiChuanModel(nn.Module): + + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + position_embedding: str = "ROPE", + ) -> None: + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: BaiChuanDecoderLayer(config, + position_embedding, + cache_config, + quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual, + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, + SupportsQuant): + packed_modules_mapping = { + "W_pack": ["W_pack"], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + position_embedding: str = "ROPE", + ): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = BaiChuanModel(vllm_config=vllm_config, + prefix=prefix, + position_embedding=position_embedding) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + self.lm_head.weight.weight_loader = self.lm_head_weight_loader + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) + + def lm_head_weight_loader(self, param: nn.Parameter, + loaded_weight: torch.Tensor): + # Unlike Baichuan, Baichuan2 normalizes the head weights. + # Refer to: + # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508 + # Distinguish between Baichuan and Baichuan2 by checking the + # vocab size. This is suggested by + # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704 + is_baichuan2 = self.config.vocab_size == 125696 + if is_baichuan2: + loaded_weight = torch.nn.functional.normalize(loaded_weight) + + default_weight_loader(param, loaded_weight) + + +class BaichuanForCausalLM(BaiChuanBaseForCausalLM): + """Baichuan 13B and Baichuan2 7B/13B. + NOTE: the class name has a lower case 'c'. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + if config.hidden_size == 4096: # baichuan2 7b + super().__init__(vllm_config=vllm_config, + prefix=prefix, + position_embedding="ROPE") + else: # baichuan 13b, baichuan2 13b + super().__init__(vllm_config=vllm_config, + prefix=prefix, + position_embedding="ALIBI") + + +class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): + """Baichuan 7B. + NOTE: the class name has an upper case 'C'. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, + prefix=prefix, + position_embedding="ROPE") diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py new file mode 100644 index 0000000..e5896f5 --- /dev/null +++ b/vllm/model_executor/models/bamba.py @@ -0,0 +1,558 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Inference-only Bamba model.""" +# Added by the IBM Team, 2024 +from typing import Iterable, Optional, Set, Tuple + +import torch +from torch import nn +from transformers import BambaConfig + +from vllm.attention.layer import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_pp_group +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba_mixer2 import ( + MambaMixer2, extra_groups_for_head_shards) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.mamba_cache import (MambaCacheManager, + MambaCacheParams) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.utils import LayerBlockType + +from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, + SupportsQuant, SupportsV0Only) +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class BambaMLP(nn.Module): + + def __init__( + self, + config: BambaConfig, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=config.hidden_size, + output_sizes=[config.intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + ) + self.down_proj = RowParallelLinear( + input_size=config.intermediate_size, + output_size=config.hidden_size, + bias=bias, + quant_config=quant_config, + ) + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +class BambaMixerDecoderLayer(nn.Module): + + def __init__(self, + config: BambaConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.config = config + self.mamba = MambaMixer2(hidden_size= config.hidden_size, + ssm_state_size = config.mamba_d_state, + conv_kernel_size = config.mamba_d_conv, + intermediate_size = config.mamba_expand *\ + config.hidden_size, + use_conv_bias = config.mamba_conv_bias, + use_bias = config.mamba_proj_bias, + n_groups=config.mamba_n_groups, + num_heads=config.mamba_n_heads, + head_dim=config.mamba_d_head, + rms_norm_eps=config.rms_norm_eps, + activation=config.hidden_act, + chunk_size=config.mamba_chunk_size, + quant_config=quant_config) + + self.feed_forward = BambaMLP(config, quant_config=quant_config) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + mamba_cache_params: MambaCacheParams, + sequence_idx: Optional[torch.Tensor] = None, + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.mamba(hidden_states, mamba_cache_params, + sequence_idx) + # Fully Connected + hidden_states, residual = self.pre_ff_layernorm( + hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +class BambaAttentionDecoderLayer(nn.Module): + + def __init__( + self, + config: BambaConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = config.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + if hasattr(config, "partial_rotary_factor"): + rotary_dim = self.head_dim * config.partial_rotary_factor + elif hasattr(config, "attn_rotary_emb"): + rotary_dim = config.attn_rotary_emb # for backward compatibility + else: + rotary_dim = self.head_dim # default + + self.rotary_emb = get_rope( + head_size=self.head_dim, + rotary_dim=rotary_dim, + max_position=max_position_embeddings, + rope_scaling=rope_scaling, + base=rope_theta, + is_neox_style=True, + dtype=torch.get_default_dtype(), # see impl of get_rope + ) + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + config.hidden_size, + bias=False, + quant_config=quant_config) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + prefix=f"{prefix}.attn", + ) + + self.feed_forward = BambaMLP(config, quant_config=quant_config) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def self_attention( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.self_attention( + positions=positions, + hidden_states=hidden_states, + ) + # Fully Connected + hidden_states, residual = self.pre_ff_layernorm( + hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +ALL_DECODER_LAYER_TYPES = { + "attention": BambaAttentionDecoderLayer, + "mamba": BambaMixerDecoderLayer +} + + +class BambaModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + def get_layer(prefix: str): + layer_idx = int(prefix.rsplit(".", 1)[1]) + layer_class = ALL_DECODER_LAYER_TYPES[ + config.layers_block_type[layer_idx]] + return layer_class( + config, + layer_idx, + cache_config, + quant_config=quant_config, + prefix=prefix, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + self.final_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + mamba_cache_params: MambaCacheParams, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + # pass a sequence index tensor, that is required for + # proper continuous batching computation including + # chunked prefill + seq_idx = None + attn_metadata = get_forward_context().attn_metadata + if attn_metadata.num_prefills > 0: + seq_idx = torch.zeros_like(input_ids, dtype=torch.int32) + for i, (srt, end) in enumerate( + zip( + attn_metadata.query_start_loc, + attn_metadata.query_start_loc[1:], + )): + seq_idx[srt:end] = i + seq_idx.unsqueeze_(0) + + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + residual = None + num_attn = 0 + for i in range(len(self.layers)): + layer = self.layers[i] + if isinstance(layer, BambaAttentionDecoderLayer): + num_attn += 1 + + layer_mamba_cache_params = None + if isinstance(layer, BambaMixerDecoderLayer): + layer_mamba_cache_params = mamba_cache_params.at_layer_idx( + i - num_attn) + + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + mamba_cache_params=layer_mamba_cache_params, + sequence_idx=seq_idx, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.final_layernorm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if "A_log" in name: + name = name.replace("A_log", "A") + + if ".self_attn." in name: + name = name.replace(".self_attn", "") + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, + IsHybrid, SupportsV0Only, SupportsQuant): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": ["up_proj", "down_proj"] + } + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config + assert not cache_config.enable_prefix_caching, \ + "Bamba currently does not support prefix caching" + + self.quant_config = vllm_config.quant_config + + super().__init__() + self.config = config + self.scheduler_config = scheduler_config + self.model = BambaModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + # Used to track and store by the Mamba cache between steps. + self.mamba_cache: Optional[MambaCacheManager] = None + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = get_sampler() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs): + if self.mamba_cache is None: + + num_mamba_layers = self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, LayerBlockType.mamba) + + self.mamba_cache = MambaCacheManager( + self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, + *self._get_mamba_cache_shape()) + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + hidden_states = self.model(input_ids, positions, mamba_cache_params, + intermediate_tensors, inputs_embeds) + + return hidden_states + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def _get_mamba_cache_shape( + self) -> Tuple[Tuple[int, int], Tuple[int, int]]: + world_size = get_tensor_model_parallel_world_size() + hidden_size = self.config.hidden_size + + conv_state_shape, temporal_state_shape = None, None + + intermediate_size = self.config.mamba_expand * hidden_size + + # if n_groups is not divisible by world_size, need to extend the shards + # to ensure all groups needed by a head is sharded along with it + n_groups = (self.config.mamba_n_groups + extra_groups_for_head_shards( + self.config.mamba_n_groups, world_size)) + + # - heads and n_groups are TP-ed + conv_dim = (intermediate_size + + 2 * n_groups * self.config.mamba_d_state) + conv_state_shape = ( + divide(conv_dim, world_size), + self.config.mamba_d_conv - 1, + ) + + # These are not TP-ed as they depend on A, dt_bias, D + # - they are typically small + # e.g., (h_heads, d_head, d_state) = (128, 64, 128) + temporal_state_shape = ( + divide(self.config.mamba_n_heads, world_size), + self.config.mamba_d_head, + self.config.mamba_d_state, + ) + return conv_state_shape, temporal_state_shape + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) \ No newline at end of file diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py new file mode 100644 index 0000000..04d6cde --- /dev/null +++ b/vllm/model_executor/models/bart.py @@ -0,0 +1,946 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Derived from BART implementation posted on HuggingFace; license below: +# +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BART model.""" +import math +from typing import Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import BartConfig +from transformers.utils import logging + +from vllm.attention import Attention, AttentionType +from vllm.config import CacheConfig, LoRAConfig, VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVCrossParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsQuant, SupportsV0Only +from .utils import maybe_prefix + +logger = logging.get_logger(__name__) + + +def get_bsz_seq_len(input_ids): + shp = input_ids.shape + ndim = len(shp) + if ndim == 1: + return 1, input_ids.numel() + else: + return shp[:2] + + +class BartLearnedPositionalEmbedding(VocabParallelEmbedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # Bart is set up so that if padding_idx is + # specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. + # Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward( + self, + positions: torch.Tensor, + ) -> torch.Tensor: + """`input_ids' shape is expected to be [bsz x seqlen].""" + return super().forward(positions + self.offset) + + +class BartScaledWordEmbedding(VocabParallelEmbedding): + """ + This module overrides VocabParallelEmbedding's + forward by multiplying with embeddings scale. + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + embed_scale: float = 1.0): + super().__init__(num_embeddings, embedding_dim) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + return super().forward(input_ids) * self.embed_scale + + +class BartParallelLMHead(ParallelLMHead): + """ + This module overrides ParallelLMHead's + forward by dividing by embeddings scale, + yielding effectively the inverse of + BartScaledWordEmbedding + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + embed_scale: float = 1.0): + super().__init__(num_embeddings, embedding_dim) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + return super().forward(input_ids) / self.embed_scale + + +class BartEncoderAttention(nn.Module): + + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + config: Optional[BartConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.d_model = config.d_model + self.embed_dim = embed_dim + self.total_num_heads = num_heads + self.total_num_kv_heads = self.total_num_heads + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError(f"embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads}).") + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + self.d_model, + self.d_model // self.total_num_heads, + self.total_num_heads, + self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + ) + + self.out_proj = RowParallelLinear( + embed_dim, + embed_dim, + bias=bias, + quant_config=quant_config, + ) + + tp_world_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_world_size == 0 + self.num_heads = self.total_num_heads // tp_world_size + + if self.total_num_kv_heads >= tp_world_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_world_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_world_size % self.total_num_kv_heads == 0 + self.num_kv_heads = self.num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=AttentionType.ENCODER) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Input shape: Batch x Time x Channel""" + + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + attn_output = self.attn(q, k, v) + + output, _ = self.out_proj(attn_output) + return output + + +class BartDecoderSelfAttention(nn.Module): + + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + config: Optional[BartConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.d_model = config.d_model + self.embed_dim = embed_dim + self.total_num_heads = num_heads + self.total_num_kv_heads = self.total_num_heads + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError(f"embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads}).") + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + self.d_model, + self.d_model // self.total_num_heads, + self.total_num_heads, + self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + ) + + self.out_proj = RowParallelLinear( + embed_dim, + embed_dim, + bias=bias, + quant_config=quant_config, + ) + + tp_world_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_world_size == 0 + self.num_heads = self.total_num_heads // tp_world_size + + if self.total_num_kv_heads >= tp_world_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_world_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_world_size % self.total_num_kv_heads == 0 + self.num_kv_heads = self.num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=AttentionType.DECODER) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Input shape: Batch x Time x Channel""" + + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + attn_output = self.attn(q, k, v) + + output, _ = self.out_proj(attn_output) + return output + + +class BartCrossAttention(nn.Module): + + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + config: Optional[BartConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.d_model = config.d_model + self.embed_dim = embed_dim + self.total_num_heads = num_heads + self.total_num_kv_heads = self.total_num_heads + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError(f"embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads}).") + self.scaling = self.head_dim**-0.5 + + # TP sharding sizes is accounted for within "*Parallel" layers. + self.qkv_proj = QKVCrossParallelLinear(self.d_model, + self.d_model // + self.total_num_heads, + self.total_num_heads, + self.total_num_kv_heads, + bias, + quant_config=quant_config) + + self.out_proj = RowParallelLinear( + embed_dim, + embed_dim, + bias=bias, + quant_config=quant_config, + ) + + tp_world_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_world_size == 0 + self.num_heads = self.total_num_heads // tp_world_size + + if self.total_num_kv_heads >= tp_world_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_world_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_world_size % self.total_num_kv_heads == 0 + self.num_kv_heads = self.num_heads # No GQA in bart + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=AttentionType.ENCODER_DECODER) + + def forward( + self, + decoder_hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Input shape: Batch x Time x Channel""" + + q, k, v = self.qkv_proj(decoder_hidden_states, encoder_hidden_states) + + attn_output = self.attn(q, k, v) + + output, _ = self.out_proj(attn_output) + return output + + +class BartEncoderLayer(nn.Module): + + def __init__( + self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = BartEncoderAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.activation_fn = get_act_fn(config.activation_function) + + ffn_hidden_size = self.embed_dim + ffn_intermediate_size = config.encoder_ffn_dim + ffn_has_bias = True + self.fc1 = ColumnParallelLinear( + ffn_hidden_size, + ffn_intermediate_size, + bias=ffn_has_bias, + quant_config=quant_config, + ) + self.act = get_act_fn("gelu") + self.fc2 = RowParallelLinear( + ffn_intermediate_size, + ffn_hidden_size, + bias=ffn_has_bias, + quant_config=quant_config, + ) + + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + r""" + Args: + hidden_states + torch.Tensor of *encoder* input embeddings. + Returns: + Encoder layer output torch.Tensor + """ + residual = hidden_states + hidden_states = self.self_attn(hidden_states=hidden_states) + + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + fc1_out, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(fc1_out) + + hidden_states, _ = self.fc2(hidden_states) + + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() + or torch.isnan(hidden_states).any()): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, + min=-clamp_value, + max=clamp_value) + + return hidden_states + + +class BartDecoderLayer(nn.Module): + + def __init__( + self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = BartDecoderSelfAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.activation_fn = get_act_fn(config.activation_function) + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + ''' + afeldman-nm: personally I would call this "cross-attention", + however I left the name as "encoder_attn" to maintain consistency + with the name of the pretrained weights. + ''' + self.encoder_attn = BartCrossAttention( + self.embed_dim, + config.decoder_attention_heads, + config=config, + prefix=f"{prefix}.encoder_attn", + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + + ffn_hidden_size = self.embed_dim + ffn_intermediate_size = config.encoder_ffn_dim + ffn_has_bias = True + self.fc1 = ColumnParallelLinear( + ffn_hidden_size, + ffn_intermediate_size, + bias=ffn_has_bias, + quant_config=quant_config, + ) + self.fc2 = RowParallelLinear( + ffn_intermediate_size, + ffn_hidden_size, + bias=ffn_has_bias, + quant_config=quant_config, + ) + + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + decoder_hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r""" + Args: + decoder_hidden_states + torch.Tensor of *decoder* input embeddings. + encoder_hidden_states + torch.Tensor of *encoder* input embeddings. + Returns: + Decoder layer output torch.Tensor + """ + residual = decoder_hidden_states + + # Self Attention + hidden_states = self.self_attn(hidden_states=decoder_hidden_states) + + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + + residual = hidden_states + + hidden_states = self.encoder_attn( + decoder_hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + ) + + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # Fully Connected + residual = hidden_states + fc1_out, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(fc1_out) + + hidden_states, _ = self.fc2(hidden_states) + + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + return hidden_states + + +class BartEncoder(nn.Module): + """ + Transformer encoder consisting of *config.encoder_layers* + self attention layers. Each layer is a [`BartEncoderLayer`]. + Args: + config: BartConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + embed_tokens: Optional[nn.Embedding] = None, + prefix: str = ""): + super().__init__() + + self.cache_config = cache_config + self.quant_config = quant_config + self.lora_config = lora_config + embed_dim = config.d_model + self.max_source_positions = config.max_position_embeddings + embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.embed_tokens = BartScaledWordEmbedding(config.vocab_size, + embed_dim, + embed_scale=embed_scale) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = BartLearnedPositionalEmbedding( + config.max_position_embeddings, + embed_dim, + ) + self.layers = nn.ModuleList([ + BartEncoderLayer(config, + cache_config, + quant_config, + prefix=f"{prefix}.layers.{layer_idx}") + for layer_idx in range(config.encoder_layers) + ]) + + self.layernorm_embedding = nn.LayerNorm(embed_dim) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r""" + Args: + input_ids + Indices of *encoder* input sequence tokens in the vocabulary. + Padding will be ignored by default should you + provide it. + positions + Positions of *encoder* input sequence tokens. + Returns: + Decoder output torch.Tensor + """ + # retrieve input_ids and inputs_embeds + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + embed_pos = self.embed_positions(positions) + embed_pos = embed_pos.to(inputs_embeds.device) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + + for encoder_layer in self.layers: + hidden_states = encoder_layer(hidden_states=hidden_states) + + return hidden_states + + +class BartDecoder(nn.Module): + """ + Transformer decoder consisting of *config.decoder_layers* layers. + Each layer is a [`BartDecoderLayer`] + Args: + config: BartConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__( + self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + embed_tokens: Optional[nn.Embedding] = None, + prefix: str = "", + ): + super().__init__() + self.cache_config = cache_config + self.quant_config = quant_config + self.lora_config = lora_config + self.max_target_positions = config.max_position_embeddings + embed_scale = math.sqrt( + config.d_model) if config.scale_embedding else 1.0 + + self.embed_tokens = BartScaledWordEmbedding(config.vocab_size, + config.d_model, + embed_scale=embed_scale) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = BartLearnedPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + ) + + self.layers = nn.ModuleList( + [BartDecoderLayer(config,cache_config,quant_config, + prefix=f"{prefix}.layers.{layer_idx}") \ + for layer_idx in range(config.decoder_layers)]) + + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + def forward( + self, + decoder_input_ids: torch.Tensor, + decoder_positions: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r""" + Args: + decoder_input_ids + Indices of *decoder* input sequence tokens in the vocabulary. + Padding will be ignored by default should you + provide it. + decoder_positions + Positions of *decoder* input sequence tokens. + encoder_hidden_states: + Tensor of encoder output embeddings + Returns: + Decoder output torch.Tensor + """ + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(decoder_input_ids) + else: + decoder_positions = inputs_embeds[:, -1] + + # embed positions + embed_pos = self.embed_positions(decoder_positions) + embed_pos = embed_pos.to(inputs_embeds.device) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + + # decoder layers + + for decoder_layer in self.layers: + hidden_states = decoder_layer( + decoder_hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + ) + + return hidden_states + + +class BartModel(nn.Module, SupportsQuant): + _tied_weights_keys = [ + "encoder.embed_tokens.weight", "decoder.embed_tokens.weight" + ] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.encoder = BartEncoder(config, + cache_config, + quant_config=quant_config, + prefix=f"{prefix}.encoder") + self.decoder = BartDecoder(config, + cache_config, + quant_config=quant_config, + prefix=f"{prefix}.decoder") + + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor) -> torch.Tensor: + r""" + Args: + input_ids + Indices of *decoder* input sequence tokens in the vocabulary. + Padding will be ignored by default should you + provide it. + positions + Positions of *decoder* input sequence tokens. + encoder_input_ids + Indices of *encoder* input sequence tokens in the vocabulary. + encoder_positions: + Positions of *encoder* input sequence tokens. + Returns: + Model output torch.Tensor + """ + + encoder_hidden_states = None + + if encoder_input_ids.numel() > 0: + # Run encoder attention if a non-zero number of encoder tokens + # are provided as input + encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, + positions=encoder_positions) + + # decoder outputs consists of + # (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + decoder_input_ids=input_ids, + decoder_positions=positions, + encoder_hidden_states=encoder_hidden_states) + + return decoder_outputs + + +class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant): + packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} + base_model_prefix = "model" + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + + super().__init__() + config = vllm_config.model_config.hf_config + lora_config = vllm_config.lora_config + # currently all existing BART models have `tie_word_embeddings` enabled + assert config.tie_word_embeddings + self.config = config + self.model = BartModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + embed_scale = math.sqrt( + config.d_model) if config.scale_embedding else 1.0 + + self.lm_head = BartParallelLMHead(config.vocab_size, + config.d_model, + embed_scale=embed_scale) + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = get_sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + *, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + r""" + Args: + input_ids + torch.Tensor of *decoder* input token ids. + positions + torch.Tensor of *decoder* position indices. + encoder_input_ids + torch.Tensor of *encoder* input token ids. + encoder_positions + torch.Tensor of *encoder* position indices + Returns: + Output torch.Tensor + """ + return self.model(input_ids, positions, encoder_input_ids, + encoder_positions) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + stacked_params_mapping = { + "q_proj": { + "param_name": "qkv_proj", + "shard_id": "q", + }, + "k_proj": { + "param_name": "qkv_proj", + "shard_id": "k", + }, + "v_proj": { + "param_name": "qkv_proj", + "shard_id": "v", + }, + } + + params_mapping = { + "beta": "bias", + "gamma": "weight", + "LayerNorm": "layernorm", + } + + def _rename_key(self, key: str): + prefix = f"{self.base_model_prefix}." + key = key[len(prefix):] if key.startswith(prefix) else key + + for src, dst in self.params_mapping.items(): + key = key.replace(src, dst) + + return key + + def _rename_stacked_param( + self, + name: str, + ) -> Tuple[str, Optional[str]]: + for key, mapping in self.stacked_params_mapping.items(): + if key in name: + name = name.replace(key, mapping["param_name"]) + return name, mapping["shard_id"] + return name, None + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + model_params_dict = dict(self.model.named_parameters()) + top_params_dict = dict(self.named_parameters()) + + weights_tuple_list = list(weights) + + shared_embedding_weight = None + shared_embedding_shard_id = None + + for name, loaded_weight in weights_tuple_list: + + name = self._rename_key(name) + name, shard_id = self._rename_stacked_param(name) + + if ('shared.weight' in name + or 'encoder.embed_tokens.weight' in name + or 'decoder.embed_tokens.weight' in name + or 'lm_head.weight' in name): + assert shared_embedding_weight is None, ( + "Conflicting embedding weights.") + shared_embedding_weight = loaded_weight + shared_embedding_shard_id = shard_id + else: + # Skip the specific downstream task weight. + if name.startswith('cls.'): + continue + # use Pooler instead. + if name.startswith('pooler.'): + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in model_params_dict: + continue + + param = model_params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if shard_id: + weight_loader(param, loaded_weight, shard_id) + else: + weight_loader(param, loaded_weight) + + # Assign shared weight values + encoder_in_param = model_params_dict['encoder.embed_tokens.weight'] + encoder_in_weight_loader = getattr(encoder_in_param, "weight_loader", + default_weight_loader) + + decoder_in_param = model_params_dict['decoder.embed_tokens.weight'] + decoder_in_weight_loader = getattr(decoder_in_param, "weight_loader", + default_weight_loader) + + lm_head_in_param = top_params_dict['lm_head.weight'] + lm_head_in_weight_loader = getattr(lm_head_in_param, "weight_loader", + default_weight_loader) + + assert shared_embedding_weight is not None + + if shared_embedding_shard_id: + encoder_in_weight_loader(encoder_in_param, shared_embedding_weight, + shared_embedding_shard_id) + decoder_in_weight_loader(decoder_in_param, shared_embedding_weight, + shared_embedding_shard_id) + lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight, + shared_embedding_shard_id) + else: + encoder_in_weight_loader(encoder_in_param, shared_embedding_weight) + decoder_in_weight_loader(decoder_in_param, shared_embedding_weight) + lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py new file mode 100644 index 0000000..111b49a --- /dev/null +++ b/vllm/model_executor/models/bert.py @@ -0,0 +1,516 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Iterable, Optional, Set, Tuple + +import torch +from torch import nn +from transformers import BertConfig + +from vllm.attention import Attention, AttentionType +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, PoolerConfig, VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler, + PoolingType) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.transformers_utils.config import ( + get_cross_encoder_activation_function) + +from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only +from .utils import WeightsMapper, maybe_prefix + + +class BertEmbedding(nn.Module): + + def __init__(self, config: BertConfig): + + super().__init__() + self.size = config.hidden_size + self.word_embeddings = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.position_embeddings = VocabParallelEmbedding( + config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = VocabParallelEmbedding( + config.type_vocab_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.position_ids = nn.Parameter( + torch.empty((1, config.max_position_embeddings)), ) + + self.position_embedding_type = config.position_embedding_type + if self.position_embedding_type != "absolute": + raise ValueError("Only 'absolute' position_embedding_type" + + " is supported") + + def forward( + self, + input_ids: torch.Tensor, + seq_lens: torch.Tensor, + position_ids: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + input_shape = input_ids.size() + + # Input embeddings. + inputs_embeds = self.word_embeddings(input_ids) + + # Position embeddings. + position_embeddings = self.position_embeddings(position_ids) + + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, + dtype=torch.long, + device=inputs_embeds.device) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + position_embeddings + embeddings = self.LayerNorm(embeddings) + return embeddings + + +class BertPooler(nn.Module): + + def __init__(self, config: BertConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[0, :] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +@support_torch_compile +class BertEncoder(nn.Module): + + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.layer = nn.ModuleList([ + BertLayer(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.layer.{layer_idx}") + for layer_idx in range(config.num_hidden_layers) + ]) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + for layer in self.layer: + hidden_states = layer(hidden_states) + return hidden_states + + +class BertLayer(nn.Module): + + def __init__(self, + config: BertConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + + self.attention = BertAttention( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + layer_norm_eps=config.layer_norm_eps, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attention") + + self.intermediate = BertIntermediate( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.intermediate") + + self.output = BertOutput(hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + layer_norm_eps=config.layer_norm_eps, + quant_config=quant_config, + prefix=f"{prefix}.output") + + def forward(self, hidden_states: torch.Tensor): + attn_output = self.attention(hidden_states) + intermediate_output = self.intermediate(attn_output) + output = self.output(intermediate_output, attn_output) + return output + + +class BertAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + layer_norm_eps: float, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + self.self = BertSelfAttention(hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.output") + + self.output = BertSelfOutput(hidden_size=hidden_size, + layer_norm_eps=layer_norm_eps, + quant_config=quant_config, + prefix=f"{prefix}.output") + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + self_output = self.self(hidden_states) + return self.output(self_output, hidden_states) + + +class BertSelfAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + + self.total_num_heads = num_attention_heads + assert self.total_num_heads % tp_size == 0 + + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = self.total_num_heads + self.head_dim = self.hidden_size // self.total_num_heads + assert self.head_dim * self.total_num_heads == self.hidden_size + + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.qkv_proj = QKVParallelLinear( + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj") + + self.attn = Attention(num_heads=self.num_heads, + head_size=self.head_dim, + scale=self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=AttentionType.ENCODER_ONLY) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + output = self.attn(q, k, v) + return output + + +class BertSelfOutput(nn.Module): + + def __init__(self, + hidden_size: int, + layer_norm_eps: float, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.dense = RowParallelLinear(input_size=hidden_size, + output_size=hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.dense") + self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor, + input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertIntermediate(nn.Module): + + def __init__(self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.dense = ColumnParallelLinear(input_size=hidden_size, + output_size=intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.dense") + self.intermediate_act_fn = get_act_fn(hidden_act) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + + def __init__(self, + hidden_size: int, + intermediate_size: int, + layer_norm_eps: float, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + + self.dense = RowParallelLinear(input_size=intermediate_size, + output_size=hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.dense") + + self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor, + input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertModel(nn.Module, SupportsQuant): + packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]} + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + embedding_class: type = BertEmbedding, + add_pooling_layer: bool = False): + super().__init__() + config = vllm_config.model_config.hf_config + self.embeddings = embedding_class(config) + self.encoder = BertEncoder(vllm_config=vllm_config, + prefix=f"{prefix}.encoder") + self.pooler = BertPooler(config) if add_pooling_layer else None + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + attn_metadata = get_forward_context().attn_metadata + assert hasattr(attn_metadata, "seq_lens_tensor") + hidden_states = self.embeddings( + input_ids=input_ids, + seq_lens=attn_metadata.seq_lens_tensor, + position_ids=position_ids, + token_type_ids=token_type_ids) + return self.encoder(hidden_states) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "query", "q"), + ("qkv_proj", "key", "k"), + ("qkv_proj", "value", "v"), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if self.pooler is None and "pooler" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant): + """A model that uses Bert to provide embedding functionalities. + + This class encapsulates the BertModel and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + model: An instance of BertModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + pooler_config = vllm_config.model_config.pooler_config + self.model = self._build_model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self._pooler = self._build_pooler(pooler_config) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.model(input_ids=input_ids, + position_ids=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + weights = self.hf_to_vllm_mapper.apply(weights) + weights = ((name, data) for name, data in weights + if not name.startswith("lm_head.")) + self.model.load_weights(weights) + + def _build_model(self, + vllm_config: VllmConfig, + prefix: str = "") -> BertModel: + return BertModel(vllm_config=vllm_config, + prefix=prefix, + embedding_class=BertEmbedding) + + def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: + return Pooler.from_config_with_defaults(pooler_config, + pooling_type=PoolingType.CLS, + normalize=True, + softmax=False) + + +class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, + SupportsQuant): + """A model that uses Bert to provide embedding functionalities. + + This class encapsulates the BertModel and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + model: An instance of BertModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + + self.default_activation_function = \ + get_cross_encoder_activation_function(config) + + self.num_labels = config.num_labels + self.bert = BertModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "bert"), + embedding_class=BertEmbedding, + add_pooling_layer=True) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + self._pooler = CrossEncodingPooler(config, self.classifier, + self.bert.pooler) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + self_weights = [] + + def weight_filter(): + for name, weight in weights: + if name.startswith("bert."): + yield (name[len("bert."):], weight) + else: + self_weights.append((name, weight)) + + self.bert.load_weights(weight_filter()) + + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in self_weights: + if name.startswith("classifier"): + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.bert(input_ids=input_ids, + position_ids=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors, + token_type_ids=token_type_ids) diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py new file mode 100644 index 0000000..f3d4889 --- /dev/null +++ b/vllm/model_executor/models/blip.py @@ -0,0 +1,337 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Minimal implementation of BlipVisionModel intended to be only used +within a vision language model.""" +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +import torch.nn as nn +from transformers import Blip2VisionConfig, BlipVisionConfig + +from vllm.attention.layer import MultiHeadAttention +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from .interfaces import SupportsQuant + + +def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int: + assert image_size % patch_size == 0 + return image_size // patch_size + + +def get_blip_num_patches(*, image_size: int, patch_size: int) -> int: + grid_length = get_blip_patch_grid_length(image_size=image_size, + patch_size=patch_size) + return grid_length * grid_length + + +# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa +class BlipVisionEmbeddings(nn.Module): + + def __init__(self, config: Union[BlipVisionConfig, Blip2VisionConfig]): + super().__init__() + + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=3, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + ) + + self.num_patches = get_blip_num_patches(image_size=self.image_size, + patch_size=self.patch_size) + self.num_positions = self.num_patches + 1 + + self.position_embedding = nn.Parameter( + torch.randn(1, self.num_positions, self.embed_dim)) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to( + dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + + position_embeds = self.position_embedding.to(target_dtype) + embeddings = embeddings + position_embeds[:, :embeddings.size(1), :] + + return embeddings + + +class BlipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: Union[BlipVisionConfig, Blip2VisionConfig], + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + "embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.qkv = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + bias=config.qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + ) + self.projection = RowParallelLinear( + self.embed_dim, + self.embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.projection", + ) + + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + + self.attn = MultiHeadAttention(self.num_heads_per_partition, + self.head_dim, self.scale) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, + self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + ): + """Input shape: Batch x Time x Channel""" + + qkv_states, _ = self.qkv(hidden_states) + query_states, key_states, value_states = qkv_states.chunk(3, dim=-1) + out = self.attn(query_states, key_states, value_states) + attn_output, _ = self.projection(out) + + return attn_output, None + + +class BlipMLP(nn.Module): + + def __init__( + self, + config: BlipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + + self.activation_fn = get_act_fn(config.hidden_act) + self.fc1 = ColumnParallelLinear(config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc1") + self.fc2 = RowParallelLinear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc2") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + + return hidden_states + + +class BlipEncoderLayer(nn.Module): + + def __init__( + self, + config: BlipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + # fallback to sdpa attention if tp unavailable + self.self_attn = BlipAttention( + config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.layer_norm1 = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.mlp = BlipMLP(config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + self.layer_norm2 = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, _ = self.self_attn(hidden_states=hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class BlipEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self + attention layers. Each layer is a [`BlipEncoderLayer`]. + + Args: + config: BlipConfig + """ + + def __init__( + self, + config: BlipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + num_hidden_layers_override: Optional[int] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + + if num_hidden_layers_override is None: + num_hidden_layers = config.num_hidden_layers + else: + num_hidden_layers = num_hidden_layers_override + + self.layers = nn.ModuleList([ + BlipEncoderLayer(config=config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}") + for layer_idx in range(num_hidden_layers) + ]) + + def forward(self, inputs_embeds: torch.Tensor): + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer(hidden_states) + + return hidden_states + + +class BlipVisionModel(nn.Module, SupportsQuant): + config_class = BlipVisionConfig + main_input_name = "pixel_values" + packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} + + def __init__( + self, + config: BlipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + num_hidden_layers_override: Optional[int] = None, + require_post_norm: Optional[bool] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + + self.embeddings = BlipVisionEmbeddings(config) + self.encoder = BlipEncoder( + config=config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_override, + prefix=f"{prefix}.encoder", + ) + + num_hidden_layers = config.num_hidden_layers + if len(self.encoder.layers) > config.num_hidden_layers: + raise ValueError( + f"The original encoder only has {num_hidden_layers} " + f"layers, but you requested {len(self.encoder.layers)} layers." + ) + + # If possible, skip post_layernorm to conserve memory + if require_post_norm is None: + require_post_norm = len(self.encoder.layers) == num_hidden_layers + + if require_post_norm: + self.post_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + else: + self.post_layernorm = None + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + hidden_states = self.embeddings(pixel_values) + hidden_states = self.encoder(inputs_embeds=hidden_states) + + if self.post_layernorm is None: + return hidden_states + + return self.post_layernorm(hidden_states) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + layer_count = len(self.encoder.layers) + + for name, loaded_weight in weights: + # post_layernorm is not needed in BlipVisionModel + if (name.startswith("post_layernorm") + and self.post_layernorm is None): + continue + + # omit layers when num_hidden_layers_override is set + if name.startswith("encoder.layers"): + layer_idx = int(name.split(".")[2]) + if layer_idx >= layer_count: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py new file mode 100644 index 0000000..db9d42f --- /dev/null +++ b/vllm/model_executor/models/blip2.py @@ -0,0 +1,725 @@ +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Iterable, Mapping, Sequence +from functools import cached_property +from typing import Literal, Optional, Set, Tuple, TypedDict, Union + +import torch +import torch.nn as nn +from transformers import (BatchFeature, Blip2Config, Blip2QFormerConfig, + apply_chunking_to_forward) + +from vllm.config import CacheConfig, VllmConfig +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptIndexTargets, + PromptInsertion, PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors + +from .blip import BlipVisionModel +from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP, + SupportsQuant) +from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, + maybe_prefix, merge_multimodal_embeddings) + +# We use this internally as placeholders since there is no image token +# defined on the HuggingFace repo +_IMAGE_TOKEN_ID = 50265 + + +class Blip2ImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: `(batch_size * num_images, num_channels, height, width)`""" + + +class Blip2ImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: torch.Tensor + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + """ + + +Blip2ImageInputs = Union[Blip2ImagePixelInputs, Blip2ImageEmbeddingInputs] + + +class Blip2QFormerMultiHeadAttention(nn.Module): + + def __init__( + self, + config: Blip2QFormerConfig, + *, + quant_config: Optional[QuantizationConfig], + cache_config: Optional[CacheConfig], + is_cross_attention: bool = False, + ) -> None: + super().__init__() + + self.config = config + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of " + f"the number of attention heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = (config.hidden_size // + config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.scaling = self.attention_head_size**-0.5 + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + kv_hidden_size = config.encoder_hidden_size + else: + kv_hidden_size = config.hidden_size + self.key = nn.Linear(kv_hidden_size, self.all_head_size) + self.value = nn.Linear(kv_hidden_size, self.all_head_size) + + self.position_embedding_type = getattr(config, + "position_embedding_type", + "absolute") + if self.position_embedding_type != "absolute": + raise NotImplementedError("Unsupported position_embedding_type: " + f"{self.position_embedding_type}") + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + x = x.view(*x.size()[:-1], self.num_attention_heads, + self.attention_head_size) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + ): + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores( + self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores( + self.value(encoder_hidden_states)) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + mixed_query_layer = self.query(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + attention_probs = torch.softmax(attention_scores * self.scaling, + dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + context_layer = context_layer.view(*context_layer.size()[:-2], + self.all_head_size) + + return context_layer + + +class Blip2QFormerSelfOutput(nn.Module): + + def __init__(self, config: Blip2QFormerConfig) -> None: + super().__init__() + + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, + hidden_states: torch.Tensor, + input_tensor: torch.Tensor, + ) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class Blip2QFormerAttention(nn.Module): + + def __init__( + self, + config: Blip2QFormerConfig, + *, + quant_config: Optional[QuantizationConfig], + cache_config: Optional[CacheConfig], + is_cross_attention: bool = False, + ) -> None: + super().__init__() + + self.attention = Blip2QFormerMultiHeadAttention( + config, + quant_config=quant_config, + cache_config=cache_config, + is_cross_attention=is_cross_attention, + ) + + self.output = Blip2QFormerSelfOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + ) -> Tuple[torch.Tensor]: + self_output = self.attention( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + ) + attention_output = self.output(self_output, hidden_states) + + return attention_output + + +class Blip2QFormerIntermediate(nn.Module): + + def __init__(self, config: Blip2QFormerConfig) -> None: + super().__init__() + + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + self.intermediate_act_fn = get_act_fn(config.hidden_act) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class Blip2QFormerOutput(nn.Module): + + def __init__(self, config: Blip2QFormerConfig) -> None: + super().__init__() + + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, + hidden_states: torch.Tensor, + input_tensor: torch.Tensor, + ) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class Blip2QFormerLayer(nn.Module): + + def __init__( + self, + config: Blip2QFormerConfig, + *, + quant_config: Optional[QuantizationConfig], + cache_config: Optional[CacheConfig], + layer_idx: int, + ) -> None: + super().__init__() + + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = Blip2QFormerAttention(config, + quant_config=quant_config, + cache_config=cache_config) + + self.layer_idx = layer_idx + + if layer_idx % config.cross_attention_frequency == 0: + self.crossattention = Blip2QFormerAttention( + config, + quant_config=quant_config, + cache_config=cache_config, + is_cross_attention=True) + self.has_cross_attention = True + else: + self.has_cross_attention = False + + self.intermediate_query = Blip2QFormerIntermediate(config) + self.output_query = Blip2QFormerOutput(config) + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor, + query_length: int, + ): + attention_output = self.attention(hidden_states) + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + query_attention_output = self.crossattention( + query_attention_output, + encoder_hidden_states=encoder_hidden_states, + ) + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk_query, + self.chunk_size_feed_forward, + self.seq_len_dim, + query_attention_output, + ) + + if attention_output.shape[1] > query_length: + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + layer_output = torch.cat([layer_output, layer_output_text], + dim=1) + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + + return layer_output + + def feed_forward_chunk(self, + attention_output: torch.Tensor) -> torch.Tensor: + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_chunk_query( + self, attention_output: torch.Tensor) -> torch.Tensor: + intermediate_output = self.intermediate_query(attention_output) + layer_output = self.output_query(intermediate_output, attention_output) + return layer_output + + +class Blip2QFormerEncoder(nn.Module): + + def __init__( + self, + config: Blip2QFormerConfig, + *, + quant_config: Optional[QuantizationConfig], + cache_config: Optional[CacheConfig], + ) -> None: + super().__init__() + + self.config = config + + self.layer = nn.ModuleList([ + Blip2QFormerLayer(config, + quant_config=quant_config, + cache_config=cache_config, + layer_idx=layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor, + query_length: int, + ) -> torch.Tensor: + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + + hidden_states = layer_module( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + query_length=query_length, + ) + + return hidden_states + + +# Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1025 +class Blip2QFormerModel(nn.Module): + + def __init__( + self, + config: Blip2QFormerConfig, + *, + quant_config: Optional[QuantizationConfig], + cache_config: Optional[CacheConfig], + ) -> None: + super().__init__() + + self.config = config + + self.layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + self.encoder = Blip2QFormerEncoder(config, + quant_config=quant_config, + cache_config=cache_config) + + def forward( + self, + query_embeds: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor, + ) -> torch.Tensor: + query_length = query_embeds.shape[1] + + embedding_output = self.layernorm(query_embeds) + embedding_output = self.dropout(embedding_output) + + sequence_output = self.encoder( + embedding_output, + encoder_hidden_states=encoder_hidden_states, + query_length=query_length, + ) + + return sequence_output + + +class Blip2ProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(Blip2Config) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"image": self.get_num_image_tokens()} + + def get_num_image_tokens(self) -> int: + hf_config = self.get_hf_config() + return hf_config.num_query_tokens + + +class Blip2DummyInputsBuilder(BaseDummyInputsBuilder[Blip2ProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + hf_config = self.info.get_hf_config() + vision_config = hf_config.vision_config + + max_image_size = vision_config.image_size + num_images = mm_counts.get("image", 0) + + mm_data = { + "image": + self._get_dummy_images(width=max_image_size, + height=max_image_size, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text="", + mm_data=mm_data, + ) + + +class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + if not mm_data: + # HF processor always adds placeholders even when there's no image + tokenizer = self.info.get_tokenizer() + prompt_ids = tokenizer.encode(prompt) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + return super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + + image_token_id = vocab[""] + num_image_tokens = self.info.get_num_image_tokens() + image_tokens = [image_token_id] * num_image_tokens + + return [ + PromptInsertion( + modality="image", + target=PromptIndexTargets.start(), + insertion=image_tokens, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor, + info=Blip2ProcessingInfo, + dummy_inputs=Blip2DummyInputsBuilder) +class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, + SupportsQuant): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.multimodal_config = multimodal_config + + # TODO: Optionally initializes this for supporting embeddings. + self.vision_model = BlipVisionModel(config.vision_config, quant_config) + + self.query_tokens = nn.Parameter( + torch.zeros(1, config.num_query_tokens, + config.qformer_config.hidden_size)) + + self.qformer = Blip2QFormerModel(config.qformer_config, + cache_config=cache_config, + quant_config=quant_config) + + self.language_projection = nn.Linear( + config.qformer_config.hidden_size, + config.text_config.hidden_size, + bias=True, + ) + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return get_sampler() + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + actual_dims = tuple(data.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("batch_size", *map(str, expected_dims)) + raise ValueError( + f"The expected shape of pixel values is {expected_expr}. " + f"You supplied {tuple(data.shape)}.") + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Blip2ImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + pixel_values = flatten_bn(pixel_values, concat=True) + + return Blip2ImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values(pixel_values), + ) + + if image_embeds is not None: + if not isinstance(image_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + + image_embeds = flatten_bn(image_embeds, concat=True) + + return Blip2ImageEmbeddingInputs( + type="image_embeds", + data=image_embeds, + ) + + raise AssertionError("This line should be unreachable.") + + def _image_pixels_to_features(self, vision_model: BlipVisionModel, + pixel_values: torch.Tensor) -> torch.Tensor: + + # NOTE: we skip the step to select the vision feature layer since + # this is already done inside the vision tower + image_features = vision_model(pixel_values) + + return image_features + + def _process_image_pixels(self, + inputs: Blip2ImagePixelInputs) -> torch.Tensor: + assert self.vision_model is not None + + pixel_values = inputs["data"] + + return self._image_pixels_to_features(self.vision_model, pixel_values) + + def _process_image_input(self, + image_input: Blip2ImageInputs) -> torch.Tensor: + + if image_input["type"] == "image_embeds": + return image_input["data"] + + assert self.vision_model is not None + image_features = self._process_image_pixels(image_input) + + query_tokens = self.query_tokens.expand(image_features.shape[0], -1, + -1) + query_output = self.qformer( + query_embeds=query_tokens, + encoder_hidden_states=image_features, + ) + + return self.language_projection(query_output) + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + _IMAGE_TOKEN_ID) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[SamplerOutput, IntermediateTensors]: + """Run forward pass for BLIP-2. + + One key thing to understand is the `input_ids` already accounts for the + positions of the to-be-inserted image embeddings. + + Concretely, consider a text prompt: + `"Question: What's the content of the image? Answer:"`. + + Tokenizer outputs: + `[2, 45641, 35, 653, 18, 5, 1383, 9, 5, 2274, 116, 31652, 35]`. + + To reserve space in KV cache, we have to insert placeholder tokens + before they are inputted to the model, so the input processor prepends + dummy tokens (denoted as `50265`), resulting in: + `[50265, ..., 50265, 2, 45641, 35, ..., 31652, 35]`. + + We insert 32 tokens since it corresponds to the number of query + embeddings outputted by the Q-Former and inputted to the language model. + + This way, the `positions` and `attn_metadata` are consistent + with the `input_ids`. + + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + pixel_values: The pixels in each input image. + + See also: + :class:`Blip2ImageInputs` + """ + + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.language_model.model(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py new file mode 100644 index 0000000..f960075 --- /dev/null +++ b/vllm/model_executor/models/bloom.py @@ -0,0 +1,368 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py +# Copyright 2023 The vLLM team. +# Copyright 2022 HuggingFace Inc. team and BigScience workshop. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only BLOOM model compatible with HuggingFace weights.""" +import math +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import BloomConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP, SupportsQuant, SupportsV0Only +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: + closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) + base = torch.tensor( + 2**(-(2**-(math.log2(closest_power_of_2) - 3))), + dtype=torch.float32, + ) + powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != total_num_heads: + extra_base = torch.tensor( + 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + dtype=torch.float32, + ) + num_remaining_heads = min(closest_power_of_2, + total_num_heads - closest_power_of_2) + extra_powers = torch.arange(start=1, + end=1 + 2 * num_remaining_heads, + step=2, + dtype=torch.int32) + slopes = torch.cat( + [slopes, torch.pow(extra_base, extra_powers)], dim=0) + return slopes + + +class BloomAttention(nn.Module): + + def __init__( + self, + config: BloomConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = config.hidden_size + self.total_num_heads = config.n_head + self.head_dim = self.hidden_size // self.total_num_heads + assert self.head_dim * self.total_num_heads == self.hidden_size + + tp_world_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_world_size == 0 + self.num_heads = self.total_num_heads // tp_world_size + + self.query_key_value = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + bias=True, + quant_config=quant_config, + ) + self.dense = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + ) + + # Create the alibi slopes and slice them. + tp_rank = get_tensor_model_parallel_rank() + head_start = tp_rank * self.num_heads + head_end = (tp_rank + 1) * self.num_heads + alibi_slopes = _get_alibi_slopes(self.total_num_heads) + alibi_slopes = alibi_slopes[head_start:head_end].tolist() + + scaling = self.head_dim**-0.5 + self.attn = Attention(self.num_heads, + self.head_dim, + scaling, + alibi_slopes=alibi_slopes, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + del position_ids # Unused. + qkv, _ = self.query_key_value(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + attn_output = self.attn(q, k, v) + output, _ = self.dense(attn_output) + return output + + +class BloomMLP(nn.Module): + + def __init__( + self, + config: BloomConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + hidden_size = config.hidden_size + self.dense_h_to_4h = ColumnParallelLinear( + hidden_size, + 4 * hidden_size, + quant_config=quant_config, + ) + self.gelu_impl = get_act_fn("gelu") + self.dense_4h_to_h = RowParallelLinear( + 4 * hidden_size, + hidden_size, + quant_config=quant_config, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.dense_h_to_4h(x) + x = self.gelu_impl(x) + x, _ = self.dense_4h_to_h(x) + return x + + +class BloomBlock(nn.Module): + + def __init__( + self, + config: BloomConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + hidden_size = config.hidden_size + + self.input_layernorm = nn.LayerNorm(hidden_size, + eps=config.layer_norm_epsilon) + self.self_attention = BloomAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.self_attention") + self.post_attention_layernorm = nn.LayerNorm( + hidden_size, eps=config.layer_norm_epsilon) + self.mlp = BloomMLP(config, quant_config) + self.apply_residual_connection_post_layernorm = ( + config.apply_residual_connection_post_layernorm) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + + # Layer norm post the self attention. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + # Self attention. + attention_output = self.self_attention( + position_ids=position_ids, + hidden_states=layernorm_output, + ) + attention_output = attention_output + residual + layernorm_output = self.post_attention_layernorm(attention_output) + + # Get residual + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = attention_output + + # MLP. + output = self.mlp(layernorm_output) + residual + return output + + +@support_torch_compile +class BloomModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.embed_dim = config.hidden_size + + # Embedding + LN Embedding + self.word_embeddings = VocabParallelEmbedding( + config.vocab_size, + self.embed_dim, + ) + self.word_embeddings_layernorm = nn.LayerNorm( + self.embed_dim, eps=config.layer_norm_epsilon) + + # Transformer blocks + self.start_layer, self.end_layer, self.h = make_layers( + config.num_hidden_layers, + lambda prefix: BloomBlock( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.h") + + # Final Layer Norm + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.word_embeddings_layernorm(self.word_embeddings(input_ids)) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + for layer in self.h[self.start_layer:self.end_layer]: + hidden_states = layer(position_ids, hidden_states) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.transformer = BloomModel(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "transformer")) + if self.config.tie_word_embeddings: + self.lm_head = self.transformer.word_embeddings + else: + self.lm_head = ParallelLMHead(self.config.vocab_size, + self.config.hidden_size) + + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if name == "lm_head.weight": + continue + if not name.startswith("transformer."): + name = "transformer." + name + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + + if "query_key_value" in name: + # NOTE: BLOOM's fused QKV's output_dim has the shape of + # (num_heads * 3 * head_size), while the + # required shape is (3 * num_heads * head_size). + # Thus, we need weight conversion. + output_dim = getattr(param, "output_dim", None) + num_heads = self.config.num_attention_heads + if output_dim is not None: + loaded_weight_shape = loaded_weight.shape + loaded_weight = loaded_weight.view( + loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + + loaded_weight_shape[output_dim + 1:]) + loaded_weight = loaded_weight.transpose( + output_dim, output_dim + 1) + loaded_weight = loaded_weight.reshape(loaded_weight_shape) + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py new file mode 100644 index 0000000..f758c98 --- /dev/null +++ b/vllm/model_executor/models/chameleon.py @@ -0,0 +1,1145 @@ +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Iterable, Mapping, Sequence +from functools import cached_property +from typing import Any, Dict, Literal, Optional, Set, Tuple, TypedDict, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor, + ChameleonVQVAEConfig) + +from vllm.attention import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, row_parallel_weight_loader) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors + +from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP, + SupportsQuant) +from .utils import (flatten_bn, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix, merge_multimodal_embeddings) + +logger = init_logger(__name__) + + +class ChameleonImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: `(batch_size * num_images, num_channels, height, width)`""" + + +class ChameleonProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(ChameleonConfig) + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(ChameleonProcessor, **kwargs) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"image": self.get_num_image_tokens()} + + def get_num_image_tokens(self) -> int: + processor = self.get_hf_processor() + return processor.image_seq_length + + +class ChameleonDummyInputsBuilder( + BaseDummyInputsBuilder[ChameleonProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + config = self.info.get_hf_config() + + width = height = config.vq_config.resolution + num_images = mm_counts.get("image", 0) + + mm_data = { + "image": + self._get_dummy_images(width=width, + height=height, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text="" * num_images, + mm_data=mm_data, + ) + + +class ChameleonMultiModalProcessor( + BaseMultiModalProcessor[ChameleonProcessingInfo]): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + if not mm_data: + prompt_ids = self.info.get_tokenizer().encode(prompt) + prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + return super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + def _apply_hf_processor_tokens_only( + self, + prompt_tokens: list[int], + ) -> list[int]: + # HF processor adds sep token for chat mode + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + + sep_token_id = vocab[tokenizer.sep_token] # type: ignore + + return prompt_tokens + [sep_token_id] + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values=MultiModalFieldConfig.batched("image")) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + + image_start_id = vocab[processor.image_start_token] + image_token_id = vocab[processor.image_token] + image_end_id = vocab[processor.image_end_token] + + num_image_tokens = self.info.get_num_image_tokens() + image_tokens = [image_token_id] * num_image_tokens + + return [ + PromptReplacement( + modality="image", + target=[image_token_id], + replacement=PromptUpdateDetails( + full=([image_start_id] + image_tokens + [image_end_id]), + features=image_tokens, + ), + ) + ] + + +class ChameleonLayerNorm(nn.LayerNorm): + + def __init__(self, hidden_size, *args, **kwargs): + super().__init__(hidden_size, *args, **kwargs) + self.normalized_shape = (hidden_size[-1], ) + + set_weight_attrs(self.weight, + {"weight_loader": row_parallel_weight_loader}) + set_weight_attrs(self.bias, + {"weight_loader": row_parallel_weight_loader}) + + def forward(self, hidden_states): + hidden_states = F.layer_norm(hidden_states, + self.normalized_shape, + None, + None, + eps=1e-5) + hidden_states = hidden_states * self.weight + self.bias + return hidden_states + + +# Copied from vllm.model_executor.models.llama.LlamaMLP -> ChameleonMLP +class ChameleonMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config) + self.down_proj = RowParallelLinear(input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +# Modified from vllm.model_executor.models.llama.LlamaAttention -> ChameleonAttention #noqa +class ChameleonAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 4096, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + ) + self.q_norm = ChameleonLayerNorm((self.num_heads, self.head_dim)) + self.k_norm = ChameleonLayerNorm((self.num_kv_heads, self.head_dim)) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def _apply_qk_norm(self, q: torch.Tensor, + k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # reshape for layernorm + q = q.reshape(-1, self.num_heads, self.head_dim) + k = k.reshape(-1, self.num_kv_heads, self.head_dim) + q = self.q_norm(q) + k = self.k_norm(k) + q = q.view(*q.shape[:-2], -1) + k = k.view(*k.shape[:-2], -1) + return q, k + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self._apply_qk_norm(q, k) + + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class ChameleonDecoderLayer(nn.Module): + + def __init__( + self, + config: ChameleonConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 4096) + + self.self_attn = ChameleonAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=False, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = ChameleonMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +class ChameleonSwinDecoderLayer(nn.Module): + + def __init__( + self, + config: ChameleonConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 4096) + + self.self_attn = ChameleonAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=False, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = ChameleonMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + + residual = hidden_states + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + hidden_states = self.input_layernorm(hidden_states) + hidden_states = hidden_states + residual + + # Fully Connected + residual = hidden_states + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, residual + + +# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEVectorQuantizer #noqa +class ChameleonVQVAEVectorQuantizer(nn.Module): + + def __init__(self, config: ChameleonVQVAEConfig): + super().__init__() + self.num_embeddings = config.num_embeddings + self.embedding_dim = config.embed_dim + self.beta = getattr(config, "beta", 0.25) + + self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim) + self.re_embed = self.num_embeddings + + def forward(self, hidden_state: torch.Tensor): + hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous() + hidden_state_flattened = hidden_state.view(-1, self.embedding_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + distances = ( + torch.sum(hidden_state_flattened**2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight**2, dim=1) - + 2 * torch.einsum("bd,dn->bn", hidden_state_flattened, + self.embedding.weight.transpose(0, 1))) + + min_encoding_indices = torch.argmin(distances, dim=1) + hidden_state_quant = self.embedding(min_encoding_indices).view( + hidden_state.shape) + + # compute loss for embedding + loss = torch.mean((hidden_state_quant.detach() - hidden_state)** + 2) + self.beta * torch.mean( + (hidden_state_quant - hidden_state.detach())**2) + + # preserve gradients + hidden_state_quant = hidden_state + (hidden_state_quant - + hidden_state).detach() + + # reshape back to match original input shape + hidden_state_quant = hidden_state_quant.permute(0, 3, 1, + 2).contiguous() + + return hidden_state_quant, loss, min_encoding_indices + + +# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderConvDownsample #noqa +class ChameleonVQVAEEncoderConvDownsample(nn.Module): + + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, hidden_states: torch.Tensor): + # no asymmetric padding in torch conv, must do it ourselves + hidden_states = F.pad(hidden_states, + pad=(0, 1, 0, 1), + mode="constant", + value=0) + hidden_states = self.conv(hidden_states) + return hidden_states + + +# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderResnetBlock #noqa +class ChameleonVQVAEEncoderResnetBlock(nn.Module): + + def __init__( + self, + config: ChameleonVQVAEConfig, + in_channels: int, + out_channels=None, + conv_shortcut=False, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None \ + else out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = torch.nn.GroupNorm(num_groups=32, + num_channels=in_channels, + eps=1e-6, + affine=True) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + self.norm2 = torch.nn.GroupNorm(num_groups=32, + num_channels=out_channels, + eps=1e-6, + affine=True) + self.dropout = torch.nn.Dropout(config.dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, hidden_states: torch.Tensor): + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + residual = self.conv_shortcut(residual) + else: + residual = self.nin_shortcut(residual) + + return residual + hidden_states + + +# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderAttnBlock #noqa +class ChameleonVQVAEEncoderAttnBlock(nn.Module): + + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=32, + num_channels=in_channels, + eps=1e-6, + affine=True) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, hidden_states: torch.Tensor): + residual = hidden_states + hidden_states = self.norm(hidden_states) + query_states = self.q(hidden_states) + key_states = self.k(hidden_states) + value_states = self.v(hidden_states) + + # compute attention + batch_size, channels, height, width = query_states.shape + query_states = query_states.reshape(batch_size, channels, + height * width).permute(0, 2, 1) + key_states = key_states.reshape(batch_size, channels, height * width) + attn_weights = torch.bmm(query_states, key_states) + attn_weights = attn_weights * (int(channels)**(-0.5)) + attn_weights = F.softmax(attn_weights, dim=2) + + # attend to values + value_states = value_states.reshape(batch_size, channels, + height * width) + attn_weights = attn_weights.permute(0, 2, 1) + attn_output = torch.bmm(value_states, + attn_weights).reshape(batch_size, channels, + height, width) + + attn_output = self.proj_out(attn_output) + return residual + attn_output + + +# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoder #noqa +class ChameleonVQVAEEncoder(nn.Module): + + def __init__(self, config: ChameleonVQVAEConfig): + super().__init__() + + self.num_resolutions = len(config.channel_multiplier) + self.num_res_blocks = config.num_res_blocks + base_channels = config.base_channels + resolution = config.resolution + in_channels = config.in_channels + double_latent = config.double_latent + latent_channels = config.latent_channels + channel_multiplier = config.channel_multiplier + + self.conv_in = torch.nn.Conv2d(in_channels, + base_channels, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_channel_multiplier = (1, ) + tuple(channel_multiplier) + self.in_channel_multiplier = in_channel_multiplier + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = base_channels * in_channel_multiplier[i_level] + block_out = base_channels * channel_multiplier[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ChameleonVQVAEEncoderResnetBlock( + config=config, + in_channels=block_in, + out_channels=block_out, + )) + block_in = block_out + if (config.attn_resolutions is not None + and curr_res in config.attn_resolutions + and config.attn_type == "vanilla"): + attn.append(ChameleonVQVAEEncoderAttnBlock(block_in)) + + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = ChameleonVQVAEEncoderConvDownsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + self.mid = nn.Module() + self.mid.block_1 = ChameleonVQVAEEncoderResnetBlock( + config=config, + in_channels=block_in, + out_channels=block_in, + ) + self.mid.attn_1 = ChameleonVQVAEEncoderAttnBlock( + block_in) if config.attn_type == "vanilla" else nn.Identity() + self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock( + config=config, + in_channels=block_in, + out_channels=block_in, + ) + + self.norm_out = torch.nn.GroupNorm(num_groups=32, + num_channels=block_in, + eps=1e-6, + affine=True) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * latent_channels if double_latent else latent_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, pixel_values: torch.Tensor): + pixel_values = pixel_values.to(self.conv_in.weight.dtype) + + # downsampling + hidden_states = [self.conv_in(pixel_values)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + hidden_state = self.down[i_level].block[i_block]( + hidden_states[-1]) + if len(self.down[i_level].attn) > 0: + hidden_state = self.down[i_level].attn[i_block]( + hidden_state) + hidden_states.append(hidden_state) + if i_level != self.num_resolutions - 1: + hidden_states.append(self.down[i_level].downsample( + hidden_states[-1])) + + # middle + last_hidden_state = hidden_states[-1] + last_hidden_state = self.mid.block_1(last_hidden_state) + last_hidden_state = self.mid.attn_1(last_hidden_state) + last_hidden_state = self.mid.block_2(last_hidden_state) + + # end + last_hidden_state = self.norm_out(last_hidden_state) + last_hidden_state *= torch.sigmoid(last_hidden_state) + last_hidden_state = self.conv_out(last_hidden_state) + return last_hidden_state + + +# Adapted from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAE #noqa +class ChameleonVQVAE(nn.Module): + + def __init__(self, config: ChameleonVQVAEConfig): + super().__init__() + self.encoder = ChameleonVQVAEEncoder(config) + self.quantize = ChameleonVQVAEVectorQuantizer(config) + self.quant_conv = torch.nn.Conv2d(config.latent_channels, + config.embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, + config.latent_channels, 1) + self.eval() # Chameleon's VQ model is frozen + + def encode( + self, pixel_values: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + hidden_states = self.encoder(pixel_values) + hidden_states = self.quant_conv(hidden_states) + quant, emb_loss, indices = self.quantize(hidden_states) + return quant, emb_loss, indices + + +# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonImageVocabularyMapping #noqa +class ChameleonImageVocabularyMapping: + """ + A class for mapping discrete image tokens from VQGAN to BPE tokens. + """ + + def __init__(self, vocab_map: Dict[str, int]): + self.vocab_map = vocab_map + self.image_token_id = vocab_map.get("") + + @cached_property + def val2name(self): + return {v: k for k, v in self.vocab_map.items()} + + @cached_property + def image_tokens(self): + return sorted([ + val for name, val in self.vocab_map.items() + if name.startswith("IMGIMG") + ]) + + @cached_property + def bpe2img(self): + img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)} + + def remap(old_name: str) -> str: + return "".join( + img_tkn_chr_mapping.get(c, c) + for c in old_name[len("IMGIMG"):-1]) + + return { + tok: int(remap(self.val2name[tok])) + for tok in self.image_tokens + } + + @cached_property + def img2bpe(self): + return {v: k for k, v in self.bpe2img.items()} + + @cached_property + def bpe2img_search_tensors(self): + return torch.tensor(sorted(self.bpe2img.keys())), torch.tensor( + sorted(self.bpe2img.values())) + + @cached_property + def img2bpe_mapping_tensor(self): + mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int) + for k, v in self.img2bpe.items(): + mapping[k] = v + return mapping + + def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor: + device = img_batch.device + img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")] + return img_tokens.to(device) + + +class ChameleonModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + ) + self.vocabulary_mapping = ChameleonImageVocabularyMapping( + config.vocabulary_map) + decoder_layer = ChameleonDecoderLayer if not self.config.swin_norm \ + else ChameleonSwinDecoderLayer + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: decoder_layer(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.vqmodel = ChameleonVQVAE(config.vq_config) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def get_image_tokens(self, pixel_values: torch.Tensor) -> torch.Tensor: + """ + Tokenizes images into discrete tokens with VQGAN module. Converts + obtained image tokens into BPE tokens and wraps with "boi" and "eoi" + special tokens. + """ + batch_size = pixel_values.shape[0] + _, _, image_toks = self.vqmodel.encode(pixel_values) + bpe_toks = self.vocabulary_mapping.convert_img2bpe(image_toks) + bpe_toks = bpe_toks.view(batch_size, -1) + return bpe_toks + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +@MULTIMODAL_REGISTRY.register_processor( + ChameleonMultiModalProcessor, + info=ChameleonProcessingInfo, + dummy_inputs=ChameleonDummyInputsBuilder) +class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP, SupportsQuant): + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"] + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.multimodal_config = multimodal_config + self.model = ChameleonModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, logit_scale) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + vq_config: ChameleonVQVAEConfig = self.config.vq_config + expected_dims = (3, vq_config.resolution, vq_config.resolution) + actual_dims = tuple(data.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("batch_size", *map(str, expected_dims)) + raise ValueError( + f"The expected shape of pixel values is {expected_expr}. " + f"You supplied {tuple(data.shape)}.") + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[ChameleonImagePixelInputs]: + pixel_values = kwargs.pop("pixel_values", None) + + if pixel_values is None: + return None + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + pixel_values = flatten_bn(pixel_values, concat=True) + + return ChameleonImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values(pixel_values), + ) + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + assert self.model.vqmodel is not None + image_tokens = self.model.get_image_tokens(image_input["data"].to( + self.config.torch_dtype)) + vision_embeddings = self.model.get_input_embeddings(image_tokens) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + + inputs_embeds = self.model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.model.vocabulary_mapping.image_token_id) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: + + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.model(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + + # Disallow image tokens which does not include special + # begin-image and end-image tokens + if logits is not None: + image_tokens = self.model.vocabulary_mapping.image_tokens + logits[:, image_tokens] = torch.finfo(logits.dtype).min + + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + + use_default_weight_loading = False + if "vqmodel" in name: + if self.model.vqmodel is not None: + # We only do sharding for language model and + # not vqvae for now. + use_default_weight_loading = True + else: + for (param_name, weight_name, + shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + logger.warning_once( + "Found kv scale in the checkpoint (e.g. " + f"{name}), but not found the expected name in " + f"the model (e.g. {remapped_kv_scale_name}). " + "kv-scale is not loaded.") + continue + else: + name = remapped_kv_scale_name + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + if use_default_weight_loading and name in params_dict: + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py new file mode 100644 index 0000000..1b1738f --- /dev/null +++ b/vllm/model_executor/models/chatglm.py @@ -0,0 +1,486 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from +# https://github.com/THUDM/ChatGLM2-6B +"""Inference-only ChatGLM model compatible with THUDM weights.""" +import json +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from torch.nn import LayerNorm + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import ChatGLMConfig + +from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant +from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class GLMAttention(nn.Module): + + def __init__( + self, + config: ChatGLMConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.multi_query_attention = config.multi_query_attention + self.total_num_kv_heads = (config.multi_query_group_num + if config.multi_query_attention else + config.num_attention_heads) + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = config.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + + self.query_key_value = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.add_bias_linear or config.add_qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.query_key_value", + ) + self.dense = RowParallelLinear( + self.total_num_heads * self.head_dim, + config.hidden_size, + bias=config.add_bias_linear, + quant_config=quant_config, + prefix=f"{prefix}.dense", + ) + + # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141 + rope_ratio = getattr(config, "rope_ratio", 1.0) + max_positions = getattr(config, "seq_length", 8192) + # NOTE: THUDM/cogagent-9b-20241220 uses original_rope=False, + # which is equivalent to is_neox_style=True + is_neox_style = not config.original_rope + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim // 2, + max_position=max_positions, + base=10000 * rope_ratio, + is_neox_style=is_neox_style, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.query_key_value(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(position_ids, q, k) + context_layer = self.attn(q, k, v) + attn_output, _ = self.dense(context_layer) + return attn_output + + +class GLMMLP(nn.Module): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + """ + + def __init__( + self, + config: ChatGLMConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + self.add_bias = config.add_bias_linear + + # Project to 4h. + self.dense_h_to_4h = MergedColumnParallelLinear( + config.hidden_size, + [config.ffn_hidden_size] * 2, + bias=config.add_bias_linear, + quant_config=quant_config, + prefix=f"{prefix}.dense_h_to_4h", + ) + + self.activation_func = SiluAndMul() + + # Project back to h. + self.dense_4h_to_h = RowParallelLinear( + config.ffn_hidden_size, + config.hidden_size, + bias=config.add_bias_linear, + quant_config=quant_config, + prefix=f"{prefix}.dense_4h_to_h", + ) + + def forward(self, hidden_states): + # [s, b, 4hp] + intermediate_parallel, _ = self.dense_h_to_4h(hidden_states) + intermediate_parallel = self.activation_func(intermediate_parallel) + # [s, b, h] + output, _ = self.dense_4h_to_h(intermediate_parallel) + return output + + +class GLMBlock(nn.Module): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__( + self, + config: ChatGLMConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.apply_residual_connection_post_layernorm = ( + config.apply_residual_connection_post_layernorm) + + self.fp32_residual_connection = config.fp32_residual_connection + + layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm + # Layernorm on the input data. + self.input_layernorm = layer_norm_func(config.hidden_size, + eps=config.layernorm_epsilon) + + # Self attention. + self.self_attention = GLMAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.self_attention") + self.hidden_dropout = config.hidden_dropout + + # Layernorm on the attention output + self.post_attention_layernorm = layer_norm_func( + config.hidden_size, eps=config.layernorm_epsilon) + + # MLP + self.mlp = GLMMLP(config, quant_config, prefix=f"{prefix}.mlp") + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + ) -> torch.Tensor: + # hidden_states: [num_tokens, h] + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output = self.self_attention( + hidden_states=layernorm_output, + position_ids=position_ids, + ) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + layernorm_input = residual + attention_output + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + output = self.mlp(layernorm_output) + residual + + return output + + +class GLMTransformer(nn.Module): + """Transformer class.""" + + def __init__( + self, + config: ChatGLMConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.post_layer_norm = config.post_layer_norm + + # Number of layers. + self.num_layers = config.num_layers + + # Transformer layers. + self.start_layer, self.end_layer, self.layers = make_layers( + self.num_layers, + lambda prefix: GLMBlock( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers", + ) + + if self.post_layer_norm: + layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm + # Final layer norm before output. + self.final_layernorm = layer_norm_func( + config.hidden_size, eps=config.layernorm_epsilon) + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + ) -> Union[torch.Tensor, IntermediateTensors]: + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(hidden_states=hidden_states, + position_ids=position_ids) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + + # Final layer norm. + if self.post_layer_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states + + +@support_torch_compile +class ChatGLMModel(nn.Module, SupportsQuant): + packed_modules_mapping = { + "linear_proj.merged_proj": + ["linear_proj.gate_proj", "linear_proj.dense_h_to_4h"] + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + + self.embedding = VocabParallelEmbedding(config.padded_vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embedding") + + self.num_layers = config.num_layers + self.multi_query_group_num = config.multi_query_group_num + self.kv_channels = config.kv_channels + self.encoder = GLMTransformer(config, + cache_config, + quant_config, + prefix=f"{prefix}.encoder") + + self.output_layer = ParallelLMHead(config.padded_vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.output_layer") + + self.make_empty_intermediate_tensors = ( + self.encoder.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embedding(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + + # Run encoder. + hidden_states = self.encoder( + hidden_states=hidden_states, + position_ids=positions, + ) + + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("linear_proj.merged_proj", "linear_proj.gate_proj", 0), + ("linear_proj.merged_proj", "linear_proj.dense_h_to_4h", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if "rotary_pos_emb.inv_freq" in name: + continue + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class ChatGLMBaseModel(nn.Module): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={".word_embeddings": ""}, ) + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + transformer_type: type[ChatGLMModel] = ChatGLMModel, + ) -> None: + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.lora_config = lora_config + self.multimodal_config = multimodal_config + + self.quant_config = quant_config + self.max_position_embeddings = getattr(config, "max_sequence_length", + 8192) + self.transformer = transformer_type(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "transformer")) + if self.config.tie_word_embeddings: + self.transformer.output_layer.weight = ( + self.transformer.embedding.weight) + self.lm_head = self.transformer.output_layer + self.logits_processor = LogitsProcessor(config.padded_vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + +class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, + SupportsQuant): + packed_modules_mapping = { + "query_key_value": ["query_key_value"], + "dense_h_to_4h": ["dense_h_to_4h"] + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + if hasattr(config, "vision_config"): + hf_overrides = {"architectures": ["GLM4VForCausalLM"]} + raise RuntimeError( + "The configuration of this model indicates that it supports " + "vision inputs, but you instantiated the text-only version " + "of this model. Please use the vision model by setting " + f"`--hf-overrides '{json.dumps(hf_overrides)}'`") + + super().__init__(vllm_config=vllm_config, prefix=prefix) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) + return hidden_states diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py new file mode 100644 index 0000000..dc3aa9c --- /dev/null +++ b/vllm/model_executor/models/clip.py @@ -0,0 +1,413 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Minimal implementation of CLIPVisionModel intended to be only used +within a vision language model.""" +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +import torch.nn as nn +from transformers import CLIPVisionConfig + +from vllm.attention.layer import MultiHeadAttention +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import SupportsQuant + +from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs + + +class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]): + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + return self.get_patch_grid_length()**2 + 1 + + def get_max_image_tokens(self) -> int: + return self.get_patch_grid_length()**2 + 1 + + def get_image_size(self) -> int: + return self.vision_config.image_size + + def get_patch_size(self) -> int: + return self.vision_config.patch_size + + def get_patch_grid_length(self) -> int: + image_size, patch_size = self.get_image_size(), self.get_patch_size() + assert image_size % patch_size == 0 + return image_size // patch_size + + +# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa +class CLIPVisionEmbeddings(nn.Module): + + def __init__(self, config: CLIPVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + assert self.image_size % self.patch_size == 0 + + self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = (self.image_size // self.patch_size)**2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, + self.embed_dim) + self.register_buffer("position_ids", + torch.arange(self.num_positions).expand((1, -1)), + persistent=False) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to( + dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + self.position_embedding(self.position_ids) + + return embeddings + + +class CLIPAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: CLIPVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + "embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.qkv_proj = QKVParallelLinear( + hidden_size=self.embed_dim, + head_size=self.head_dim, + total_num_heads=self.num_heads, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.out_proj = RowParallelLinear( + input_size=self.embed_dim, + output_size=self.embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + + self.attn = MultiHeadAttention(self.num_heads_per_partition, + self.head_dim, self.scale) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, + self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + ): + """Input shape: Batch x Time x Channel""" + + qkv_states, _ = self.qkv_proj(hidden_states) + query_states, key_states, value_states = qkv_states.chunk(3, dim=-1) + out = self.attn(query_states, key_states, value_states) + attn_output, _ = self.out_proj(out) + + return attn_output, None + + +class CLIPMLP(nn.Module): + + def __init__( + self, + config: CLIPVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + self.fc1 = ColumnParallelLinear(config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc1") + self.fc2 = RowParallelLinear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc2") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + + return hidden_states + + +class CLIPEncoderLayer(nn.Module): + + def __init__( + self, + config: CLIPVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.self_attn = CLIPAttention( + config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.layer_norm1 = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.mlp = CLIPMLP(config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + self.layer_norm2 = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, _ = self.self_attn(hidden_states=hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class CLIPEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self + attention layers. Each layer is a [`CLIPEncoderLayer`]. + + Args: + config: CLIPConfig + """ + + def __init__( + self, + config: CLIPVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + num_hidden_layers_override: Optional[int] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + + if num_hidden_layers_override is None: + num_hidden_layers = config.num_hidden_layers + else: + num_hidden_layers = num_hidden_layers_override + self.layers = nn.ModuleList([ + CLIPEncoderLayer(config=config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}") + for layer_idx in range(num_hidden_layers) + ]) + + def forward( + self, inputs_embeds: torch.Tensor, return_all_hidden_states: bool + ) -> Union[torch.Tensor, list[torch.Tensor]]: + hidden_states_pool = [inputs_embeds] + hidden_states = inputs_embeds + + for encoder_layer in self.layers: + hidden_states = encoder_layer(hidden_states) + if return_all_hidden_states: + hidden_states_pool.append(hidden_states) + # If we have multiple feature sample layers, we return all hidden + # states in order and grab the ones we need by index. + if return_all_hidden_states: + return hidden_states_pool + return hidden_states + + +class CLIPVisionTransformer(nn.Module): + + def __init__( + self, + config: CLIPVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + num_hidden_layers_override: Optional[int] = None, + require_post_norm: Optional[bool] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + embed_dim = config.hidden_size + + self.embeddings = CLIPVisionEmbeddings(config) + + # NOTE: This typo of "layrnorm" is not fixed on purpose to match + # the original transformers code and name of the model weights. + self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + self.encoder = CLIPEncoder( + config=config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_override, + prefix=f"{prefix}.encoder", + ) + + num_hidden_layers = config.num_hidden_layers + if len(self.encoder.layers) > config.num_hidden_layers: + raise ValueError( + f"The original encoder only has {num_hidden_layers} " + f"layers, but you requested {len(self.encoder.layers)} layers." + ) + + # If possible, skip post_layernorm to conserve memory + if require_post_norm is None: + require_post_norm = len(self.encoder.layers) == num_hidden_layers + + if require_post_norm: + self.post_layernorm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps) + else: + self.post_layernorm = None + + def forward( + self, + pixel_values: torch.Tensor, + feature_sample_layers: Optional[list[int]] = None, + ) -> torch.Tensor: + + hidden_states = self.embeddings(pixel_values) + hidden_states = self.pre_layrnorm(hidden_states) + + return_all_hidden_states = feature_sample_layers is not None + + # Produces either the last layer output or all of the hidden states, + # depending on if we have feature_sample_layers or not + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + return_all_hidden_states=return_all_hidden_states) + + # Handle post-norm (if applicable) and stacks feature layers if needed + encoder_outputs = resolve_visual_encoder_outputs( + encoder_outputs, feature_sample_layers, self.post_layernorm, + self.config.num_hidden_layers) + + return encoder_outputs + + +class CLIPVisionModel(nn.Module, SupportsQuant): + config_class = CLIPVisionConfig + main_input_name = "pixel_values" + packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} + + def __init__( + self, + config: CLIPVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + num_hidden_layers_override: Optional[int] = None, + require_post_norm: Optional[bool] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.vision_model = CLIPVisionTransformer( + config=config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_override, + require_post_norm=require_post_norm, + prefix=f"{prefix}.vision_model") + + def forward( + self, + pixel_values: torch.Tensor, + feature_sample_layers: Optional[list[int]] = None, + ) -> torch.Tensor: + return self.vision_model(pixel_values, feature_sample_layers) + + @property + def device(self): + return next(self.parameters()).device + + # (TODO) Add prefix argument for filtering out weights to be loaded + # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + layer_count = len(self.vision_model.encoder.layers) + + for name, loaded_weight in weights: + # post_layernorm is not needed in CLIPVisionModel + if (name.startswith("vision_model.post_layernorm") + and self.vision_model.post_layernorm is None): + continue + + # omit layers when num_hidden_layers_override is set + if name.startswith("vision_model.encoder.layers"): + layer_idx = int(name.split(".")[3]) + if layer_idx >= layer_count: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py new file mode 100644 index 0000000..b1757c1 --- /dev/null +++ b/vllm/model_executor/models/commandr.py @@ -0,0 +1,469 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2024 Cohere and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is based on the LLama model definition file in transformers +"""PyTorch Cohere model.""" +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import CohereConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name, + row_parallel_weight_loader) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant +from .utils import (extract_layer_index, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +# @torch.compile(backend=current_platform.simple_compile_backend) +def layer_norm_func(hidden_states, weight, variance_epsilon): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + mean = hidden_states.mean(-1, keepdim=True) + variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True) + hidden_states = (hidden_states - mean) * torch.rsqrt(variance + + variance_epsilon) + hidden_states = weight.to(torch.float32) * hidden_states + return hidden_states.to(input_dtype) + + +class LayerNorm(nn.Module): + + def __init__(self, param_shape=None, eps=1e-5): + super().__init__() + self.weight = nn.Parameter(torch.ones(param_shape)) + self.variance_epsilon = eps + set_weight_attrs(self.weight, + {"weight_loader": row_parallel_weight_loader}) + + def forward(self, hidden_states, residuals=None): + hidden_states = layer_norm_func(hidden_states, self.weight, + self.variance_epsilon) + return hidden_states, residuals + + +# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere +class CohereMLP(nn.Module): + + def __init__( + self, + config: CohereConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_up_proj = MergedColumnParallelLinear( + self.hidden_size, + [self.intermediate_size] * 2, + bias=False, + quant_config=quant_config, + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + quant_config=quant_config, + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class CohereAttention(nn.Module): + + def __init__( + self, + config: CohereConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + tp_size = get_tensor_model_parallel_world_size() + self.config = config + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.total_num_heads = config.num_attention_heads + self.num_heads = self.total_num_heads // tp_size + self.head_dim = self.hidden_size // self.total_num_heads + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.max_position_embeddings = getattr( + config, "model_max_length", None) or getattr( + config, "max_position_embeddings", 8192) + self.rope_theta = config.rope_theta + self.rope_scaling = getattr(config, "rope_scaling", None) + self.use_qk_norm = getattr(config, "use_qk_norm", False) + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + rope_scaling=self.rope_scaling, + is_neox_style=False, + ) + + # Model v2 has interleaved sliding windows, v1 does not + interleaved_sliding_window = getattr(config, + "interleaved_sliding_window", + None) + self.v1 = interleaved_sliding_window is None + + layer_idx = extract_layer_index(prefix) + layer_has_sliding_window = ( + getattr(config, "sliding_window_pattern", False) + and (layer_idx + 1) % self.config.sliding_window_pattern != 0) + + self.sliding_window = (interleaved_sliding_window + if layer_has_sliding_window else None) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + per_layer_sliding_window=self.sliding_window, + prefix=f"{prefix}.attn") + if self.use_qk_norm: + self.q_norm = LayerNorm(param_shape=(self.num_heads, + self.head_dim), + eps=config.layer_norm_eps) + self.k_norm = LayerNorm(param_shape=(self.num_kv_heads, + self.head_dim), + eps=config.layer_norm_eps) + + def _apply_qk_norm(self, q, k): + q = q.view(*q.shape[:-1], -1, self.head_dim) + k = k.view(*k.shape[:-1], -1, self.head_dim) + q, _ = self.q_norm(q) + k, _ = self.k_norm(k) + q = q.view(*q.shape[:-2], -1) + k = k.view(*k.shape[:-2], -1) + return q, k + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if self.use_qk_norm: + q, k = self._apply_qk_norm(q, k) + if self.v1 or self.sliding_window: + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class CohereDecoderLayer(nn.Module): + + def __init__(self, + config: CohereConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = CohereAttention(config, + cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn") + + self.mlp = CohereMLP(config, quant_config=quant_config) + self.input_layernorm = LayerNorm(param_shape=(config.hidden_size), + eps=config.layer_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + residual = hidden_states + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states_attention = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + hidden_states_mlp = self.mlp(hidden_states) + # Add everything together + hidden_states = residual + hidden_states_attention + hidden_states_mlp + + return hidden_states, residual + + +@support_torch_compile +class CohereModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: CohereDecoderLayer( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers") + self.norm = LayerNorm(param_shape=(config.hidden_size), + eps=config.layer_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + # LoRA specific attributes + embedding_modules = {"embed_tokens": "input_embeddings"} + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config + # currently all existing command R models have `tie_word_embeddings` + # enabled + assert config.tie_word_embeddings + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.quant_config = quant_config + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + scale=config.logit_scale) + self.model = CohereModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + is_not_lora = hasattr(self.model.embed_tokens, 'weight') + if is_not_lora: + logits = self.logits_processor(self.model.embed_tokens, + hidden_states, sampling_metadata) + else: + logits = self.logits_processor(self.model.embed_tokens.base_layer, + hidden_states, sampling_metadata) + + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + + for param_name, shard_name, shard_id in stacked_params_mapping: + if shard_name not in name: + continue + name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # lm_head is not used in vllm as it is tied with embed_token. + # To prevent errors, skip loading lm_head.weight. + if "lm_head.weight" in name: + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/constant_size_cache.py b/vllm/model_executor/models/constant_size_cache.py new file mode 100644 index 0000000..d073a7d --- /dev/null +++ b/vllm/model_executor/models/constant_size_cache.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: Apache-2.0 +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Tuple + +import torch + +from vllm.attention.backends.utils import PAD_SLOT_ID + + +class ConstantSizeCache(ABC): + """ + Abstract base class for managing constant size caches + like Mamba and Minimax. + """ + + def __init__(self, max_batch_size: int): + # Maps between the request id and a dict that maps between the seq_id + # and its index inside the cache + self.cache_indices_mapping: Dict[str, Dict[int, int]] = {} + self.free_cache_indices = list(range(max_batch_size)) + + @property + @abstractmethod + def cache(self) -> Any: + """Return the underlying cache tensor(s)""" + pass + + @abstractmethod + def _copy_cache(self, from_index: int, to_index: int): + """Copy cache data from one index to another""" + pass + + def current_run_tensors(self, **kwargs) -> Tuple: + """ + Return the tensors for the current run's conv and ssm state. + """ + if "seqlen_agnostic_capture_inputs" not in kwargs: + # We get here only on Prefill/Eager mode runs + request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] + finished_requests_ids = kwargs["finished_requests_ids"] + + self._release_finished_requests(finished_requests_ids) + state_indices = self._prepare_current_run_cache( + request_ids_to_seq_ids, finished_requests_ids) + + state_indices_tensor = torch.as_tensor(state_indices, + dtype=torch.int32, + device="cuda") + cache_tensors = self.cache + else: + # CUDA graph capturing runs + cache_tensors, state_indices_tensor = kwargs[ + "seqlen_agnostic_capture_inputs"] + + return (cache_tensors, state_indices_tensor) + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + """ + Copy the relevant state_indices into the CUDA graph input buffer + """ + assert all( + key in kwargs + for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) + finished_requests_ids = kwargs["finished_requests_ids"] + request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] + assert "seqlen_agnostic_capture_inputs" in input_buffers + _, input_state_indices_buffer = input_buffers[ + "seqlen_agnostic_capture_inputs"] + + self._release_finished_requests(finished_requests_ids) + state_indices = self._prepare_current_run_cache( + request_ids_to_seq_ids, finished_requests_ids) + cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len( + state_indices) + state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len) + + input_state_indices_buffer.copy_( + torch.as_tensor(state_indices, dtype=torch.int32, device="cuda")) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + """ + Provide the CUDA graph capture runs with a buffer in adjusted size. + The buffer is used to maintain the Cache during the CUDA graph replay + runs. + """ + state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size, + dtype=torch.int32, + device="cuda") + return (self.cache, state_indices_tensor) + + def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int, + finished_requests_ids) -> int: + """ + Assign (req_id,seq_id) pair to a `destination_index` index, if + already occupied, move the occupying index to a free index. + """ + if cur_rid in finished_requests_ids: + # set as pad, do not allocate destination index + return PAD_SLOT_ID + elif cur_rid not in self.cache_indices_mapping: + destination_index = self.free_cache_indices.pop() + self.cache_indices_mapping[cur_rid] = {seq_id: destination_index} + return destination_index + elif seq_id not in (seq_ids2indices := + self.cache_indices_mapping[cur_rid]): + # parallel sampling , where n > 1, assume prefill have + # already happened, so we copy the + # existing cache into the siblings seq_ids caches + index_exists = next(iter(seq_ids2indices.values())) + # case of decoding n>1, copy prefill cache to decoding indices + destination_index = self.free_cache_indices.pop() + self._copy_cache(from_index=index_exists, + to_index=destination_index) + self.cache_indices_mapping[cur_rid][seq_id] = destination_index + return destination_index + else: + return self.cache_indices_mapping[cur_rid][seq_id] + + def _prepare_current_run_cache( + self, request_ids_to_seq_ids: Dict[str, list[int]], + finished_requests_ids: List[str]) -> List[int]: + return [ + self._assign_seq_id_to_cache_index(req_id, seq_id, + finished_requests_ids) + for req_id, seq_ids in request_ids_to_seq_ids.items() + for seq_id in seq_ids + ] + + def _release_finished_requests(self, + finished_seq_groups_req_ids: List[str]): + for req_id in finished_seq_groups_req_ids: + if req_id in self.cache_indices_mapping: + for seq_id in self.cache_indices_mapping[req_id]: + self.free_cache_indices.append( + self.cache_indices_mapping[req_id][seq_id]) + self.cache_indices_mapping.pop(req_id) diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py new file mode 100644 index 0000000..b665298 --- /dev/null +++ b/vllm/model_executor/models/dbrx.py @@ -0,0 +1,479 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +import torch.nn as nn + +from vllm.attention import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.dbrx import DbrxConfig + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class DbrxRouter(nn.Module): + """A Router implementation for DBRX that returns logits for each expert + per token. + """ + + def __init__( + self, + config: DbrxConfig, + params_dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.num_total_experts = config.ffn_config.moe_num_experts + self.d_model = config.d_model + self.layer = ReplicatedLinear( + self.d_model, + self.num_total_experts, + bias=False, + params_dtype=params_dtype, + quant_config=None, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + router_logits, _ = self.layer(hidden_states) + return router_logits + + +class DbrxExperts(FusedMoE): + + def __init__( + self, + config: DbrxConfig, + quant_config: Optional[QuantizationConfig] = None, + params_dtype: Optional[torch.dtype] = None, + prefix: str = "", + ): + super().__init__( + num_experts=config.ffn_config.moe_num_experts, + top_k=config.ffn_config.moe_top_k, + hidden_size=config.d_model, + intermediate_size=config.ffn_config.ffn_hidden_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=get_tensor_model_parallel_world_size(), + prefix=prefix, + ) + self.config = config + self.tp_size = get_tensor_model_parallel_world_size() + self.d_model = config.d_model + self.intermediate_size = (self.config.ffn_config.ffn_hidden_size // + self.tp_size) + + # Define custom weight loader for dbrx model + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, + weight_name: str, param_name: str): + tp_rank = get_tensor_model_parallel_rank() + param_data = param.data + shard_size = self.intermediate_size + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + # DBRX uses GLU for each experts. + # GLU has 3 linear layers: w1, v1 and w2. + if weight_name.endswith("w1"): + if param_name.endswith("weight"): + loaded_weight = torch.reshape( + loaded_weight, + [-1, self.intermediate_size * self.tp_size, self.d_model], + ) + param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :] + elif param_name.endswith("weight_scale"): + param_data[:, 0] = loaded_weight + else: + param_data = loaded_weight + if weight_name.endswith("v1"): + if param_name.endswith("weight"): + loaded_weight = torch.reshape( + loaded_weight, + [-1, self.intermediate_size * self.tp_size, self.d_model], + ) + param_data[:, shard_size:2 * + shard_size, :] = loaded_weight[:, shard, :] + elif param_name.endswith("weight_scale"): + param_data[:, 1] = loaded_weight + else: + param_data[:] = loaded_weight + if weight_name.endswith("w2"): + if param_name.endswith("weight"): + loaded_weight = torch.reshape( + loaded_weight, + [-1, self.intermediate_size * self.tp_size, self.d_model], + ).transpose(1, 2) + param_data[:] = loaded_weight[:, :, shard] + else: + param_data[:] = loaded_weight + + +class DbrxMoE(nn.Module): + """A tensor-parallel MoE implementation for DBRX. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__( + self, + config: DbrxConfig, + quant_config: Optional[QuantizationConfig] = None, + params_dtype: Optional[torch.dtype] = None, + prefix: str = "", + ): + super().__init__() + self.d_model = config.d_model + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + self.router = DbrxRouter(config, self.params_dtype) + + self.experts = DbrxExperts(config=config, + quant_config=quant_config, + params_dtype=self.params_dtype, + prefix=f"{prefix}.experts") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.d_model) + # router_logits: (num_tokens, n_experts) + router_logits = self.router(hidden_states) + final_hidden_states = self.experts(hidden_states, router_logits) + return final_hidden_states.view(orig_shape) + + +class DbrxAttention(nn.Module): + + def __init__( + self, + config: DbrxConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.d_model = config.d_model + self.total_num_heads = config.n_heads + self.head_dim = self.d_model // self.total_num_heads + self.total_num_kv_heads = config.attn_config.kv_n_heads + self.clip_qkv = config.attn_config.clip_qkv + self.rope_theta = config.attn_config.rope_theta + self.max_position = config.max_seq_len + + # pylint: disable=invalid-name + self.Wqkv = QKVParallelLinear( + self.d_model, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + self.out_proj = RowParallelLinear( + self.d_model, + self.d_model, + bias=False, + quant_config=quant_config, + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position, + base=int(self.rope_theta), + is_neox_style=True, + ) + + tp_world_size = get_tensor_model_parallel_world_size() + self.tp_size = tp_world_size + assert self.total_num_heads % tp_world_size == 0 + self.num_heads = self.total_num_heads // tp_world_size + if self.total_num_kv_heads >= tp_world_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_world_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_world_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.Wqkv(hidden_states) + if self.clip_qkv is not None: + qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(position_ids, q, k) + attn_output = self.attn(q, k, v) + hidden_states, _ = self.out_proj(attn_output) + return hidden_states + + +class DbrxFusedNormAttention(nn.Module): + + def __init__( + self, + config: DbrxConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.d_model = config.d_model + self.attn = DbrxAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.attn") + self.norm_1 = nn.LayerNorm(self.d_model) + self.norm_2 = nn.LayerNorm(self.d_model) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.norm_1(hidden_states) + x = self.attn( + position_ids=position_ids, + hidden_states=hidden_states, + ) + hidden_states = residual + x + residual = hidden_states + hidden_states = self.norm_2(hidden_states) + return hidden_states, residual + + +class DbrxBlock(nn.Module): + + def __init__( + self, + config: DbrxConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.norm_attn_norm = DbrxFusedNormAttention( + config, + cache_config, + quant_config, + prefix=f"{prefix}.norm_attn_norm") + self.ffn = DbrxMoE(config, quant_config, prefix=f"{prefix}.ffn") + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + hidden_states, residual = self.norm_attn_norm( + position_ids=position_ids, + hidden_states=hidden_states, + ) + hidden_states = self.ffn(hidden_states) + hidden_states = hidden_states + residual + return hidden_states + + +class DbrxModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.wte = VocabParallelEmbedding( + config.vocab_size, + config.d_model, + ) + self.start_layer, self.end_layer, self.blocks = make_layers( + config.n_layers, + lambda prefix: DbrxBlock( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.blocks", + ) + self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5) + for module in self.modules(): + if hasattr(module, "bias") and isinstance(module.bias, + nn.Parameter): + # Remove the bias term in Linear and LayerNorm. + module.register_parameter("bias", None) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.d_model)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + else: + assert intermediate_tensors + hidden_states = intermediate_tensors["hidden_states"] + for block in self.blocks[self.start_layer:self.end_layer]: + hidden_states = block(position_ids, hidden_states) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + hidden_states = self.norm_f(hidden_states) + return hidden_states + + +class DbrxForCausalLM(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + if config.tie_word_embeddings: + raise ValueError( + "tie_word_embeddings is not supported for Dbrx models.") + self.quant_config = quant_config + self.unpadded_vocab_size = config.vocab_size + self.transformer = DbrxModel(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "transformer")) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.d_model, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + quant_config=quant_config, + ) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + expert_params_mapping = [( + "w13" if weight_name in ["w1", "v1"] else "w2", + f"mlp.{weight_name}", + ) for weight_name in ["w1", "v1", "w2"]] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() + + for name, loaded_weight in weights: + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + + if name.endswith(("w1", "w2", "v1")): + name = name + "_weight" + for param_name, weight_name in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, weight_name, name) + break + + else: + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py new file mode 100644 index 0000000..f0212f3 --- /dev/null +++ b/vllm/model_executor/models/deepseek.py @@ -0,0 +1,488 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Deepseek model.""" +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (extract_layer_index, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class DeepseekMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class DeepseekMoE(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + self.n_routed_experts = config.n_routed_experts + self.top_k = config.num_experts_per_tok + if self.tp_size > self.n_routed_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {self.n_routed_experts}.") + + self.experts = nn.ModuleList([ + DeepseekMLP(hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False) + for idx in range(self.n_routed_experts) + ]) + self.pack_params() + + self.gate = ReplicatedLinear(config.hidden_size, + self.n_routed_experts, + bias=False, + quant_config=None) + + if config.n_shared_experts is not None: + intermediate_size = (config.moe_intermediate_size * + config.n_shared_experts) + self.shared_experts = DeepseekMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + ) + + def pack_params(self): + w1 = [] + w2 = [] + for expert in self.experts: + w1.append(expert.gate_up_proj.weight) + w2.append(expert.down_proj.weight) + self.w1 = torch._utils._flatten_dense_tensors(w1) + w1s = torch._utils._unflatten_dense_tensors(self.w1, w1) + for data, param in zip(w1s, w1): + param.data = data + self.w1 = self.w1.view(len(w1), *w1s[0].shape) + + self.w2 = torch._utils._flatten_dense_tensors(w2) + w2s = torch._utils._unflatten_dense_tensors(self.w2, w2) + for data, param in zip(w2s, w2): + param.data = data + + self.w2 = self.w2.view(len(w2), *w2s[0].shape) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + if self.config.n_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = fused_moe(hidden_states, + self.w1, + self.w2, + router_logits, + self.top_k, + renormalize=self.config.norm_topk_prob, + inplace=True) + + if self.config.n_shared_experts is not None: + final_hidden_states = final_hidden_states + shared_output + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(num_tokens, hidden_dim) + + +class DeepseekAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class DeepseekDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + layer_idx = extract_layer_index(prefix) + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.self_attn = DeepseekAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + if (config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0): + self.mlp = DeepseekMoE(config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + else: + self.mlp = DeepseekMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class DeepseekModel(nn.Module): + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: DeepseekDecoderLayer( + config, cache_config, quant_config=quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers") + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer(positions, hidden_states, residual) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class DeepseekForCausalLM(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = DeepseekModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip experts that are not assigned to this worker. + if (("mlp.experts." in name or "mlp.shared_experts." in name) + and name not in params_dict): + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip experts that are not assigned to this worker. + if (("mlp.experts." in name or "mlp.shared_experts." in name) + and name not in params_dict): + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py new file mode 100644 index 0000000..558ec23 --- /dev/null +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -0,0 +1,278 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Iterable, Optional, Set, Tuple + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .deepseek_v2 import (DeepseekV2DecoderLayer, + get_spec_layer_idx_from_weight_name) +from .utils import maybe_prefix + + +class SharedHead(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.norm(hidden_states) + + +class DeepSeekMultiTokenPredictorLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + model_config: ModelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.eh_proj = nn.Linear(config.hidden_size * 2, + config.hidden_size, + bias=False) + self.shared_head = SharedHead(config=config, quant_config=quant_config) + self.mtp_block = DeepseekV2DecoderLayer(config, prefix, model_config, + cache_config, quant_config) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_index: int = 0, + ) -> torch.Tensor: + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + assert inputs_embeds is not None + # masking inputs at position 0, as not needed by MTP + inputs_embeds[positions == 0] = 0 + inputs_embeds = self.enorm(inputs_embeds) + previous_hidden_states = self.hnorm(previous_hidden_states) + + hidden_states = self.eh_proj( + torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + + hidden_states, residual = self.mtp_block(positions=positions, + hidden_states=hidden_states, + residual=None) + hidden_states = residual + hidden_states + return hidden_states + + +class DeepSeekMultiTokenPredictor(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = config.num_nextn_predict_layers + # to map the exact layer index from weights + self.layers = torch.nn.ModuleDict({ + str(idx): + DeepSeekMultiTokenPredictorLayer( + config, + f"{prefix}.layers.{idx}", + model_config=vllm_config.model_config, + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + ) + for idx in range(self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers) + }) + + self.logits_processor = LogitsProcessor(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + current_step_idx = (spec_step_idx % self.num_mtp_layers) + return self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( + input_ids, + positions, + previous_hidden_states, + inputs_embeds, + current_step_idx, + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, + ) -> torch.Tensor: + current_step_idx = (spec_step_idx % self.num_mtp_layers) + mtp_layer = self.layers[str(self.mtp_start_layer_idx + + current_step_idx)] + logits = self.logits_processor(mtp_layer.shared_head.head, + mtp_layer.shared_head(hidden_states), + sampling_metadata) + return logits + + +class DeepSeekMTP(nn.Module): + + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.config = vllm_config.model_config.hf_config + self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + + self.sampler = get_sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, + previous_hidden_states, inputs_embeds, + spec_step_idx) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, + ) -> Optional[torch.Tensor]: + return self.model.compute_logits(hidden_states, sampling_metadata, + spec_step_idx) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is None: + continue + name = self._rewrite_spec_layer_name(spec_layer, name) + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: + """ + Rewrite the weight name to match the format of the original model. + Add .mtp_block for modules in transformer layer block for spec layer + """ + spec_layer_weight_names = [ + "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head" + ] + spec_layer_weight = False + for weight_name in spec_layer_weight_names: + if weight_name in name: + spec_layer_weight = True + break + if not spec_layer_weight: + # treat rest weights as weights for transformer layer block + name = name.replace(f"model.layers.{spec_layer}.", + f"model.layers.{spec_layer}.mtp_block.") + return name diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py new file mode 100644 index 0000000..943a82e --- /dev/null +++ b/vllm/model_executor/models/deepseek_v2.py @@ -0,0 +1,937 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only DeepseekV2/DeepseekV3 model.""" +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.gguf import GGUFConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) +import ixformer.inference.functions as ops +QUANT_CONFIG_FOR_GGUF = None +from torch.cuda import profiler + + +class DeepseekV2MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj") + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class DeepseekV2MoE(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.routed_scaling_factor = config.routed_scaling_factor + self.n_shared_experts = config.n_shared_experts + + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + + self.gate = ReplicatedLinear(config.hidden_size, + config.n_routed_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate") + if config.topk_method == "noaux_tc": + self.gate.e_score_correction_bias = nn.Parameter( + torch.empty(config.n_routed_experts)) + else: + self.gate.e_score_correction_bias = None + + self.experts = FusedMoE( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=config.scoring_func, + e_score_correction_bias=self.gate.e_score_correction_bias) + + if config.n_shared_experts is not None: + intermediate_size = (config.moe_intermediate_size * + config.n_shared_experts) + self.shared_experts = DeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + prefix=f"{prefix}.shared_experts", + ) + self.add_shared_on_experts = quant_config.__class__.__name__ != "GGUFConfig" and config.n_shared_experts is not None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + if self.n_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + if self.add_shared_on_experts: + final_hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + extra_residual=shared_output, + routed_scaling_factor=self.routed_scaling_factor, + ) + else: + final_hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits) * self.routed_scaling_factor + if self.n_shared_experts is not None: + final_hidden_states = final_hidden_states + shared_output + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(num_tokens, hidden_dim) + + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + import math + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +class DeepseekV2Attention(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: int, + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.num_heads = num_heads + tp_size = get_tensor_model_parallel_world_size() + assert num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear(self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj") + self.q_a_layernorm = RMSNorm(self.q_lora_rank, + eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear(q_lora_rank, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj") + else: + self.q_proj = ColumnParallelLinear(self.hidden_size, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj") + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa") + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj") + # O projection. + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + if rope_scaling: + rope_scaling["rope_type"] = 'deepseek_yarn' + + self.rotary_emb = get_rope(qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False) + + if rope_scaling: + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + self.attn = Attention(self.num_local_heads, + self.qk_head_dim, + self.scaling, + num_kv_heads=self.num_local_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward_opt( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + if self.q_lora_rank is not None: + q_latent_kpe = self.q_a_proj(hidden_states)[0] + q, kv_a, k_pe = q_latent_kpe.split([self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim], dim=1) + q = self.q_a_layernorm(q) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) + else: + q_latent_kpe = self.q_proj(hidden_states)[0] + q, kv_a, k_pe = q_latent_kpe.split([self.num_heads * self.qk_head_dim, self.kv_lora_rank, self.qk_rope_head_dim], dim=1) + q = q.view(-1, self.num_local_heads, self.qk_head_dim) + + _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + kv_a = self.kv_a_layernorm(kv_a) + kv = self.kv_b_proj(kv_a)[0] + kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v_nope = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.empty_like(q) + v = torch.empty_like(q) + ops.mla_rope(positions, q_pe, k_pe, k[...,self.qk_nope_head_dim:], self.rotary_emb.cos_sin_cache) + ops.mla_copy_kv(k_nope, v_nope, k, v) + + attn_output = self.attn(q, k, v) + attn_output = attn_output.view( + -1, self.num_local_heads, self.qk_head_dim)[..., :self.v_head_dim].reshape( + -1, self.num_local_heads * self.v_head_dim) + output, _ = self.o_proj(attn_output) + return output + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor + ) -> torch.Tensor: + if self.q_lora_rank is not None: + q = self.q_a_proj(hidden_states)[0] + kv_a, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split([self.kv_lora_rank, self.qk_rope_head_dim], dim=1) + q = self.q_a_layernorm(q) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) + else: + q = self.q_proj(hidden_states)[0] + kv_a, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split([self.kv_lora_rank, self.qk_rope_head_dim], dim=1) + q = q.view(-1, self.num_local_heads, self.qk_head_dim) + + _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + kv_a = self.kv_a_layernorm(kv_a) + kv = self.kv_b_proj(kv_a)[0] + kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v_nope = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.empty_like(q) + v = torch.empty_like(q) + ops.mla_rope(positions, q_pe, k_pe, k[...,self.qk_nope_head_dim:], self.rotary_emb.cos_sin_cache) + ops.mla_copy_kv(k_nope, v_nope, k, v) + + attn_output = self.attn(q, k, v) + attn_output = attn_output.view( + -1, self.num_local_heads, self.qk_head_dim)[..., :self.v_head_dim].reshape( + -1, self.num_local_heads * self.v_head_dim) + output, _ = self.o_proj(attn_output) + return output + + +class DeepseekV2MLAAttention(nn.Module): + """ + Main reference: DeepseekV2 paper, and FlashInfer Implementation + (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). + + For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py + """ + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: Optional[int], + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + + self.num_heads = num_heads + tp_size = get_tensor_model_parallel_world_size() + assert num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size + + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + global QUANT_CONFIG_FOR_GGUF + if QUANT_CONFIG_FOR_GGUF is None: + import copy + quant_config_for_gguf = copy.copy(quant_config) if quant_config is not None and quant_config.__class__.__name__ == "GGUFConfig" else quant_config + if hasattr(quant_config_for_gguf, "keep_half_weight"): + setattr(quant_config_for_gguf, "keep_half_weight", True) + QUANT_CONFIG_FOR_GGUF = quant_config_for_gguf + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear(self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj") + self.q_a_layernorm = RMSNorm(self.q_lora_rank, + eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear(q_lora_rank, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=QUANT_CONFIG_FOR_GGUF, + prefix=f"{prefix}.q_b_proj") + else: + self.q_proj = ColumnParallelLinear(self.hidden_size, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=QUANT_CONFIG_FOR_GGUF, + prefix=f"{prefix}.q_proj") + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa") + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=QUANT_CONFIG_FOR_GGUF, + prefix=f"{prefix}.kv_b_proj") + from vllm import envs + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=QUANT_CONFIG_FOR_GGUF, + prefix=f"{prefix}.o_proj") + + if rope_scaling: + rope_scaling["rope_type"] = 'deepseek_yarn' + self.rotary_emb = get_rope(qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False) + if rope_scaling: + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + # In the MLA backend, kv_cache includes both k_c and + # pe (i.e. decoupled position embeddings). In particular, + # the concat_and_cache_mla op requires + # k_c.size(1) + k_pe.size(1) == kv_cache.size(2) + # i.e. + # kv_lora_rank + qk_rope_head_dim == head_size + self.mla_attn = Attention( + num_heads=self.num_local_heads, + head_size=self.kv_lora_rank + self.qk_rope_head_dim, + scale=self.scaling, + num_kv_heads=1, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_mla=True, + # MLA Args + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + qk_head_dim=self.qk_head_dim, + v_head_dim=self.v_head_dim, + rotary_emb=self.rotary_emb, + q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, + kv_b_proj=self.kv_b_proj, + o_proj=self.o_proj, + ) + + self.prefix = prefix + self.debug_layer_idx = int(self.prefix.split(".")[-2]) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + if self.q_lora_rank is not None: + q = self.q_a_proj(hidden_states)[0] + kv_a, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split([self.kv_lora_rank, self.qk_rope_head_dim], dim=1) + q = self.q_a_layernorm(q) + kv_a = self.kv_a_layernorm(kv_a) + else: + q = hidden_states + latent_kpe = self.kv_a_proj_with_mqa(hidden_states)[0] + kv_a, k_pe = latent_kpe.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=1) + kv_a = self.kv_a_layernorm(kv_a) + + return self.mla_attn(q, kv_a, k_pe) + + def forward_opt( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + if self.q_lora_rank is not None: + q_latent_kpe = self.q_a_proj(hidden_states)[0] + q, kv_a, k_pe, _ = q_latent_kpe.split([self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim, self.q_a_proj.output_padding_size], dim=1) + q = self.q_a_layernorm(q) + kv_a = self.kv_a_layernorm(kv_a) + else: + q = hidden_states + latent_kpe = self.kv_a_proj_with_mqa(hidden_states)[0] + kv_a, k_pe = latent_kpe.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=1) + kv_a = self.kv_a_layernorm(kv_a) + + return self.mla_attn(q, kv_a, k_pe) + + +class DeepseekV2DecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + model_config: ModelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # DecoderLayers are created with `make_layers` which passes the prefix + # with the layer's index. + layer_idx = int(prefix.split(sep='.')[-1]) + if model_config.use_mla: + attn_cls = DeepseekV2MLAAttention + else: + attn_cls = DeepseekV2Attention + self.self_attn = attn_cls( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=config.v_head_dim, + q_lora_rank=config.q_lora_rank + if hasattr(config, "q_lora_rank") else None, + kv_lora_rank=config.kv_lora_rank, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + if (config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0): + self.mlp = DeepseekV2MoE( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + else: + self.mlp = DeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.routed_scaling_factor = config.routed_scaling_factor + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + if isinstance(self.mlp, DeepseekV2MoE) and \ + hidden_states.dtype == torch.float16: + # This is a special case to avoid FP16 overflow + hidden_states *= 1. / self.routed_scaling_factor + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + if isinstance(self.mlp, DeepseekV2MLP) and \ + hidden_states.dtype == torch.float16: + # This is a special case to avoid FP16 overflow + hidden_states *= 1. / self.routed_scaling_factor + residual *= 1. / self.routed_scaling_factor + return hidden_states, residual + + +@support_torch_compile +class DeepseekV2Model(nn.Module): + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config + + self.vocab_size = config.vocab_size + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens") + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: DeepseekV2DecoderLayer( + config, + prefix, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + ), + prefix=f"{prefix}.layers") + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer(positions, hidden_states, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class DeepseekV2ForCausalLM(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = DeepseekV2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + else: + self.lm_head = PPMissingLayer() + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is not None: + continue # skip spec decode layers for main model + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + opt_support_quant_method = ["GGUFLinearMethod", "UnquantizedLinearMethod", "CompressedTensorsW8A8Int8", "AWQMarlinLinearMethod"] + # add your opt here.. + def inject_layer(layer, quant_method, is_mla): + q_lora_rank = getattr(layer, "q_lora_rank", None) + if quant_method in ["UnquantizedLinearMethod", "CompressedTensorsW8A8Int8"]: + if q_lora_rank is not None: + layer.q_a_proj.weight.data = torch.cat([layer.q_a_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=0) + if hasattr(layer.q_a_proj, "weight_scale"): + layer.q_a_proj.weight_scale.data = torch.cat([layer.q_a_proj.weight_scale, layer.kv_a_proj_with_mqa.weight_scale], dim=0) + del layer.kv_a_proj_with_mqa.weight_scale + elif not is_mla: + layer.q_proj.weight.data = torch.cat([layer.q_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=0) + if hasattr(layer.q_proj, "weight_scale"): + layer.q_proj.weight_scale.data = torch.cat([layer.q_proj.weight_scale, layer.kv_a_proj_with_mqa.weight_scale], dim=0) + del layer.kv_a_proj_with_mqa.weight_scale + else: + return + del layer.kv_a_proj_with_mqa.weight + del layer.kv_a_proj_with_mqa + layer.forward = layer.forward_opt + elif quant_method == "GGUFLinearMethod": + pass + elif quant_method == "AWQMarlinLinearMethod": + dtype = layer.kv_a_proj_with_mqa.qweight.dtype + assert dtype == torch.int32 + if layer.q_lora_rank is not None: + layer.q_a_proj.qweight.data = torch.cat([layer.q_a_proj.qweight, layer.kv_a_proj_with_mqa.qweight], dim=1) + layer.q_a_proj.scales.data = torch.cat([layer.q_a_proj.scales, layer.kv_a_proj_with_mqa.scales], dim=1) + del layer.kv_a_proj_with_mqa.scales + layer.q_a_proj.qzeros.data = torch.cat([layer.q_a_proj.qzeros, layer.kv_a_proj_with_mqa.qzeros], dim=1) + del layer.kv_a_proj_with_mqa.qzeros + elif not is_mla: + layer.q_proj.weight.data = torch.cat([layer.q_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=1) + layer.q_proj.scales.data = torch.cat([layer.q_proj.scales, layer.kv_a_proj_with_mqa.scales], dim=1) + del layer.kv_a_proj_with_mqa.scales + layer.q_proj.qzeros.data = torch.cat([layer.q_proj.qzeros, layer.kv_a_proj_with_mqa.qzeros], dim=1) + del layer.kv_a_proj_with_mqa.qzeros + else: + return + + del layer.kv_a_proj_with_mqa.qweight + del layer.kv_a_proj_with_mqa + layer.forward = layer.forward_opt + else: + pass + + for _, layer in self.model.named_modules(): + if layer.__class__.__name__ in ["DeepseekV2Attention","DeepseekV2MLAAttention"]: + if hasattr(layer.kv_a_proj_with_mqa, "scheme"): + quant_method = layer.kv_a_proj_with_mqa.scheme.__class__.__name__ + else: + quant_method = layer.kv_a_proj_with_mqa.quant_method.__class__.__name__ + if quant_method not in opt_support_quant_method: + break + + inject_layer(layer, quant_method, is_mla = layer.__class__.__name__ == "DeepseekV2MLAAttention") + + return loaded_params + + +class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): + pass + + +def get_spec_layer_idx_from_weight_name(config: PretrainedConfig, + weight_name: str) -> Optional[int]: + if hasattr(config, + "num_nextn_predict_layers") and (config.num_nextn_predict_layers + > 0): + layer_idx = config.num_hidden_layers + for i in range(config.num_nextn_predict_layers): + if weight_name.startswith(f"model.layers.{layer_idx+i}."): + return layer_idx + i + return None diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py new file mode 100644 index 0000000..4554a99 --- /dev/null +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -0,0 +1,673 @@ +# SPDX-License-Identifier: Apache-2.0 + +# adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/faf18023f24b962b32d9f0a2d89e402a8d383a78/deepseek_vl2/models/modeling_deepseek_vl_v2.py +"""Inference-only Deepseek-VL2 model compatible with HuggingFace weights.""" +import math +from collections.abc import Iterable, Mapping, Sequence +from functools import cached_property +from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from transformers import BatchFeature + +from vllm.config import VllmConfig +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.model_loader.utils import set_default_torch_dtype +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, + NestedTensors) +from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, + ImageSize, MultiModalDataItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config, + MlpProjectorConfig, + VisionEncoderConfig) +from vllm.transformers_utils.processors.deepseek_vl2 import ( + DeepseekVLV2Processor) +from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config +from vllm.utils import is_list_of + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, + init_vllm_registered_model, maybe_prefix, + merge_multimodal_embeddings) + +# The image token id may be various +_IMAGE_TOKEN = "" + + +class DeepseekVL2ImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: Union[torch.Tensor, List[torch.Tensor]] + """ + Shape: `(batch_size * num_images, num_channels, height, width)` + """ + images_spatial_crop: torch.Tensor + """ + Shape: `(batch_size * num_images, 2)` + """ + + +class DeepseekVL2VImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: Union[torch.Tensor, List[torch.Tensor]] + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + """ + + +DeepseekVL2ImageInputs = Union[DeepseekVL2ImagePixelInputs, + DeepseekVL2VImageEmbeddingInputs] + + +class MlpProjector(nn.Module): + + def __init__(self, cfg: MlpProjectorConfig): + + super().__init__() + + self.cfg = cfg + assert not cfg.token_pooling, ( + "Token pooling is not supported currently.") + + if cfg.projector_type == "downsample_mlp_gelu": + mlp_depth = cfg.depth + mlp_ratio = cfg.mlp_ratio + modules = [ + nn.Linear( + cfg.input_dim * cfg.downsample_ratio * + cfg.downsample_ratio, cfg.n_embed * mlp_ratio) + ] + for _ in range(1, mlp_depth - 1): + modules.append(nn.GELU()) + modules.append( + nn.Linear(cfg.n_embed * mlp_ratio, + cfg.n_embed * mlp_ratio)) + modules.append(nn.GELU()) + modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed)) + modules = nn.Sequential(*modules) + + else: + raise NotImplementedError( + f"Unsupported projector type: {cfg.projector_type}") + + self.layers = modules + + def forward(self, x): + bs, hw, input_dim = x.shape + h = w = int((hw)**0.5) + """compute padding""" + if h % self.cfg.downsample_ratio: + pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio + else: + pad = 0 + x = x.reshape(bs, h, w, input_dim) + if pad > 0: + x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0) + """4 to 1 concat""" + x = x.permute(0, 3, 1, 2) # B, C, H, W + x = F.unfold(x, + kernel_size=self.cfg.downsample_ratio, + stride=self.cfg.downsample_ratio, + padding=0) # B, C*4, HW // 4 + x = x.permute(0, 2, 1) + + return self.layers(x) + + +class DeepseekVL2ProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(DeepseekVLV2Config) + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(DeepseekVLV2Processor, **kwargs) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_num_image_tokens(self, + *, + image_width: int, + image_height: int, + cropping: bool = True) -> int: + hf_processor = self.get_hf_processor() + image_size = hf_processor.image_size + patch_size = hf_processor.patch_size + downsample_ratio = hf_processor.downsample_ratio + + if cropping: + best_width, best_height = hf_processor.select_best_resolution( + (image_width, image_height)) + num_width_tiles, num_height_tiles = (best_width // image_size, + best_height // image_size) + else: + num_width_tiles = num_height_tiles = 1 + + h = w = math.ceil((image_size // patch_size) / downsample_ratio) + + global_views_tokens = h * (w + 1) + local_views_tokens = (num_height_tiles * h) * (num_width_tiles * w + 1) + return global_views_tokens + local_views_tokens + 1 + + def get_image_size_with_most_features(self) -> ImageSize: + hf_config = self.get_hf_config() + candidate_resolutions = hf_config.candidate_resolutions + height, width = max(candidate_resolutions, + key=lambda x: self.get_num_image_tokens( + image_width=x[1], image_height=x[0])) + return ImageSize(width=width, height=height) + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + num_images = mm_counts.get("image", 0) + max_image_size = self.get_image_size_with_most_features() + max_image_tokens = self.get_num_image_tokens( + image_height=max_image_size.height, + image_width=max_image_size.width, + cropping=num_images <= 2) + + return {"image": max_image_tokens} + + +class DeepseekVL2DummyInputsBuilder( + BaseDummyInputsBuilder[DeepseekVL2ProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + hf_processor = self.info.get_hf_processor() + image_token: str = hf_processor.image_token + + max_image_size = self.info.get_image_size_with_most_features() + + mm_data = { + "image": + self._get_dummy_images(width=max_image_size.width, + height=max_image_size.height, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text=image_token * num_images, + mm_data=mm_data, + ) + + +class DeepseekVL2MultiModalProcessor( + BaseMultiModalProcessor[DeepseekVL2ProcessingInfo]): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + if mm_data: + processed_outputs = self.info.ctx.call_hf_processor( + self.info.get_hf_processor(**mm_kwargs), + dict(prompt=prompt, **mm_data), + mm_kwargs, + ) + target_dtype = self.info.ctx.model_config.dtype + pixel_values = processed_outputs.pop("pixel_values").to( + target_dtype) + # split pixel values into patches corresponding to each image + images_spatial_crop = processed_outputs["images_spatial_crop"] + patches_per_image = [ + x.prod().item() + 1 for x in images_spatial_crop + ] + pixel_values = pixel_values.split(patches_per_image) + processed_outputs["pixel_values"] = pixel_values + else: + tokenizer = self.info.get_tokenizer() + processed_outputs = tokenizer(prompt, + add_special_tokens=True, + return_tensors="pt") + + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + images_spatial_crop=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + image_token_id = hf_processor.image_token_id + assert isinstance(image_token_id, int) + + def get_replacement_deepseek_vl2(item_idx: int): + images = mm_items.get_items( + "image", (ImageEmbeddingItems, ImageProcessorItems)) + + if isinstance(images, ImageEmbeddingItems): + num_image_tokens = images.get_feature_size(item_idx) + else: + image_size = images.get_image_size(item_idx) + + num_image_tokens = self.info.get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + cropping=len(images) <= 2, + ) + return [image_token_id] * num_image_tokens + + return [ + PromptReplacement( + modality="image", + target=[image_token_id], + replacement=get_replacement_deepseek_vl2, + ) + ] + + def _cached_apply_hf_processor( + self, + prompt: Union[str, list[int]], + mm_data_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> tuple[list[int], MultiModalKwargs, bool]: + # The processor logic is different for len(images) <= 2 vs > 2 + # Since the processing cache assumes that the processor output is + # invariant of how many images are passed per prompt, we only + # perform caching for the most common case + if mm_data_items.get_count("image", strict=False) > 2: + # This code path corresponds to the cache being disabled + return self._apply_hf_processor_main( + prompt=prompt, + mm_items=mm_data_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + enable_hf_prompt_update=True, + ) + + return super()._cached_apply_hf_processor( + prompt=prompt, + mm_data_items=mm_data_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + + +@MULTIMODAL_REGISTRY.register_processor( + DeepseekVL2MultiModalProcessor, + info=DeepseekVL2ProcessingInfo, + dummy_inputs=DeepseekVL2DummyInputsBuilder) +class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): + + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ + "language.": "language_model.", + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config: DeepseekVLV2Config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + + self.vision_config = config.vision_config + self.projector_config = config.projector_config + self.text_config = config.text_config + + model_config = vllm_config.model_config + tokenizer = cached_tokenizer_from_config(model_config) + self.image_token_id = tokenizer.vocab[_IMAGE_TOKEN] + + self.vision = self._init_vision_module(self.vision_config, + quant_config, + maybe_prefix(prefix, "vision")) + + self.projector = MlpProjector(self.projector_config) + self.tile_tag = config.tile_tag + self.global_view_pos = config.global_view_pos + + # special token for image token sequence format + embed_std = 1 / torch.sqrt( + torch.tensor(self.projector_config.n_embed, dtype=torch.float32)) + if self.tile_tag == "2D": + # <|view_separator|>, <|\n|> + self.image_newline = nn.Parameter( + torch.randn(self.projector_config.n_embed) * embed_std) + # This is a typo in original implementation + self.view_seperator = nn.Parameter( + torch.randn(self.projector_config.n_embed) * embed_std) + else: + raise ValueError( + f"Only 2D tile_tag is supported currently, got: {self.tile_tag}" + ) + + if self.text_config.topk_method == "noaux_tc": + architectures = ["DeepseekV3ForCausalLM"] + elif not self.text_config.use_mla: + architectures = ["DeepseekForCausalLM"] + else: + architectures = ["DeepseekV2ForCausalLM"] + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=self.text_config, + prefix=maybe_prefix(prefix, "language"), + architectures=architectures, + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + def _init_vision_module( + self, + vision_config: VisionEncoderConfig, + quant_config: Optional[QuantizationConfig], + prefix: str = "", + ) -> nn.Module: + # TODO: refactor vision model through timm wrapper from transformers + try: + import timm + except ImportError: + raise ImportError("Please install timm") from ImportError + + with set_default_torch_dtype(torch.float16): + model = timm.create_model( + "vit_so400m_patch14_siglip_384.webli", + pretrained=False, + num_classes=0, + dynamic_img_size=True, + dynamic_img_pad=True, + ) + + model = model.to(dtype=torch.get_default_dtype()) + return model + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return get_sampler() + + def _validate_pixel_values( + self, data: Union[torch.Tensor, List[torch.Tensor]] + ) -> Union[torch.Tensor, List[torch.Tensor]]: + + h = w = self.vision_config.image_size + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("num_patches", *map(str, expected_dims)) + raise ValueError( + "The expected shape of pixel values per image per batch " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _validate_images_spatial_crop( + self, data: Union[torch.Tensor, List[torch.Tensor]] + ) -> Union[torch.Tensor, List[torch.Tensor]]: + expected_dims = 2 + + def _validate_shape(d: torch.Tensor): + actual_dims = d.size(-1) + + if actual_dims != expected_dims: + expected_expr = str(expected_dims) + raise ValueError( + f"The expected shape of image sizes per image per batch " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[DeepseekVL2ImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + images_spatial_crop = kwargs.pop("images_spatial_crop", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + if not isinstance(images_spatial_crop, (torch.Tensor, list)): + raise ValueError("Incorrect type of image sizes. " + f"Got type: {type(images_spatial_crop)}") + + return DeepseekVL2ImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values(flatten_bn(pixel_values)), + images_spatial_crop=self._validate_images_spatial_crop( + flatten_bn(images_spatial_crop, concat=True))) + + if image_embeds is not None: + if not isinstance(image_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + + return DeepseekVL2VImageEmbeddingInputs( + type="image_embeds", + data=flatten_bn(image_embeds), + ) + + raise AssertionError("This line should be unreachable.") + + def _pixel_values_to_embedding( + self, + pixel_values: NestedTensors, + images_spatial_crop: torch.Tensor, + ) -> NestedTensors: + # Pixel_values: n_image * batch_size * [patch_per_img, 3, height, width] + total_tiles = [x for x in pixel_values] + + # [batch_all_tiles, 3, height, width] + total_tiles = torch.cat(total_tiles, dim=0) + + # [batch_all_tiles, vit_seq_len, c] + images_feature = self.vision.forward_features(total_tiles) + + # [batch_all_tiles, hw, D] + images_embeds = self.projector(images_feature) + + _, hw, n_dim = images_embeds.shape + h = w = int(hw**0.5) + + # fill image token based on self.tile_tag & self.global_view_pos + tile_index = 0 + vision_embeddings = [] + for jdx in range(images_spatial_crop.size(0)): + # extra global & local features + num_width_tiles, num_height_tiles = images_spatial_crop[jdx] + if num_width_tiles == 0 or num_height_tiles == 0: + break + num_tiles_in_image = num_width_tiles * num_height_tiles + + # [hw, D] + global_features = images_embeds[tile_index] + + # [num_height_tiles * num_width_tiles, hw, D] + local_features = images_embeds[tile_index + 1:tile_index + 1 + + num_tiles_in_image] + tile_index += num_tiles_in_image + 1 + + # format global and local features + # ----------------- global view add newline ----------------- + # [hw, D] -> [h, w, D] + global_features = global_features.view(h, w, n_dim) + + # [D] -> [h, 1, D] + new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h) + + # cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D] + global_features = torch.cat([global_features, new_lines_in_global], + dim=1) + + # [h, w + 1, D] -> [h * (w + 1), D] + global_features = global_features.view(-1, n_dim) + + # ----------------- local view add newline ----------------- + # [num_height_tiles * num_width_tiles, h * w, D] -> + # [num_height_tiles * h, num_width_tiles * w, D] + local_features = rearrange(local_features, + "(th tw) (h w) d -> (th h) (tw w) d", + th=num_height_tiles, + tw=num_width_tiles, + h=h, + w=w) + + # [D] -> [num_height_tiles * h, 1, D] + new_lines_in_local = repeat(self.image_newline, + "d -> (th h) 1 d", + th=num_height_tiles, + h=h) + + # [num_height_tiles * h, num_width_tiles * w + 1, D] + local_features = torch.cat([local_features, new_lines_in_local], + dim=1) + + # [num_height_tiles * h, num_width_tiles * w + 1, D] + # --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D] + local_features = local_features.view(-1, n_dim) + + # merge global and local tiles + if self.global_view_pos == "head": + global_local_features = torch.cat([ + global_features, + self.view_seperator[None, :], + local_features, + ]) + else: + global_local_features = torch.cat([ + local_features, + self.view_seperator[None, :], + global_features, + ]) + + vision_embeddings.append(global_local_features) + return vision_embeddings + + def _process_image_input( + self, image_input: DeepseekVL2ImageInputs) -> torch.Tensor: + if image_input["type"] == "image_embeds": + image_data = image_input["data"] + if is_list_of(image_data, torch.Tensor): + # it's already a list of tensors + return image_data + if len(image_data.shape) == 3: + # 3D tensor + return list(torch.unbind(image_data, dim=0)) + raise ValueError( + "We expect batched 2D tensors; " + "this can be either a list of 2D tensors or a single 3D tensor." + ) + + pixel_values = image_input["data"] + images_spatial_crop = image_input["images_spatial_crop"] + + return self._pixel_values_to_embedding( + pixel_values=pixel_values, images_spatial_crop=images_spatial_crop) + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.image_token_id) + return inputs_embeds + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object): + + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.language_model(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + + loader = AutoWeightsLoader(self) + autoloaded_weights = loader.load_weights(weights, + mapper=self.hf_to_vllm_mapper) + return autoloaded_weights diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py new file mode 100644 index 0000000..9183ab4 --- /dev/null +++ b/vllm/model_executor/models/eagle.py @@ -0,0 +1,260 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Iterable, Optional, Tuple + +import torch +import torch.nn as nn + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from .interfaces import SupportsPP +from .utils import maybe_prefix + +logger = init_logger(__name__) + + +class DummyInputLayerNorm(nn.Module): + + def __init__(self, weight=None, bias=None): + super().__init__() + self.weight = nn.Parameter(weight) if weight is not None else None + self.bias = nn.Parameter(bias) if bias is not None else None + + def forward(self, x): + return x + + +class DummyOutputNorm(nn.Module): + + def forward(self, x, residual): + if residual is None: + return x + else: + return x + residual, None + + +class EAGLE(nn.Module, SupportsPP): + """This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077 + Reference implementation: https://github.com/SafeAILab/EAGLE + + Differences from reference implementation: + 1. In reference, LlamaDecoderLayer implementation doesn't have + input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427). + Following this approach, our implementation also disables + the input_layernorm for the first decoder layer. + 2. We allow any decoder layer to be used in EAGLE whereas in reference + decoder layer is fixed to be LlamaDecoderLayer. + 3. We have an optional token_map which reduces draft vocab to most + frequently used tokens to give some additional speed-up by reducing + sampling overhead. This is disabled unless the checkpoint file has + explicit token_map tensor and config has an optional attribute + truncated_vocab_size < vocab_size. To use this technique, one has to find + the top-k most frequent tokens in target dataset and add that as a tensor + in the draft checkpoint (using key token_map). Also, the draft config + needs to have truncated_vocab_size (=k) as an attribute. + 4. We allow an enhanced EAGLE architecture similar to the DeepSeek MTP + module with regards to the use of additional RMS norms. The original + EAGLE architecture 1) skips the pre-attention norm in its first + transformer block, and 2) skips the final output norm, both of which we + found to be suboptimal. We also add the support for separate norms + applying to both the token embedding and hidden states before projection + as in DeepSeek MTP, which we found to improve performance as well. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.config = config + + architectures = getattr(self.config.model, "architectures", []) + model_cls, _ = ModelRegistry.resolve_model_cls(architectures) + + self.model = model_cls(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + self.fc = nn.Linear(config.model.hidden_size * 2, + config.model.hidden_size, + bias=getattr(self.config, "eagle_fc_bias", False)) + + # Modify layer normalization and residual connections as suggested + # in the EAGLE framework: https://github.com/SafeAILab/EAGLE + # While weights and biases are generally not needed, + # they are retained here to support certain unit tests + # (e.g., spec_decode/e2e/test_eagle_correctness.py). + if not hasattr(self.config.model, + "skip_prenorm") or self.config.model.skip_prenorm: + self.model.model.layers[0].input_layernorm = DummyInputLayerNorm( + weight=self.model.model.layers[0].input_layernorm.weight) + + if not hasattr( + self.config.model, + "skip_output_norm") or self.config.model.skip_output_norm: + self.model.model.norm = DummyOutputNorm() + + self.add_para_norm = False + if hasattr(self.config.model, + "add_para_norm") and self.config.model.add_para_norm: + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.add_para_norm = True + + self.orig_vocab_size = config.vocab_size + self.truncated_vocab_size = config.truncated_vocab_size + self.unpadded_vocab_size = self.truncated_vocab_size + + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=self.truncated_vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + ) + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + self.truncated_vocab_size, + logit_scale) + + # Token map is a idx to token mapping to reduce the vocab size for + # the draft model. Using smaller vocab size for draft, containing + # only most frequent tokens reduces the speculation overhead. This + # doesn't affect the acceptance rate much and thus gives more speed + # -up. By default, this is disabled and is only used if the EAGLE + # checkpoint file has token_map tensor. + self.token_map = None + + @property + def sampler(self): + return self.model.sampler + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) + + if self.add_para_norm: + inputs_embeds = torch.cat([ + self.enorm(inputs_embeds), + self.hnorm(previous_hidden_states) + ], + dim=-1) + else: + inputs_embeds = torch.cat([inputs_embeds, previous_hidden_states], + dim=-1) + + inputs_embeds = self.fc(inputs_embeds) + + inputs_embeds[positions == 0] = 0 # masking inputs at position=0 + + hidden_states = self.model.model( + input_ids=None, + inputs_embeds=inputs_embeds, + positions=positions, + intermediate_tensors=intermediate_tensors, + ) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + + if self.token_map is not None: + _logits = logits + logits = -torch.inf * torch.ones( + size=(*_logits.shape[:-1], self.orig_vocab_size), + device=_logits.device, + dtype=_logits.dtype) + + logits[..., self.token_map] = _logits + + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + # This implementation is incompitable with https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B + # due to missing lm_head weights and its config being that of a + # Llama model. Here's a compatible version with the same weights: + # https://huggingface.co/abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm + # Also, here's an example script for converting trained EAGLE + # checkpoint to vLLM compatible version: https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d + model_weights = {} + for name, loaded_weight in weights: + if name == "token_map": + if self.config.truncated_vocab_size < self.config.vocab_size: + self.token_map = nn.Parameter(loaded_weight, + requires_grad=False) + elif name.startswith("fc.weight"): + weight_loader = getattr(self.fc.weight, "weight_loader", + default_weight_loader) + weight_loader(self.fc.weight, loaded_weight) + elif name.startswith("fc.bias"): + if self.fc.bias is not None: + weight_loader = getattr(self.fc.bias, "weight_loader", + default_weight_loader) + weight_loader(self.fc.bias, loaded_weight) + else: + logger.warning_once("Found bias in the loaded weights but " + "the model config doesn't have bias.") + elif name.startswith("enorm.weight"): + weight_loader = getattr(self.enorm.weight, "weight_loader", + default_weight_loader) + weight_loader(self.enorm.weight, loaded_weight) + elif name.startswith("hnorm.weight"): + weight_loader = getattr(self.hnorm.weight, "weight_loader", + default_weight_loader) + weight_loader(self.hnorm.weight, loaded_weight) + elif name.startswith("model.lm_head.") or name.startswith( + "model.model."): + model_weights[name.split("model.", 1)[-1]] = loaded_weight + elif name.startswith("lm_head.") or name.startswith("model."): + model_weights[name] = loaded_weight + else: + model_weights[f"model.{name}"] = loaded_weight + + if "lm_head.weight" in model_weights: + lm_head_weight = model_weights.pop("lm_head.weight") + + if self.token_map is not None and\ + lm_head_weight.shape[0] > self.token_map.shape[0]: + + lm_head_weight = lm_head_weight[self.token_map] + + else: + # NOTE(Shangming): initialize the placeholder for lm_head weight. + lm_head_weight = torch.zeros( + self.lm_head.org_vocab_size, + self.lm_head.embedding_dim, + dtype=self.config.torch_dtype, + ) + + weight_loader = getattr(self.lm_head.weight, "weight_loader", + default_weight_loader) + weight_loader(self.lm_head.weight, lm_head_weight) + + self.model.load_weights(model_weights.items()) diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py new file mode 100644 index 0000000..553c524 --- /dev/null +++ b/vllm/model_executor/models/exaone.py @@ -0,0 +1,559 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/blob/main/modeling_exaone.py +# Copyright 2024 The LG U+ CTO AI Tech Lab. +# Copyright 2021 The LG AI Research EXAONE Lab +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Exaone model compatible with HuggingFace weights.""" + +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.exaone import ExaoneConfig + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class ExaoneGatedMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.c_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.c_proj", + ) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.c_proj(x) + return x + + +class ExaoneAttention(nn.Module): + + def __init__( + self, + config: ExaoneConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr(config, "head_dim", + self.hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.out_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + + is_neox_style = True + if quant_config is not None and quant_config.get_name() == "gguf": + is_neox_style = False + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=is_neox_style, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.out_proj(attn_output) + return output + + +class ExaoneBlockAttention(nn.Module): + + def __init__( + self, + config: ExaoneConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.attention = ExaoneAttention( + config=config, + hidden_size=hidden_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=bias, + cache_config=cache_config, + prefix=f"{prefix}.attention", + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + return self.attention( + positions=positions, + hidden_states=hidden_states, + ) + + +class ExaoneDecoderLayer(nn.Module): + + def __init__( + self, + config: ExaoneConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) + self.attn = ExaoneBlockAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + cache_config=cache_config, + prefix=f"{prefix}.attn", + ) + self.mlp = ExaoneGatedMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.activation_function, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + else: + hidden_states, residual = self.ln_1(hidden_states, residual) + hidden_states = self.attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.ln_2(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class ExaoneModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.quant_config = quant_config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.wte = config.vocab_size + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.wte = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + else: + self.wte = PPMissingLayer() + self.start_layer, self.end_layer, self.h = make_layers( + config.num_hidden_layers, + lambda prefix: ExaoneDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.h", + ) + if get_pp_group().is_last_rank: + self.ln_f = RMSNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + else: + self.ln_f = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in self.h[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.ln_f(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".c_fc_0", 0), + (".gate_up_proj", ".c_fc_1", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "c_fc_0", + "c_fc_1", + ], + } + + # LoRA specific attributes + embedding_modules = { + "wte": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + self.quant_config = quant_config + + self.transformer = ExaoneModel( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + ) + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.transformer.wte.weight + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + else: + self.lm_head = PPMissingLayer() + + self.sampler = get_sampler() + + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader( + self, + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/fairseq2_llama.py b/vllm/model_executor/models/fairseq2_llama.py new file mode 100644 index 0000000..310aca9 --- /dev/null +++ b/vllm/model_executor/models/fairseq2_llama.py @@ -0,0 +1,153 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2024 The vLLM team. +# Copyright 2024 Meta Platforms, Inc. and affiliates. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Llama model for fairseq2 weights.""" + +from typing import Iterable, Set, Tuple + +import torch +from torch.nn import Parameter + +from vllm.config import VllmConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.linear import set_weight_attrs +from vllm.model_executor.models.llama import LlamaForCausalLM + +from .utils import AutoWeightsLoader, WeightsMapper + + +class Fairseq2LlamaForCausalLM(LlamaForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + # For the model loader to read only the relevant checkpoint files + self.allow_patterns_overrides = [ + # either the full checkpoint + "model.pt", + # or the tp-sharded checkpoint of the current rank + f"model.{self.tp_rank}.pt", + ] + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + # fairseq2's serialization adds a wrapper to usual .pt state_dict's: + # { "model_key": my_model_name, "my_model_name": state_dict } + # which we first need to unpack + weights_wrapped = dict(weights) + weights = weights_wrapped[ + weights_wrapped["model_key"]].items() # type: ignore + + # remap keys + fs2_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "decoder_frontend.embed.": "model.embed_tokens.", + "decoder.": "model.", + "final_proj.": "lm_head.", + }, + orig_to_new_substr={ + ".self_attn_layer_norm.": ".input_layernorm.", + ".ffn_layer_norm.": ".post_attention_layernorm.", + ".self_attn.output_proj.": ".self_attn.o_proj.", + ".ffn.gate_proj.": ".mlp.gate_proj.", + ".ffn.inner_proj.": ".mlp.up_proj.", + ".ffn.output_proj.": ".mlp.down_proj.", + ".layer_norm.": ".norm.", + }, + ) + weights = fs2_to_vllm_mapper.apply(weights) + + params = dict(self.named_parameters()) + + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights( + (self.reshape_fairseq2_weights(name, loaded_weight, params) + for name, loaded_weight in weights)) + + def flag_sharded_weights(self, params: dict[str, Parameter]): + """Sets the `is_sharded_weight` flag to True for all sharded weights""" + for name, param in params.items(): + modules = name.split(".") + if "norm" in name and len(param.size()) < 2: + # layer norms are not sharded + continue + elif any(emb in modules for emb in ["embed_tokens", "lm_head"]): + # for now we repeat embedding layers for compatibility + continue + else: + # all other layers are sharded + set_weight_attrs(param, {"is_sharded_weight": True}) + + def reshape_fairseq2_weights( + self, + name: str, + loaded_weight: torch.Tensor, + params: dict[str, Parameter], + ) -> Tuple[str, torch.Tensor]: + """Reshape fairseq2's weights.""" + + def permute(w: torch.Tensor, n_heads: int) -> torch.Tensor: + attn_in = self.config.head_dim * n_heads + # check for a sharded weight on dim 0 + if attn_in // self.tp_size == w.size()[0]: + attn_in //= self.tp_size + n_heads //= self.tp_size + attn_out = self.config.hidden_size + return (w.view(n_heads, attn_in // n_heads // 2, 2, + attn_out).transpose(1, + 2).reshape(attn_in, attn_out)) + + modules = name.split(".") + + # rotary embeds should be sliced + if "k_proj" in modules: + loaded_weight = permute(loaded_weight, + self.config.num_key_value_heads) + + elif "q_proj" in modules: + loaded_weight = permute(loaded_weight, + self.config.num_attention_heads) + + # We make the loaded weights compatible with both + # full checkpoints and tp sharded checkpoints. + # Embeddings are repeated to fit the vocab size. + # Other weights are flagged for the weight_loader calls. + if any(emb in modules for emb in ["embed_tokens", "lm_head"]): + # Embeddings are sharded on dim 0 + dim = 0 + # In fairseq2, vocab size has to be divisible by tp_size + # so we don't worry about padding + if self.tp_size > 1 and loaded_weight.shape[ + dim] < self.config.vocab_size: + assert loaded_weight.shape[ + dim] * self.tp_size == self.config.vocab_size, \ + "vocab_size should be divisible by tp_size." + repeats = [1] * len(loaded_weight.size()) + repeats[dim] = self.tp_size + # repeat to match vocab size and to be easily 'narrow'able + loaded_weight = loaded_weight.repeat(repeats) + set_weight_attrs(params[name], {"is_sharded_weight": False}) + # if embeddings are sharded, the rest is too + if "embed_tokens" in modules: + self.flag_sharded_weights(params) + + return name, loaded_weight diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py new file mode 100644 index 0000000..0e67b1e --- /dev/null +++ b/vllm/model_executor/models/falcon.py @@ -0,0 +1,518 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/a5cc30d72ae2dc19af534e4b35c986cc28db1275/src/transformers/models/falcon/modeling_falcon.py +# Copyright 2023 The vLLM team. +# Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights +# reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Falcon model.""" + +import math +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from torch.nn import LayerNorm +from transformers import FalconConfig as HF_FalconConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import RWConfig + +from .interfaces import SupportsPP +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +FalconConfig = Union[HF_FalconConfig, RWConfig] + + +def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: + closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) + base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))), + dtype=torch.float32) + powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != total_num_heads: + extra_base = torch.tensor( + 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + dtype=torch.float32) + num_remaining_heads = min(closest_power_of_2, + total_num_heads - closest_power_of_2) + extra_powers = torch.arange(1, + 1 + 2 * num_remaining_heads, + 2, + dtype=torch.int32) + slopes = torch.cat( + [slopes, torch.pow(extra_base, extra_powers)], dim=0) + + return slopes + + +class FalconAttention(nn.Module): + + def __init__( + self, + config: FalconConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.head_dim = self.hidden_size // self.total_num_heads + assert self.head_dim * self.total_num_heads == self.hidden_size + + self.new_decoder_architecture = config.new_decoder_architecture + self.multi_query = config.multi_query + + if self.new_decoder_architecture: + self.total_num_kv_heads = config.num_kv_heads + elif self.multi_query: + self.total_num_kv_heads = 1 + else: + self.total_num_kv_heads = self.total_num_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + self.query_key_value = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.bias, + skip_bias_add=True, + quant_config=quant_config, + ) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + # Layer-wise attention scaling + self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) + self.reduce_row_parallel_results = not (config.new_decoder_architecture + or config.parallel_attn) + self.dense = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=config.bias, + skip_bias_add=True, + quant_config=quant_config, + reduce_results=self.reduce_row_parallel_results) + + self.use_rotary = config.rotary + self.use_alibi = config.alibi + assert not (self.use_rotary and self.use_alibi), ( + "Rotary and alibi are mutually exclusive.") + + if self.use_rotary: + rope_theta = getattr(config, "rope_theta", 10000) + max_position_embeddings = getattr(config, + "max_position_embeddings", 8192) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.inv_norm_factor, + num_kv_heads=self.num_kv_heads, + quant_config=quant_config, + prefix=f"{prefix}.attn") + elif self.use_alibi: + tp_rank = get_tensor_model_parallel_rank() + head_start = tp_rank * self.num_heads + head_end = (tp_rank + 1) * self.num_heads + alibi_slopes = (_get_alibi_slopes(self.total_num_heads) * + self.inv_norm_factor) + alibi_slopes = alibi_slopes[head_start:head_end].tolist() + self.attn = Attention(self.num_heads, + self.head_dim, + self.inv_norm_factor, + num_kv_heads=self.num_kv_heads, + alibi_slopes=alibi_slopes, + quant_config=quant_config, + prefix=f"{prefix}.attn") + else: + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.inv_norm_factor, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, bias = self.query_key_value(hidden_states) + if bias is not None: + qkv += bias + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if self.use_rotary: + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + attn_output, bias = self.dense(attn_output) + return attn_output, bias + + +class FalconMLP(nn.Module): + + def __init__( + self, + config: FalconConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + hidden_size = config.hidden_size + + self.dense_h_to_4h = ColumnParallelLinear(hidden_size, + 4 * hidden_size, + bias=config.bias, + skip_bias_add=True, + quant_config=quant_config) + self.act = get_act_fn("gelu") + self.reduce_row_parallel_results = not (config.new_decoder_architecture + or config.parallel_attn) + self.dense_4h_to_h = RowParallelLinear( + 4 * hidden_size, + hidden_size, + bias=config.bias, + skip_bias_add=True, + reduce_results=self.reduce_row_parallel_results, + quant_config=quant_config) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # NOTE(zhuohan): Following huggingface, we do not fuse bias add here. + x, bias = self.dense_h_to_4h(x) + if bias is not None: + x += bias + x = self.act(x) + x, bias = self.dense_4h_to_h(x) + return x, bias + + +class FalconDecoderLayer(nn.Module): + + def __init__( + self, + config: FalconConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.self_attention = FalconAttention( + config, + cache_config, + quant_config, + prefix=f"{prefix}.self_attention") + self.mlp = FalconMLP(config, quant_config) + self.config = config + + if (not hasattr(config, "num_ln_in_parallel_attn")): + config.num_ln_in_parallel_attn = None + + if (config.num_ln_in_parallel_attn is None + and config.new_decoder_architecture): + config.num_ln_in_parallel_attn = 2 + + if not config.parallel_attn: + self.post_attention_layernorm = LayerNorm( + hidden_size, eps=config.layer_norm_epsilon) + self.input_layernorm = LayerNorm(hidden_size, + eps=config.layer_norm_epsilon) + else: + if config.num_ln_in_parallel_attn == 2: + # The layer norm before self-attention + self.ln_attn = LayerNorm(hidden_size, + eps=config.layer_norm_epsilon) + # The layer norm before the MLP + self.ln_mlp = LayerNorm(hidden_size, + eps=config.layer_norm_epsilon) + else: + self.input_layernorm = LayerNorm(hidden_size, + eps=config.layer_norm_epsilon) + + self.reduce_row_parallel_results = not (config.new_decoder_architecture + or config.parallel_attn) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + residual = hidden_states + + if self.config.num_ln_in_parallel_attn == 2: + attention_layernorm_out = self.ln_attn(hidden_states) + mlp_layernorm_out = self.ln_mlp(hidden_states) + else: + attention_layernorm_out = self.input_layernorm(hidden_states) + + # Self attention. + attention_output, attention_bias = self.self_attention( + positions=positions, + hidden_states=attention_layernorm_out, + ) + if self.reduce_row_parallel_results and attention_bias is not None: + attention_output += attention_bias + + if not self.config.new_decoder_architecture: + if self.config.parallel_attn: + mlp_layernorm_out = attention_layernorm_out + else: + residual += attention_output + mlp_layernorm_out = self.post_attention_layernorm(residual) + + if (self.config.new_decoder_architecture and self.config.parallel_attn + and self.config.num_ln_in_parallel_attn == 1): + mlp_layernorm_out = attention_layernorm_out + + # MLP. + mlp_output, mlp_bias = self.mlp(mlp_layernorm_out) + if self.reduce_row_parallel_results and mlp_bias is not None: + mlp_output += mlp_bias + + if not self.reduce_row_parallel_results: + # When MLP and Attention layers are parallel, we can use + # only one all-reduce operator to reduce the results from + # both MLP and Attention layers. + mlp_output += attention_output + mlp_output = tensor_model_parallel_all_reduce(mlp_output) + if attention_bias is not None: + mlp_output += attention_bias + if mlp_bias is not None: + mlp_output += mlp_bias + + output = mlp_output + residual + return output + + +@support_torch_compile +class FalconModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.use_alibi = config.alibi + + # Embedding + LN Embedding + self.word_embeddings = VocabParallelEmbedding( + config.vocab_size, + self.embed_dim, + ) + + # Transformer blocks + self.start_layer, self.end_layer, self.h = make_layers( + config.num_hidden_layers, + lambda prefix: FalconDecoderLayer( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.h") + + # Final Layer Norm + self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.word_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + else: + hidden_states = intermediate_tensors["hidden_states"] + for layer in self.h[self.start_layer:self.end_layer]: + hidden_states = layer(positions, hidden_states) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + hidden_states = self.ln_f(hidden_states) + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + total_num_heads = self.config.num_attention_heads + if self.config.new_decoder_architecture: + total_num_kv_heads = self.config.num_kv_heads + elif self.config.multi_query: + total_num_kv_heads = 1 + else: + total_num_kv_heads = total_num_heads + num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + if "query_key_value" in name: + output_dim = getattr(param, "output_dim", None) + loaded_weight_shape = loaded_weight.shape + if output_dim is not None: + loaded_weight = loaded_weight.view( + loaded_weight_shape[:output_dim] + + (total_num_kv_heads, num_query_heads_per_kv_head + 2, + -1) + loaded_weight_shape[output_dim + 1:]) + wq = loaded_weight.narrow( + output_dim + 1, 0, + num_query_heads_per_kv_head).reshape( + *loaded_weight_shape[:output_dim], -1, + *loaded_weight_shape[output_dim + 1:]) + wk = loaded_weight.narrow( + output_dim + 1, num_query_heads_per_kv_head, + 1).reshape(*loaded_weight_shape[:output_dim], -1, + *loaded_weight_shape[output_dim + 1:]) + wv = loaded_weight.narrow( + output_dim + 1, num_query_heads_per_kv_head + 1, + 1).reshape(*loaded_weight_shape[:output_dim], -1, + *loaded_weight_shape[output_dim + 1:]) + loaded_weight = torch.cat([wq, wk, wv], dim=output_dim) + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class FalconForCausalLM(nn.Module, SupportsPP): + packed_modules_mapping = { + "query_key_value": ["query_key_value"], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.transformer = FalconModel(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "transformer")) + # only Falcon-11B doesn't share lm_head weight with word embeddings + # and previous Falcon model doesn't have tie_word_embeddings config + # so we set tie_word_embeddings to True by default + self.tie_word_embeddings = (config.tie_word_embeddings + if config.tie_word_embeddings is not None + else True) + if self.tie_word_embeddings: + self.lm_head = self.transformer.word_embeddings + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.LongTensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py new file mode 100644 index 0000000..02535cc --- /dev/null +++ b/vllm/model_executor/models/florence2.py @@ -0,0 +1,1116 @@ +# SPDX-License-Identifier: Apache-2.0 + +import math +from collections import OrderedDict +from collections.abc import Iterable, Mapping, Sequence +from functools import cached_property +from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from transformers import BatchFeature, PretrainedConfig + +from vllm.config import VllmConfig +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.bart import (BartDecoder, BartEncoder, + BartParallelLMHead, + BartScaledWordEmbedding) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.parse import MultiModalDataDict, MultiModalDataItems +from vllm.multimodal.processing import (BaseProcessingInfo, + EncDecMultiModalProcessor, + PromptIndexTargets, PromptInsertion, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors + +from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, + SupportsV0Only) +from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings + + +class Florence2ImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: (batch_size, num_channel, height, width)""" + + +# ViT implementation are all copied from +# https://huggingface.co/microsoft/Florence-2-base/blob/main/modeling_florence2.py +class LearnedAbsolutePositionEmbedding2D(nn.Module): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, embedding_dim=256, num_pos=50): + super().__init__() + self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2) + self.column_embeddings = nn.Embedding( + num_pos, embedding_dim - (embedding_dim // 2)) + + def forward(self, pixel_values): + """ + pixel_values: (batch_size, height, width, num_channels) + returns: (batch_size, height, width, embedding_dim * 2) + """ + if len(pixel_values.shape) != 4: + raise ValueError('pixel_values must be a 4D tensor') + height, width = pixel_values.shape[1:3] + width_values = torch.arange(width, device=pixel_values.device) + height_values = torch.arange(height, device=pixel_values.device) + x_emb = self.column_embeddings(width_values) + y_emb = self.row_embeddings(height_values) + # (height, width, embedding_dim * 2) + pos = torch.cat([ + x_emb.unsqueeze(0).repeat(height, 1, 1), + y_emb.unsqueeze(1).repeat(1, width, 1) + ], + dim=-1) + # (embedding_dim * 2, height, width) + pos = pos.permute(2, 0, 1) + pos = pos.unsqueeze(0) + # (batch_size, embedding_dim * 2, height, width) + pos = pos.repeat(pixel_values.shape[0], 1, 1, 1) + # (batch_size, height, width, embedding_dim * 2) + pos = pos.permute(0, 2, 3, 1) + return pos + + +class PositionalEmbeddingCosine1D(nn.Module): + """ + This class implements a very simple positional encoding. It follows closely + the encoder from the link below: + https://pytorch.org/tutorials/beginner/translation_transformer.html + Args: + embed_dim: The dimension of the embeddings. + dropout_prob: The dropout probability. + max_seq_len: The maximum length to precompute the positional encodings. + """ + + def __init__(self, embed_dim: int = 512, max_seq_len: int = 1024) -> None: + super().__init__() + self.embed_dim = embed_dim + self.max_seq_len = max_seq_len + # Generate the sinusoidal arrays. + factor = math.log(10000) + denominator = torch.exp(-factor * torch.arange(0, self.embed_dim, 2) / + self.embed_dim) + # Matrix where rows correspond to a positional embedding as a function + # of the position index (i.e., the row index). + frequencies = \ + torch.arange(0, self.max_seq_len) \ + .reshape(self.max_seq_len, 1) * denominator + pos_idx_to_embed = torch.zeros((self.max_seq_len, self.embed_dim)) + # Populate uneven entries. + pos_idx_to_embed[:, 0::2] = torch.sin(frequencies) + pos_idx_to_embed[:, 1::2] = torch.cos(frequencies) + # Save the positional embeddings in a constant buffer. + # self.register_buffer("pos_idx_to_embed", pos_idx_to_embed) + self.pos_idx_to_embed = nn.Parameter(pos_idx_to_embed, + requires_grad=False) + + def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor: + """ + Args: + seq_embeds: The sequence embeddings in order. Allowed size: + 1. [T, D], where T is the length of the sequence, and D is the + frame embedding dimension. + 2. [B, T, D], where B is the batch size and T and D are the + same as above. + Returns a tensor of with the same dimensions as the input: i.e., + [1, T, D] or [T, D]. + """ + shape_len = len(seq_embeds.shape) + assert 2 <= shape_len <= 3 + len_seq = seq_embeds.size(-2) + assert len_seq <= self.max_seq_len + pos_embeds = self.pos_idx_to_embed[0:seq_embeds.size(-2), :] + # Adapt pre-computed positional embeddings to the input. + if shape_len == 3: + pos_embeds = pos_embeds.view( + (1, pos_embeds.size(0), pos_embeds.size(1))) + return pos_embeds + + +class MySequential(nn.Sequential): + + def forward(self, *inputs): + for module in self._modules.values(): + if isinstance(inputs, tuple): + inputs = module(*inputs) + else: + inputs = module(inputs) + return inputs + + +class PreNorm(nn.Module): + + def __init__(self, norm, fn): + super().__init__() + self.norm = norm + self.fn = fn + + def forward(self, x, *args, **kwargs): + shortcut = x + if self.norm is not None: + x, size = self.fn(self.norm(x), *args, **kwargs) + else: + x, size = self.fn(x, *args, **kwargs) + + x = shortcut + x + + return x, size + + +class Mlp(nn.Module): + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.net = nn.Sequential( + OrderedDict([("fc1", nn.Linear(in_features, hidden_features)), + ("act", act_layer()), + ("fc2", nn.Linear(hidden_features, out_features))])) + + def forward(self, x, size): + return self.net(x), size + + +class DepthWiseConv2d(nn.Module): + + def __init__( + self, + dim_in, + kernel_size, + padding, + stride, + bias=True, + ): + super().__init__() + self.dw = nn.Conv2d(dim_in, + dim_in, + kernel_size=kernel_size, + padding=padding, + groups=dim_in, + stride=stride, + bias=bias) + + def forward(self, x, size): + B, N, C = x.shape + H, W = size + assert N == H * W + + x = self.dw(x.transpose(1, 2).view(B, C, H, W)) + size = (x.size(-2), x.size(-1)) + x = x.flatten(2).transpose(1, 2) + return x, size + + +class ConvEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, + patch_size=7, + in_chans=3, + embed_dim=64, + stride=4, + padding=2, + norm_layer=None, + pre_norm=True): + super().__init__() + self.patch_size = patch_size + + self.proj = nn.Conv2d(in_chans, + embed_dim, + kernel_size=patch_size, + stride=stride, + padding=padding) + + dim_norm = in_chans if pre_norm else embed_dim + self.norm = norm_layer(dim_norm) if norm_layer else None + + self.pre_norm = pre_norm + + def forward(self, x, size): + H, W = size + if len(x.size()) == 3: + if self.norm and self.pre_norm: + x = self.norm(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W) + + x = self.proj(x) + + _, _, H, W = x.shape + x = rearrange(x, 'b c h w -> b (h w) c') + if self.norm and not self.pre_norm: + x = self.norm(x) + + return x, (H, W) + + +class ChannelAttention(nn.Module): + + def __init__(self, dim, groups=8, qkv_bias=True): + super().__init__() + + self.groups = groups + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + def forward(self, x, size): + B, N, C = x.shape + + qkv = self.qkv(x).reshape(B, N, 3, self.groups, + C // self.groups).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * (float(N)**-0.5) + attention = q.transpose(-1, -2) @ k + attention = attention.softmax(dim=-1) + x = (attention @ v.transpose(-1, -2)).transpose(-1, -2) + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + return x, size + + +class ChannelBlock(nn.Module): + + def __init__(self, + dim, + groups, + mlp_ratio=4., + qkv_bias=True, + drop_path_rate=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + conv_at_attn=True, + conv_at_ffn=True): + super().__init__() + + self.conv1 = PreNorm(None, DepthWiseConv2d( + dim, 3, 1, 1)) if conv_at_attn else None + self.channel_attn = PreNorm( + norm_layer(dim), + ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias), + ) + self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, + 1)) if conv_at_ffn else None + self.ffn = PreNorm( + norm_layer(dim), + Mlp(in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer), + ) + + def forward(self, x, size): + if self.conv1: + x, size = self.conv1(x, size) + x, size = self.channel_attn(x, size) + + if self.conv2: + x, size = self.conv2(x, size) + x, size = self.ffn(x, size) + + return x, size + + +def window_partition(x, window_size: int): + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, + C) + windows = x.permute(0, 1, 3, 2, 4, + 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int): + B = batch_size + + x = windows.view(B, H // window_size, W // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + + def __init__(self, dim, num_heads, window_size, qkv_bias=True): + + super().__init__() + self.dim = dim + self.window_size = window_size + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = float(head_dim)**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, size): + + H, W = size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + x = window_partition(x, self.window_size) + x = x.view(-1, self.window_size * self.window_size, C) + + # W-MSA/SW-MSA + # attn_windows = self.attn(x_windows) + + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + attn = self.softmax(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + + # merge windows + x = x.view(-1, self.window_size, self.window_size, C) + x = window_reverse(x, B, self.window_size, Hp, Wp) + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + return x, size + + +class SpatialBlock(nn.Module): + + def __init__(self, + dim, + num_heads, + window_size, + mlp_ratio=4., + qkv_bias=True, + drop_path_rate=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + conv_at_attn=True, + conv_at_ffn=True): + super().__init__() + + self.conv1 = PreNorm(None, DepthWiseConv2d( + dim, 3, 1, 1)) if conv_at_attn else None + self.window_attn = PreNorm( + norm_layer(dim), + WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias), + ) + self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, + 1)) if conv_at_ffn else None + self.ffn = PreNorm( + norm_layer(dim), + Mlp(in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer), + ) + + def forward(self, x, size): + if self.conv1: + x, size = self.conv1(x, size) + x, size = self.window_attn(x, size) + + if self.conv2: + x, size = self.conv2(x, size) + x, size = self.ffn(x, size) + return x, size + + +class DaViT(nn.Module): + + def __init__( + self, + in_chans=3, + num_classes=1000, + depths=(1, 1, 3, 1), + patch_size=(7, 2, 2, 2), + patch_stride=(4, 2, 2, 2), + patch_padding=(3, 0, 0, 0), + patch_prenorm=(False, False, False, False), + embed_dims=(64, 128, 192, 256), + num_heads=(3, 6, 12, 24), + num_groups=(3, 6, 12, 24), + window_size=7, + mlp_ratio=4., + qkv_bias=True, + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + enable_checkpoint=False, + conv_at_attn=True, + conv_at_ffn=True, + ): + super().__init__() + + self.num_classes = num_classes + self.embed_dims = embed_dims + self.num_heads = num_heads + self.num_groups = num_groups + self.num_stages = len(self.embed_dims) + self.enable_checkpoint = enable_checkpoint + assert self.num_stages == len(self.num_heads) == len(self.num_groups) + + num_stages = len(embed_dims) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, + sum(depths) * 2) + ] + + depth_offset = 0 + convs = [] + blocks = [] + for i in range(num_stages): + conv_embed = ConvEmbed( + patch_size=patch_size[i], + stride=patch_stride[i], + padding=patch_padding[i], + in_chans=in_chans if i == 0 else self.embed_dims[i - 1], + embed_dim=self.embed_dims[i], + norm_layer=norm_layer, + pre_norm=patch_prenorm[i]) + convs.append(conv_embed) + + block = MySequential(*[ + MySequential( + OrderedDict([('spatial_block', + SpatialBlock( + embed_dims[i], + num_heads[i], + window_size, + drop_path_rate=dpr[depth_offset + j * 2], + qkv_bias=qkv_bias, + mlp_ratio=mlp_ratio, + conv_at_attn=conv_at_attn, + conv_at_ffn=conv_at_ffn, + )), + ('channel_block', + ChannelBlock( + embed_dims[i], + num_groups[i], + drop_path_rate=dpr[depth_offset + j * 2 + + 1], + qkv_bias=qkv_bias, + mlp_ratio=mlp_ratio, + conv_at_attn=conv_at_attn, + conv_at_ffn=conv_at_ffn, + ))])) for j in range(depths[i]) + ]) + blocks.append(block) + depth_offset += depths[i] * 2 + + self.convs = nn.ModuleList(convs) + self.blocks = nn.ModuleList(blocks) + + self.avgpool = nn.AdaptiveAvgPool1d(1) + + @property + def dim_out(self): + return self.embed_dims[-1] + + def forward_features_unpool(self, x): + """ + forward until avg pooling + Args: + x (_type_): input image tensor + """ + input_size = (x.size(2), x.size(3)) + for conv, block in zip(self.convs, self.blocks): + x, input_size = conv(x, input_size) + x, input_size = block(x, input_size) + return x + + def forward_features(self, x): + x = self.forward_features_unpool(x) + + # (batch_size, num_tokens, token_dim) + x = self.avgpool(x.transpose(1, 2)) + # (batch_size, 1, num_tokens) + x = torch.flatten(x, 1) + x = self.norms(x) + + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + @classmethod + def from_config(cls, config): + return cls( + depths=config.depths, + embed_dims=config.dim_embed, + num_heads=config.num_heads, + num_groups=config.num_groups, + patch_size=config.patch_size, + patch_stride=config.patch_stride, + patch_padding=config.patch_padding, + patch_prenorm=config.patch_prenorm, + drop_path_rate=config.drop_path_rate, + window_size=config.window_size, + ) + + +# Language backbone and processor implementation +class Florence2LanguageModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + + self.vocab_size = config.vocab_size + + self.shared = BartScaledWordEmbedding(self.vocab_size, config.d_model) + self.encoder = BartEncoder(config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.encoder") + self.decoder = BartDecoder(config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.decoder") + + if self.config.tie_word_embeddings: + self.encoder.embed_tokens.weight = self.shared.weight + self.decoder.embed_tokens.weight = self.shared.weight + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r""" + Args: + input_ids + Indices of *decoder* input sequence tokens in the vocabulary. + Padding will be ignored by default should you + provide it. + positions + Positions of *decoder* input sequence tokens. + encoder_input_ids + Indices of *encoder* input sequence tokens in the vocabulary. + encoder_positions: + Positions of *encoder* input sequence tokens. + Returns: + Model output torch.Tensor + """ + + encoder_hidden_states = None + + if inputs_embeds is not None or encoder_input_ids.numel() > 0: + # Run encoder attention if a non-zero number of encoder tokens + # are provided as input + encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, + positions=encoder_positions, + inputs_embeds=inputs_embeds) + + # decoder outputs consists of + # (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + decoder_input_ids=input_ids, + decoder_positions=positions, + encoder_hidden_states=encoder_hidden_states) + + return decoder_outputs + + +class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + + self.config = config + self.model = Florence2LanguageModel(vllm_config=vllm_config, + prefix=f"{prefix}.model") + embed_scale = math.sqrt( + config.d_model) if config.scale_embedding else 1.0 + + self.vocab_size = config.vocab_size + self.lm_head = BartParallelLMHead(self.vocab_size, + config.d_model, + embed_scale=embed_scale) + + self.logits_processor = LogitsProcessor(self.vocab_size, + config.vocab_size) + self.sampler = get_sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + r""" + Args: + input_ids + torch.Tensor of *decoder* input token ids. + positions + torch.Tensor of *decoder* position indices. + encoder_input_ids + torch.Tensor of *encoder* input token ids. + encoder_positions + torch.Tensor of *encoder* position indices + Returns: + Output torch.Tensor + """ + + return self.model(input_ids, + positions, + encoder_input_ids, + encoder_positions, + inputs_embeds=inputs_embeds) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.encoder.embed_tokens(input_ids) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample(self, logits: torch.Tensor, + sampling_metadata: SamplingMetadata) -> SamplerOutput: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if "final_logits_bias" in name: + continue + if self.config.tie_word_embeddings and "embed_tokens" in name: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Florence2ProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config() + + def get_hf_processor(self): + return self.ctx.get_hf_processor() + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} + + def get_max_image_tokens(self) -> int: + processor_config = self.ctx.get_hf_image_processor_config() + return processor_config["image_seq_length"] + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"image": self.get_max_image_tokens()} + + +class Florence2DummyInputsBuilder( + BaseDummyInputsBuilder[Florence2ProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + + target_width = target_height = self.info.get_hf_config().projection_dim + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text="", + mm_data=mm_data, + ) + + +class Florence2MultiModalProcessor( + EncDecMultiModalProcessor[Florence2ProcessingInfo]): + + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> bool: + return False + + def create_encoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + ) -> Union[str, list[int]]: + return prompt + + def create_decoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + ) -> Union[str, list[int]]: + return [self.info.get_hf_config().eos_token_id] + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + if mm_data: + processed_outputs = super()._call_hf_processor( + prompt, mm_data, mm_kwargs) + else: + hf_processor = self.info.get_hf_processor() + tokenizer = hf_processor.tokenizer + prompt = hf_processor._construct_prompts([prompt])[0] + processed_outputs = tokenizer(prompt, + add_special_tokens=True, + return_tensors="pt") + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values=MultiModalFieldConfig.batched("image")) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_config = self.info.get_hf_config() + pad_token_id = hf_config.pad_token_id + num_image_tokens = self.info.get_max_image_tokens() + image_tokens = [pad_token_id] * num_image_tokens + + return [ + PromptInsertion( + modality="image", + target=PromptIndexTargets.start(), + insertion=image_tokens, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor( + Florence2MultiModalProcessor, + info=Florence2ProcessingInfo, + dummy_inputs=Florence2DummyInputsBuilder) +class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsV0Only): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + processor_config = vllm_config.model_config.hf_image_processor_config + + self.config = config + self.vision_config = config.vision_config + self.processor_config = processor_config + assert config.vision_config.model_type == 'davit', ( + 'only DaViT is supported for now') + self.vision_tower = DaViT.from_config(config=config.vision_config) + self._build_image_projection_layers(config) + self.language_model = Florence2LanguageForConditionalGeneration( + vllm_config=vllm_config.with_hf_config(config.text_config), + prefix=f"{prefix}.language_model", + ) + self.pad_token_id = config.pad_token_id + + def _build_image_projection_layers(self, config: PretrainedConfig): + image_dim_out = config.vision_config.dim_embed[-1] + dim_projection = config.vision_config.projection_dim + self.image_projection = nn.Parameter( + torch.empty(image_dim_out, dim_projection)) + self.image_proj_norm = nn.LayerNorm(dim_projection) + image_pos_embed_config = config.vision_config.image_pos_embed + if image_pos_embed_config['type'] == 'learned_abs_2d': + self.image_pos_embed = LearnedAbsolutePositionEmbedding2D( + embedding_dim=image_dim_out, + num_pos=image_pos_embed_config['max_pos_embeddings']) + else: + raise NotImplementedError("Florence2 only supports learned_abs_2d " + "as image position embedding.") + + self.image_feature_source = config.vision_config.image_feature_source + + # temporal embedding + visual_temporal_embedding_config = ( + self.vision_config.visual_temporal_embedding) + if visual_temporal_embedding_config['type'] == 'COSINE': + self.visual_temporal_embed = PositionalEmbeddingCosine1D( + embed_dim=image_dim_out, + max_seq_len=visual_temporal_embedding_config[ + 'max_temporal_embeddings']) + else: + raise NotImplementedError( + 'Florence2 only supports COSINE as temporal embedding.') + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + return get_sampler() + + def _validate_pixel_values( + self, data: Union[torch.Tensor, List[torch.Tensor]] + ) -> Union[torch.Tensor, List[torch.Tensor]]: + + size = self.processor_config["size"] + h, w = size["height"], size["width"] + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape) + + if actual_dims != expected_dims: + expected_expr = tuple(*map(str, expected_dims)) + raise ValueError( + "The expected shape of pixel values per batch " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_image_input(self, **kwargs: object): + pixel_values: Optional[Union[List[List[torch.Tensor]], + List[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "pixel_values", None) + image_embeds: Optional[Union[List[List[torch.Tensor]], + List[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None and image_embeds is not None: + raise ValueError( + "Both pixel values and image embeds are provided.") + + if pixel_values is not None: + return Florence2ImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values( + flatten_bn(pixel_values, concat=True)), + ) + + if image_embeds is not None: + raise NotImplementedError + + raise AssertionError("This line should be unreachable.") + + def _encode_image(self, pixel_values: torch.Tensor) -> torch.Tensor: + dtype = next(self.vision_tower.parameters()).dtype + pixel_values = pixel_values.to(dtype) + + batch_size, T = pixel_values.size(0), 1 + x = self.vision_tower.forward_features_unpool(pixel_values) + if self.image_pos_embed is not None: + x = x.view(batch_size * T, -1, x.shape[-1]) + num_tokens = x.shape[-2] + h, w = int(num_tokens**0.5), int(num_tokens**0.5) + assert h * w == num_tokens, ( + 'only support square feature maps for now') + x = x.view(batch_size * T, h, w, x.shape[-1]) + pos_embed = self.image_pos_embed(x) + x = x + pos_embed + x = x.view(batch_size, T * h * w, x.shape[-1]) + + if self.visual_temporal_embed is not None: + visual_temporal_embed = self.visual_temporal_embed( + x.view(batch_size, T, -1, x.shape[-1])[:, :, 0]) + x = x.view(batch_size, T, -1, + x.shape[-1]) + visual_temporal_embed.view( + 1, T, 1, x.shape[-1]) + + x_feat_dict = {} + + spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2) + x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x + + temporal_avg_pool_x = x.view(batch_size, T, -1, + x.shape[-1]).mean(dim=1) + x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x + + x = x.view(batch_size, T, -1, x.shape[-1])[:, -1] + x_feat_dict['last_frame'] = x + + new_x = [] + for _image_feature_source in self.image_feature_source: + if _image_feature_source not in x_feat_dict: + raise ValueError('invalid image feature source: {}'.format( + _image_feature_source)) + new_x.append(x_feat_dict[_image_feature_source]) + + x = torch.cat(new_x, dim=1) + + x = x @ self.image_projection + x = self.image_proj_norm(x) + + return x + + def _process_image_input( + self, image_input: Florence2ImagePixelInputs) -> torch.Tensor: + assert image_input["type"] == "pixel_values" + pixel_values = image_input["data"] + return self._encode_image(pixel_values) + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.pad_token_id) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + *, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + r""" + Args: + input_ids + torch.Tensor of *decoder* input token ids. + positions + torch.Tensor of *decoder* position indices. + encoder_input_ids + torch.Tensor of *encoder* input token ids. + encoder_positions + torch.Tensor of *encoder* position indices + Returns: + Output torch.Tensor + """ + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + if encoder_input_ids.numel() > 0 or vision_embeddings is not None: + inputs_embeds = self.get_input_embeddings(encoder_input_ids, + vision_embeddings) + else: + inputs_embeds = None + + hidden_states = self.language_model(input_ids, + positions, + encoder_input_ids, + encoder_positions, + inputs_embeds=inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py new file mode 100644 index 0000000..a807b04 --- /dev/null +++ b/vllm/model_executor/models/fuyu.py @@ -0,0 +1,433 @@ +# SPDX-License-Identifier: Apache-2.0 + +# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/fuyu/modeling_fuyu.py +# Copyright 2023 The vLLM team. +# Copyright 2023 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Fuyu model.""" +import math +from collections.abc import Iterable, Mapping, Sequence +from typing import Literal, Optional, Set, Tuple, TypedDict, Union + +import torch +import torch.nn as nn +from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor, + FuyuProcessor) + +from vllm.config import VllmConfig +from vllm.model_executor.layers.linear import ColumnParallelLinear +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.models.persimmon import PersimmonForCausalLM +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, + MultiModalDataItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, + merge_multimodal_embeddings) +from .vision import scatter_patch_features, select_patch_features + +# Cannot find the following 2 numbers from hf config. +_IMAGE_TOKEN_ID = 71011 +_NEWLINE_TOKEN_ID = 71019 + + +class FuyuImagePatchInputs(TypedDict): + type: Literal["image_patches"] + flat_data: torch.Tensor + """ + Shape: + `(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)` + """ + + patches_per_image: list[int] + """ + The number of total patches for each image in the batch. + + This is used to split the embeddings which has the first two dimensions + flattened just like `flat_data`. + """ + + embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] + """ + A boolean mask indicating which image embeddings correspond + to patch tokens. + + Shape: `(batch_size * num_images, num_embeds)` + """ + + +class FuyuProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(FuyuConfig) + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(FuyuProcessor, **kwargs) + + def get_image_processor(self) -> FuyuImageProcessor: + return self.get_hf_processor().image_processor + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + target_width, target_height = self.get_image_size_with_most_features() + + max_ncols, max_nrows = self.get_image_feature_grid_size( + image_width=target_width, + image_height=target_height, + ) + max_image_tokens = (max_ncols + 1) * max_nrows + + return {"image": max_image_tokens} + + def get_image_feature_grid_size( + self, + *, + image_width: int, + image_height: int, + ) -> tuple[int, int]: + image_processor = self.get_image_processor() + target_width = image_processor.size["width"] + target_height = image_processor.size["height"] + patch_width = image_processor.patch_size["width"] + patch_height = image_processor.patch_size["height"] + + if not (image_width <= target_width and image_height <= target_height): + height_scale_factor = target_height / image_height + width_scale_factor = target_width / image_width + optimal_scale_factor = min(height_scale_factor, width_scale_factor) + + image_height = int(image_height * optimal_scale_factor) + image_width = int(image_width * optimal_scale_factor) + + ncols = math.ceil(image_width / patch_width) + nrows = math.ceil(image_height / patch_height) + return ncols, nrows + + def get_image_size_with_most_features(self) -> ImageSize: + image_processor = self.get_image_processor() + return ImageSize(width=image_processor.size["width"], + height=image_processor.size["height"]) + + +class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + target_width, target_height = \ + self.info.get_image_size_with_most_features() + num_images = mm_counts.get("image", 0) + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text="", + mm_data=mm_data, + ) + + +class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + if not mm_data: + # Avoid warning from HF logger for text-only input + prompt_ids = self.info.get_tokenizer().encode(prompt) + prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + image_patches = processed_outputs.get("image_patches") + if image_patches is not None: + images = mm_data["images"] + assert isinstance(images, list) + + # Original output: (1, num_images, Pn, Px * Py * C) + # New output: (num_images, Pn, Px * Py * C) + assert (isinstance(image_patches, list) + and len(image_patches) == 1) + assert (isinstance(image_patches[0], torch.Tensor) + and len(image_patches[0]) == len(images)) + + processed_outputs["image_patches"] = image_patches[0] + + # get patch grid size for each image + embed_is_patch = [] + for image in images: + ncols, nrows = self.info.get_image_feature_grid_size( + image_width=image.width, + image_height=image.height, + ) + + mask = torch.tensor(([True] * ncols + [False]) * nrows) + embed_is_patch.append(mask) + + processed_outputs["embed_is_patch"] = embed_is_patch + + return processed_outputs + + def _apply_hf_processor_tokens_only( + self, + prompt_tokens: list[int], + ) -> list[int]: + # HF processor adds boa_token_id + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + + boa_token_id = vocab["<0x04>"] + + return prompt_tokens + [boa_token_id] + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(image_patches=MultiModalFieldConfig.batched("image"), + embed_is_patch=MultiModalFieldConfig.batched("image")) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_config = self.info.get_hf_config() + bos_token_id = hf_config.bos_token_id + assert isinstance(bos_token_id, int) + + tokenizer = self.info.get_tokenizer() + eot_token_id = tokenizer.bos_token_id + assert isinstance(eot_token_id, int) + + def get_replacement_fuyu(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) + + ncols, nrows = self.info.get_image_feature_grid_size( + image_width=image_size.width, + image_height=image_size.height, + ) + image_tokens = ([_IMAGE_TOKEN_ID] * ncols + + [_NEWLINE_TOKEN_ID]) * nrows + + return PromptUpdateDetails( + full=image_tokens + [bos_token_id], + features=image_tokens, + ) + + return [ + PromptReplacement( + modality="image", + target=[eot_token_id], + replacement=get_replacement_fuyu, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor, + info=FuyuProcessingInfo, + dummy_inputs=FuyuDummyInputsBuilder) +class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.multimodal_config = multimodal_config + + self.vocab_size = config.text_config.vocab_size + self.image_token_id = _IMAGE_TOKEN_ID + self.image_feature_size = config.patch_size**2 * config.num_channels + + self.vision_embed_tokens = ColumnParallelLinear( + self.image_feature_size, + config.hidden_size, + quant_config=quant_config, + gather_output=True, + ) + self.language_model = PersimmonForCausalLM( + vllm_config=vllm_config.with_hf_config(config.text_config), + prefix=maybe_prefix(prefix, "language_model"), + ) + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @property + def sampler(self): + return self.language_model.sampler + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + + h = w = self.config.patch_size + num_channels = self.config.num_channels + expected_dims = num_channels * h * w + + def _validate_shape(d: torch.Tensor): + actual_dims = d.size(-1) + + if actual_dims != expected_dims: + expected_expr = str(expected_dims) + raise ValueError( + "The expected shape of pixel values per image per batch " + f"per patch is {expected_expr}. " + f"You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data.to(self.vision_embed_tokens.weight.dtype) + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[FuyuImagePatchInputs]: + image_patches = kwargs.pop("image_patches", None) + if image_patches is not None: + if not isinstance(image_patches, (torch.Tensor, list)): + raise ValueError("Incorrect type of image patches. " + f"Got type: {type(image_patches)}") + + embed_is_patch = kwargs.pop("embed_is_patch") + if not isinstance(embed_is_patch, (torch.Tensor, list)): + raise ValueError("Incorrect type of embed_is_patch. " + f"Got type: {type(embed_is_patch)}") + + image_patches_flat = flatten_bn(image_patches) + embed_is_patch = flatten_bn(embed_is_patch) + + return FuyuImagePatchInputs( + type="image_patches", + flat_data=self._validate_pixel_values( + flatten_bn(image_patches_flat, concat=True)), + patches_per_image=[x.size(0) for x in image_patches_flat], + embed_is_patch=embed_is_patch, + ) + + return None + + def _process_image_input( + self, image_input: FuyuImagePatchInputs) -> MultiModalEmbeddings: + image_patches_flat = image_input["flat_data"] + patches_per_image = image_input["patches_per_image"] + + assert self.vision_embed_tokens is not None + vision_embeddings_flat, _ = self.vision_embed_tokens( + image_patches_flat) + + return vision_embeddings_flat.split(patches_per_image, dim=0) + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + + image_features = self._process_image_input(image_input) + + return scatter_patch_features( + image_features, + image_input["embed_is_patch"], + ) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, + select_patch_features(multimodal_embeddings), _IMAGE_TOKEN_ID) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ): + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.language_model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.language_model.logits_processor( + self.language_model.lm_head, hidden_states, sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.language_model.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py new file mode 100644 index 0000000..d741880 --- /dev/null +++ b/vllm/model_executor/models/gemma.py @@ -0,0 +1,428 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2023 The vLLM team. +# Copyright (c) Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Gemma model compatible with HuggingFace weights.""" +from functools import cache +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import GemmaConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import GeluAndMul +from vllm.model_executor.layers.layernorm import GemmaRMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + + +@cache +def _get_gemma_act_fn( + hidden_act: Optional[str], + hidden_activation: Optional[str], +) -> nn.Module: + if hidden_activation is None: + if hidden_act is not None: + logger.warning( + "Gemma's activation function was incorrectly set to exact GeLU " + "in the config JSON file when it was initially released. " + "Changing the activation function to approximate GeLU " + "(`gelu_pytorch_tanh`). If you want to use the legacy " + "`%s`, edit the config JSON to set " + "`hidden_activation=%s` instead of `hidden_act`. " + "See https://github.com/huggingface/transformers/pull/29402 " + "for more details.", hidden_act, hidden_act) + return GeluAndMul(approximate="tanh") + elif hidden_activation == "gelu_pytorch_tanh": + return GeluAndMul(approximate="tanh") + elif hidden_activation == "gelu": + return GeluAndMul(approximate="none") + else: + raise ValueError(f"Activation function {hidden_act} is not " + "supported for Gemma models.") + + +class GemmaMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: Optional[str] = None, + hidden_activation: Optional[str] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation) + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class GemmaAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + max_position_embeddings: int = 8192, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=self.rope_theta, + is_neox_style=True, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class GemmaDecoderLayer(nn.Module): + + def __init__( + self, + config: GemmaConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = GemmaAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + head_dim=config.head_dim, + max_position_embeddings=config.max_position_embeddings, + rope_theta=config.rope_theta, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = GemmaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + hidden_activation=getattr(config, "hidden_activation", None), + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class GemmaModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: GemmaDecoderLayer( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers") + self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Normalize the embedding by sqrt(hidden_size) + # The normalizer's data type should be downcasted to the model's + # data type such as bfloat16, not float32. + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = self.config.hidden_size**0.5 + self.register_buffer("normalizer", torch.tensor(normalizer)) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + hidden_states *= self.normalizer + residual = None + else: + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + # currently all existing Gemma models have `tie_word_embeddings` enabled + assert config.tie_word_embeddings + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = GemmaModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.model.embed_tokens, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + for (param_name, shard_name, shard_id) in stacked_params_mapping: + if shard_name not in name: + continue + name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # lm_head is not used in vllm as it is tied with embed_token. + # To prevent errors, skip loading lm_head.weight. + if "lm_head.weight" in name: + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py new file mode 100644 index 0000000..d125c66 --- /dev/null +++ b/vllm/model_executor/models/gemma2.py @@ -0,0 +1,433 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2024 The vLLM team. +# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import Gemma2Config + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import GeluAndMul +from vllm.model_executor.layers.layernorm import GemmaRMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + + +class Gemma2MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + hidden_activation: str, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"): + raise ValueError( + "Gemma2 uses `gelu_pytorch_tanh` as the hidden activation " + "function. Please set `hidden_act` and `hidden_activation` to " + "`gelu_pytorch_tanh`.") + self.act_fn = GeluAndMul(approximate="tanh") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Gemma2Attention(nn.Module): + + def __init__(self, + config: Gemma2Config, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + max_position_embeddings: int, + rope_theta: float, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + attn_logits_soft_cap: Optional[float] = None, + prefix: str = "") -> None: + super().__init__() + self.config = config + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = config.query_pre_attn_scalar**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.attention_bias, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=config.attention_bias, + quant_config=quant_config, + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=self.rope_theta, + is_neox_style=True, + ) + + # reference: + # https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa + layer_idx = extract_layer_index(prefix) + use_sliding_window = (layer_idx % 2 == 0 and + config.interleaved_sliding_window is not None) + sliding_window = config.interleaved_sliding_window if \ + use_sliding_window else None + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + logits_soft_cap=attn_logits_soft_cap, + per_layer_sliding_window=sliding_window, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class Gemma2DecoderLayer(nn.Module): + + def __init__( + self, + config: Gemma2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Gemma2Attention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + head_dim=config.head_dim, + max_position_embeddings=config.max_position_embeddings, + rope_theta=config.rope_theta, + cache_config=cache_config, + quant_config=quant_config, + attn_logits_soft_cap=config.attn_logit_softcapping, + prefix=f"{prefix}.self_attn", + ) + self.hidden_size = config.hidden_size + self.mlp = Gemma2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + hidden_activation=config.hidden_activation, + quant_config=quant_config, + ) + self.input_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_feedforward_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_feedforward_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states, residual = self.pre_feedforward_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class Gemma2Model(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Gemma2DecoderLayer( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers") + self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Normalize the embedding by sqrt(hidden_size) + # The normalizer's data type should be downcasted to the model's + # data type such as bfloat16, not float32. + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = self.config.hidden_size**0.5 + self.register_buffer("normalizer", torch.tensor(normalizer)) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + hidden_states *= self.normalizer + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache scales for compressed-tensors quantization + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight[0] + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for (param_name, shard_name, shard_id) in stacked_params_mapping: + if shard_name not in name: + continue + name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + + +class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + del lora_config # Unused. + super().__init__() + self.config = config + # currently all existing Gemma models have `tie_word_embeddings` enabled + assert config.tie_word_embeddings + self.quant_config = quant_config + self.model = Gemma2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.logits_processor = LogitsProcessor( + config.vocab_size, soft_cap=config.final_logit_softcapping) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.model.embed_tokens, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py new file mode 100644 index 0000000..fb8eccc --- /dev/null +++ b/vllm/model_executor/models/gemma3.py @@ -0,0 +1,539 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2025 The vLLM team. +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import Gemma3TextConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import GeluAndMul +from vllm.model_executor.layers.layernorm import GemmaRMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + + +class Gemma3MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_activation: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + if hidden_activation != "gelu_pytorch_tanh": + raise ValueError( + "Gemma3 uses `gelu_pytorch_tanh` as the hidden activation " + "function. Please set `hidden_act` and `hidden_activation` to " + "`gelu_pytorch_tanh`.") + self.act_fn = GeluAndMul(approximate="tanh") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Gemma3Attention(nn.Module): + + def __init__(self, + config: Gemma3TextConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + max_position_embeddings: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + attn_logits_soft_cap: Optional[float] = None, + prefix: str = "") -> None: + super().__init__() + self.config = config + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = config.query_pre_attn_scalar**-0.5 + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=config.attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps) + + # TODO(woosuk): Add reference to the original HF implementation. + layer_idx = extract_layer_index(prefix) + self.is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern) + # Initialize the rotary embedding. + if self.is_sliding: + # Local attention. Override the values in config.json. + self.rope_theta = config.rope_local_base_freq + self.rope_scaling = {"rope_type": "default"} + self.sliding_window = config.interleaved_sliding_window + else: + # Global attention. Use the values in config.json. + self.rope_theta = config.rope_theta + self.rope_scaling = config.rope_scaling + self.sliding_window = None + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=self.rope_theta, + is_neox_style=True, + rope_scaling=self.rope_scaling, + ) + + # Initialize the attention. + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + logits_soft_cap=attn_logits_soft_cap, + per_layer_sliding_window=self.sliding_window, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + q = q.unflatten(-1, (self.num_heads, self.head_dim)) + q = self.q_norm(q) + q = q.flatten(-2, -1) + k = k.unflatten(-1, (self.num_kv_heads, self.head_dim)) + k = self.k_norm(k) + k = k.flatten(-2, -1) + + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + + if not kwargs.get("has_images", False): + # Fast path for text-only inputs. The performance for the text-only + # inputs are not affected by the naive attention below. + output, _ = self.o_proj(attn_output) + return output + + # NOTE(woosuk): Gemma3 uses bidirectional attention between image tokens + # that correspond to the same image while using causal attention + # otherwise. Current attention backends cannot handle this pattern, so + # we temporarily use a naive attention implementation with mask tensors. + + # We intentionally keep the attention backend as-is and only override + # `attn_output` with the naive implementation's output. This minimizes + # changes to existing model runners and attention backends. The call to + # `self.attn(q, k, v)` is only used to populate the KV cache - its + # output is discarded and overwritten below. While this duplicates + # computation, it maintains compatibility. + # TODO(woosuk): Optimize by implementing custom attention kernels. + attn_output = self.naive_attn_with_masks(q, + k, + v, + out=attn_output, + **kwargs) + output, _ = self.o_proj(attn_output) + return output + + def naive_attn_with_masks( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + # NOTE(woosuk): As described in the comment above, this code is not + # meant to be performant. It is only meant to be correct. + q = q.view(-1, self.num_heads, self.head_dim) + # Expand the key and value to handle GQA. + num_queries_per_kv = self.num_heads // self.num_kv_heads + k = k.view(-1, self.num_kv_heads, self.head_dim) + k = k.repeat_interleave(num_queries_per_kv, dim=-2) + v = v.view(-1, self.num_kv_heads, self.head_dim) + v = v.repeat_interleave(num_queries_per_kv, dim=-2) + + if self.is_sliding: + attn_masks = kwargs["local_attn_masks"] + else: + attn_masks = kwargs["global_attn_masks"] + + seq_lens = kwargs["seq_lens"] + start_idx = 0 + for seq_len, attn_mask in zip(seq_lens, attn_masks): + end_idx = start_idx + seq_len + query = q[start_idx:end_idx].unsqueeze(0) + key = k[start_idx:end_idx].unsqueeze(0) + value = v[start_idx:end_idx].unsqueeze(0) + + # Transpose. + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + output = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask, + self.scaling, + ) + output = output.transpose(1, 2).flatten(-2, -1) + out[start_idx:end_idx] = output + start_idx = end_idx + return out + + +class Gemma3DecoderLayer(nn.Module): + + def __init__( + self, + config: Gemma3TextConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Gemma3Attention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + head_dim=config.head_dim, + max_position_embeddings=config.max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + attn_logits_soft_cap=None, + prefix=f"{prefix}.self_attn", + ) + self.hidden_size = config.hidden_size + self.mlp = Gemma3MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_activation=config.hidden_activation, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_feedforward_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_feedforward_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + **kwargs, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states, residual = self.pre_feedforward_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class Gemma3Model(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + prefix=f"{prefix}.embed_tokens", + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Gemma3DecoderLayer( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers") + self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Normalize the embedding by sqrt(hidden_size) + # The normalizer's data type should be downcasted to the model's + # data type such as bfloat16, not float32. + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = self.config.hidden_size**0.5 + self.register_buffer("normalizer", torch.tensor(normalizer)) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + # NOTE(woosuk): Only apply the normalizer to the output of + # vocab embedding. Don't apply it to the vision embedding. + return self.embed_tokens(input_ids) * self.normalizer + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + **kwargs, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache scales for compressed-tensors quantization + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight[0] + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for (param_name, shard_name, shard_id) in stacked_params_mapping: + if shard_name not in name: + continue + name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + + +class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + del lora_config # Unused. + super().__init__() + self.config = config + # currently all existing Gemma models have `tie_word_embeddings` enabled + assert config.tie_word_embeddings + self.quant_config = quant_config + self.model = Gemma3Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.logits_processor = LogitsProcessor( + config.vocab_size, soft_cap=config.final_logit_softcapping) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds, **kwargs) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.model.embed_tokens, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py new file mode 100644 index 0000000..bbdea70 --- /dev/null +++ b/vllm/model_executor/models/gemma3_mm.py @@ -0,0 +1,781 @@ +# SPDX-License-Identifier: Apache-2.0 +import math +from collections.abc import Iterable, Mapping, Sequence +from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union + +import torch +from torch import nn +from transformers import BatchFeature, Gemma3Config, Gemma3Processor +from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs + +import vllm.envs as envs +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import GemmaRMSNorm +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal.inputs import MultiModalFieldConfig +from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, + MultiModalDataItems) +# yapf: disable +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, BoundPromptUpdate, + PlaceholderFeaturesInfo, + PromptReplacement, PromptTargetMatch, + PromptUpdate, PromptUpdateDetails, + encode_tokens, find_mm_placeholders, + replace_token_matches) +# yapf: enable +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors + +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsPP) +from .siglip import SiglipVisionModel +from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, + maybe_prefix, merge_multimodal_embeddings) +from .vision import scatter_patch_features, select_patch_features + +logger = init_logger(__name__) + + +class Gemma3ImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + pixel_values: torch.Tensor + """ + Shape: `(num_patches_total, num_channels, height, width)` + + `num_patches_total` is the total number of patches + over each image over each prompt in the batch. + """ + + num_patches: torch.Tensor + """Shape: `(batch_size * num_images)`""" + + embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] + """ + A boolean mask indicating which image embeddings correspond + to patch tokens. + + Shape: `(batch_size * num_images, num_embeds)` + """ + + +Gemma3ImageInputs = Gemma3ImagePixelInputs + + +class Gemma3ProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(Gemma3Config) + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(Gemma3Processor, **kwargs) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"image": self.get_max_image_tokens()} + + def _resolve_image_kwargs( + self, + processor: Gemma3Processor, + keys: set[str], + ) -> dict[str, Any]: + image_processor = processor.image_processor + kwargs = processor._merge_kwargs( + Gemma3ProcessorKwargs, + tokenizer_init_kwargs=processor.tokenizer.init_kwargs, + ) + + images_kwargs = kwargs["images_kwargs"] + + def _resolve_kw(key: str): + val = getattr(image_processor, key) + if val is None: + val = images_kwargs[key] + + return val + + return {k: _resolve_kw(k) for k in keys} + + def get_num_crops( + self, + *, + image_width: int, + image_height: int, + processor: Optional[Gemma3Processor], + ) -> int: + if processor is None: + processor = self.get_hf_processor() + + images_kwargs = self._resolve_image_kwargs( + processor, { + "do_pan_and_scan", "pan_and_scan_min_crop_size", + "pan_and_scan_max_num_crops", + "pan_and_scan_min_ratio_to_activate" + }) + + do_pan_and_scan = images_kwargs["do_pan_and_scan"] + pan_and_scan_min_crop_size = images_kwargs[ + "pan_and_scan_min_crop_size"] + pan_and_scan_max_num_crops = images_kwargs[ + "pan_and_scan_max_num_crops"] + pan_and_scan_min_ratio_to_activate = images_kwargs[ + "pan_and_scan_min_ratio_to_activate"] + + if not do_pan_and_scan: + return 0 + + if envs.VLLM_USE_V1: + logger.warning_once( + "`do_pan_and_scan=True` has suboptimal results on V1 " + "because of the simplified attention pattern being used.") + + # Based on Gemma3ImageProcessor.pan_and_scan + if image_width >= image_height: + if image_width / image_height < pan_and_scan_min_ratio_to_activate: + return 0 + + num_crops_w = min( + int(math.floor(image_width / pan_and_scan_min_crop_size)), + int(math.floor(image_width / image_height + 0.5)), + ) + + num_crops_w = max(2, num_crops_w) + num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w) + num_crops_h = 1 + else: + if image_height / image_width < pan_and_scan_min_ratio_to_activate: + return 0 + + num_crops_h = min( + int(math.floor(image_height / pan_and_scan_min_crop_size)), + int(math.floor(image_height / image_width + 0.5)), + ) + + num_crops_h = max(2, num_crops_h) + num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h) + num_crops_w = 1 + + crop_size_w = int(math.ceil(image_width / num_crops_w)) + crop_size_h = int(math.ceil(image_height / num_crops_h)) + + if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size: + return 0 + + return num_crops_w * num_crops_h + + def get_image_repl( + self, + *, + image_width: int, + image_height: int, + processor: Optional[Gemma3Processor], + ) -> PromptUpdateDetails[str]: + if processor is None: + processor = self.get_hf_processor() + + image_token = processor.boi_token + + num_crops = self.get_num_crops( + image_width=image_width, + image_height=image_height, + processor=processor, + ) + + if num_crops == 0: + image_text = image_token + else: + crops_image_tokens = " ".join(image_token + for _ in range(num_crops)) + image_text = ( + f"Here is the original image {image_token} and here are some " + f"crops to help you see better {crops_image_tokens}") + + repl_full = image_text.replace(image_token, + processor.full_image_sequence) + repl_features = repl_full.strip("\n") + + return PromptUpdateDetails(full=repl_full, features=repl_features) + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + processor: Optional[Gemma3Processor], + ) -> int: + tokenizer = self.get_tokenizer() + image_repl = self.get_image_repl( + image_width=image_width, + image_height=image_height, + processor=processor, + ) + + image_repl_tokens = encode_tokens( + tokenizer, + image_repl.features, + add_special_tokens=False, + ) + return len(image_repl_tokens) + + def get_image_size_with_most_features(self) -> ImageSize: + processor = self.get_hf_processor() + + images_kwargs = self._resolve_image_kwargs( + processor, {"pan_and_scan_max_num_crops"}) + max_num_crops = images_kwargs["pan_and_scan_max_num_crops"] + + # Result in the max possible feature size (h:w = max_num_crops:1) + return ImageSize(height=50 * max_num_crops, width=50) + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + return self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + processor=None, + ) + + +class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + processor = self.info.get_hf_processor() + image_token = processor.boi_token + + num_images = mm_counts.get("image", 0) + + target_width, target_height = \ + self.info.get_image_size_with_most_features() + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text=image_token * num_images, + mm_data=mm_data, + ) + + +class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + processed_outputs = super()._call_hf_processor( + prompt, + mm_data, + mm_kwargs, + ) + + # HF processor pops the `num_crops` kwarg, which is needed by vLLM + if (images := mm_data.get("images")) is not None: + parsed_images = (self._get_data_parser().parse_mm_data({ + "image": + images + }).get_items("image", ImageProcessorItems)) + image_sizes = [ + parsed_images.get_image_size(i) + for i in range(len(parsed_images)) + ] + hf_processor = self.info.get_hf_processor(**mm_kwargs) + + image_repl_features = [ + self.info.get_image_repl(image_width=size.width, + image_height=size.height, + processor=hf_processor).features + for size in image_sizes + ] + + tokenizer = self.info.get_tokenizer() + image_repls_feature_tokens = [ + tokenizer.encode(image_repl, add_special_tokens=False) + for image_repl in image_repl_features + ] + + vocab = tokenizer.get_vocab() + image_token_id = vocab[tokenizer.image_token] + + embed_is_patch = [ + torch.tensor(image_repl_tokens) == image_token_id + for image_repl_tokens in image_repls_feature_tokens + ] + processed_outputs["embed_is_patch"] = embed_is_patch + + num_crops = [ + self.info.get_num_crops(image_width=size.width, + image_height=size.height, + processor=hf_processor) + for size in image_sizes + ] + processed_outputs["num_crops"] = torch.tensor(num_crops) + + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + num_crops = hf_inputs.get("num_crops", torch.empty(0)) + + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", num_crops + 1), + num_crops=MultiModalFieldConfig.batched("image"), + embed_is_patch=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_token = hf_processor.boi_token + + def get_replacement_gemma3(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + + image_size = images.get_image_size(item_idx) + return self.info.get_image_repl( + image_width=image_size.width, + image_height=image_size.height, + processor=hf_processor, + ) + + return [ + PromptReplacement( + modality="image", + target=image_token, + replacement=get_replacement_gemma3, + ) + ] + + def _apply_token_matches( + self, + prompt: list[int], + mm_matches: Mapping[str, Sequence[PromptTargetMatch]], + mm_item_counts: Mapping[str, int], + ) -> list[int]: + token_ids = super()._apply_token_matches( + prompt, + mm_matches, + mm_item_counts, + ) + + # "\n\n\n" and "\n\n\n\n" are single tokens + # Since our replacement can insert "\n\n" next to "\n" + # tokens, we have to combine them to be consistent with + # the output of the tokenizer + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + newline_1 = vocab["\n"] + newline_2 = vocab["\n\n"] + newline_3 = vocab["\n\n\n"] + newline_4 = vocab["\n\n\n\n"] + + token_ids = replace_token_matches( + token_ids, + [newline_1, newline_2], + [newline_3], + ) + token_ids = replace_token_matches( + token_ids, + [newline_2, newline_1], + [newline_3], + ) + token_ids = replace_token_matches( + token_ids, + [newline_2, newline_2], + [newline_4], + ) + + return token_ids + + def _find_mm_placeholders( + self, + mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], + new_token_ids: list[int], + mm_item_counts: Mapping[str, int], + ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: + # We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n" + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + newline_1 = vocab["\n"] + newline_2 = vocab["\n\n"] + newline_3 = vocab["\n\n\n"] + newline_4 = vocab["\n\n\n\n"] + + def get_repl_toks(tok: int) -> list[int]: + if tok == newline_3: + return [newline_1, newline_2] + if tok == newline_4: + return [newline_2, newline_2] + + return [tok] + + repl_token_ids = list[int]() + repl_orig_idxs = list[int]() + for orig_idx, orig_tok in enumerate(new_token_ids): + repl_toks = get_repl_toks(orig_tok) + repl_token_ids.extend(repl_toks) + repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks))) + + repls = find_mm_placeholders(mm_prompt_updates, repl_token_ids, + mm_item_counts) + + return { + modality: [ + PlaceholderFeaturesInfo( + modality=p.modality, + item_idx=p.item_idx, + start_idx=repl_orig_idxs[p.start_idx], + tokens=p.tokens, + ) for p in placeholders + ] + for modality, placeholders in repls.items() + } + + +class Gemma3MultiModalProjector(nn.Module): + + def __init__(self, config: Gemma3Config): + super().__init__() + + self.mm_input_projection_weight = nn.Parameter( + torch.zeros(config.vision_config.hidden_size, + config.text_config.hidden_size)) + + self.mm_soft_emb_norm = GemmaRMSNorm( + config.vision_config.hidden_size, + eps=config.vision_config.layer_norm_eps) + + self.patches_per_image = int(config.vision_config.image_size // + config.vision_config.patch_size) + self.tokens_per_side = int(config.mm_tokens_per_image**0.5) + self.kernel_size = self.patches_per_image // self.tokens_per_side + self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, + stride=self.kernel_size) + + def forward(self, vision_outputs: torch.Tensor): + batch_size, _, seq_length = vision_outputs.shape + + reshaped_vision_outputs = vision_outputs.transpose(1, 2) + reshaped_vision_outputs = reshaped_vision_outputs.reshape( + batch_size, seq_length, self.patches_per_image, + self.patches_per_image) + reshaped_vision_outputs = reshaped_vision_outputs.contiguous() + + pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) + pooled_vision_outputs = pooled_vision_outputs.flatten(2) + pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2) + + normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) + + projected_vision_outputs = torch.matmul( + normed_vision_outputs, self.mm_input_projection_weight) + return projected_vision_outputs.type_as(vision_outputs) + + +@MULTIMODAL_REGISTRY.register_processor(Gemma3MultiModalProcessor, + info=Gemma3ProcessingInfo, + dummy_inputs=Gemma3DummyInputsBuilder) +class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, + SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.quant_config = quant_config + self.multimodal_config = multimodal_config + self.sliding_window = config.text_config.interleaved_sliding_window + + self.vision_tower = SiglipVisionModel(config.vision_config, + quant_config, + prefix=maybe_prefix( + prefix, "vision_tower")) + self.multi_modal_projector = Gemma3MultiModalProjector(config) + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=["Gemma3ForCausalLM"], + ) + logit_scale = getattr(config, "logit_scale", 1.0) + self.language_model.logits_processor.scale *= logit_scale + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @property + def dtype(self): + return next(self.parameters()).dtype + + @property + def sampler(self): + return self.language_model.sampler + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + if d.shape != expected_dims: + raise ValueError( + "The expected shape of pixel values per image per batch " + f"is {expected_dims}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Gemma3ImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + num_crops = kwargs.pop("num_crops", None) + embed_is_patch = kwargs.pop("embed_is_patch", None) + image_embeds = kwargs.pop("image_embeds", None) + assert image_embeds is None, "Gemma3 does not support image_embeds." + if pixel_values is None: + return None + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + if not isinstance(num_crops, (torch.Tensor, list)): + raise ValueError("Incorrect type of num_crops. " + f"Got type: {type(num_crops)}") + + if not isinstance(embed_is_patch, (torch.Tensor, list)): + raise ValueError("Incorrect type of embed_is_patch. " + f"Got type: {type(embed_is_patch)}") + + pixel_values = flatten_bn(pixel_values, concat=True) + num_crops = flatten_bn(num_crops, concat=True) + embed_is_patch = flatten_bn(embed_is_patch) + + return Gemma3ImagePixelInputs( + type="pixel_values", + pixel_values=self._validate_pixel_values(pixel_values), + num_patches=num_crops + 1, + embed_is_patch=embed_is_patch, + ) + + def _image_pixels_to_features( + self, + vision_tower: SiglipVisionModel, + pixel_values: torch.Tensor, + ) -> torch.Tensor: + target_dtype = vision_tower.get_input_embeddings().weight.dtype + image_features = vision_tower(pixel_values.to(dtype=target_dtype)) + return image_features + + def _process_image_input( + self, + image_input: Gemma3ImageInputs, + ) -> list[torch.Tensor]: + assert self.vision_tower is not None + + pixel_values = image_input["pixel_values"] + num_patches = image_input["num_patches"] + + image_features = self._image_pixels_to_features( + self.vision_tower, + pixel_values, + ) + image_embeds = self.multi_modal_projector(image_features) + + return [ + e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist()) + ] + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + + image_features = self._process_image_input(image_input) + + return scatter_patch_features( + image_features, + image_input["embed_is_patch"], + ) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + select_patch_features(multimodal_embeddings), + self.config.image_token_index, + ) + return inputs_embeds + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object) -> Union[SamplerOutput, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + if vision_embeddings is not None: + kwargs = self.prepare_attn_masks( + input_ids, + positions, + mask_dtype=self.dtype, + **kwargs, + ) + input_ids = None + + hidden_states = self.language_model.model(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds, + **kwargs) + + return hidden_states + + def prepare_attn_masks( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + mask_dtype: torch.dtype, + **kwargs, + ): + kwargs["has_images"] = True + # NOTE(woosuk): Here, we distinguish the sequences by the position id 0. + # This is a HACK. Fix this. + start_idices = (positions == 0).cpu().nonzero() + num_seqs = len(start_idices) + seq_lens = [] + for i in range(num_seqs): + start_idx = start_idices[i].item() + if i < num_seqs - 1: + end_idx = start_idices[i + 1].item() + else: + end_idx = len(input_ids) + seq_lens.append(end_idx - start_idx) + kwargs["seq_lens"] = seq_lens + + global_attn_masks = [] + local_attn_masks = [] + start_idx = 0 + for seq_len in seq_lens: + end_idx = start_idx + seq_len + input_token_ids = input_ids[start_idx:end_idx] + start_idx = end_idx + # Create a global causal mask. + global_attn_mask = torch.empty( + 1, + 1, + seq_len, + seq_len, + dtype=mask_dtype, + device=input_ids.device, + ) + global_attn_mask.fill_(float("-inf")) + # Fill the lower triangle with 0. + global_attn_mask = global_attn_mask.triu(diagonal=1) + + # Consider the bidirectional attention between image tokens. + img_mask = torch.zeros_like(global_attn_mask) + img_pos = (input_token_ids == self.config.image_token_index) + img_mask[:, :, :, img_pos] += 1 + img_mask[:, :, img_pos, :] += 1 + global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask) + global_attn_masks.append(global_attn_mask) + + # Create a local causal mask with sliding window (1024). + local_attn_mask = torch.ones_like(global_attn_mask) + local_attn_mask = torch.tril(local_attn_mask, + diagonal=-self.sliding_window) + local_attn_mask = torch.where(local_attn_mask == 0, + global_attn_mask, float("-inf")) + local_attn_masks.append(local_attn_mask) + kwargs["global_attn_masks"] = global_attn_masks + kwargs["local_attn_masks"] = local_attn_masks + return kwargs + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="multi_modal_projector", + tower_model="vision_tower") diff --git a/vllm/model_executor/models/glm.py b/vllm/model_executor/models/glm.py new file mode 100644 index 0000000..8d52da8 --- /dev/null +++ b/vllm/model_executor/models/glm.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Inference-only HF format GLM-4 model compatible with THUDM weights.""" +from vllm.config import VllmConfig +from vllm.model_executor.models.llama import LlamaForCausalLM + +from .interfaces import SupportsV0Only +from .utils import PPMissingLayer + + +class GlmForCausalLM(LlamaForCausalLM, SupportsV0Only): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + # Hack Llama model to fit HF format GLM implementation + # Attention difference between GLM and Llama: + # 1. Half partial rotary_dim and no Neox style. + # 2. There is no bias for o_proj in attention + for layer in self.model.layers: + if not isinstance(layer, PPMissingLayer): + layer.self_attn.rotary_emb.rotary_dim //= 2 + layer.self_attn.rotary_emb.is_neox_style = False + layer.self_attn.o_proj.bias = None + layer.self_attn.o_proj.skip_bias_add = True diff --git a/vllm/model_executor/models/glm4.py b/vllm/model_executor/models/glm4.py new file mode 100644 index 0000000..6c8c0a2 --- /dev/null +++ b/vllm/model_executor/models/glm4.py @@ -0,0 +1,313 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2025 The Zhipu AI team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GLM-4-0414 model compatible with HuggingFace weights.""" +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import Glm4Config + +from vllm.attention import Attention, AttentionType +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .llama import LlamaMLP as Glm4MLP +from .llama import LlamaModel +from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler + + +class Glm4Attention(nn.Module): + + def __init__(self, + config: Glm4Config, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + head_dim: Optional[int] = None, + qkv_bias: bool = False, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[Tuple] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5) + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or hidden_size // self.total_num_heads + self.rotary_dim = self.head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.rotary_dim, + max_position=max_position, + base=self.rope_theta, + rope_scaling=rope_scaling, + partial_rotary_factor=partial_rotary_factor, + is_neox_style=False, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=attn_type) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class Glm4DecoderLayer(nn.Module): + + def __init__( + self, + config: Glm4Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 1000000) + rope_scaling = getattr(config, "rope_scaling", None) + + self.self_attn = Glm4Attention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + qkv_bias=getattr(config, 'attention_bias', False), + head_dim=getattr(config, 'head_dim', None), + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=rope_scaling, + prefix=f"{prefix}.self_attn", + attn_type=AttentionType.DECODER, + ) + self.mlp = Glm4MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_self_attn_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_mlp_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + hidden_states = self.post_self_attn_layernorm(hidden_states) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_mlp_layernorm(hidden_states) + + return hidden_states, residual + + +ALL_DECODER_LAYER_TYPES = { + "attention": Glm4DecoderLayer, +} + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }) +class Glm4Model(LlamaModel): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, + prefix=prefix, + layer_type=Glm4DecoderLayer) + + +class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = Glm4Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) + else: + self.lm_head = PPMissingLayer() + + self.sampler = get_sampler() + self.logits_processor = LogitsProcessor(config.vocab_size) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py new file mode 100644 index 0000000..c190a45 --- /dev/null +++ b/vllm/model_executor/models/glm4v.py @@ -0,0 +1,651 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/THUDM/CogAgent +"""Inference-only CogAgent model compatible with THUDM weights.""" +from argparse import Namespace +from collections.abc import Mapping, Sequence +from typing import Literal, Optional, TypedDict, Union + +import torch +from torch import nn +from torch.nn import LayerNorm +from torchvision import transforms +from torchvision.transforms import InterpolationMode +from transformers import PreTrainedTokenizer, TensorType +from transformers.image_utils import ImageInput +from transformers.tokenization_utils_base import TextInput + +from vllm.attention.layer import MultiHeadAttention +from vllm.config import VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalKwargs +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, BatchFeature, + MultiModalFieldConfig, + PromptReplacement, PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import ChatGLMConfig + +from .chatglm import ChatGLMBaseModel, ChatGLMModel +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsPP) +from .utils import flatten_bn, merge_multimodal_embeddings + + +class GLMVImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: `(batch_size, num_channels, height, width)`""" + + +class EVA2CLIPPatchEmbedding(nn.Module): + + def __init__(self, config): + super().__init__() + self.proj = nn.Conv2d(config.in_channels, + config.hidden_size, + kernel_size=config.patch_size, + stride=config.patch_size) + self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size)) + self.position_embedding = nn.Embedding(config.num_positions, + config.hidden_size) + + def forward(self, images: torch.Tensor) -> torch.Tensor: + """ + Parameters: + images : torch.Tensor + Input image tensor with shape (B, C, H, W) + + Returns: + torch.Tensor + Transformed tensor with shape (B, L, D) + """ + images = images.to(device=self.proj.weight.device, + dtype=self.proj.weight.dtype) + x = self.proj(images) + x = x.flatten(2).transpose(1, 2) + cls_token = self.cls_embedding.expand(x.shape[0], -1, -1) + x = torch.cat((cls_token, x), dim=1) + x += self.position_embedding.weight.unsqueeze(0) + return x + + +class EVA2CLIPAttention(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', + ): + super().__init__() + self.hidden_size = config.hidden_size + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_rank = config.num_heads // self.tp_size + self.head_dim = config.hidden_size // config.num_heads + self.scale = self.head_dim**-0.5 + + self.query_key_value = QKVParallelLinear( + config.hidden_size, + self.head_dim, + config.num_heads, + quant_config=quant_config, + prefix=f"{prefix}.query_key_value", + ) + self.dense = RowParallelLinear( + config.hidden_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.dense", + ) + + self.attn = MultiHeadAttention(self.num_heads_per_rank, self.head_dim, + self.scale) + self.output_dropout = torch.nn.Dropout(config.dropout_prob) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + qkv, _ = self.query_key_value(x) # B, L, 3 * H * D + q, k, v = qkv.chunk(3, dim=-1) + + out = self.attn(q, k, v) + output, _ = self.dense(out) + output = self.output_dropout(output) + return output + + +class EVA2CLIPMLP(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', + ): + super().__init__() + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.fc1(x) + x = self.activation_fn(x) + x, _ = self.fc2(x) + return x + + +class EVA2CLIPTransformerLayer(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', + ): + super().__init__() + self.input_layernorm = LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.attention = EVA2CLIPAttention(config, + quant_config=quant_config, + prefix=f"{prefix}.attention") + self.mlp = EVA2CLIPMLP(config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + self.post_attention_layernorm = LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward(self, hidden_states): + attention_input = hidden_states + attention_output = self.input_layernorm( + self.attention(attention_input)) + hidden_states = attention_input + attention_output + mlp_input = hidden_states + mlp_output = self.post_attention_layernorm(self.mlp(mlp_input)) + output = mlp_input + mlp_output + return output + + +class EVA2CLIPTransformer(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', + ): + super().__init__() + self.layers = nn.ModuleList([ + EVA2CLIPTransformerLayer(config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}") + for layer_idx in range(config.num_hidden_layers) + ]) + + def forward(self, hidden_states): + for layer_module in self.layers: + hidden_states = layer_module(hidden_states) + return hidden_states + + +class EVA2CLIPGLU(nn.Module): + + def __init__( + self, + config, + in_features, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', + ): + """ + The original implementation is the same as: + ```python + self.dense_h_to_4h = ColumnParallelLinear( + config.hidden_size, + config.ffn_hidden_size, + bias=False, + quant_config=quant_config + ) + + self.gate_proj = ColumnParallelLinear( + config.hidden_size, + config.ffn_hidden_size, + bias=False, + quant_config=quant_config + ) + ``` + ``` + gate_proj_output, _ = self.gate_proj(x) + dense_h_to_4h_output, _ = self.dense_h_to_4h(x) + x = torch.cat([gate_proj_output, dense_h_to_4h_output], dim=-1) + ``` + + We merge two ColumnParallelLinear into one MergedColumnParallelLinear: + ``` + self.merged_proj = MergedColumnParallelLinear( + config.hidden_size, + [config.ffn_hidden_size] * 2, + bias=False, + quant_config=quant_config + ) + ``` + ``` + x, _ = self.merged_proj(x) + ``` + """ + super().__init__() + self.linear_proj = ReplicatedLinear(in_features, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.linear_proj") + self.norm1 = nn.LayerNorm(config.hidden_size) + self.act1 = nn.GELU() + self.act2 = SiluAndMul() + + self.merged_proj = MergedColumnParallelLinear( + config.hidden_size, [config.ffn_hidden_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.merged_proj") + + self.dense_4h_to_h = RowParallelLinear( + config.ffn_hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.dense_4h_to_h") + + def forward(self, x): + x, _ = self.linear_proj(x) + x = self.act1(self.norm1(x)) + x, _ = self.merged_proj(x) + x = self.act2(x) + x, _ = self.dense_4h_to_h(x) + return x + + +class EVA2CLIPModel(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', + ): + super().__init__() + vision_config = Namespace(**config.vision_config) + self.patch_embedding = EVA2CLIPPatchEmbedding(vision_config) + self.transformer = EVA2CLIPTransformer(vision_config, + quant_config=quant_config, + prefix=f"{prefix}.transformer") + self.linear_proj = EVA2CLIPGLU(config, + in_features=config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.linear_proj") + self.conv = nn.Conv2d(in_channels=vision_config.hidden_size, + out_channels=config.hidden_size, + kernel_size=2, + stride=2) + self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.scaling_factor = vision_config.scaling_factor + + def forward(self, images: torch.Tensor) -> torch.Tensor: + """ + Parameters: + images : torch.Tensor + Input image tensor with shape (B, C, H, W) + + Returns: + torch.Tensor + Transformed tensor with shape (B, L, D) + """ + x = self.patch_embedding(images) + x = self.transformer(x) + x = x[:, 1:] + + b, s, h = x.shape + grid_size = int(s**0.5) + x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2) + x = self.conv(x) + + x = x.flatten(2).transpose(1, 2) + x = self.linear_proj(x) + boi = self.boi.expand(x.shape[0], -1, -1) + eoi = self.eoi.expand(x.shape[0], -1, -1) + x = torch.cat((boi, x, eoi), dim=1) + x = x / self.scaling_factor + return x + + +class GLM4VModel(ChatGLMModel): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + quant_config = vllm_config.quant_config + + self.vision = EVA2CLIPModel(self.config, + quant_config, + prefix=f"{prefix}.vision") + + +class GLM4VProcessor: + """ + This model doesn't define its own HF processor, + so we implement our own one here. + """ + + def __init__( + self, + config: ChatGLMConfig, + tokenizer: PreTrainedTokenizer, + ) -> None: + super().__init__() + + self.config = config + self.tokenizer = tokenizer + + vision_config = config.vision_config + image_size = vision_config["image_size"] + + self.image_transform = transforms.Compose([ + transforms.Resize( + (image_size, image_size), + interpolation=InterpolationMode.BICUBIC, + ), + transforms.ToTensor(), + transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + ]) + + def __call__( + self, + text: Optional[Union[TextInput, list[TextInput]]] = None, + images: Optional[Union[ImageInput, list[ImageInput]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + ) -> BatchFeature: + if text is None: + text = [] + if not isinstance(text, list): + text = [text] + if images is None: + images = [] + if not isinstance(images, list): + images = [images] + + text_inputs = self.tokenizer(text) + + if len(images) == 0: + image_inputs = {} + else: + pixel_values = [self.image_transform(image) for image in images] + image_inputs = {"pixel_values": torch.stack(pixel_values)} + + return BatchFeature( + { + **text_inputs, + **image_inputs, + }, + tensor_type=return_tensors, + ) + + +class GLM4VProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(ChatGLMConfig) + + def get_hf_processor(self, **kwargs: object) -> GLM4VProcessor: + return self.ctx.init_processor( + GLM4VProcessor, + config=self.get_hf_config(), + tokenizer=self.get_tokenizer(), + **kwargs, + ) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"image": self.get_num_image_feature_tokens()} + + def get_num_image_tokens(self) -> int: + hf_config = self.get_hf_config() + vision_config = hf_config.vision_config + + image_size = vision_config["image_size"] + patch_size = vision_config["patch_size"] + grid_length = image_size // patch_size // 2 + return grid_length * grid_length + + def get_num_image_feature_tokens(self) -> int: + # EVA2CLIPModel has embeddings for boi and eoi tokens as well + return self.get_num_image_tokens() + 2 + + +class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + hf_config = self.info.get_hf_config() + vision_config = hf_config.vision_config + + target_width = target_height = vision_config["image_size"] + num_images = mm_counts.get("image", 0) + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + base_text = "<|begin_of_image|><|endoftext|><|end_of_image|>" + + return ProcessorInputs( + prompt_text=base_text * num_images, + mm_data=mm_data, + ) + + +class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]): + + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> bool: + return False + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values=MultiModalFieldConfig.batched("image")) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_config = self.info.get_hf_config() + + boi_token_id = hf_config.boi_token_id + image_token_id = hf_config.pad_token_id + eoi_token_id = hf_config.eoi_token_id + + def get_replacement(item_idx: int): + num_image_tokens = self.info.get_num_image_tokens() + image_tokens = [image_token_id] * num_image_tokens + + return [boi_token_id] + image_tokens + [eoi_token_id] + + return [ + PromptReplacement( + modality="image", + target=[boi_token_id, image_token_id, eoi_token_id], + replacement=get_replacement, + ), + ] + + +@MULTIMODAL_REGISTRY.register_processor(GLM4VMultiModalProcessor, + info=GLM4VProcessingInfo, + dummy_inputs=GLM4VDummyInputsBuilder) +class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, + SupportsMultiModal): + + packed_modules_mapping = { + "query_key_value": ["query_key_value"], + "dense_h_to_4h": ["dense_h_to_4h"], + "merged_proj": ["gate_proj", "dense_h_to_4h"] + } + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="transformer.encoder", + connector="transformer.vision.linear_proj", + tower_model="transformer.vision.transformer") + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + transformer_type: type[GLM4VModel] = GLM4VModel, + ) -> None: + super().__init__( + vllm_config=vllm_config, + prefix=prefix, + transformer_type=transformer_type, + ) + + self.transformer: GLM4VModel + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + h = w = self.config.vision_config["image_size"] + expected_dims = (3, h, w) + actual_dims = tuple(data.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("batch_size", *map(str, expected_dims)) + raise ValueError( + f"The expected shape of pixel values is {expected_expr}. " + f"You supplied {tuple(data.shape)}.") + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[GLMVImagePixelInputs]: + pixel_values = kwargs.pop("pixel_values", None) + + if pixel_values is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + return GLMVImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values( + flatten_bn(pixel_values, concat=True)), + ) + + return None + + def _process_image_input( + self, image_input: GLMVImagePixelInputs) -> torch.Tensor: + pixel_values = image_input["data"].to(dtype=self.config.torch_dtype) + + return self.transformer.vision(pixel_values) + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.transformer.get_input_embeddings(input_ids) + + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + placeholder_token_id=[ + self.config.boi_token_id, + self.config.pad_token_id, + self.config.eoi_token_id, + ], + ) + + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) + + return hidden_states diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py new file mode 100644 index 0000000..776c03f --- /dev/null +++ b/vllm/model_executor/models/gpt2.py @@ -0,0 +1,323 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py +# Copyright 2023 The vLLM team. +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GPT-2 model compatible with HuggingFace weights.""" +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import GPT2Config + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed.parallel_state import ( + get_pp_group, get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class GPT2Attention(nn.Module): + + def __init__( + self, + config: GPT2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = config.hidden_size + total_num_heads = config.num_attention_heads + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + assert total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = total_num_heads // tensor_model_parallel_world_size + self.head_dim = self.hidden_size // total_num_heads + self.scale = self.head_dim**-0.5 + + self.c_attn = QKVParallelLinear( + self.hidden_size, + self.head_dim, + total_num_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.c_attn", + ) + self.c_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.c_proj", + ) + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.scale, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.c_attn(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + attn_output = self.attn(q, k, v) + attn_output, _ = self.c_proj(attn_output) + return attn_output + + +class GPT2MLP(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: GPT2Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + hidden_size = config.hidden_size + self.c_fc = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.c_fc", + ) + self.c_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.c_proj", + ) + self.act = get_act_fn(config.activation_function) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.c_proj(hidden_states) + return hidden_states + + +class GPT2Block(nn.Module): + + def __init__( + self, + config: GPT2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + hidden_size = config.hidden_size + inner_dim = (config.n_inner if config.n_inner is not None else 4 * + hidden_size) + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = GPT2Attention(config, + cache_config, + quant_config, + prefix=f"{prefix}.attn") + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.mlp = GPT2MLP(inner_dim, + config, + quant_config, + prefix=f"{prefix}.mlp") + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_output = self.attn(hidden_states=hidden_states) + # residual connection + hidden_states = attn_output + residual + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + return hidden_states + + +@support_torch_compile +class GPT2Model(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + assert not config.add_cross_attention + assert not config.scale_attn_by_inverse_layer_idx + assert not config.reorder_and_upcast_attn + self.embed_dim = config.hidden_size + self.wte = VocabParallelEmbedding(config.vocab_size, + self.embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.wte") + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + self.start_layer, self.end_layer, self.h = make_layers( + config.num_hidden_layers, + lambda prefix: GPT2Block( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.h") + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.n_embd)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + + for layer in self.h[self.start_layer:self.end_layer]: + hidden_states = layer(hidden_states) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +class GPT2LMHeadModel(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.transformer = GPT2Model(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "transformer")) + self.lm_head = ParallelLMHead(self.config.vocab_size, + self.config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.lm_head") + if self.config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights(self.transformer.wte) + + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if ".attn.bias" in name or ".attn.masked_bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + if not name.startswith("transformer.") and not name.startswith( + "lm_head"): + name = "transformer." + name + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + # The HF's GPT-2 implementation uses Conv1D instead of Linear. + # Because of this, we need to transpose the weights. + # Note(zhuohan): the logic below might break quantized models. + for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: + if conv1d_weight_name not in name: + continue + if not name.endswith(".weight"): + continue + loaded_weight = loaded_weight.t() + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py new file mode 100644 index 0000000..43f3d4f --- /dev/null +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -0,0 +1,340 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py +# Copyright 2023 The vLLM team. +# Copyright 2023 CTranslate2, and Michael Feil +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GPTBigCode model compatible with HuggingFace weights.""" +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import GPTBigCodeConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +class GPTBigCodeAttention(nn.Module): + + def __init__( + self, + config: GPTBigCodeConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = config.hidden_size + total_num_heads = config.num_attention_heads + self.tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + assert total_num_heads % self.tensor_model_parallel_world_size == 0 + self.num_heads = (total_num_heads // + self.tensor_model_parallel_world_size) + self.head_dim = self.hidden_size // total_num_heads + self.scale = self.head_dim**-0.5 + + self.multi_query = config.multi_query + if self.multi_query: + total_num_kv_heads = 1 + self.num_kv_heads = 1 + else: + total_num_kv_heads = total_num_heads + self.num_kv_heads = self.num_heads + self.kv_dim = self.head_dim * self.num_kv_heads + self.c_attn = QKVParallelLinear( + self.hidden_size, + self.head_dim, + total_num_heads, + total_num_kv_heads, + bias=True, + quant_config=quant_config, + ) + + self.c_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.scale, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.c_attn(hidden_states) + q, k, v = qkv.split( + [ + self.hidden_size // self.tensor_model_parallel_world_size, + self.kv_dim, self.kv_dim + ], + dim=-1, + ) + attn_output = self.attn(q, k, v) + attn_output, _ = self.c_proj(attn_output) + return attn_output + + +class GPTBigMLP(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: GPTBigCodeConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + hidden_size = config.hidden_size + self.c_fc = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=True, + quant_config=quant_config, + ) + self.c_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=True, + quant_config=quant_config, + ) + self.act = get_act_fn(config.activation_function) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.c_proj(hidden_states) + return hidden_states + + +class GPTBigCodeBlock(nn.Module): + + def __init__( + self, + config: GPTBigCodeConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + hidden_size = config.hidden_size + inner_dim = (config.n_inner if config.n_inner is not None else 4 * + hidden_size) + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = GPTBigCodeAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.attn") + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.mlp = GPTBigMLP(inner_dim, config, quant_config) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_output = self.attn(hidden_states=hidden_states, ) + # residual connection + hidden_states = attn_output + residual + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + return hidden_states + + +@support_torch_compile +class GPTBigCodeModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + assert not config.add_cross_attention + + self.embed_dim = config.hidden_size + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.wte = VocabParallelEmbedding(self.vocab_size, + self.embed_dim, + org_num_embeddings=config.vocab_size) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + self.start_layer, self.end_layer, self.h = make_layers( + config.num_hidden_layers, + lambda prefix: GPTBigCodeBlock( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.h", + ) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.n_embd)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) + hidden_states = inputs_embeds + self.wpe(position_ids) + else: + hidden_states = intermediate_tensors["hidden_states"] + + for layer in self.h[self.start_layer:self.end_layer]: + hidden_states = layer(hidden_states) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = {"c_attn": ["c_attn"]} + + # LoRA specific attributes + embedding_modules = { + "wte": "input_embeddings", + "lm_head": "output_embeddings", + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.transformer = GPTBigCodeModel(vllm_config=vllm_config, + prefix=prefix) + if self.config.tie_word_embeddings: + self.lm_head = self.transformer.wte + else: + self.lm_head = ParallelLMHead( + self.transformer.vocab_size, + self.transformer.embed_dim, + org_num_embeddings=self.config.vocab_size) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "lm_head.weight" in name: + continue + if ".attn.bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + # TODO (@robertgshaw2-neuralmagic): move to fp8 linear method + if "c_attn.input_scale" in name or "c_attn.weight_scale" in name: + weight_loader(param, loaded_weight, 'q') + weight_loader(param, loaded_weight, 'k') + weight_loader(param, loaded_weight, 'v') + else: + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py new file mode 100644 index 0000000..752aec0 --- /dev/null +++ b/vllm/model_executor/models/gpt_j.py @@ -0,0 +1,341 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py +# Copyright 2023 The vLLM team. +# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GPT-J model compatible with HuggingFace weights.""" +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import GPTJConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class GPTJAttention(nn.Module): + + def __init__( + self, + config: GPTJConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.total_num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.total_num_heads + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_size, + self.total_num_heads, + bias=False, + quant_config=quant_config, + ) + self.out_proj = RowParallelLinear( + config.hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config, + ) + + tp_world_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_world_size == 0 + self.num_heads = self.total_num_heads // tp_world_size + + scaling = self.head_size**-0.5 + assert getattr(config, "rotary", True) + assert config.rotary_dim % 2 == 0 + rope_theta = getattr(config, "rope_theta", 10000) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.rotary_emb = get_rope( + self.head_size, + rotary_dim=config.rotary_dim, + max_position=max_position_embeddings, + base=rope_theta, + is_neox_style=False, + ) + self.attn = Attention(self.num_heads, + self.head_size, + scaling, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + q, k = self.rotary_emb(position_ids, q, k) + attn_output = self.attn(q, k, v) + attn_output, _ = self.out_proj(attn_output) + return attn_output + + +class GPTJMLP(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: GPTJConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + hidden_size = config.n_embd + self.fc_in = ColumnParallelLinear( + hidden_size, + intermediate_size, + quant_config=quant_config, + ) + self.fc_out = RowParallelLinear( + intermediate_size, + hidden_size, + quant_config=quant_config, + ) + self.act = get_act_fn(config.activation_function) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc_in(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.fc_out(hidden_states) + return hidden_states + + +class GPTJBlock(nn.Module): + + def __init__( + self, + config: GPTJConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + inner_dim = (4 * config.n_embd + if config.n_inner is None else config.n_inner) + self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + self.attn = GPTJAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.attn") + self.mlp = GPTJMLP(inner_dim, config, quant_config) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_output = self.attn( + position_ids=position_ids, + hidden_states=hidden_states, + ) + mlp_output = self.mlp(hidden_states) + hidden_states = attn_output + mlp_output + residual + return hidden_states + + +@support_torch_compile +class GPTJModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + self.embed_dim = config.n_embd + self.wte = VocabParallelEmbedding( + config.vocab_size, + self.embed_dim, + ) + self.start_layer, self.end_layer, self.h = make_layers( + config.n_layer, + lambda prefix: GPTJBlock( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.h", + ) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.n_embd)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + else: + hidden_states = intermediate_tensors["hidden_states"] + for layer in self.h[self.start_layer:self.end_layer]: + hidden_states = layer(position_ids, hidden_states) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +class GPTJForCausalLM(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + assert not config.tie_word_embeddings + self.transformer = GPTJModel(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "transformer")) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.n_embd, + bias=True, + quant_config=quant_config, + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata, self.lm_head.bias) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "attn.bias" in name or "attn.masked_bias" in name: + continue + + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py new file mode 100644 index 0000000..582b2ff --- /dev/null +++ b/vllm/model_executor/models/gpt_neox.py @@ -0,0 +1,340 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GPT-NeoX model compatible with HuggingFace weights.""" +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import GPTNeoXConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class GPTNeoXAttention(nn.Module): + + def __init__( + self, + config: GPTNeoXConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.total_num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.total_num_heads + self.bias = getattr(config, "attention_bias", True) + + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = (self.total_num_heads // + tensor_model_parallel_world_size) + + self.query_key_value = QKVParallelLinear( + config.hidden_size, + self.head_size, + self.total_num_heads, + bias=self.bias, + quant_config=quant_config, + ) + self.dense = RowParallelLinear( + config.hidden_size, + config.hidden_size, + bias=self.bias, + quant_config=quant_config, + ) + scaling = self.head_size**-0.5 + rotary_dim = int(self.head_size * config.rotary_pct) + assert rotary_dim % 2 == 0 + rope_theta = getattr(config, "rope_theta", 10000) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.rotary_emb = get_rope( + self.head_size, + rotary_dim=rotary_dim, + max_position=max_position_embeddings, + base=rope_theta, + ) + self.attn = Attention(self.num_heads, + self.head_size, + scaling, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.query_key_value(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + q, k = self.rotary_emb(position_ids, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.dense(attn_output) + return output + + +class GPTNeoXMLP(nn.Module): + + def __init__( + self, + config: GPTNeoXConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.dense_h_to_4h = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + quant_config=quant_config, + ) + self.dense_4h_to_h = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + quant_config=quant_config, + ) + self.act = get_act_fn(config.hidden_act) + + def forward(self, hidden_states): + hidden_states, _ = self.dense_h_to_4h(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.dense_4h_to_h(hidden_states) + return hidden_states + + +class GPTNeoXLayer(nn.Module): + + def __init__( + self, + config: GPTNeoXConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.use_parallel_residual = config.use_parallel_residual + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.attention = GPTNeoXAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.attention") + self.mlp = GPTNeoXMLP(config, quant_config) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + attn_input = self.input_layernorm(hidden_states) + attn_output = self.attention( + position_ids=position_ids, + hidden_states=attn_input, + ) + + if self.use_parallel_residual: + # pseudocode: + # x = x + attn(ln1(x)) + mlp(ln2(x)) + mlp_input = self.post_attention_layernorm(hidden_states) + mlp_output = self.mlp(mlp_input) + hidden_states = mlp_output + attn_output + hidden_states + else: + # pseudocode: + # x = x + attn(ln1(x)) + # x = x + mlp(ln2(x)) + attn_output = attn_output + hidden_states + mlp_input = self.post_attention_layernorm(attn_output) + mlp_output = self.mlp(mlp_input) + hidden_states = mlp_output + attn_output + return hidden_states + + +@support_torch_compile +class GPTNeoXModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + + self.embed_in = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: GPTNeoXLayer( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers", + ) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_in(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + else: + hidden_states = intermediate_tensors["hidden_states"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(position_ids, hidden_states) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + hidden_states = self.final_layer_norm(hidden_states) + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if ("attention.bias" in name or "attention.masked_bias" in name + or "rotary_emb.inv_freq" in name): + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using OpenRLHF may include + # these tensors in the checkpoint. Skip them. + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + + if "query_key_value" in name: + # NOTE: GPT-NeoX's fused QKV's output_dim has the shape of + # (num_heads * 3 * head_size), while the + # required shape is (3 * num_heads * head_size). + # Thus, we need weight conversion. + output_dim = getattr(param, "output_dim", None) + num_heads = self.config.num_attention_heads + if output_dim is not None: + loaded_weight_shape = loaded_weight.shape + loaded_weight = loaded_weight.view( + loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + + loaded_weight_shape[output_dim + 1:]) + loaded_weight = loaded_weight.transpose( + output_dim, output_dim + 1) + loaded_weight = loaded_weight.reshape(loaded_weight_shape) + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class GPTNeoXForCausalLM(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.gpt_neox = GPTNeoXModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "gpt_neox")) + self.embed_out = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + ) + if self.config.tie_word_embeddings: + self.embed_out.weight = self.gpt_neox.embed_in.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.gpt_neox.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.gpt_neox.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.gpt_neox(input_ids, positions, + intermediate_tensors, inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.embed_out, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py new file mode 100644 index 0000000..eba8207 --- /dev/null +++ b/vllm/model_executor/models/granite.py @@ -0,0 +1,498 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only IBM Granite model compatible with HuggingFace weights.""" +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import GraniteConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (PPMissingLayer, is_pp_missing_parameter, make_layers, + maybe_prefix) + + +class GraniteMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj") + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class GraniteAttention(nn.Module): + + def __init__( + self, + config: GraniteConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr(config, "head_dim", + self.hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = config.attention_multiplier + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class GraniteDecoderLayer(nn.Module): + + def __init__( + self, + config: GraniteConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.residual_multiplier = config.residual_multiplier + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) + self.self_attn = GraniteAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + ) + + self.mlp = GraniteMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + hidden_states = residual + hidden_states * self.residual_multiplier + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states * self.residual_multiplier + return hidden_states + + +@support_torch_compile +class GraniteModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + else: + self.embed_tokens = PPMissingLayer() + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: GraniteDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers") + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + + hidden_states *= self.config.embedding_multiplier + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(positions, hidden_states) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states = self.norm(hidden_states) + return hidden_states + + +class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + self.quant_config = quant_config + + self.model = GraniteModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + logit_scale = getattr(config, "logit_scale", 1.0) + if hasattr(config, "logits_scaling"): + logit_scale /= config.logits_scaling + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + scale=logit_scale) + else: + self.lm_head = PPMissingLayer() + + self.sampler = get_sampler() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return model_output + + def compute_logits( + self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py new file mode 100644 index 0000000..5152539 --- /dev/null +++ b/vllm/model_executor/models/granitemoe.py @@ -0,0 +1,439 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GraniteMoe model.""" +from typing import Iterable, Optional, Set, Tuple + +import torch +from torch import nn +from transformers.models.granitemoe import GraniteMoeConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from . import mixtral +from .interfaces import SupportsLoRA, SupportsPP +from .utils import make_layers, maybe_prefix + + +class GraniteMoeMoE(nn.Module): + """A tensor-parallel MoE implementation for GraniteMoe that shards each + expert across all ranks. + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__(self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = ""): + super().__init__() + self.hidden_size = hidden_size + + # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear(hidden_size, + num_experts, + bias=False, + params_dtype=params_dtype, + quant_config=None, + prefix=f"{prefix}.gate") + + self.experts = FusedMoE(num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=tp_size, + prefix=f"{prefix}.experts") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states, router_logits) + return final_hidden_states.view(orig_shape) + + +class GraniteMoeAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + attention_multiplier: Optional[float] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = (attention_multiplier if attention_multiplier + is not None else self.head_dim**-1) + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=int(self.rope_theta), + is_neox_style=True, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class GraniteMoeDecoderLayer(nn.Module): + + def __init__( + self, + config: GraniteMoeConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 10000) + self.self_attn = GraniteMoeAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + attention_multiplier=config.attention_multiplier) + self.block_sparse_moe = GraniteMoeMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.block_sparse_moe") + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + self.residual_multiplier = config.residual_multiplier + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + hidden_states = residual + hidden_states * self.residual_multiplier + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.block_sparse_moe(hidden_states) + hidden_states = residual + hidden_states * self.residual_multiplier + + return hidden_states + + +@support_torch_compile +class GraniteMoeModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + self.embedding_multiplier = config.embedding_multiplier + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: GraniteMoeDecoderLayer( + config, cache_config, quant_config=quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers") + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + hidden_states *= self.embedding_multiplier + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(positions, hidden_states) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + fall_back_to_pt_during_load = False + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + self.quant_config = quant_config # Required by MixtralForCausalLM + + self.model = GraniteMoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + scale=1 / + self.config.logits_scaling) + + self.sampler = get_sampler() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + new_weights = {} + for n, p in weights: + if n.endswith('.block_sparse_moe.input_linear.weight'): + for e in range(p.size(0)): + w1_name = n.replace( + '.block_sparse_moe.input_linear.weight', + f".block_sparse_moe.experts.{e}.w1.weight") + w3_name = n.replace( + '.block_sparse_moe.input_linear.weight', + f".block_sparse_moe.experts.{e}.w3.weight") + w1_param, w3_param = p[e].chunk(2, dim=0) + assert w1_name not in new_weights + assert w3_name not in new_weights + new_weights[w1_name] = w1_param + new_weights[w3_name] = w3_param + elif n.endswith('.block_sparse_moe.output_linear.weight'): + for e in range(p.size(0)): + w2_name = n.replace( + '.block_sparse_moe.output_linear.weight', + f".block_sparse_moe.experts.{e}.w2.weight") + w2_param = p[e] + assert w2_name not in new_weights + new_weights[w2_name] = w2_param + elif n.endswith('.block_sparse_moe.router.layer.weight'): + gate_name = n.replace('.block_sparse_moe.router.layer.weight', + ".block_sparse_moe.gate.weight") + assert gate_name not in new_weights + new_weights[gate_name] = p + elif n == 'lm_head.weight' and self.config.tie_word_embeddings: + pass + else: + new_weights[n] = p + return mixtral.MixtralForCausalLM.load_weights(self, + new_weights.items()) diff --git a/vllm/model_executor/models/granitemoeshared.py b/vllm/model_executor/models/granitemoeshared.py new file mode 100644 index 0000000..7e2e4cd --- /dev/null +++ b/vllm/model_executor/models/granitemoeshared.py @@ -0,0 +1,343 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Inference-only GraniteMoeShared model. + +The architecture is the same as granitemoe but with the addition of shared +experts. +""" +from typing import Iterable, Optional, Set, Tuple + +import torch +from torch import nn +from transformers.models.granitemoeshared import GraniteMoeSharedConfig + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from . import mixtral +from .granitemoe import GraniteMoeAttention, GraniteMoeMoE +from .interfaces import SupportsLoRA, SupportsPP +from .utils import make_layers, maybe_prefix + + +class GraniteMoeSharedMLP(nn.Module): + + def __init__( + self, + config: GraniteMoeSharedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + self.input_size = config.hidden_size + self.hidden_size = config.shared_intermediate_size + self.input_linear = MergedColumnParallelLinear( + input_size=self.input_size, + output_sizes=[self.hidden_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.input_linear") + self.output_linear = RowParallelLinear( + self.hidden_size, + self.input_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.output_linear") + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.input_linear(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states, _ = self.output_linear(hidden_states) + return hidden_states + + +class GraniteMoeSharedDecoderLayer(nn.Module): + + def __init__( + self, + config: GraniteMoeSharedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 10000) + self.self_attn = GraniteMoeAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + attention_multiplier=config.attention_multiplier) + self.block_sparse_moe = GraniteMoeMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.block_sparse_moe") + self.shared_mlp = None if \ + getattr(config, 'shared_intermediate_size', 0) == 0 \ + else GraniteMoeSharedMLP( + config, + quant_config=quant_config, + prefix=f"{prefix}.shared_mlp" + ) + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + self.residual_multiplier = config.residual_multiplier + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + hidden_states = residual + hidden_states * self.residual_multiplier + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + if self.shared_mlp is None: + hidden_states = self.block_sparse_moe(hidden_states) + else: + # create a copy since block_sparse_moe modifies in-place + moe_hidden_states = hidden_states.clone() + moe_hidden_states = self.block_sparse_moe(moe_hidden_states) + hidden_states = moe_hidden_states + self.shared_mlp(hidden_states) + del moe_hidden_states + hidden_states = residual + hidden_states * self.residual_multiplier + + return hidden_states + + +@support_torch_compile +class GraniteMoeSharedModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + self.embedding_multiplier = config.embedding_multiplier + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: GraniteMoeSharedDecoderLayer( + config, cache_config, quant_config=quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers") + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + hidden_states *= self.embedding_multiplier + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states = layer(positions, hidden_states) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + fall_back_to_pt_during_load = False + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + self.quant_config = quant_config + + self.model = GraniteMoeSharedModel(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head")) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + scale=1 / + self.config.logits_scaling) + + self.sampler = get_sampler() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + new_weights = {} + for n, p in weights: + if n.endswith('.block_sparse_moe.input_linear.weight'): + for e in range(p.size(0)): + w1_name = n.replace( + '.block_sparse_moe.input_linear.weight', + f".block_sparse_moe.experts.{e}.w1.weight") + w3_name = n.replace( + '.block_sparse_moe.input_linear.weight', + f".block_sparse_moe.experts.{e}.w3.weight") + w1_param, w3_param = p[e].chunk(2, dim=0) + assert w1_name not in new_weights + assert w3_name not in new_weights + new_weights[w1_name] = w1_param + new_weights[w3_name] = w3_param + elif n.endswith('.block_sparse_moe.output_linear.weight'): + for e in range(p.size(0)): + w2_name = n.replace( + '.block_sparse_moe.output_linear.weight', + f".block_sparse_moe.experts.{e}.w2.weight") + w2_param = p[e] + assert w2_name not in new_weights + new_weights[w2_name] = w2_param + elif n.endswith('.block_sparse_moe.router.layer.weight'): + gate_name = n.replace('.block_sparse_moe.router.layer.weight', + ".block_sparse_moe.gate.weight") + assert gate_name not in new_weights + new_weights[gate_name] = p + elif n == 'lm_head.weight' and self.config.tie_word_embeddings: + pass + else: + new_weights[n] = p + return mixtral.MixtralForCausalLM.load_weights(self, + new_weights.items()) diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py new file mode 100644 index 0000000..2984f22 --- /dev/null +++ b/vllm/model_executor/models/gritlm.py @@ -0,0 +1,244 @@ +# SPDX-License-Identifier: Apache-2.0 + +from array import array +from typing import Optional, Union + +import torch +import torch.nn as nn +from xformers.ops.fmha.attn_bias import BlockDiagonalMask + +from vllm.attention.backends.xformers import XFormersImpl +from vllm.config import ModelConfig, VllmConfig +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.layers.pooler import PoolerHead +from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.model_executor.pooling_metadata import (PoolingMetadata, + PoolingTensors) +from vllm.sequence import (IntermediateTensors, PoolerOutput, + PoolingSequenceGroupOutput) +from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config + +from .interfaces import SupportsV0Only + +logger = init_logger(__name__) + + +class GritLMPooler(nn.Module): + + def __init__(self, model_config: ModelConfig): + super().__init__() + + self.model_config = model_config + + tokenizer = cached_tokenizer_from_config(self.model_config) + + # Collect the tokens needed for pattern matching. + # "▁<" is different from "_<". The former uses "▁" to indicate that + # the next token is the start of a word. + # "<0x0A>" is the newline token (i.e. "\n")." + self.token_ids = { + tok: tokenizer.convert_tokens_to_ids([tok])[0] + for tok in ["", "▁<", "<", "|", "embed", ">", "<0x0A>", "user"] + } + + def tokens_to_ids(tokens: list[str]) -> array: + return array("i", [self.token_ids[token] for token in tokens]) + + self.user_pattern_ids = tokens_to_ids( + ["▁<", "|", "user", "|", ">", "<0x0A>"]) + self.embed_newline_pattern_ids = tokens_to_ids( + ["<0x0A>", "<", "|", "embed", "|", ">", "<0x0A>"]) + self.embed_pattern_ids = tokens_to_ids( + ["▁<", "|", "embed", "|", ">", "<0x0A>"]) + + self.head = PoolerHead(normalize=True, softmax=False) + + def _find_array(self, arr: array, target: array, start_idx: int) -> int: + """ + Find the first occurrence of target in arr starting from start_idx. + + Args: + arr: The array to search within + target: The consecutive subsequence to find + start_idx: The starting index to search from + + Returns: + int: The index of the first occurrence of target in arr. + """ + if start_idx < 0: + raise ValueError("start_idx must be non-negative") + if not target or not arr: + raise ValueError("Empty arr or target not allowed") + + target_len = len(target) + for i in range(start_idx, len(arr) - target_len + 1): + if arr[i:i + target_len] == target: + return i + return -1 + + def _get_instruction_len(self, prompt_token_ids: array) -> int: + """ + Get the length of the instruction in the prompt. + + We do a pattern matching to find the instruction in the prompt, + and then return the length of the instruction. + + The pattern matching is done using integers instead of strings + because the prompt is given as a list of token IDs. + """ + + instruction_len = 0 + + # Return no instruction in case of missing BOS token. + if prompt_token_ids[0] != self.token_ids[""]: + logger.warning("BOS token not found in prompt, " + "thus using empty string for instruction. " + "GritLM requires BOS token in prompt.") + return instruction_len + + # If user pattern is found in the prompt, that means there should be + # a newline token before the embed pattern. + embed_pattern_ids = self.embed_pattern_ids + if self._find_array(prompt_token_ids, + self.user_pattern_ids, + start_idx=1) == 1: + embed_pattern_ids = self.embed_newline_pattern_ids + + # Find the embed pattern in the prompt. + found_embed_pattern_idx = self._find_array(prompt_token_ids, + embed_pattern_ids, + start_idx=1) + + if found_embed_pattern_idx != -1: + instruction_len = found_embed_pattern_idx + len(embed_pattern_ids) + else: + logger.warning("Query instruction not found in prompt, " + "thus using BOS token as instruction instead. " + "GritLM requires query instruction in prompt.") + instruction_len = 1 + + return instruction_len + + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + """ + Pool the hidden states by summing the embeddings of + non-instruction tokens. + """ + prompts_token_ids = [ + token_ids.prompt_token_ids_array + for _, token_ids in pooling_metadata.seq_data.items() + ] + + instruction_lens = torch.tensor( + [ + self._get_instruction_len(prompt_token_ids) + for prompt_token_ids in prompts_token_ids + ], + device=hidden_states.device, + ) + + prompt_lens = PoolingTensors.from_pooling_metadata( + pooling_metadata, hidden_states.device).prompt_lens + + mask = torch.zeros_like(hidden_states, dtype=torch.bool) + + start_idx = 0 + for prompt_len, instruction_len in zip(prompt_lens, instruction_lens): + end_idx = start_idx + prompt_len + mask[start_idx + instruction_len:end_idx] = True + start_idx = end_idx + + masked_hidden_states = hidden_states.masked_fill(~mask, 0.0) + + sum_embeddings = torch.zeros(len(prompt_lens), + hidden_states.size(1), + device=hidden_states.device) + + start_idx = 0 + for i, prompt_len in enumerate(prompt_lens): + end_idx = start_idx + prompt_len + sum_embeddings[i] = masked_hidden_states[start_idx:end_idx].sum( + dim=0) + start_idx = end_idx + + num_non_instruction_tokens = prompt_lens - instruction_lens + mean_embeddings = sum_embeddings / num_non_instruction_tokens.unsqueeze( + 1) + + pooled_data = self.head(mean_embeddings) + + pooled_outputs = [ + PoolingSequenceGroupOutput(data) for data in pooled_data + ] + + return PoolerOutput(outputs=pooled_outputs) + + +class GritLM(LlamaForCausalLM, SupportsV0Only): + """This class implements the embedding model for parasail-ai/GritLM-7B-vllm. + + The class inherits from LlamaForCausalLM and provides a custom pooling + layer. + + The main difference between the pooling layer in GritLM and the one in + LlamaForCausalLM is that GritLM ignores the query instruction in the prompt + when pooling the hidden states. + + Embedding prompts should be in the following format: + - With instruction: "<|user|>\nINSTRUCTION\n<|embed|>\nPROMPT". + - Without instruction: "<|embed|>\nPROMPT". + + Generation prompts should be in the following format: + - "<|user|>\nPROMPT\n<|assistant|>\n" + """ + + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + **kwargs, + ) -> None: + super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) + + self.runner_type = vllm_config.model_config.runner_type + + self._pooler = GritLMPooler(vllm_config.model_config) + + for layer in self.model.layers: + if self.runner_type == "pooling" and hasattr(layer, "self_attn"): + assert isinstance(layer.self_attn.attn.impl, XFormersImpl), ( + "GritLM embedding is only supported by XFormers backend, " + "which can be forced by VLLM_ATTENTION_BACKEND=XFORMERS") + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: + + # Change attention to non-causal for pooling tasks. + if self.runner_type == "pooling": + attn_metadata = get_forward_context().attn_metadata + assert attn_metadata.prefill_metadata.attn_bias is None + attn_metadata.prefill_metadata.attn_bias = [ + BlockDiagonalMask.from_seqlens(attn_metadata.seq_lens) + ] + + return super().forward( + input_ids=input_ids, + positions=positions, + **kwargs, + ) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py new file mode 100644 index 0000000..f2e8201 --- /dev/null +++ b/vllm/model_executor/models/grok1.py @@ -0,0 +1,565 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from +# https://github.com/ROCm/vllm/blob/cea7419f151cc50293a05b7fac8547f8f887c9f6/vllm/model_executor/models/grok1.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Grok1 model.""" +from typing import Iterable, List, Optional, Set, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +# Default Grok1-specific constants, overridden by config values if present +DEFAULT_ATTN_OUTPUT_MULTIPLIER = 0.08838834764831845 +DEFAULT_OUTPUT_MULTIPLIER_SCALE = 0.5773502691896257 +DEFAULT_EMBEDDING_MULTIPLIER_SCALE = 78.38367176906169 + + +class Grok1MoE(nn.Module): + """A tensor-parallel MoE implementation for Grok1 that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__(self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = ""): + super().__init__() + self.hidden_size = hidden_size + + # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear(hidden_size, + num_experts, + bias=False, + params_dtype=params_dtype, + quant_config=None, + prefix=f"{prefix}.gate") + + self.experts = FusedMoE(num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=tp_size, + activation="gelu", + prefix=f"{prefix}.experts") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + router_logits = 30.0 * F.tanh(router_logits / 30.0) + final_hidden_states = self.experts(hidden_states, router_logits) + return final_hidden_states.view(orig_shape) + + +class Grok1Attention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + config=None, # Added config parameter + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.config = config # Store config reference + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=int(self.rope_theta), + is_neox_style=True, + ) + + attn_logits_soft_cap = max( + getattr(config, "attn_logit_softcapping", 30.0), 0.0) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + logits_soft_cap=attn_logits_soft_cap, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + + # Apply attention output multiplier if specified in config + attn_multiplier = getattr(self.config, "attn_output_multiplier", + None) if self.config else None + if attn_multiplier is not None: + output = output * attn_multiplier + return output + + +class Grok1DecoderLayer(nn.Module): + + def __init__( + self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Check for fp8 quantization + self.use_fp8 = False + if quant_config is not None: + self.use_fp8 = getattr(quant_config, "is_fp8_w8a8", + lambda: False)() + if not self.use_fp8 and hasattr(quant_config, "is_fp8"): + self.use_fp8 = quant_config.is_fp8 + + # Requires transformers > 4.32.0 + # Default rope_theta value if not in config + rope_theta = 10000 + self.attn = Grok1Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + config=config) # Pass config to Grok1Attention + + # Grok1 uses "num_experts" in its config + num_experts = getattr(config, "num_experts", 8) + num_experts_per_tok = getattr(config, "num_experts_per_tok", 2) + + self.moe_block = Grok1MoE(num_experts=num_experts, + top_k=num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.moe_block") + + self.pre_attn_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attn_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_moe_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_moe_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.pre_attn_norm(hidden_states) + else: + hidden_states, residual = self.pre_attn_norm( + hidden_states, residual) + + hidden_states = self.attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Post attention normalization + hidden_states = self.post_attn_norm(hidden_states) + + # MoE block with normalization + hidden_states, residual = self.pre_moe_norm(hidden_states, residual) + hidden_states = self.moe_block(hidden_states) + hidden_states = self.post_moe_norm(hidden_states) + + return hidden_states, residual + + +@support_torch_compile +class Grok1Model(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + self.embedding_multiplier_scale = getattr( + config, "embedding_multiplier_scale", + DEFAULT_EMBEDDING_MULTIPLIER_SCALE) + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Grok1DecoderLayer( + config, cache_config, quant_config=quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers") + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + hidden_states = hidden_states * self.embedding_multiplier_scale + return hidden_states + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + fall_back_to_pt_during_load = False + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + self.quant_config = quant_config + + self.model = Grok1Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + self.output_multiplier_scale = getattr( + config, "output_multiplier_scale", DEFAULT_OUTPUT_MULTIPLIER_SCALE) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + self.output_multiplier_scale) + + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + # Map Grok1's unique expert parameter names to standard names + # Grok1 uses "num_experts" in its config + num_experts = getattr(self.config, "num_experts", 8) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="linear", # Grok1 specific + ckpt_down_proj_name="linear_1", # Grok1 specific + ckpt_up_proj_name="linear_v", # Grok1 specific + num_experts=num_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if name.endswith("scale"): + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + # Handle Grok1-specific norm.scale naming + if "norm.scale" in name: + name = name.replace("scale", "weight") + + # Skip lm_head when tie_word_embeddings is True + if "lm_head" in name and self.config.tie_word_embeddings: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py new file mode 100644 index 0000000..3b2ad69 --- /dev/null +++ b/vllm/model_executor/models/h2ovl.py @@ -0,0 +1,565 @@ +# SPDX-License-Identifier: Apache-2.0 + +# adapted from https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/modeling_h2ovl_chat.py +# https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/image_process.py +# -------------------------------------------------------- +# H2OVL-Mississippi +# Copyright (c) 2024 H2O.AI +# Licensed under Apache 2.0 License [see LICENSE for details] +# -------------------------------------------------------- +from collections.abc import Mapping, Sequence +from typing import Optional, Union + +import torch +from PIL import Image +from transformers import PretrainedConfig + +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalKwargs +from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, + MultiModalDataItems) +from vllm.multimodal.processing import (PromptReplacement, PromptUpdate, + PromptUpdateDetails) +from vllm.transformers_utils.tokenizer import AnyTokenizer + +from .intern_vit import InternVisionModel +from .internvl import (IMG_CONTEXT, IMG_END, IMG_START, + BaseInternVLProcessingInfo, BaseInternVLProcessor, + InternVLChatModel, InternVLDummyInputsBuilder, + InternVLMultiModalProcessor, build_transform, + find_closest_aspect_ratio, get_internvl_target_ratios) + + +def resolve_h2ovl_min_max_num( + *, + min_dynamic_patch: int, + max_dynamic_patch: int, + dynamic_image_size: bool, + use_thumbnail: bool, +) -> tuple[int, int]: + min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1 + max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1 + + if use_thumbnail and max_dynamic_patch != 1: + max_dynamic_patch += 1 + + return min_dynamic_patch, max_dynamic_patch + + +def get_h2ovl_target_ratios( + min_num: int, + max_num: int, + *, + prior_aspect_ratio: Optional[tuple[int, int]], +) -> list[tuple[int, int]]: + target_ratios = get_internvl_target_ratios(min_num, max_num) + + # if prior_aspect_ratio is provided, filter the target ratios + if prior_aspect_ratio is not None: + target_ratios = [ + ratio for ratio in target_ratios if prior_aspect_ratio[0] % + ratio[0] != 0 and prior_aspect_ratio[1] % ratio[1] != 0 + ] + + return target_ratios + + +# modified to include blocks generated in second pass +def calculate_h2ovl_targets( + *, + orig_width: int, + orig_height: int, + target_ratios: list[tuple[int, int]], + image_size: int, + use_thumbnail: bool, +) -> tuple[int, int, int, tuple[int, int]]: + aspect_ratio = orig_width / orig_height + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio( + aspect_ratio, + target_ratios, + width=orig_width, + height=orig_height, + image_size=image_size, + ) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # add thumbnail image if num_blocks != 1 + if use_thumbnail and blocks != 1: + blocks += 1 + + return blocks, target_width, target_height, target_aspect_ratio + + +# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B +# refactored to handle prior_aspect_ratio +def dynamic_preprocess_h2ovl( + image: Image.Image, + *, + target_ratios: list[tuple[int, int]], + image_size: int, + use_thumbnail: bool, +) -> tuple[list[Image.Image], tuple[int, int]]: + orig_width, orig_height = image.size + + # calculate the number of blocks without thumbnail + ( + blocks, + target_width, + target_height, + target_aspect_ratio, + ) = calculate_h2ovl_targets( + orig_width=orig_width, + orig_height=orig_height, + target_ratios=target_ratios, + image_size=image_size, + use_thumbnail=False, + ) + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + + assert len(processed_images) == blocks + + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + + return processed_images, target_aspect_ratio + + +def _preprocess_image( + image: Image.Image, + *, + input_size: int, + min_num: int, + max_num: int, + use_thumbnail: bool, + prior_aspect_ratio: Optional[tuple[int, int]], +) -> tuple[torch.Tensor, tuple[int, int]]: + target_ratios = get_h2ovl_target_ratios( + min_num, + max_num, + prior_aspect_ratio=prior_aspect_ratio, + ) + + transform = build_transform(input_size=input_size) + images, target_aspect_ratio = dynamic_preprocess_h2ovl( + image, + image_size=input_size, + use_thumbnail=use_thumbnail, + target_ratios=target_ratios, + ) + + pixel_values = torch.stack([transform(image) for image in images]) + return pixel_values, target_aspect_ratio + + +# refactored to use the _preprocess_image function +def image_to_pixel_values_h2ovl( + image: Image.Image, + *, + input_size: int, + min_num: int, + max_num: int, + use_thumbnail: bool, + use_msac: bool, +) -> torch.Tensor: + # when MSAC is turned on, we need to process the image twice + if use_msac: + # first pass + pixel_values1, aspect_ratio1 = _preprocess_image( + image, + input_size=input_size, + min_num=1, + max_num=max_num, + use_thumbnail=True, + prior_aspect_ratio=None, + ) + # second pass + pixel_values2, _ = _preprocess_image( + image, + input_size=input_size, + min_num=3, + max_num=max_num, + use_thumbnail=True, + prior_aspect_ratio=aspect_ratio1, + ) + # combine pixel values + pixel_values = torch.cat( + [pixel_values2[:-1], pixel_values1[:-1], pixel_values2[-1:]], 0) + + else: + pixel_values, _ = _preprocess_image( + image, + input_size=input_size, + min_num=min_num, + max_num=max_num, + use_thumbnail=use_thumbnail, + prior_aspect_ratio=None, + ) + + return pixel_values + + +class H2OVLProcessor(BaseInternVLProcessor): + + def __init__( + self, + config: PretrainedConfig, + tokenizer: AnyTokenizer, + *, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + use_msac: Optional[bool] = None, + ) -> None: + super().__init__( + config, + tokenizer, + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + ) + + if use_msac is None: + use_msac = config.use_msac + assert isinstance(use_msac, bool) + + self.use_msac = use_msac + + @property + def image_token_id(self) -> int: + return self.tokenizer.get_vocab()[IMG_CONTEXT] + + def get_image_repl( + self, + feature_size: int, + num_patches: Optional[int], + ) -> PromptUpdateDetails[str]: + repl_features = IMG_CONTEXT * feature_size + repl_full = IMG_START + repl_features + IMG_END + + return PromptUpdateDetails(full=repl_full, features=repl_features) + + def resolve_min_max_num( + self, + *, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + use_thumbnail: Optional[bool] = None, + ) -> tuple[int, int]: + min_dynamic_patch = (self.min_dynamic_patch if min_dynamic_patch + is None else min_dynamic_patch) + max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch + is None else max_dynamic_patch) + dynamic_image_size = (self.dynamic_image_size if dynamic_image_size + is None else dynamic_image_size) + use_thumbnail = (self.use_thumbnail + if use_thumbnail is None else use_thumbnail) + + return resolve_h2ovl_min_max_num( + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + use_thumbnail=use_thumbnail, + ) + + def resolve_target_ratios( + self, + *, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + use_thumbnail: Optional[bool] = None, + prior_aspect_ratio: Optional[tuple[int, int]] = None, + override_min_num: Optional[int] = None, + ) -> list[tuple[int, int]]: + min_num, max_num = self.resolve_min_max_num( + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + use_thumbnail=use_thumbnail, + ) + if override_min_num is not None: + min_num = override_min_num + + return get_h2ovl_target_ratios( + min_num, + max_num, + prior_aspect_ratio=prior_aspect_ratio, + ) + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + use_msac: Optional[bool] = None, + ) -> int: + use_msac = (self.use_msac if use_msac is None else use_msac) + + use_thumbnail = self.use_thumbnail + + if use_msac: + target_ratios_1 = self.resolve_target_ratios( + use_thumbnail=False, # Applied in calculate_targets + override_min_num=1, + ) + num_patches_1, _, _, aspect_ratio_1 = calculate_h2ovl_targets( + orig_width=image_width, + orig_height=image_height, + image_size=self.image_size, + target_ratios=target_ratios_1, + use_thumbnail=True, + ) + + target_ratios_2 = self.resolve_target_ratios( + use_thumbnail=False, # Applied in calculate_targets + prior_aspect_ratio=aspect_ratio_1, + override_min_num=3, + ) + num_patches_2, _, _, _ = calculate_h2ovl_targets( + orig_width=image_width, + orig_height=image_height, + image_size=self.image_size, + target_ratios=target_ratios_2, + use_thumbnail=True, + ) + + num_patches = num_patches_1 + num_patches_2 - 1 + else: + target_ratios = self.resolve_target_ratios( + use_thumbnail=False, # Applied in calculate_targets + ) + num_patches, _, _, _ = calculate_h2ovl_targets( + orig_width=image_width, + orig_height=image_height, + image_size=self.image_size, + target_ratios=target_ratios, + use_thumbnail=use_thumbnail, + ) + + return num_patches * self.num_image_token + + def _images_to_pixel_values_lst( + self, + images: list[Image.Image], + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + ) -> list[torch.Tensor]: + use_msac = self.use_msac if len(images) == 1 else False + + min_num, max_num = self.resolve_min_max_num( + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + use_thumbnail=False, # Applied in image_to_pixel_values + ) + + return [ + image_to_pixel_values_h2ovl( + image, + input_size=self.image_size, + min_num=min_num, + max_num=max_num, + use_thumbnail=self.use_thumbnail, + use_msac=use_msac, + ) for image in images + ] + + +class H2OVLProcessingInfo(BaseInternVLProcessingInfo): + + def get_hf_processor( + self, + *, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + **kwargs: object, + ) -> H2OVLProcessor: + if min_dynamic_patch is not None: + kwargs["min_dynamic_patch"] = min_dynamic_patch + if max_dynamic_patch is not None: + kwargs["max_dynamic_patch"] = max_dynamic_patch + if dynamic_image_size is not None: + kwargs["dynamic_image_size"] = dynamic_image_size + + return self.ctx.init_processor( + H2OVLProcessor, + config=self.get_hf_config(), + tokenizer=self.get_tokenizer(), + **kwargs, + ) + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + max_tokens_one_image = self.get_max_image_tokens(use_msac=None) + if mm_counts.get("image", 0) <= 1: + max_tokens_per_image = max_tokens_one_image + else: + max_tokens_per_image = self.get_max_image_tokens(use_msac=False) + + return {"image": max_tokens_per_image} + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + processor: Optional[H2OVLProcessor], + use_msac: Optional[bool] = None, + ) -> int: + if processor is None: + processor = self.get_hf_processor() + + return processor.get_num_image_tokens( + image_width=image_width, + image_height=image_height, + use_msac=use_msac, + ) + + def get_max_image_tokens(self, use_msac: Optional[bool] = None) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + return self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + processor=None, + use_msac=use_msac, + ) + + +class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo] + ): + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + if "image_num_patches" in out_mm_kwargs: + image_num_patches = out_mm_kwargs["image_num_patches"] + assert isinstance(image_num_patches, torch.Tensor) + image_num_patches = image_num_patches.tolist() + elif "image_embeds" in out_mm_kwargs: + # TODO: Use image size information in dictionary embedding inputs + # to compute num_patches (similar to Qwen2-VL) + image_num_patches = [None] * len(out_mm_kwargs["image_embeds"]) + else: + image_num_patches = [] + + num_images = len(image_num_patches) + + def get_replacement_internvl(item_idx: int): + images = mm_items.get_items( + "image", (ImageEmbeddingItems, ImageProcessorItems)) + + if isinstance(images, ImageEmbeddingItems): + feature_size = images.get_feature_size(item_idx) + else: + image_size = images.get_image_size(item_idx) + feature_size = self.info.get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + processor=hf_processor, + use_msac=None if num_images == 1 else False, + ) + + num_patches = image_num_patches[item_idx] + if num_patches is not None: + assert isinstance(num_patches, int) + + return hf_processor.get_image_repl(feature_size, num_patches) + + return [ + PromptReplacement( + modality="image", + target="", + replacement=get_replacement_internvl, + ) + ] + + def _cached_apply_hf_processor( + self, + prompt: Union[str, list[int]], + mm_data_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> tuple[list[int], MultiModalKwargs, bool]: + # The processor logic is different for len(images) <= 1 vs > 1 + # Since the processing cache assumes that the processor output is + # invariant of how many images are passed per prompt, we only + # perform caching for the most common case + if mm_data_items.get_count("image", strict=False) > 1: + # This code path corresponds to the cache being disabled + return self._apply_hf_processor_main( + prompt=prompt, + mm_items=mm_data_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + enable_hf_prompt_update=True, + ) + + return super()._cached_apply_hf_processor( + prompt=prompt, + mm_data_items=mm_data_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + + +@MULTIMODAL_REGISTRY.register_processor( + H2OVLMultiModalProcessor, + info=H2OVLProcessingInfo, + dummy_inputs=InternVLDummyInputsBuilder) +class H2OVLChatModel(InternVLChatModel): + + def _init_vision_model( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig], + *, + is_mono: bool, + prefix: str, + ): + if not is_mono: + vision_feature_layer = config.select_layer + if vision_feature_layer < 0: + num_hidden_layers = (config.vision_config.num_hidden_layers + + vision_feature_layer + 1) + else: + num_hidden_layers = vision_feature_layer + 1 + + return InternVisionModel( + config.vision_config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers, + prefix=prefix, + ) + else: + msg = "Monolith mode is not applicable to H2OVL" + raise NotImplementedError(msg) diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py new file mode 100644 index 0000000..cb0379c --- /dev/null +++ b/vllm/model_executor/models/idefics2_vision_model.py @@ -0,0 +1,387 @@ +# SPDX-License-Identifier: Apache-2.0 + +# adapted from https://github.com/huggingface/transformers/blob/v4.43.2/src/transformers/models/idefics2/modeling_idefics2.py +# Copyright 2024 The vLLM team. +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Idefics2 model.""" + +from typing import Iterable, Optional, Set, Tuple + +import torch +from torch import nn +from transformers.models.idefics2.configuration_idefics2 import ( + Idefics2Config, Idefics2VisionConfig) + +from vllm.attention.layer import MultiHeadAttention +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + + +class Idefics2VisionEmbeddings(nn.Module): + """ + This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings + ` to enable images of variable + resolution. + + The modifications are adapted from [Patch n' Pack: NaViT, a Vision + Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304) + which allows treating images in their native aspect ratio and without the + need to resize them to the same fixed size. In particular, we start from the + original pre-trained SigLIP model(which uses images of fixed-size square + images) and adapt it by training on images of variable resolutions. + """ + + def __init__(self, config: Idefics2VisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + self.num_patches_per_side = self.image_size // self.patch_size + self.num_patches = self.num_patches_per_side**2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, + self.embed_dim) + + def forward(self, + pixel_values: torch.FloatTensor, + patch_attention_mask: torch.BoolTensor, + tgt_sizes: Optional[torch.IntTensor] = None) -> torch.Tensor: + batch_size, _, max_im_h, max_im_w = pixel_values.shape + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(target_dtype)) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + max_nb_patches_h, max_nb_patches_w = ( + max_im_h // self.patch_size, + max_im_w // self.patch_size, + ) + boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, + 1 / self.num_patches_per_side) + position_ids = torch.full(size=(batch_size, + max_nb_patches_h * max_nb_patches_w), + fill_value=0) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + + if tgt_sizes is not None: + nb_patches_h = tgt_sizes[batch_idx][0] + nb_patches_w = tgt_sizes[batch_idx][1] + else: + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + bucket_coords_h = torch.bucketize(fractional_coords_h, + boundaries, + right=True) + bucket_coords_w = torch.bucketize(fractional_coords_w, + boundaries, + right=True) + pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + + bucket_coords_w).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + position_ids = position_ids.to(self.position_embedding.weight.device) + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + +class Idefics2VisionAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: Idefics2VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" # noqa: E501 + f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.qkv_proj = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.out_proj = RowParallelLinear( + self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + self.attn = MultiHeadAttention(self.num_heads_per_partition, + self.head_dim, self.scale) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj( + hidden_states + ) # batch_size, q_len, 3 * num_heads_per_partition * head_dim + query_states, key_states, value_states = qkv.chunk(3, dim=-1) + out = self.attn(query_states, key_states, value_states) + attn_output, _ = self.out_proj(out) + return attn_output + + +class Idefics2VisionMLP(nn.Module): + + def __init__( + self, + config: Idefics2VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + return hidden_states + + +class Idefics2EncoderLayer(nn.Module): + + def __init__( + self, + config: Idefics2Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = Idefics2VisionAttention(config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn") + self.layer_norm1 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + self.mlp = Idefics2VisionMLP(config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + self.layer_norm2 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + + """ + residual = hidden_states + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn(hidden_states) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class Idefics2Encoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention + layers. Each layer is a + [`Idefics2EncoderLayer`]. + + Args: + config: Idefics2Config + """ + + def __init__( + self, + config: Idefics2Config, + quant_config: Optional[QuantizationConfig] = None, + *, + num_hidden_layers_override: Optional[int] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + + if num_hidden_layers_override is None: + num_hidden_layers = config.num_hidden_layers + else: + num_hidden_layers = num_hidden_layers_override + + self.layers = nn.ModuleList([ + Idefics2EncoderLayer(config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}") + for layer_idx in range(num_hidden_layers) + ]) + + def forward( + self, + inputs_embeds: torch.Tensor, + ) -> torch.Tensor: + r""" + Args: + inputs_embeds (torch.Tensor): + Optionally, instead of passing `input_ids` you can choose to + directly pass an embedded representation. + This is useful if you want more control over how to convert + `input_ids` indices into associated vectorsthan the model's + internal embedding lookup matrix. + """ + hidden_states = inputs_embeds + for encoder_layer in self.layers: + layer_outputs = encoder_layer(hidden_states) + hidden_states = layer_outputs + return hidden_states + + +class Idefics2VisionTransformer(nn.Module): + + def __init__( + self, + config: Idefics2VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + num_hidden_layers_override: Optional[int] = None, + require_post_norm: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + + embed_dim = config.hidden_size + self.config = config + self.embeddings = Idefics2VisionEmbeddings(config) + self.encoder = Idefics2Encoder( + config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_override, + prefix=f"{prefix}.encoder") + + num_hidden_layers = config.num_hidden_layers + if len(self.encoder.layers) > config.num_hidden_layers: + raise ValueError( + f"The original encoder only has {num_hidden_layers} " + f"layers, but you requested {len(self.encoder.layers)} layers." + ) + + self.require_post_norm = require_post_norm + self.post_layernorm = nn.LayerNorm( + embed_dim, + eps=config.layer_norm_eps, + ) if require_post_norm else nn.Identity() + + def get_input_embeddings(self): + return self.embeddings + + def forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + tgt_sizes: Optional[torch.IntTensor] = None, + ) -> torch.Tensor: + hidden_states = self.embeddings( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + tgt_sizes=tgt_sizes, + ) + encoder_outputs = self.encoder(hidden_states) + last_hidden_state = self.post_layernorm(encoder_outputs) + return last_hidden_state + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + layer_count = len(self.encoder.layers) + + for name, loaded_weight in weights: + # skip pooling header + if name.startswith("head."): + continue + + # post_layernorm is optional + if (name.startswith("post_layernorm.") + and not self.require_post_norm): + continue + + # omit layers when num_hidden_layers_override is set + if name.startswith("encoder.layers."): + layer_idx = int(name.split(".")[2]) + if layer_idx >= layer_count: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py new file mode 100644 index 0000000..da4a443 --- /dev/null +++ b/vllm/model_executor/models/idefics3.py @@ -0,0 +1,831 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Idefics3 model compatible with HuggingFace weights.""" + +import math +from collections.abc import Iterable, Mapping, Sequence +from typing import Dict, Literal, Optional, Set, Tuple, TypedDict, Union + +import torch +from torch import nn +from transformers import (BatchFeature, Idefics3Config, Idefics3ImageProcessor, + Idefics3Processor) + +from vllm.config import VllmConfig +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal.parse import ImageProcessorItems, ImageSize +# yapf conflicts with isort for this block +# yapf: disable +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, + MultiModalDataItems, + MultiModalFieldConfig, + PromptReplacement, PromptUpdate, + encode_tokens) +# yapf: enable +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors + +# yapf: disable +from .idefics2_vision_model import ( + Idefics2VisionTransformer as Idefics3VisionTransformer) +# yapf: enable +from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal +from .llama import LlamaModel +from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, + merge_multimodal_embeddings) +from .vision import scatter_patch_features, select_patch_features + + +class Idefics3ImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + pixel_values: torch.Tensor + """ + Shape: `(batch_size * num_images * num_patches, + num_channels, height, width)` + """ + pixel_attention_mask: torch.Tensor + + num_patches: torch.Tensor + """Shape: `(batch_size * num_images)`""" + + embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] + """ + A boolean mask indicating which image embeddings correspond + to patch tokens. + + Shape: `(batch_size * num_images, num_embeds)` + """ + + +class Idefics3ImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: torch.Tensor + """ + Shape: `(batch_size * num_images, image_feature_size, hidden_size)` + `hidden_size` must match the hidden size of language model backbone. + """ + + embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] + """ + A boolean mask indicating which image embeddings correspond + to patch tokens. + + Shape: `(batch_size * num_images, num_embeds)` + """ + + +ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs] + + +class Idefics3ProcessingInfo(BaseProcessingInfo): + + def get_hf_processor( + self, + *, + size: Optional[Dict[str, int]] = None, + **kwargs: object, + ) -> Idefics3Processor: + if size is not None: + kwargs["size"] = size + + return self.ctx.get_hf_processor(Idefics3Processor, **kwargs) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"image": self.get_max_image_tokens()} + + def _resize_output_size(self, + *, + height: int, + width: int, + max_len: Optional[int] = None, + min_len: int = 1, + max_size: Optional[int] = None) -> tuple[int, int]: + # Set default value for max_len if not provided + max_len = max(height, width) if max_len is None else max_len + aspect_ratio = width / height + + # Handle the maximum size constraint + if max_size is not None: + max_len = min(max_len, max_size) + + # Adjust dimensions according to the aspect ratio + if width >= height: + width = max_len + height = int(width / aspect_ratio) + else: + height = max_len + width = int(height * aspect_ratio) + + # Ensure both width and height are even (if needed) + height += height % 2 + width += width % 2 + + # Ensure dimensions are not smaller than the minimum length + height = max(height, min_len) + width = max(width, min_len) + + return height, width + + def _get_resize_output_image_size( + self, + *, + image_width: int, + image_height: int, + resolution_max_side: int, + ) -> tuple[int, int]: + hf_processor = self.get_hf_processor() + image_processor: Idefics3ImageProcessor = hf_processor.image_processor + max_image_size = image_processor.size['longest_edge'] + if resolution_max_side > max_image_size: + raise ValueError( + "`resolution_max_side` cannot be larger than `max_image_size`") + + height, width = image_height, image_width + + # Find the output size, when rescaling the longest edge to max_len and + # preserving the aspect ratio + height, width = self._resize_output_size(height=height, + width=width, + max_len=resolution_max_side) + return height, width + + def _get_image_feature_grid_size( + self, + *, + image_width: int, + image_height: int, + processor: Optional[Idefics3Processor], + ) -> tuple[int, int]: + if processor is None: + processor = self.get_hf_processor() + + image_processor: Idefics3ImageProcessor = processor.image_processor + + max_image_size = image_processor.max_image_size['longest_edge'] + size = image_processor.size['longest_edge'] + assert size % max_image_size == 0, ( + "`longest_edge` in image_processor's `size` must be divisible by " + "`longest_edge` in `max_image_size`, this may be caused by " + "incorrect mm_kwargs override.") + + resized_height, resized_width = self._get_resize_output_image_size( + image_width=image_width, + image_height=image_height, + resolution_max_side=size, + ) + if resized_height > max_image_size or resized_width > max_image_size: + grid_h = math.ceil(resized_height / max_image_size) + grid_w = math.ceil(resized_width / max_image_size) + else: + grid_h = grid_w = 0 + return grid_w, grid_h + + def get_num_patches( + self, + *, + image_width: int, + image_height: int, + processor: Optional[Idefics3Processor], + ) -> int: + grid_w, grid_h = self._get_image_feature_grid_size( + image_width=image_width, + image_height=image_height, + processor=processor, + ) + + return grid_w * grid_h + 1 + + def get_image_repl( + self, + *, + image_width: int, + image_height: int, + processor: Optional[Idefics3Processor], + ) -> str: + if processor is None: + processor = self.get_hf_processor() + + image_token = processor.image_token.content + fake_image_token = processor.fake_image_token.content + global_img_token = processor.global_image_tag + image_seq_len = processor.image_seq_len + grid_placeholder = "" + + p_img = image_token * image_seq_len + global_img_placeholder = fake_image_token + global_img_token + p_img + tile_img_placeholder = fake_image_token + grid_placeholder + p_img + + grid_w, grid_h = self._get_image_feature_grid_size( + image_width=image_width, + image_height=image_height, + processor=processor, + ) + if grid_w == 0 and grid_h == 0: + return global_img_placeholder + fake_image_token + + tiles_placeholder = list[str]() + for i in range(grid_h): + for j in range(grid_w): + placeholder_per_tile = tile_img_placeholder.format(n_h=i + 1, + n_w=j + 1) + tiles_placeholder.append(placeholder_per_tile) + # Add line break if it is the last tile in the row + if j == grid_w - 1: + tiles_placeholder.append("\n") + + return "".join([ + *tiles_placeholder, + "\n", + global_img_placeholder, + fake_image_token, + ]) + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + processor: Optional[Idefics3Processor], + ) -> int: + tokenizer = self.get_tokenizer() + image_repl = self.get_image_repl( + image_width=image_width, + image_height=image_height, + processor=processor, + ) + + image_repl_tokens = encode_tokens( + tokenizer, + image_repl, + add_special_tokens=False, + ) + return len(image_repl_tokens) + + def get_image_size_with_most_features(self) -> ImageSize: + processor = self.get_hf_processor() + image_processor: Idefics3ImageProcessor = processor.image_processor + + return ImageSize( + width=image_processor.size["longest_edge"], + height=image_processor.size["longest_edge"], + ) + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + return self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + processor=None, + ) + + +class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo] + ): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + hf_processor = self.info.get_hf_processor() + image_processor: Idefics3ImageProcessor = hf_processor.image_processor + longest_edge = image_processor.max_image_size['longest_edge'] + image_token = hf_processor.image_token.content + + mm_data = { + "image": + self._get_dummy_images(width=longest_edge, + height=longest_edge, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text=image_token * num_images, + mm_data=mm_data, + ) + + +class Idefics3MultiModalProcessor( + BaseMultiModalProcessor[Idefics3ProcessingInfo]): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + # Text-only input not supported in composite processor + if not (images := mm_data.get("images", [])): + prompt_ids = self.info.get_tokenizer().encode(prompt) + prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + processed_outputs = super()._call_hf_processor( + prompt, + mm_data, + mm_kwargs, + ) + + parsed_images = (self._get_data_parser().parse_mm_data({ + "image": images + }).get_items("image", ImageProcessorItems)) + image_sizes = [ + parsed_images.get_image_size(i) for i in range(len(parsed_images)) + ] + hf_processor = self.info.get_hf_processor(**mm_kwargs) + + image_repl_features = [ + self.info.get_image_repl(image_width=size.width, + image_height=size.height, + processor=hf_processor) + for size in image_sizes + ] + + tokenizer = self.info.get_tokenizer() + image_repls_feature_tokens = [ + tokenizer.encode(image_repl, add_special_tokens=False) + for image_repl in image_repl_features + ] + + vocab = tokenizer.get_vocab() + image_token_id = vocab[hf_processor.image_token.content] + + embed_is_patch = [ + torch.tensor(image_repl_tokens) == image_token_id + for image_repl_tokens in image_repls_feature_tokens + ] + processed_outputs["embed_is_patch"] = embed_is_patch + + num_patches = [ + self.info.get_num_patches( + image_width=size.width, + image_height=size.height, + processor=hf_processor, + ) for size in image_sizes + ] + processed_outputs["num_patches"] = torch.tensor(num_patches) + + # Remove the extra batch dimension + processed_outputs["pixel_values"].squeeze_(0) + processed_outputs["pixel_attention_mask"].squeeze_(0) + + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + num_patches = hf_inputs.get("num_patches", torch.empty(0)) + + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", num_patches), + pixel_attention_mask=MultiModalFieldConfig.flat_from_sizes( + "image", num_patches), + image_embeds=MultiModalFieldConfig.batched("image"), + num_patches=MultiModalFieldConfig.batched("image"), + embed_is_patch=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_token = hf_processor.image_token.content + + def get_replacement_idefics3(item_idx: int) -> str: + images = mm_items.get_items("image", ImageProcessorItems) + + image_size = images.get_image_size(item_idx) + + return self.info.get_image_repl( + image_width=image_size.width, + image_height=image_size.height, + processor=hf_processor, + ) + + return [ + PromptReplacement( + modality="image", + target=image_token, + replacement=get_replacement_idefics3, + ) + ] + + +class Idefics3SimpleMLP(nn.Module): + + def __init__( + self, + config: Idefics3Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + input_size = config.vision_config.hidden_size * (config.scale_factor** + 2) + output_size = config.text_config.hidden_size + self.proj = ReplicatedLinear( + input_size, + output_size, + bias=False, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "proj"), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out, _ = self.proj(x) + return out + + +class Idefics3Connector(nn.Module): + + def __init__( + self, + config: Idefics3Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.scale_factor = config.scale_factor + self.modality_projection = Idefics3SimpleMLP( + config, + quant_config, + prefix=maybe_prefix(prefix, "modality_projection"), + ) + + def pixel_shuffle(self, + x: torch.Tensor, + scale_factor: int = 2) -> torch.Tensor: + bsz, seq, embed_dim = x.size() + height = width = int(seq**0.5) + x = x.view(bsz, height, width, embed_dim) + x = x.view(bsz, height, int(width / scale_factor), + embed_dim * scale_factor) + x = x.permute(0, 2, 1, 3) + x = x.reshape( + bsz, + int(width / scale_factor), + int(height / scale_factor), + embed_dim * (scale_factor**2), + ) + x = x.permute(0, 2, 1, 3) + x = x.reshape(bsz, int(seq / (scale_factor**2)), + embed_dim * (scale_factor**2)) + return x + + def forward(self, image_hidden_states: torch.Tensor) -> torch.Tensor: + image_hidden_states = self.pixel_shuffle(image_hidden_states, + self.scale_factor) + image_hidden_states = self.modality_projection(image_hidden_states) + return image_hidden_states + + +class Idefics3Model(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config: Idefics3Config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.config = config + self.vocab_size = self.config.text_config.vocab_size + self.vision_model = Idefics3VisionTransformer( + config.vision_config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "vision_model")) + self.connector = Idefics3Connector( + config, + quant_config, + prefix=maybe_prefix(prefix, "connector"), + ) + self.text_model = LlamaModel( + vllm_config=vllm_config.with_hf_config(config.text_config), + prefix=maybe_prefix(prefix, "text_model"), + ) + + self.image_seq_len = int( + ((config.vision_config.image_size // + config.vision_config.patch_size)**2) / (config.scale_factor**2)) + self.image_token_id = self.config.image_token_id + + def image_pixels_to_features( + self, + pixel_values: torch.Tensor, + pixel_attention_mask: torch.Tensor, + ) -> torch.Tensor: + # NOTE: we skip the step to select the vision feature layer since + # this is already done inside the vision tower + pixel_values = pixel_values.to( + dtype=self.vision_model.embeddings.patch_embedding.weight.dtype + ) # fp16 compatibility + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum( + dim=(-1, -2, -3)) != nb_values_per_image + pixel_values = pixel_values[real_images_inds].contiguous() + + # Handle the vision attention mask + # Remove padding images from the mask + pixel_attention_mask = pixel_attention_mask[ + real_images_inds].contiguous() + + patch_size = self.config.vision_config.patch_size + patches_subgrid = pixel_attention_mask.unfold(dimension=1, + size=patch_size, + step=patch_size) + patches_subgrid = patches_subgrid.unfold(dimension=2, + size=patch_size, + step=patch_size) + patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + + return image_hidden_states + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + ) -> torch.Tensor: + return self.text_model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + + hidden_states = self.text_model( + input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + +@MULTIMODAL_REGISTRY.register_processor( + Idefics3MultiModalProcessor, + info=Idefics3ProcessingInfo, + dummy_inputs=Idefics3DummyInputsBuilder) +class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + + self.model = Idefics3Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.image_token_id = self.config.image_token_id + + self.lm_head = ParallelLMHead( + config.text_config.vocab_size, + config.text_config.hidden_size, + quant_config=quant_config, + ) + if self.config.text_config.tie_word_embeddings: + self.lm_head.weight = self.model.text_model.wte.weight + self.logits_processor = LogitsProcessor(config.text_config.vocab_size) + self.sampler = get_sampler() + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape) + + if actual_dims != expected_dims: + expected_expr = str(expected_dims) + raise ValueError( + "The expected shape of pixel values per image per batch " + f" per patch is {expected_expr}. " + f"You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[ImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + embed_is_patch = kwargs.pop("embed_is_patch") + if not isinstance(embed_is_patch, (torch.Tensor, list)): + raise ValueError("Incorrect type of embed_is_patch. " + f"Got type: {type(embed_is_patch)}") + + embed_is_patch = flatten_bn(embed_is_patch) + + if image_embeds is not None: + if not isinstance(image_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + + return Idefics3ImageEmbeddingInputs( + type="image_embeds", + data=flatten_bn(image_embeds, concat=True), + embed_is_patch=embed_is_patch, + ) + + if pixel_values is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + pixel_attention_mask = kwargs.pop("pixel_attention_mask") + if not isinstance(pixel_attention_mask, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel_attention_mask. " + f"Got type: {type(pixel_attention_mask)}") + + num_patches = kwargs.pop("num_patches") + if not isinstance(num_patches, (torch.Tensor, list)): + raise ValueError("Incorrect type of num_patches. " + f"Got type: {type(num_patches)}") + + pixel_values = flatten_bn(pixel_values, concat=True) + pixel_attention_mask = flatten_bn(pixel_attention_mask, + concat=True) + num_patches = flatten_bn(num_patches, concat=True) + + return Idefics3ImagePixelInputs( + type="pixel_values", + pixel_values=self._validate_pixel_values(pixel_values), + pixel_attention_mask=pixel_attention_mask, + num_patches=num_patches, + embed_is_patch=embed_is_patch, + ) + + raise AssertionError("This line should be unreachable.") + + def _process_image_pixels( + self, inputs: Idefics3ImagePixelInputs) -> torch.Tensor: + pixel_values = inputs["pixel_values"] + pixel_attention_mask = inputs["pixel_attention_mask"] + + return self.model.image_pixels_to_features( + pixel_values, + pixel_attention_mask=pixel_attention_mask, + ) + + def _process_image_input( + self, + image_input: ImageInputs, + ) -> Union[torch.Tensor, list[torch.Tensor]]: + if image_input["type"] == "image_embeds": + return image_input["data"] + + image_features = self._process_image_pixels(image_input) + image_features = self.model.connector(image_features) + + num_patches = image_input["num_patches"] + return [ + e.flatten(0, 1) for e in image_features.split(num_patches.tolist()) + ] + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + + image_features = self._process_image_input(image_input) + + return scatter_patch_features( + image_features, + image_input["embed_is_patch"], + ) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + select_patch_features(multimodal_embeddings), + self.config.image_token_id, + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.model.text_model(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds) + + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="model.text_model", + connector="model.connector", + tower_model="model.vision_model") diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py new file mode 100644 index 0000000..c61254a --- /dev/null +++ b/vllm/model_executor/models/interfaces.py @@ -0,0 +1,557 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional, + Protocol, Type, Union, overload, runtime_checkable) + +import torch +from torch import Tensor +from typing_extensions import Self, TypeIs + +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.utils import supports_kw + +from .interfaces_base import is_pooling_model + +if TYPE_CHECKING: + from vllm.attention import AttentionMetadata + from vllm.sequence import IntermediateTensors + +logger = init_logger(__name__) + +MultiModalEmbeddings = Union[list[Tensor], Tensor, tuple[Tensor, ...]] +""" +The output embeddings must be one of the following formats: + +- A list or tuple of 2D tensors, where each tensor corresponds to + each input multimodal data item (e.g, image). +- A single 3D tensor, with the batch dimension grouping the 2D tensors. +""" + + +@runtime_checkable +class SupportsMultiModal(Protocol): + """The interface required for all multi-modal models.""" + + supports_multimodal: ClassVar[Literal[True]] = True + """ + A flag that indicates this model supports multi-modal inputs. + + Note: + There is no need to redefine this flag if this class is in the + MRO of your model class. + """ + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + """ + Returns multimodal embeddings generated from multimodal kwargs + to be merged with text embeddings. + + Note: + The returned multimodal embeddings must be in the same order as + the appearances of their corresponding multimodal data item in the + input prompt. + """ + ... + + # Only for models that support v0 chunked prefill + # TODO(ywang96): Remove this overload once v0 is deprecated + @overload + def get_input_embeddings( + self, + input_ids: Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + attn_metadata: Optional["AttentionMetadata"] = None, + ) -> Tensor: + ... + + @overload + def get_input_embeddings( + self, + input_ids: Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> Tensor: + """ + Returns the input embeddings merged from the text embeddings from + input_ids and the multimodal embeddings generated from multimodal + kwargs. + """ + ... + + +# We can't use runtime_checkable with ClassVar for issubclass checks +# so we need to treat the class as an instance and use isinstance instead +@runtime_checkable +class _SupportsMultiModalType(Protocol): + supports_multimodal: Literal[True] + + +@overload +def supports_multimodal( + model: Type[object]) -> TypeIs[Type[SupportsMultiModal]]: + ... + + +@overload +def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]: + ... + + +def supports_multimodal( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]: + if isinstance(model, type): + return isinstance(model, _SupportsMultiModalType) + + return isinstance(model, SupportsMultiModal) + + +@runtime_checkable +class SupportsLoRA(Protocol): + """The interface required for all models that support LoRA.""" + + supports_lora: ClassVar[Literal[True]] = True + """ + A flag that indicates this model supports LoRA. + + Note: + There is no need to redefine this flag if this class is in the + MRO of your model class. + """ + # The `embedding_module` and `embedding_padding_modules` + # are empty by default. + embedding_modules: ClassVar[Dict[str, str]] = {} + embedding_padding_modules: ClassVar[List[str]] = [] + packed_modules_mapping: ClassVar[Dict[str, List[str]]] = {} + + +# We can't use runtime_checkable with ClassVar for issubclass checks +# so we need to treat the class as an instance and use isinstance instead +@runtime_checkable +class _SupportsLoRAType(Protocol): + supports_lora: Literal[True] + + packed_modules_mapping: Dict[str, List[str]] + embedding_modules: Dict[str, str] + embedding_padding_modules: List[str] + + +@overload +def supports_lora(model: Type[object]) -> TypeIs[Type[SupportsLoRA]]: + ... + + +@overload +def supports_lora(model: object) -> TypeIs[SupportsLoRA]: + ... + + +def supports_lora( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]: + result = _supports_lora(model) + + if not result: + lora_attrs = ( + "packed_modules_mapping", + "embedding_modules", + "embedding_padding_modules", + ) + missing_attrs = tuple(attr for attr in lora_attrs + if not hasattr(model, attr)) + + if getattr(model, "supports_lora", False): + if missing_attrs: + logger.warning( + "The model (%s) sets `supports_lora=True`, " + "but is missing LoRA-specific attributes: %s", + model, + missing_attrs, + ) + else: + if not missing_attrs: + logger.warning( + "The model (%s) contains all LoRA-specific attributes, " + "but does not set `supports_lora=True`.", model) + + return result + + +def _supports_lora(model: Union[Type[object], object]) -> bool: + if isinstance(model, type): + return isinstance(model, _SupportsLoRAType) + + return isinstance(model, SupportsLoRA) + + +@runtime_checkable +class SupportsPP(Protocol): + """The interface required for all models that support pipeline parallel.""" + + supports_pp: ClassVar[Literal[True]] = True + """ + A flag that indicates this model supports pipeline parallel. + + Note: + There is no need to redefine this flag if this class is in the + MRO of your model class. + """ + + def make_empty_intermediate_tensors( + self, + batch_size: int, + dtype: torch.dtype, + device: torch.device, + ) -> "IntermediateTensors": + """Called when PP rank > 0 for profiling purposes.""" + ... + + def forward( + self, + *, + intermediate_tensors: Optional["IntermediateTensors"], + ) -> Union[Tensor, "IntermediateTensors"]: + """ + Accept :class:`IntermediateTensors` when PP rank > 0. + + Return :class:`IntermediateTensors` only for the last PP rank. + """ + ... + + +# We can't use runtime_checkable with ClassVar for issubclass checks +# so we need to treat the class as an instance and use isinstance instead +@runtime_checkable +class _SupportsPPType(Protocol): + supports_pp: Literal[True] + + def make_empty_intermediate_tensors( + self, + batch_size: int, + dtype: torch.dtype, + device: torch.device, + ) -> "IntermediateTensors": + ... + + def forward( + self, + *, + intermediate_tensors: Optional["IntermediateTensors"], + ) -> Union[Tensor, "IntermediateTensors"]: + ... + + +@overload +def supports_pp(model: Type[object]) -> TypeIs[Type[SupportsPP]]: + ... + + +@overload +def supports_pp(model: object) -> TypeIs[SupportsPP]: + ... + + +def supports_pp( + model: Union[Type[object], object], +) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]: + supports_attributes = _supports_pp_attributes(model) + supports_inspect = _supports_pp_inspect(model) + + if supports_attributes and not supports_inspect: + logger.warning( + "The model (%s) sets `supports_pp=True`, but does not accept " + "`intermediate_tensors` in its `forward` method", model) + + if not supports_attributes: + pp_attrs = ("make_empty_intermediate_tensors", ) + missing_attrs = tuple(attr for attr in pp_attrs + if not hasattr(model, attr)) + + if getattr(model, "supports_pp", False): + if missing_attrs: + logger.warning( + "The model (%s) sets `supports_pp=True`, " + "but is missing PP-specific attributes: %s", + model, + missing_attrs, + ) + else: + if not missing_attrs: + logger.warning( + "The model (%s) contains all PP-specific attributes, " + "but does not set `supports_pp=True`.", model) + + return supports_attributes and supports_inspect + + +def _supports_pp_attributes(model: Union[Type[object], object]) -> bool: + if isinstance(model, type): + return isinstance(model, _SupportsPPType) + + return isinstance(model, SupportsPP) + + +def _supports_pp_inspect(model: Union[Type[object], object]) -> bool: + model_forward = getattr(model, "forward", None) + if not callable(model_forward): + return False + + return supports_kw(model_forward, "intermediate_tensors") + + +@runtime_checkable +class HasInnerState(Protocol): + """The interface required for all models that has inner state.""" + + has_inner_state: ClassVar[Literal[True]] = True + """ + A flag that indicates this model has inner state. + Models that has inner state usually need access to the scheduler_config + for max_num_seqs, etc. True for e.g. both Mamba and Jamba. + """ + + +@runtime_checkable +class _HasInnerStateType(Protocol): + has_inner_state: ClassVar[Literal[True]] + + +@overload +def has_inner_state(model: object) -> TypeIs[HasInnerState]: + ... + + +@overload +def has_inner_state(model: Type[object]) -> TypeIs[Type[HasInnerState]]: + ... + + +def has_inner_state( + model: Union[Type[object], object] +) -> Union[TypeIs[Type[HasInnerState]], TypeIs[HasInnerState]]: + if isinstance(model, type): + return isinstance(model, _HasInnerStateType) + + return isinstance(model, HasInnerState) + + +@runtime_checkable +class IsAttentionFree(Protocol): + """The interface required for all models like Mamba that lack attention, + but do have state whose size is constant wrt the number of tokens.""" + + is_attention_free: ClassVar[Literal[True]] = True + """ + A flag that indicates this model has no attention. + Used for block manager and attention backend selection. + True for Mamba but not Jamba. + """ + + +@runtime_checkable +class _IsAttentionFreeType(Protocol): + is_attention_free: ClassVar[Literal[True]] + + +@overload +def is_attention_free(model: object) -> TypeIs[IsAttentionFree]: + ... + + +@overload +def is_attention_free(model: Type[object]) -> TypeIs[Type[IsAttentionFree]]: + ... + + +def is_attention_free( + model: Union[Type[object], object] +) -> Union[TypeIs[Type[IsAttentionFree]], TypeIs[IsAttentionFree]]: + if isinstance(model, type): + return isinstance(model, _IsAttentionFreeType) + + return isinstance(model, IsAttentionFree) + + +@runtime_checkable +class IsHybrid(Protocol): + """The interface required for all models like Jamba that have both + attention and mamba blocks, indicates that + hf_config has 'layers_block_type'""" + + is_hybrid: ClassVar[Literal[True]] = True + """ + A flag that indicates this model has both mamba and attention blocks + , also indicates that the model's hf_config has + 'layers_block_type' """ + + +@runtime_checkable +class _IsHybridType(Protocol): + is_hybrid: ClassVar[Literal[True]] + + +@overload +def is_hybrid(model: object) -> TypeIs[IsHybrid]: + ... + + +@overload +def is_hybrid(model: Type[object]) -> TypeIs[Type[IsHybrid]]: + ... + + +def is_hybrid( + model: Union[Type[object], object] +) -> Union[TypeIs[Type[IsHybrid]], TypeIs[IsHybrid]]: + if isinstance(model, type): + return isinstance(model, _IsHybridType) + + return isinstance(model, IsHybrid) + + +@runtime_checkable +class HasNoOps(Protocol): + has_noops: ClassVar[Literal[True]] = True + + +@runtime_checkable +class _HasNoOpsType(Protocol): + has_noops: ClassVar[Literal[True]] + + +@overload +def has_noops(model: object) -> TypeIs[HasNoOps]: + ... + + +@overload +def has_noops(model: Type[object]) -> TypeIs[Type[HasNoOps]]: + ... + + +def has_noops( + model: Union[Type[object], object] +) -> Union[TypeIs[Type[HasNoOps]], TypeIs[HasNoOps]]: + if isinstance(model, type): + return isinstance(model, _HasNoOpsType) + + return isinstance(model, HasNoOps) + + +@runtime_checkable +class SupportsCrossEncoding(Protocol): + """The interface required for all models that support cross encoding.""" + + supports_cross_encoding: ClassVar[Literal[True]] = True + + +@overload +def supports_cross_encoding( + model: Type[object]) -> TypeIs[Type[SupportsCrossEncoding]]: + ... + + +@overload +def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]: + ... + + +def _supports_cross_encoding( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: + + if isinstance(model, type): + return isinstance(model, SupportsCrossEncoding) + + return isinstance(model, SupportsCrossEncoding) + + +def supports_cross_encoding( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: + return is_pooling_model(model) and _supports_cross_encoding(model) + + +class SupportsQuant: + """The interface required for all models that support quantization.""" + + packed_modules_mapping: ClassVar[Dict[str, List[str]]] = {} + quant_config: Optional[QuantizationConfig] = None + + def __new__(cls, *args, **kwargs) -> Self: + instance = super().__new__(cls) + quant_config = cls._find_quant_config(*args, **kwargs) + if quant_config is not None: + instance.quant_config = quant_config + instance.quant_config.packed_modules_mapping.update( + cls.packed_modules_mapping) + return instance + + @staticmethod + def _find_quant_config(*args, **kwargs) -> Optional[QuantizationConfig]: + from vllm.config import VllmConfig # avoid circular import + + args_values = list(args) + list(kwargs.values()) + for arg in args_values: + if isinstance(arg, VllmConfig): + return arg.quant_config + + if isinstance(arg, QuantizationConfig): + return arg + + return None + + +@runtime_checkable +class SupportsTranscription(Protocol): + """The interface required for all models that support transcription.""" + + supports_transcription: ClassVar[Literal[True]] = True + + +@overload +def supports_transcription( + model: Type[object]) -> TypeIs[Type[SupportsTranscription]]: + ... + + +@overload +def supports_transcription(model: object) -> TypeIs[SupportsTranscription]: + ... + + +def supports_transcription( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[SupportsTranscription]], TypeIs[SupportsTranscription]]: + if isinstance(model, type): + return isinstance(model, SupportsTranscription) + + return isinstance(model, SupportsTranscription) + + +@runtime_checkable +class SupportsV0Only(Protocol): + """Models with this interface are not compatible with V1 vLLM.""" + + supports_v0_only: ClassVar[Literal[True]] = True + + +@overload +def supports_v0_only(model: Type[object]) -> TypeIs[Type[SupportsV0Only]]: + ... + + +@overload +def supports_v0_only(model: object) -> TypeIs[SupportsV0Only]: + ... + + +def supports_v0_only( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[SupportsV0Only]], TypeIs[SupportsV0Only]]: + if isinstance(model, type): + return isinstance(model, SupportsV0Only) + + return isinstance(model, SupportsV0Only) diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py new file mode 100644 index 0000000..22c9287 --- /dev/null +++ b/vllm/model_executor/models/interfaces_base.py @@ -0,0 +1,172 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import (TYPE_CHECKING, Optional, Protocol, Type, Union, overload, + runtime_checkable) + +import torch +import torch.nn as nn +from typing_extensions import TypeIs, TypeVar + +from vllm.logger import init_logger +from vllm.utils import supports_kw + +if TYPE_CHECKING: + from vllm.config import VllmConfig + from vllm.model_executor.layers.pooler import PoolerOutput + from vllm.model_executor.layers.sampler import SamplerOutput + from vllm.model_executor.pooling_metadata import PoolingMetadata + from vllm.model_executor.sampling_metadata import SamplingMetadata + +logger = init_logger(__name__) + +# The type of hidden states +# Currently, T = torch.Tensor for all models except for Medusa +# which has T = List[torch.Tensor] +T = TypeVar("T", default=torch.Tensor) +T_co = TypeVar("T_co", default=torch.Tensor, covariant=True) + +# NOTE: Unlike those in `interfaces.py`, we don't define `ClassVar` tags +# for the base interfaces to avoid breaking OOT registration for existing models +# that don't inherit from the base interface classes + + +@runtime_checkable +class VllmModel(Protocol[T_co]): + """The interface required for all models in vLLM.""" + + def __init__( + self, + vllm_config: "VllmConfig", + prefix: str = "", + ) -> None: + ... + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + ) -> T_co: + ... + + +def _check_vllm_model_init(model: Union[Type[object], object]) -> bool: + model_init = model.__init__ + return supports_kw(model_init, "vllm_config") + + +def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool: + model_forward = getattr(model, "forward", None) + if not callable(model_forward): + return False + + vllm_kws = ("input_ids", "positions") + missing_kws = tuple(kw for kw in vllm_kws + if not supports_kw(model_forward, kw)) + + if missing_kws and (isinstance(model, type) + and issubclass(model, nn.Module)): + logger.warning( + "The model (%s) is missing " + "vLLM-specific keywords from its `forward` method: %s", + model, + missing_kws, + ) + + return len(missing_kws) == 0 + + +@overload +def is_vllm_model(model: Type[object]) -> TypeIs[Type[VllmModel]]: + ... + + +@overload +def is_vllm_model(model: object) -> TypeIs[VllmModel]: + ... + + +def is_vllm_model( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[VllmModel]], TypeIs[VllmModel]]: + return _check_vllm_model_init(model) and _check_vllm_model_forward(model) + + +@runtime_checkable +class VllmModelForTextGeneration(VllmModel[T], Protocol[T]): + """The interface required for all generative models in vLLM.""" + + def compute_logits( + self, + hidden_states: T, + sampling_metadata: "SamplingMetadata", + ) -> Optional[T]: + """Return `None` if TP rank > 0.""" + ... + + def sample( + self, + logits: T, + sampling_metadata: "SamplingMetadata", + ) -> "SamplerOutput": + """Only called on TP rank 0.""" + ... + + +@overload +def is_text_generation_model( + model: Type[object]) -> TypeIs[Type[VllmModelForTextGeneration]]: + ... + + +@overload +def is_text_generation_model( + model: object) -> TypeIs[VllmModelForTextGeneration]: + ... + + +def is_text_generation_model( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[VllmModelForTextGeneration]], + TypeIs[VllmModelForTextGeneration]]: + if not is_vllm_model(model): + return False + + if isinstance(model, type): + return isinstance(model, VllmModelForTextGeneration) + + return isinstance(model, VllmModelForTextGeneration) + + +@runtime_checkable +class VllmModelForPooling(VllmModel[T], Protocol[T]): + """The interface required for all pooling models in vLLM.""" + + def pooler( + self, + hidden_states: T, + pooling_metadata: "PoolingMetadata", + ) -> "PoolerOutput": + """Only called on TP rank 0.""" + ... + + +@overload +def is_pooling_model(model: Type[object]) -> TypeIs[Type[VllmModelForPooling]]: + ... + + +@overload +def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]: + ... + + +def is_pooling_model( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[VllmModelForPooling]], TypeIs[VllmModelForPooling]]: + if not is_vllm_model(model): + return False + + if isinstance(model, type): + return isinstance(model, VllmModelForPooling) + + return isinstance(model, VllmModelForPooling) diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py new file mode 100644 index 0000000..0499f33 --- /dev/null +++ b/vllm/model_executor/models/intern_vit.py @@ -0,0 +1,476 @@ +# SPDX-License-Identifier: Apache-2.0 + +# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2023 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- +from functools import partial +from typing import Iterable, Optional, Set, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PretrainedConfig + +from vllm.attention.layer import MultiHeadAttention +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather) +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +NORM2FN = { + 'rms_norm': RMSNorm, + 'layer_norm': nn.LayerNorm, +} + + +class InternVisionEmbeddings(nn.Module): + + def __init__(self, config: PretrainedConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim)) + + self.patch_embedding = nn.Conv2d(in_channels=3, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size) + + self.num_patches = (self.image_size // self.patch_size)**2 + self.num_positions = self.num_patches + 1 + + self.position_embedding = nn.Parameter( + torch.randn(1, self.num_positions, self.embed_dim)) + + def _get_pos_embed(self, pos_embed: torch.Tensor, H: int, W: int): + target_dtype = pos_embed.dtype + pos_embed = pos_embed.float().reshape( + 1, self.image_size // self.patch_size, + self.image_size // self.patch_size, -1).permute(0, 3, 1, 2) + pos_embed = F.interpolate(pos_embed, + size=(H, W), + mode='bicubic', + align_corners=False) + return pos_embed.reshape(1, -1, H * W).permute(0, 2, + 1).to(target_dtype) + + def _get_position_embedding(self, H: int, W: int) -> torch.Tensor: + position_embedding = self.position_embedding + if self.num_patches == H * W: + return position_embedding + + return torch.cat( + [ + position_embedding[:, :1, :], + self._get_pos_embed(position_embedding[:, 1:, :], H, W), + ], + dim=1, + ) + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to( + target_dtype)) # shape = [*, channel, width, height] + batch_size, _, height, width = patch_embeds.shape + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + class_embeds = self.class_embedding.expand(batch_size, 1, + -1).to(target_dtype) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + position_embedding = self._get_position_embedding(height, width) + embeddings = embeddings + position_embedding.to(target_dtype) + return embeddings + + +class InternVisionPatchModel(nn.Module): + + def __init__(self, config: PretrainedConfig): + super().__init__() + self.config = config + self.embeddings = InternVisionEmbeddings(config) + + def get_input_embeddings(self): + return self.embeddings + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + pixel_embeds: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + if pixel_values is None and pixel_embeds is None: + raise ValueError( + 'You have to specify pixel_values or pixel_embeds') + + if pixel_embeds is not None: + hidden_states = pixel_embeds + elif pixel_values is not None: + if pixel_values.ndim == 4: + hidden_states = self.embeddings(pixel_values) + else: + raise ValueError( + f'wrong pixel_values size: {pixel_values.shape}') + + return hidden_states + + +class InternParallelAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + num_dummy_heads: int = 0, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f'embed_dim must be divisible by num_heads ' + f'(got `embed_dim`: {self.embed_dim} and `num_heads`:' + f' {self.num_heads}).') + + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + # Additional dummy heads are used to enable TP for common GPU counts. + self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim + self.num_heads_per_partition = divide(num_dummy_heads + self.num_heads, + self.tp_size) + + self.scale = self.head_dim**-0.5 + self.qkv = QKVParallelLinear( + self.embed_dim, + self.head_dim, + num_dummy_heads + self.num_heads, + bias=config.qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + ) + + self.qk_normalization = config.qk_normalization + + if self.qk_normalization: + self.q_norm = RMSNorm(self.dummy_dim, + eps=config.layer_norm_eps, + var_hidden_size=self.embed_dim) + self.k_norm = RMSNorm(self.dummy_dim, + eps=config.layer_norm_eps, + var_hidden_size=self.embed_dim) + + self.proj = RowParallelLinear( + self.dummy_dim, + self.embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj", + ) + + self.attn = MultiHeadAttention(self.num_heads_per_partition, + self.head_dim, self.scale) + + def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor): + if self.tp_size > 1: + q = tensor_model_parallel_all_gather(q.contiguous()) + k = tensor_model_parallel_all_gather(k.contiguous()) + q = self.q_norm.forward_native(q) + k = self.k_norm.forward_native(k) + if self.tp_size > 1: + splitter = partial(split_tensor_along_last_dim, + num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + return q, k + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, _ = x.shape + qkv, _ = self.qkv(x) + q, k, v = qkv.chunk(3, dim=-1) + + if self.qk_normalization: + q, k = self._apply_qk_norm(q, k) + + out = self.attn(q, k, v) + out, _ = self.proj(out) + return out + + +class InternSdpaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: PretrainedConfig, + *, + num_dummy_heads: int = 0, + ) -> None: + super().__init__() + + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f'embed_dim must be divisible by num_heads ' + f'(got `embed_dim`: {self.embed_dim} and `num_heads`:' + f' {self.num_heads}).') + + # Additional dummy heads are used to enable TP for common GPU counts. + self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim + + self.scale = self.head_dim**-0.5 + self.qkv = nn.Linear(self.embed_dim, + 3 * self.dummy_dim, + bias=config.qkv_bias) + + self.qk_normalization = config.qk_normalization + + if self.qk_normalization: + self.q_norm = RMSNorm(self.dummy_dim, + eps=config.layer_norm_eps, + var_hidden_size=self.embed_dim) + self.k_norm = RMSNorm(self.dummy_dim, + eps=config.layer_norm_eps, + var_hidden_size=self.embed_dim) + + self.proj = nn.Linear(self.dummy_dim, self.embed_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x) + q, k, v = qkv.chunk(3, dim=-1) + + q = q.view(B, N, self.num_heads, self.head_dim) + k = k.view(B, N, self.num_heads, self.head_dim) + v = v.view(B, N, self.num_heads, self.head_dim) + + if self.qk_normalization: + B_, N_, H_, D_ = q.shape + q = self.q_norm.forward_native(q.flatten(-2, + -1)).view(B_, N_, H_, D_) + k = self.k_norm.forward_native(k.flatten(-2, + -1)).view(B_, N_, H_, D_) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + x = F.scaled_dot_product_attention(q, k, v, scale=self.scale) + x = x.transpose(1, 2).reshape(B, N, -1) + + x = self.proj(x) + return x + + +class InternMLP(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + self.fc1 = ColumnParallelLinear(config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc1") + self.fc2 = RowParallelLinear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc2") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + + return hidden_states + + +class InternVisionEncoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + num_dummy_heads: int = 0, + prefix: str = "", + ) -> None: + super().__init__() + + self.embed_dim = config.hidden_size + self.intermediate_size = config.intermediate_size + self.norm_type = config.norm_type + + self.attn = self._init_attn(config, + quant_config, + num_dummy_heads=num_dummy_heads, + prefix=f"{prefix}.attn") + + self.mlp = InternMLP(config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + self.norm1 = NORM2FN[self.norm_type](self.embed_dim, + eps=config.layer_norm_eps) + self.norm2 = NORM2FN[self.norm_type](self.embed_dim, + eps=config.layer_norm_eps) + + self.ls1 = nn.Parameter(config.initializer_factor * + torch.ones(self.embed_dim)) + self.ls2 = nn.Parameter(config.initializer_factor * + torch.ones(self.embed_dim)) + + def _init_attn( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig], + *, + num_dummy_heads: int, + prefix: str = "", + ): + # fallback to sdpa attention if tp unavailable + tp_size = get_tensor_model_parallel_world_size() + num_heads = config.num_attention_heads + + if (num_heads + num_dummy_heads) % tp_size == 0: + return InternParallelAttention(config, + quant_config=quant_config, + num_dummy_heads=num_dummy_heads, + prefix=prefix) + + return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads) + + def forward( + self, + hidden_states: torch.Tensor, + ): + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states)) * self.ls1 + + hidden_states = hidden_states + self.mlp( + self.norm2(hidden_states)) * self.ls2 + + return hidden_states + + +class InternVisionEncoder(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + num_hidden_layers_override: Optional[int] = None, + num_dummy_heads: int = 0, + prefix: str = "", + ): + super().__init__() + + self.config = config + + if num_hidden_layers_override is None: + num_hidden_layers = config.num_hidden_layers + else: + num_hidden_layers = num_hidden_layers_override + + self.layers = nn.ModuleList([ + InternVisionEncoderLayer(config, + quant_config, + num_dummy_heads=num_dummy_heads, + prefix=f"{prefix}.layers.{layer_idx}") + for layer_idx in range(num_hidden_layers) + ]) + + def forward(self, inputs_embeds: torch.Tensor): + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer(hidden_states) + + return hidden_states + + +class InternVisionModel(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + num_hidden_layers_override: Optional[int] = None, + num_dummy_heads: int = 0, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + + self.embeddings = InternVisionEmbeddings(config) + self.encoder = InternVisionEncoder( + config=config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_override, + num_dummy_heads=num_dummy_heads, + prefix=f"{prefix}.encoder", + ) + + def get_input_embeddings(self): + return self.embeddings + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + pixel_embeds: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + if pixel_values is None and pixel_embeds is None: + raise ValueError( + 'You have to specify pixel_values or pixel_embeds') + + if pixel_embeds is not None: + hidden_states = pixel_embeds + elif pixel_values is not None: + if pixel_values.ndim == 4: + hidden_states = self.embeddings(pixel_values) + else: + raise ValueError( + f'wrong pixel_values size: {pixel_values.shape}') + + encoder_outputs = self.encoder(inputs_embeds=hidden_states) + + return encoder_outputs + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py new file mode 100644 index 0000000..520b85c --- /dev/null +++ b/vllm/model_executor/models/internlm2.py @@ -0,0 +1,463 @@ +# SPDX-License-Identifier: Apache-2.0 + +from functools import partial +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors, PoolerOutput + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class InternLM2MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.w2 = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.w2", + ) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.w2(x) + return x + + +class InternLM2Attention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.total_num_heads = num_heads + assert self.total_num_heads % self.tp_size == 0 + self.num_heads = self.total_num_heads // self.tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= self.tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % self.tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert self.tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.key_value_groups = int(self.num_heads / self.num_kv_heads) + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.wqkv = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wqkv", + ) + self.wo = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wo", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + + def split_qkv(self, qkv: torch.Tensor): + seq_len = qkv.shape[0] + if self.tp_size > 1: + qkv_map = [self.q_size, self.kv_size, self.kv_size] * self.tp_size + qkv = tensor_model_parallel_all_gather(qkv) + qkv = torch.split(qkv, qkv_map, dim=-1) + qkv = qkv[::3] + qkv[1::3] + qkv[2::3] + qkv = torch.cat(qkv, dim=-1) + + qkv = qkv.view(seq_len, self.total_num_kv_heads, + self.key_value_groups + 2, self.head_dim) + q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2) + q = q.reshape(seq_len, self.q_size * self.tp_size) + k = k.reshape(seq_len, self.kv_size * self.tp_size) + v = v.reshape(seq_len, self.kv_size * self.tp_size) + + if self.tp_size > 1: + splitter = partial(split_tensor_along_last_dim, + num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + v = splitter(v)[self.tp_rank] + return q, k, v + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.wqkv(hidden_states) + q, k, v = self.split_qkv(qkv) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.wo(attn_output) + return output + + +class InternLMDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.attention = InternLM2Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attention", + ) + self.feed_forward = InternLM2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + ) + self.attention_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.attention_norm(hidden_states) + else: + hidden_states, residual = self.attention_norm( + hidden_states, residual) + hidden_states = self.attention( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.ffn_norm(hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class InternLM2Model(nn.Module): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: Type[InternLMDecoderLayer] = InternLMDecoderLayer): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + self.vocab_size = config.vocab_size + self.tok_embeddings = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: layer_type( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers") + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.tok_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer(positions, hidden_states, residual) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): + packed_modules_mapping = { + "wqkv": ["wqkv"], + "gate_up_proj": ["w1", "w3"], + } + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + model_type: Type[InternLM2Model] = InternLM2Model): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.quant_config = quant_config + self.lora_config = lora_config + + self.model = model_type(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.output = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "output")) + if self.config.tie_word_embeddings: + self.output.weight = self.model.tok_embeddings.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.output, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "w1", 0), + ("gate_up_proj", "w3", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class InternLM2ForRewardModel(InternLM2ForCausalLM): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + model_type: Type[InternLM2Model] = InternLM2Model, + ): + super().__init__(vllm_config=vllm_config, + prefix=prefix, + model_type=model_type) + + for attr in ("output", "logits_processor", "sampler"): + delattr(self, attr) + + config = vllm_config.model_config.hf_config + self.v_head = RowParallelLinear( + config.hidden_size, + 1, + bias=False, + input_is_parallel=False, + prefix=maybe_prefix(prefix, "v_head"), + ) + + pooler_config = vllm_config.model_config.pooler_config + self._pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.ALL, + normalize=False, + softmax=False, + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + logits, _ = self.v_head(hidden_states) + return logits + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) diff --git a/vllm/model_executor/models/internlm2_ve.py b/vllm/model_executor/models/internlm2_ve.py new file mode 100644 index 0000000..69b0caa --- /dev/null +++ b/vllm/model_executor/models/internlm2_ve.py @@ -0,0 +1,146 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.models.internlm2 import (InternLM2Attention, + InternLM2ForCausalLM, + InternLM2MLP, InternLM2Model) +from vllm.sequence import IntermediateTensors + + +class InternLM2VEDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.attention = InternLM2Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attention", + ) + self.feed_forward = InternLM2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + ) + self.feed_forward_ve = InternLM2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward_ve", + ) + self.attention_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + visual_token_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.attention_norm(hidden_states) + else: + hidden_states, residual = self.attention_norm( + hidden_states, residual) + hidden_states = self.attention( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.ffn_norm(hidden_states, residual) + if visual_token_mask is not None and visual_token_mask.any(): + visual_token_mask = visual_token_mask.repeat( + 1, self.hidden_size).bool() + text_token_mask = ~visual_token_mask + hidden_states[visual_token_mask] = self.feed_forward_ve( + hidden_states[visual_token_mask].reshape( + -1, self.hidden_size)).flatten() + if text_token_mask.any(): + hidden_states[text_token_mask] = self.feed_forward( + hidden_states[text_token_mask].reshape( + -1, self.hidden_size)).flatten() + else: + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +class InternLM2VEModel(InternLM2Model): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, + prefix=prefix, + layer_type=InternLM2VEDecoderLayer) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + visual_token_mask: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.tok_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + visual_token_mask=visual_token_mask, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class InternLM2VEForCausalLM(InternLM2ForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, + prefix=prefix, + model_type=InternLM2VEModel) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py new file mode 100644 index 0000000..0729f4c --- /dev/null +++ b/vllm/model_executor/models/internvl.py @@ -0,0 +1,1011 @@ +# SPDX-License-Identifier: Apache-2.0 + +# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2023 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- +from abc import ABC, abstractmethod +from collections.abc import Iterable, Mapping, Sequence +from functools import cached_property +from typing import Literal, Optional, Set, Tuple, TypedDict, TypeVar, Union + +import torch +import torch.nn as nn +import torchvision.transforms as T +from PIL import Image +from transformers import BatchEncoding, PretrainedConfig, TensorType + +from vllm.config import VllmConfig +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.awq import AWQConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.models.intern_vit import (InternVisionModel, + InternVisionPatchModel) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, + NestedTensors) +from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, + ImageSize, MultiModalDataItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.tokenizer import AnyTokenizer + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, + maybe_prefix, merge_multimodal_embeddings) +from .vision import scatter_patch_features, select_patch_features + +IMG_START = '' +IMG_END = '' +IMG_CONTEXT = '' + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) + + +class InternVLImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + pixel_values_flat: torch.Tensor + """ + Shape: + `(batch_size * num_images * (1 + num_patches), num_channels, height, width)` + """ + + num_patches: torch.Tensor + """Shape: `(batch_size * num_images)`""" + + embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] + """ + A boolean mask indicating which image embeddings correspond + to patch tokens. + + Shape: `(batch_size * num_images, num_embeds)` + """ + + +class InternVLImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: Union[torch.Tensor, list[torch.Tensor]] + """ + A tensor of shape `(num_images, total_image_feature_size, hidden_size)` + or a list of tensors of shape `(total_image_feature_size, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + """ + + +InternVLImageInputs = Union[InternVLImagePixelInputs, + InternVLImageEmbeddingInputs] + + +# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B +def build_transform(input_size: int): + MEAN, STD = IMAGENET_MEAN, IMAGENET_STD + return T.Compose([ + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.Resize((input_size, input_size), + interpolation=T.InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=MEAN, std=STD) + ]) + + +# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B +def find_closest_aspect_ratio( + aspect_ratio: float, + target_ratios: list[tuple[int, int]], + *, + width: int, + height: int, + image_size: int, +) -> tuple[int, int]: + best_ratio_diff = float('inf') + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + +def resolve_internvl_min_max_num( + *, + min_dynamic_patch: int, + max_dynamic_patch: int, + dynamic_image_size: bool, + use_thumbnail: bool, +) -> tuple[int, int]: + min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1 + max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1 + + if use_thumbnail and max_dynamic_patch != 1: + max_dynamic_patch += 1 + + return min_dynamic_patch, max_dynamic_patch + + +def get_internvl_target_ratios( + min_num: int, + max_num: int, +) -> list[tuple[int, int]]: + target_ratios = {(i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) if min_num <= i * j <= max_num} + return sorted(target_ratios, key=lambda x: x[0] * x[1]) + + +def calculate_internvl_targets( + *, + orig_width: int, + orig_height: int, + target_ratios: list[tuple[int, int]], + image_size: int, + use_thumbnail: bool, +) -> tuple[int, int, int]: + aspect_ratio = orig_width / orig_height + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio( + aspect_ratio, + target_ratios, + width=orig_width, + height=orig_height, + image_size=image_size, + ) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # add thumbnail image if num_blocks != 1 + if use_thumbnail and blocks != 1: + blocks += 1 + + return blocks, target_width, target_height + + +# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B +def dynamic_preprocess_internvl( + image: Image.Image, + *, + target_ratios: list[tuple[int, int]], + image_size: int, + use_thumbnail: bool, +) -> list[Image.Image]: + orig_width, orig_height = image.size + + # calculate the number of blocks without thumbnail + blocks, target_width, target_height = calculate_internvl_targets( + orig_width=orig_width, + orig_height=orig_height, + target_ratios=target_ratios, + image_size=image_size, + use_thumbnail=False, + ) + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ((i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + + assert len(processed_images) == blocks + + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + + return processed_images + + +# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B +def image_to_pixel_values_internvl( + image: Image.Image, + *, + input_size: int, + min_num: int, + max_num: int, + use_thumbnail: bool, +) -> torch.Tensor: + target_ratios = get_internvl_target_ratios(min_num, max_num) + + transform = build_transform(input_size=input_size) + images = dynamic_preprocess_internvl( + image, + target_ratios=target_ratios, + image_size=input_size, + use_thumbnail=use_thumbnail, + ) + + pixel_values = torch.stack([transform(image) for image in images]) + return pixel_values + + +class BaseInternVLProcessor(ABC): + """ + This model doesn't define its own HF processor, + so we implement our own one here. + + The code to insert image tokens is based on: + https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py#L252 + """ + + def __init__( + self, + config: PretrainedConfig, + tokenizer: AnyTokenizer, + *, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + ) -> None: + super().__init__() + + self.config = config + self.tokenizer = tokenizer + + image_size: int = config.vision_config.image_size + patch_size: int = config.vision_config.patch_size + + if min_dynamic_patch is None: + min_dynamic_patch = config.min_dynamic_patch + assert isinstance(min_dynamic_patch, int) + + if max_dynamic_patch is None: + max_dynamic_patch = config.max_dynamic_patch + assert isinstance(max_dynamic_patch, int) + + if dynamic_image_size is None: + dynamic_image_size = config.dynamic_image_size + assert isinstance(dynamic_image_size, bool) + + self.num_image_token = int( + (image_size // patch_size)**2 * (config.downsample_ratio**2)) + self.image_size = image_size + self.min_dynamic_patch = min_dynamic_patch + self.max_dynamic_patch = max_dynamic_patch + self.dynamic_image_size = dynamic_image_size + self.use_thumbnail: bool = config.use_thumbnail + + @property + @abstractmethod + def image_token_id(self) -> int: + raise NotImplementedError + + @abstractmethod + def get_image_repl( + self, + feature_size: int, + num_patches: Optional[int], + ) -> PromptUpdateDetails[str]: + raise NotImplementedError + + def resolve_min_max_num( + self, + *, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + use_thumbnail: Optional[bool] = None, + ) -> tuple[int, int]: + min_dynamic_patch = (self.min_dynamic_patch if min_dynamic_patch + is None else min_dynamic_patch) + max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch + is None else max_dynamic_patch) + dynamic_image_size = (self.dynamic_image_size if dynamic_image_size + is None else dynamic_image_size) + use_thumbnail = (self.use_thumbnail + if use_thumbnail is None else use_thumbnail) + + return resolve_internvl_min_max_num( + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + use_thumbnail=use_thumbnail, + ) + + def resolve_target_ratios( + self, + *, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + use_thumbnail: Optional[bool] = None, + ) -> list[tuple[int, int]]: + min_num, max_num = self.resolve_min_max_num( + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + use_thumbnail=use_thumbnail, + ) + + return get_internvl_target_ratios(min_num, max_num) + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + target_ratios = self.resolve_target_ratios( + use_thumbnail=False, # Applied in calculate_targets + ) + + num_patches, _, _ = calculate_internvl_targets( + orig_width=image_width, + orig_height=image_height, + image_size=self.image_size, + target_ratios=target_ratios, + use_thumbnail=self.use_thumbnail, + ) + + return num_patches * self.num_image_token + + def _images_to_pixel_values_lst( + self, + images: list[Image.Image], + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + ) -> list[torch.Tensor]: + min_num, max_num = self.resolve_min_max_num( + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + use_thumbnail=False, # Applied in image_to_pixel_values + ) + + return [ + image_to_pixel_values_internvl( + image, + input_size=self.image_size, + min_num=min_num, + max_num=max_num, + use_thumbnail=self.use_thumbnail, + ) for image in images + ] + + def __call__( + self, + text: Optional[Union[str, list[str]]] = None, + images: Optional[Union[Image.Image, list[Image.Image]]] = None, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + ) -> Mapping[str, NestedTensors]: + if text is None: + text = [] + if not isinstance(text, list): + text = [text] + if images is None: + images = [] + if not isinstance(images, list): + images = [images] + + if len(images) == 0: + image_inputs = {} + else: + pixel_values_lst = self._images_to_pixel_values_lst( + images, + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + ) + image_inputs: dict[str, NestedTensors] = { + "pixel_values_flat": + torch.cat(pixel_values_lst), + "image_num_patches": + torch.tensor([len(item) for item in pixel_values_lst]), + } + + tokenizer = self.tokenizer + image_token_id = self.image_token_id + + embed_is_patch = list[torch.Tensor]() + + for pixel_values in pixel_values_lst: + num_patches = pixel_values.shape[0] + feature_size = num_patches * self.num_image_token + + image_repl = self.get_image_repl(feature_size, num_patches) + feature_tokens = tokenizer.encode(image_repl.features, + add_special_tokens=False) + + text = [t.replace('', image_repl.full, 1) for t in text] + embed_is_patch.append( + torch.tensor(feature_tokens) == image_token_id) + + image_inputs["embed_is_patch"] = embed_is_patch + + text_inputs = self.tokenizer(text) + + return { + **BatchEncoding(text_inputs, tensor_type=return_tensors), + **image_inputs, + } + + +class InternVLProcessor(BaseInternVLProcessor): + + @property + def image_token_id(self) -> int: + return self.tokenizer.get_vocab()[IMG_CONTEXT] + + def get_image_repl( + self, + feature_size: int, + num_patches: Optional[int], + ) -> PromptUpdateDetails[str]: + repl_features = IMG_CONTEXT * feature_size + repl_full = IMG_START + repl_features + IMG_END + + return PromptUpdateDetails(full=repl_full, features=repl_features) + + +class BaseInternVLProcessingInfo(BaseProcessingInfo): + + @abstractmethod + def get_hf_processor( + self, + *, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + **kwargs: object, + ) -> BaseInternVLProcessor: + raise NotImplementedError + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"image": self.get_max_image_tokens()} + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + processor: Optional[BaseInternVLProcessor], + ) -> int: + if processor is None: + processor = self.get_hf_processor() + + return processor.get_num_image_tokens( + image_width=image_width, + image_height=image_height, + ) + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + return self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + processor=None, + ) + + def get_image_size_with_most_features(self) -> ImageSize: + processor = self.get_hf_processor() + + base_size = processor.image_size + target_ratios = processor.resolve_target_ratios() + + largest_feature_size, largest_feature_pinpoint = 0, None + for wr, hr in target_ratios: + width, height = base_size * wr, base_size * hr + + feat_size = self.get_num_image_tokens( + image_width=width, + image_height=height, + processor=processor, + ) + if feat_size > largest_feature_size: + largest_feature_size = feat_size + largest_feature_pinpoint = ImageSize(width=width, + height=height) + + if largest_feature_size == 0 or largest_feature_pinpoint is None: + raise ValueError("Cannot have a largest feature size of 0!") + + return largest_feature_pinpoint + + +_I = TypeVar("_I", bound=BaseInternVLProcessingInfo) + + +class InternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + target_width, target_height = \ + self.info.get_image_size_with_most_features() + num_images = mm_counts.get("image", 0) + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text="" * num_images, + mm_data=mm_data, + ) + + +class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> Mapping[str, NestedTensors]: + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + hf_processor = self.info.get_hf_processor(**mm_kwargs) + image_token_id = hf_processor.image_token_id + + # Since there may be extra tokens in the feature placeholders, + # we need to pass the image token ID to the model to select the + # tokens to merge from the vision encoder outputs + processed_outputs["image_token_id"] = torch.tensor(image_token_id) + + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: Mapping[str, NestedTensors], + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0)) + num_images = len(image_num_patches) + + return dict( + pixel_values_flat=MultiModalFieldConfig.flat_from_sizes( + "image", image_num_patches), + image_num_patches=MultiModalFieldConfig.batched("image"), + embed_is_patch=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + image_token_id=MultiModalFieldConfig.shared("image", num_images), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + if "image_num_patches" in out_mm_kwargs: + image_num_patches = out_mm_kwargs["image_num_patches"] + assert isinstance(image_num_patches, torch.Tensor) + image_num_patches = image_num_patches.tolist() + elif "image_embeds" in out_mm_kwargs: + # TODO: Use image size information in dictionary embedding inputs + # to compute num_patches (similar to Qwen2-VL) + image_num_patches = [None] * len(out_mm_kwargs["image_embeds"]) + else: + image_num_patches = [] + + def get_replacement_internvl(item_idx: int): + images = mm_items.get_items( + "image", (ImageEmbeddingItems, ImageProcessorItems)) + + if isinstance(images, ImageEmbeddingItems): + feature_size = images.get_feature_size(item_idx) + else: + image_size = images.get_image_size(item_idx) + feature_size = self.info.get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + processor=hf_processor, + ) + + num_patches = image_num_patches[item_idx] + if num_patches is not None: + assert isinstance(num_patches, int) + + return hf_processor.get_image_repl(feature_size, num_patches) + + return [ + PromptReplacement( + modality="image", + target="", + replacement=get_replacement_internvl, + ) + ] + + +class InternVLProcessingInfo(BaseInternVLProcessingInfo): + + def get_hf_processor( + self, + *, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + **kwargs: object, + ) -> InternVLProcessor: + if min_dynamic_patch is not None: + kwargs["min_dynamic_patch"] = min_dynamic_patch + if max_dynamic_patch is not None: + kwargs["max_dynamic_patch"] = max_dynamic_patch + if dynamic_image_size is not None: + kwargs["dynamic_image_size"] = dynamic_image_size + + return self.ctx.init_processor( + InternVLProcessor, + config=self.get_hf_config(), + tokenizer=self.get_tokenizer(), + **kwargs, + ) + + +@MULTIMODAL_REGISTRY.register_processor( + InternVLMultiModalProcessor, + info=InternVLProcessingInfo, + dummy_inputs=InternVLDummyInputsBuilder) +class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + self._patch_quant_config(config, quant_config) + + image_size = config.force_image_size or config.vision_config.image_size + patch_size = config.vision_config.patch_size + self.patch_size = patch_size + self.num_image_token = int( + (image_size // patch_size)**2 * (config.downsample_ratio**2)) + self.downsample_ratio = config.downsample_ratio + self.ps_version = config.ps_version + + self.llm_arch_name = config.text_config.architectures[0] + self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM' + self.vision_model = self._init_vision_model( + config, + quant_config=quant_config, + is_mono=self.is_mono, + prefix=maybe_prefix(prefix, "vision_model"), + ) + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + + self.mlp1 = self._init_mlp1(config) + + self.img_context_token_id = None + self.visual_token_mask = None + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + def _patch_quant_config(self, config: PretrainedConfig, + quant_config: QuantizationConfig): + # the awq models from OpenGVLab missing `modules_to_not_convert` + # patch the quant_config to add `modules_to_not_convert` back + if isinstance(quant_config, AWQConfig): + text_config = config.text_config + llm_quant_config = getattr(text_config, "quantization_config", + None) + if (not quant_config.modules_to_not_convert) and \ + (llm_quant_config is not None): + quant_config.modules_to_not_convert.append("vision_model") + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return get_sampler() + + def _init_vision_model( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig], + *, + is_mono: bool, + prefix: str, + ): + if not is_mono: + vision_feature_layer = config.select_layer + if vision_feature_layer < 0: + num_hidden_layers = config.vision_config.num_hidden_layers \ + + vision_feature_layer + 1 + else: + num_hidden_layers = vision_feature_layer + 1 + + return InternVisionModel( + config.vision_config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers, + prefix=prefix, + ) + else: + return InternVisionPatchModel(config.vision_config) + + def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential: + vit_hidden_size = config.vision_config.hidden_size + llm_hidden_size = config.text_config.hidden_size + + return nn.Sequential( + nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2), + nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2, + llm_hidden_size), + nn.GELU(), + nn.Linear(llm_hidden_size, llm_hidden_size), + ) + + def pixel_shuffle(self, x, scale_factor=0.5): + n, w, h, c = x.size() + # N, W, H, C --> N, W, H * scale, C // scale + x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) + # N, W, H * scale, C // scale --> N, H * scale, W, C // scale + x = x.permute(0, 2, 1, 3).contiguous() + x = x.view(n, int(h * scale_factor), int(w * scale_factor), + int(c / (scale_factor * scale_factor))) + if self.ps_version == 'v1': + pass + else: + x = x.permute(0, 2, 1, 3).contiguous() + return x + + def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor: + vit_embeds = self.vision_model(pixel_values=pixel_values) + vit_embeds = vit_embeds[:, 1:, :] + + h = w = int(vit_embeds.shape[1]**0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = self.pixel_shuffle(vit_embeds, + scale_factor=self.downsample_ratio) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, + vit_embeds.shape[-1]) + vit_embeds = self.mlp1(vit_embeds) + return vit_embeds + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape) + + if actual_dims != expected_dims: + expected_expr = str(expected_dims) + raise ValueError( + "The expected shape of pixel values per image per batch " + f" per patch is {expected_expr}. " + f"You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[InternVLImageInputs]: + pixel_values_flat = kwargs.pop("pixel_values_flat", None) + image_num_patches = kwargs.pop("image_num_patches", None) + embed_is_patch = kwargs.pop("embed_is_patch", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values_flat is None and image_embeds is None: + return None + + if image_embeds is not None: + if not isinstance(image_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + + return InternVLImageEmbeddingInputs( + type="image_embeds", + data=flatten_bn(image_embeds), + ) + + image_token_id = kwargs["image_token_id"] + assert isinstance(image_token_id, torch.Tensor) + self.img_context_token_id = image_token_id.flatten().unique().item() + + if pixel_values_flat is not None: + if not isinstance(pixel_values_flat, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values_flat)}") + + if not isinstance(image_num_patches, (torch.Tensor, list)): + raise ValueError("Incorrect type of image_num_patches. " + f"Got type: {type(image_num_patches)}") + + if not isinstance(embed_is_patch, (torch.Tensor, list)): + raise ValueError("Incorrect type of embed_is_patch. " + f"Got type: {type(embed_is_patch)}") + + pixel_values_flat = flatten_bn(pixel_values_flat, concat=True) + image_num_patches = flatten_bn(image_num_patches, concat=True) + embed_is_patch = flatten_bn(embed_is_patch) + + return InternVLImagePixelInputs( + type="pixel_values", + pixel_values_flat=self._validate_pixel_values( + pixel_values_flat), + num_patches=image_num_patches, + embed_is_patch=embed_is_patch, + ) + + raise AssertionError("This line should be unreachable.") + + def _process_image_input( + self, + image_input: InternVLImageInputs, + ) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]: + if image_input["type"] == "image_embeds": + return image_input["data"] + + assert self.vision_model is not None + + image_embeds = self.extract_feature(image_input["pixel_values_flat"]) + + num_patches = image_input["num_patches"] + + # Only one image in the current batch + if len(num_patches) == 1: + return image_embeds.view( + -1, self.config.text_config.hidden_size).unsqueeze(0) + + # NOTE: Image embeddings are split into separate tensors for each image + # by the size of each embedding. + feature_size = image_embeds.shape[1] + image_embeds = image_embeds.view(-1, + self.config.text_config.hidden_size) + image_feature_sizes = [ + num_patches * feature_size for num_patches in num_patches + ] + return image_embeds.split(image_feature_sizes) + + def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None: + if self.is_mono: + self.visual_token_mask = ( + input_ids == self.img_context_token_id).reshape(-1, 1) + else: + self.visual_token_mask = None + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + + image_features = self._process_image_input(image_input) + + if image_input["type"] != "pixel_values": + return image_features + + return scatter_patch_features( + image_features, + image_input["embed_is_patch"], + ) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + assert self.img_context_token_id is not None + self._set_visual_token_mask(input_ids) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + select_patch_features(multimodal_embeddings), + self.img_context_token_id, + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[SamplerOutput, IntermediateTensors]: + + if intermediate_tensors is not None: + input_ids = None + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + forward_kwargs = { + "input_ids": input_ids, + "positions": positions, + "intermediate_tensors": intermediate_tensors, + "inputs_embeds": inputs_embeds, + } + + # Only required if the model is mono-architecture + if self.visual_token_mask is not None: + forward_kwargs.update( + {"visual_token_mask": self.visual_token_mask}) + self.visual_token_mask = None + + hidden_states = self.language_model.model(**forward_kwargs) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + # unused modules appear in OpenGVLab/InternVideo2_5_Chat_8B + skip_prefixes = [ + "action_embed", "temporal_embed", "track_embed", + "track_embed_decoder", "box_token", "cg_criterion", "cg_model", + "loc_encoder", "loc_decoder", "sam", "temporal_token", + "track_token" + ] + loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py new file mode 100644 index 0000000..78fe658 --- /dev/null +++ b/vllm/model_executor/models/jais.py @@ -0,0 +1,381 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://huggingface.co/inceptionai/jais-30b-chat-v3/blob/main/modeling_jais.py +# Copyright 2023 The vLLM team. +# Copyright 2023 the Jais authors and HuggingFace Inc. team. All rights +# reserved. +# Copyright 2023 Cerebras Systems. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Jais model compatible with HuggingFace weights.""" + +import math +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import JAISConfig + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class SwiGLUActivation(nn.Module): + + def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + return x1 * nn.functional.silu(x2) + + +def _get_alibi_slopes(n): + + def get_slopes_power_of_2(n): + start = 2**(-(2**-(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2**math.floor(math.log2(n)) + return (get_slopes_power_of_2(closest_power_of_2) + _get_alibi_slopes( + 2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) + + +class JAISAttention(nn.Module): + + def __init__( + self, + config: JAISConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = config.hidden_size + total_num_heads = config.num_attention_heads + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + assert total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = total_num_heads // tensor_model_parallel_world_size + self.head_dim = self.hidden_size // total_num_heads + if hasattr(config, "scale_qk_dot_by_d"): + config.mup_scale_qk_dot_by_d = config.scale_qk_dot_by_d + self.attn_scale_power = 1.0 if config.mup_scale_qk_dot_by_d else 0.5 + self.scale = self.head_dim**-self.attn_scale_power + + self.c_attn = QKVParallelLinear( + self.hidden_size, + self.head_dim, + total_num_heads, + bias=True, + quant_config=quant_config, + ) + self.c_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + ) + + tp_rank = get_tensor_model_parallel_rank() + head_start = tp_rank * self.num_heads + head_end = (tp_rank + 1) * self.num_heads + alibi_slopes = _get_alibi_slopes(total_num_heads) + alibi_slopes = alibi_slopes[head_start:head_end] + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.scale, + alibi_slopes=alibi_slopes, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.c_attn(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + attn_output = self.attn(q, k, v) + attn_output, _ = self.c_proj(attn_output) + return attn_output + + +class JAISMLP(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: JAISConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + hidden_size = config.hidden_size + self.swiglu = config.activation_function == "swiglu" + self.c_fc = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=True, + quant_config=quant_config, + ) + self.c_fc2 = (ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=True, + quant_config=quant_config, + ) if self.swiglu else None) + self.c_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=True, + quant_config=quant_config, + ) + + self.act = SwiGLUActivation() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.swiglu: + hidden_states2, _ = self.c_fc2(hidden_states) + hidden_states, _ = self.c_fc(hidden_states) + hidden_states = (self.act(hidden_states, hidden_states2) + if self.swiglu else self.act(hidden_states)) + hidden_states, _ = self.c_proj(hidden_states) + return hidden_states + + +class JAISBlock(nn.Module): + + def __init__( + self, + config: JAISConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + hidden_size = config.hidden_size + inner_dim = (config.n_inner if config.n_inner is not None else 4 * + hidden_size) + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = JAISAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.attn") + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.mlp = JAISMLP(inner_dim, config, quant_config) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_output = self.attn(hidden_states=hidden_states, ) + # residual connection + hidden_states = attn_output + residual + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + return hidden_states + + +@support_torch_compile +class JAISModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + assert not config.add_cross_attention + assert not config.scale_attn_by_inverse_layer_idx + assert not config.reorder_and_upcast_attn + self.embed_dim = config.hidden_size + self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) + self.wpe = (nn.Embedding(config.max_position_embeddings, + self.embed_dim) + if config.position_embedding_type != "alibi" else None) + if hasattr(config, "embeddings_scale"): + self.embeddings_scale = config.embeddings_scale + else: + self.embeddings_scale = config.mup_embeddings_scale + + self.start_layer, self.end_layer, self.h = make_layers( + config.num_hidden_layers, + lambda prefix: JAISBlock(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.h", + ) + + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.n_embd)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[IntermediateTensors, torch.Tensor]: + if get_pp_group().is_first_rank: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) + if self.wpe is not None: + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + else: + hidden_states = inputs_embeds + hidden_states *= torch.tensor(float(self.embeddings_scale), + dtype=hidden_states.dtype) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + + for layer in self.h[self.start_layer:self.end_layer]: + hidden_states = layer(hidden_states) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +class JAISLMHeadModel(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.transformer = JAISModel(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "transformer")) + if self.config.tie_word_embeddings: + self.lm_head = self.transformer.wte + else: + self.lm_head = ParallelLMHead(self.config.vocab_size, + self.config.hidden_size) + if hasattr(config, "width_scale"): + self.output_logits_scale = config.width_scale + else: + self.output_logits_scale = (config.mup_output_alpha * + config.mup_width_scale) + self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size, + scale=self.output_logits_scale) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[IntermediateTensors, torch.Tensor]: + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "lm_head.weight" in name: + # GPT-2 ties the weights of the embedding layer and the final + # linear layer. + continue + if ".attn.bias" in name or ".attn.masked_bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + if "relative_pe" in name: + continue + if not name.startswith("transformer."): + name = "transformer." + name + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + # The HF's GPT-2 implementation uses Conv1D instead of Linear. + # Because of this, we need to transpose the weights. + # Note(zhuohan): the logic below might break quantized models. + for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: + if conv1d_weight_name not in name: + continue + if not name.endswith(".weight"): + continue + loaded_weight = loaded_weight.t() + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py new file mode 100644 index 0000000..6fabc82 --- /dev/null +++ b/vllm/model_executor/models/jamba.py @@ -0,0 +1,600 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Inference-only Jamba model.""" +from typing import Iterable, Optional, Set, Tuple + +import torch +from torch import nn +from transformers import JambaConfig + +from vllm.attention.layer import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_pp_group +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer +from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.mamba_cache import (MambaCacheManager, + MambaCacheParams) +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.utils import LayerBlockType + +from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, + SupportsV0Only) +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class JambaMoE(nn.Module): + + def __init__(self, + config: JambaConfig, + num_experts: Optional[int] = None, + top_k: Optional[int] = None, + params_dtype: Optional[torch.dtype] = None, + tp_size: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.num_total_experts = num_experts or config.num_experts + self.top_k = top_k or config.num_experts_per_tok + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + if self.num_total_experts > 1: + self.router = ReplicatedLinear(self.hidden_size, + self.num_total_experts, + bias=False, + quant_config=None, + params_dtype=params_dtype) + + self.experts = FusedMoE(self.num_total_experts, + self.top_k, + self.hidden_size, + self.intermediate_size, + tp_size=tp_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=False, + use_grouped_topk=False, + quant_config=quant_config, + prefix=f"{prefix}.experts") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (batch * sequence_length, n_experts) + if self.num_total_experts > 1: + router_logits, _ = self.router(hidden_states) + else: + router_logits = torch.ones((hidden_states.shape[0], 1), + device=hidden_states.device, + dtype=hidden_states.dtype) + hidden_states = self.experts(hidden_states, router_logits) + return hidden_states.view(orig_shape) + + +class JambaMLP(JambaMoE): + + def __init__(self, + config: JambaConfig, + params_dtype: Optional[torch.dtype] = None, + tp_size: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__(config, + num_experts=1, + top_k=1, + params_dtype=params_dtype, + tp_size=tp_size, + quant_config=quant_config, + prefix=prefix) + + +class JambaMambaDecoderLayer(nn.Module): + + def __init__(self, + config: JambaConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + is_lora_enabled: Optional[bool] = False, + prefix: str = "", + **kwargs) -> None: + super().__init__() + self.config = config + self.is_lora_enabled = is_lora_enabled + self.mamba = MambaMixer(hidden_size= config.hidden_size, + ssm_state_size = config.mamba_d_state, + conv_kernel_size = config.mamba_d_conv, + intermediate_size = config.mamba_expand *\ + config.hidden_size, + time_step_rank = config.mamba_dt_rank, + use_conv_bias = config.mamba_conv_bias, + use_bias = config.mamba_proj_bias, + use_rms_norm=True, + rms_norm_eps=config.rms_norm_eps, + activation=config.hidden_act, + is_lora_enabled = self.is_lora_enabled + ) + + num_experts = config.layers_num_experts[layer_idx] + ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP + self.feed_forward = ffn_layer_class(config, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward") + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + mamba_cache_params: MambaCacheParams, + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.mamba(hidden_states, mamba_cache_params) + # Fully Connected + hidden_states, residual = self.pre_ff_layernorm( + hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +class JambaAttentionDecoderLayer(nn.Module): + + def __init__(self, + config: JambaConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + **kwargs) -> None: + super().__init__() + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = config.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + config.hidden_size, + bias=False, + quant_config=quant_config) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + prefix=f"{prefix}.attn", + ) + + num_experts = config.layers_num_experts[layer_idx] + ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP + self.feed_forward = ffn_layer_class(config, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward") + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def self_attention( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.self_attention( + positions=positions, + hidden_states=hidden_states, + ) + # Fully Connected + hidden_states, residual = self.pre_ff_layernorm( + hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +ALL_DECODER_LAYER_TYPES = { + "attention": JambaAttentionDecoderLayer, + "mamba": JambaMambaDecoderLayer +} + + +class JambaModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + extra_kwargs = {"is_lora_enabled": bool(vllm_config.lora_config)} + + def get_layer(prefix: str): + layer_idx = int(prefix.rsplit(".", 1)[1]) + layer_class = ALL_DECODER_LAYER_TYPES[ + config.layers_block_type[layer_idx]] + return layer_class(config, + layer_idx, + cache_config, + quant_config=quant_config, + prefix=prefix, + **extra_kwargs) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + self.final_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + mamba_cache_params: MambaCacheParams, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + kv_cache_index = 0 + mamba_cache_index = 0 + for layer in self.layers[self.start_layer:self.end_layer]: + layer_mamba_cache_params = None + if isinstance(layer, JambaAttentionDecoderLayer): + kv_cache_index += 1 + if isinstance(layer, JambaMambaDecoderLayer): + current_state_layer = mamba_cache_index + layer_mamba_cache_params = mamba_cache_params.at_layer_idx( + current_state_layer) + mamba_cache_index += 1 + + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + mamba_cache_params=layer_mamba_cache_params) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.final_layernorm(hidden_states, residual) + return hidden_states + + +class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, + IsHybrid, SupportsV0Only): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "in_proj": ["in_proj"], + } + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config + assert not cache_config.enable_prefix_caching, \ + "Jamba currently does not support prefix caching" + + super().__init__() + self.config = config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.scheduler_config = scheduler_config + self.model = JambaModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + # Used to track and store by the Mamba cache between steps. + self.mamba_cache: Optional[MambaCacheManager] = None + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = get_sampler() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs): + if self.mamba_cache is None: + num_mamba_layers = self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, LayerBlockType.mamba) + self.mamba_cache = MambaCacheManager( + self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, + *self._get_mamba_cache_shape()) + + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + + hidden_states = self.model(input_ids, positions, mamba_cache_params, + intermediate_tensors, inputs_embeds) + return hidden_states + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def _get_mamba_cache_shape( + self) -> Tuple[Tuple[int, int], Tuple[int, int]]: + world_size = get_tensor_model_parallel_world_size() + hidden_size = self.config.hidden_size + conv_state_shape = ( + self.config.mamba_expand * hidden_size // world_size, + self.config.mamba_d_conv - 1, + ) + temporal_state_shape = ( + self.config.mamba_expand * hidden_size // world_size, + self.config.mamba_d_state, + ) + return conv_state_shape, temporal_state_shape + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if "A_log" in name: + name = name.replace("A_log", "A") + + if ".self_attn." in name: + name = name.replace(".self_attn", "") + + if "feed_forward" in name and not _is_moe_layer(name): + ## map MLP layers to expert with ID=0 + name = name.replace("feed_forward", "feed_forward.experts.0") + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + if 'experts' in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for ( + param_name, + weight_name, + expert_id, + shard_id, + ) in expert_params_mapping: + if weight_name not in name: + continue + + if is_pp_missing_parameter(name, self): + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +def _is_moe_layer(name: str): + return any( + [experts_name in name for experts_name in [ + "experts", + "router", + ]]) + + +class JambaForSequenceClassification(JambaForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + config = vllm_config.model_config.hf_config + num_labels: int = config.num_labels + score_bias: bool = getattr(config, 'score_bias', False) + self.score = nn.Linear(config.hidden_size, num_labels, bias=score_bias) + + pooler_config = vllm_config.model_config.pooler_config + self._pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.LAST, + normalize=False, + softmax=False) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + hidden_states = hidden_states.float() + logits = self.score(hidden_states) + return self._pooler(logits, pooling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + # TODO: The reward weights themselves have float32 accuracy data, we + # would like to load them in fp32 to get that extra precision. + super().load_weights(weights) + self.score = self.score.float() diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py new file mode 100644 index 0000000..4a5982e --- /dev/null +++ b/vllm/model_executor/models/llama.py @@ -0,0 +1,609 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only LLaMA model compatible with HuggingFace weights.""" +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union + +import torch +from torch import nn +from transformers import LlamaConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class LlamaMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + prefix: str = "", + reduce_results: bool = True, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +class LlamaAttention(nn.Module): + + def __init__(self, + config: LlamaConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + bias_o_proj: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "") -> None: + super().__init__() + layer_idx = extract_layer_index(prefix) + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr(config, "head_dim", + self.hidden_size // self.total_num_heads) + # Phi models introduced a partial_rotary_factor parameter in the config + partial_rotary_factor = getattr(config, "partial_rotary_factor", 1) + self.rotary_dim = int(partial_rotary_factor * self.head_dim) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias_o_proj, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + is_neox_style = True + is_gguf = quant_config and quant_config.get_name() == "gguf" + if is_gguf and config.model_type == "llama": + is_neox_style = False + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.rotary_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=is_neox_style, + ) + + if hasattr(config, "interleaved_sliding_window"): + interleaved_sliding_window = config.interleaved_sliding_window + if isinstance(interleaved_sliding_window, int): + sliding_window = interleaved_sliding_window + elif isinstance(interleaved_sliding_window, list): + sw_idx = layer_idx % len(interleaved_sliding_window) + sliding_window = interleaved_sliding_window[sw_idx] + else: + raise ValueError( + f"{type(interleaved_sliding_window)} is not supported.") + else: + sliding_window = None + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + per_layer_sliding_window=sliding_window, + prefix=f"{prefix}.attn", + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class LlamaDecoderLayer(nn.Module): + + def __init__( + self, + config: LlamaConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) + bias_o_proj = attention_bias + # support internlm/internlm3-8b with qkv_bias + if hasattr(config, 'qkv_bias'): + attention_bias = config.qkv_bias + + self.self_attn = LlamaAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + bias_o_proj=bias_o_proj, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = LlamaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class LlamaModel(nn.Module): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.quant_config = quant_config + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + else: + self.embed_tokens = PPMissingLayer() + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: layer_type(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer(positions, hidden_states, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + if "scale" in name: + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"] + } + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings" + } + embedding_padding_modules = ["lm_head"] + + # Mistral/Llama models can also be loaded with --load-format mistral + # from consolidated.safetensors checkpoints + mistral_mapping = { + "layers": "model.layers", + "attention": "self_attn", + "qscale_act": "input_scale", + "qscale_weight": "weight_scale", + "kv_fake_quantizer.qscale_act": "kv_scale", + "wq": "q_proj", + "wk": "k_proj", + "wv": "v_proj", + "wo": "o_proj", + "attention_norm": "input_layernorm", + "feed_forward": "mlp", + "w1": "gate_proj", + "w2": "down_proj", + "w3": "up_proj", + "ffn_norm": "post_attention_layernorm", + "tok_embeddings": "model.embed_tokens", + "output": "lm_head", + "norm": "model.norm", + } + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config + self.lora_config = lora_config + + self.model = self._init_model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + layer_type=layer_type) + + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else + lora_config.lora_vocab_padding_size), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights( + self.model.embed_tokens) + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + else: + self.lm_head = PPMissingLayer() + + self.sampler = get_sampler() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def _init_model(self, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer): + return LlamaModel(vllm_config=vllm_config, + prefix=prefix, + layer_type=layer_type) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample(self, logits: torch.Tensor, + sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights( + self.maybe_remap_mistral(name, loaded_weight) + for name, loaded_weight in weights) + + # This function is used to remap the mistral format as + # used by Mistral and Llama <=2 + def maybe_remap_mistral( + self, + name: str, + loaded_weight: torch.Tensor, + ) -> Tuple[str, torch.Tensor]: + + def permute(w: torch.Tensor, n_heads: int): + attn_in = self.config.head_dim * n_heads + attn_out = self.config.hidden_size + + return w.view(n_heads, attn_in // n_heads // 2, 2, + attn_out).transpose(1, 2).reshape(attn_in, attn_out) + + mapping = self.mistral_mapping + modules = name.split(".") + + # rotary embeds should be sliced + if "wk" in modules and modules[-1] == "weight": + loaded_weight = permute(loaded_weight, + self.config.num_key_value_heads) + elif "wq" in modules and modules[-1] == "weight": + loaded_weight = permute(loaded_weight, + self.config.num_attention_heads) + + num_modules = len(modules) + for i in range(num_modules): + item = modules[i] + next_item = modules[i + 1] if i < num_modules - 1 else None + + combined_item = (f"{item}.{next_item}" + if next_item is not None else None) + + if combined_item in mapping: + name = name.replace(combined_item, mapping[combined_item]) + elif item in mapping and mapping[item] not in name: + name = name.replace(item, mapping[item]) + + return name, loaded_weight diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py new file mode 100644 index 0000000..27c8720 --- /dev/null +++ b/vllm/model_executor/models/llama4.py @@ -0,0 +1,530 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team. +# All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only LLaMA model compatible with HuggingFace weights.""" +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type + +import torch +from torch import nn +from transformers import Llama4TextConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from .llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaMLP, LlamaModel +from .utils import (AutoWeightsLoader, extract_layer_index, + is_pp_missing_parameter) + + +class Llama4MoE(nn.Module): + + @staticmethod + def custom_routing_function( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + router_scores, router_indices = torch.topk(gating_output, topk, dim=-1) + router_scores = torch.sigmoid(router_scores.float()).to( + hidden_states.dtype) + return (router_scores, router_indices.to(torch.int32)) + + def __init__(self, + config: Llama4TextConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.top_k = config.num_experts_per_tok + + intermediate_size_moe = config.intermediate_size + self.router = ReplicatedLinear(config.hidden_size, + config.num_local_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.router") + + self.experts = FusedMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + custom_routing_function=Llama4MoE.custom_routing_function, + intermediate_size=intermediate_size_moe, + apply_router_weight_on_input=True, + reduce_results=False, + renormalize=False, + quant_config=quant_config, + prefix=f"{prefix}.experts") + + self.shared_expert = LlamaMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size_moe, + hidden_act="silu", + quant_config=quant_config, + bias=False, + prefix=f"{prefix}.shared_expert", + reduce_results=False, # We need to do scatter before reduce + ) + + def forward(self, hidden_states): + router_logits, _ = self.router(hidden_states) + shared_out = self.shared_expert(hidden_states) + routed_out = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + ) + experts_out = routed_out + shared_out + + if self.tp_size > 1: + experts_out = tensor_model_parallel_all_reduce(experts_out) + + return experts_out + + +class Llama4Attention(nn.Module): + + def __init__(self, + config: Llama4TextConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + bias_o_proj: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.layer_idx = extract_layer_index(prefix) + self.hidden_size = hidden_size + self.no_rope_layers = config.no_rope_layers + self.nope = self.no_rope_layers[self.layer_idx] == 0 + self.use_qk_norm = config.use_qk_norm and not self.nope + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = config.head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + # TODO: attn_temperature_tuning should be a bool in huggingface + self.attn_temperature_tuning = self.nope and \ + config.attn_temperature_tuning > 0 + + self.floor_scale = getattr(config, "floor_scale", 8192.0) + self.attn_scale = getattr(config, "attn_scale", 0.1) + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.n_rep = self.num_heads // self.num_kv_heads + self.q_norm = RMSNorm( + hidden_size=self.q_size, + eps=config.rms_norm_eps, + has_weight=False, + dtype=torch.float32, + ) if self.use_qk_norm else None + self.k_norm = RMSNorm( + hidden_size=self.kv_size, + eps=config.rms_norm_eps, + has_weight=False, + dtype=torch.float32, + ) if self.use_qk_norm else None + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias_o_proj, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + is_neox_style = True + is_gguf = quant_config and quant_config.get_name() == "gguf" + if is_gguf and config.model_type == "llama": + is_neox_style = False + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=int(rope_theta), + rope_scaling=rope_scaling if rope_scaling != "default" else None, + is_neox_style=is_neox_style, + ) if not self.nope else None + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + per_layer_sliding_window=None, + use_irope=not self.nope, + prefix=f"{prefix}.attn", + ) + + def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor: + floor = torch.floor((positions + 1.0) / self.floor_scale) + attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0 + + return attn_scale.unsqueeze(-1) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + if self.rotary_emb is not None: + q, k = self.rotary_emb(positions, q, k) + if self.q_norm is not None: + q = self.q_norm(q.float()).to(q.dtype) + if self.k_norm is not None: + k = self.k_norm(k.float()).to(k.dtype) + + # We are applying temperature tuning (https://arxiv.org/abs/2501.19399) + # to NoPE layers, where the inference-time temperature tuning function + # is customized to not affect short context + # while working at very long context + # https://arxiv.org/abs/2501.19399 + # + # We should apply temperature tuning between (after) rotary / QK norm + # and (before) attention. + if self.attn_temperature_tuning and self.nope: + attn_scale = self._get_attn_scale(positions) + q = (q * attn_scale).to(q.dtype) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class Llama4DecoderLayer(LlamaDecoderLayer): + + def __init__( + self, + config: Llama4TextConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + self.layer_idx = extract_layer_index(prefix) + nn.Module.__init__(self) + self.hidden_size = config.hidden_size + rope_theta = config.rope_theta + rope_scaling = config.rope_scaling + max_position_embeddings = config.max_position_embeddings + + self.self_attn = Llama4Attention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=False, + bias_o_proj=False, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + ) + is_moe_layer = (self.layer_idx + + 1) % config.interleave_moe_layer_step == 0 + if is_moe_layer: + self.feed_forward = Llama4MoE( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + ) + else: + self.feed_forward = LlamaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size_mlp, + hidden_act="silu", + quant_config=quant_config, + bias=False, + prefix=f"{prefix}.feed_forward", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class Llama4Model(LlamaModel): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: Type[Llama4DecoderLayer] = Llama4DecoderLayer): + self.num_experts = vllm_config.model_config.hf_config.num_local_experts + super().__init__(vllm_config=vllm_config, + prefix=prefix, + layer_type=layer_type) + + def load_moe_expert_weights( + self, + name: str, + loaded_weight: torch.Tensor, + params_dict: Dict[str, nn.Parameter], + loaded_params: Set[str], + expert_params_mapping: List[Tuple[str, str, int, str]], + fused: bool = True, + ) -> bool: + expert_param_loaded = False + if "experts.gate_up_proj" in name: + loaded_weight = loaded_weight.chunk(2, dim=-1) + for (param_name, weight_name, expert_id, + shard_id) in expert_params_mapping: + new_loaded_weight = loaded_weight + if fused: + e_str, _, proj_str, _ = weight_name.split('.') + weight_name = f"{e_str}.{proj_str}" + param_name = f"{param_name}weight" + if weight_name not in name: + continue + full_param_name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[full_param_name] + weight_loader = param.weight_loader + if fused: + if "w13" in full_param_name: + shard_idx = 0 if shard_id == "w1" else 1 + new_loaded_weight = new_loaded_weight[shard_idx] + new_loaded_weight = new_loaded_weight.transpose(-1, -2) + layer_idx = extract_layer_index(name) + # EP mapping + expert_map = self.layers[ + layer_idx].feed_forward.experts.expert_map + if expert_map is not None: + local_expert_indices = (expert_map != -1) \ + .nonzero() \ + .flatten() \ + .to(new_loaded_weight.device) + new_loaded_weight = new_loaded_weight[local_expert_indices] + expert_id = local_expert_indices[0].item() + else: + # TODO: add EP support for non fused weights + pass + weight_loader(param, + new_loaded_weight, + full_param_name, + shard_id=shard_id, + expert_id=expert_id) + + loaded_params.add(full_param_name) + expert_param_loaded = True + return expert_param_loaded + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + fused_experts_params = False + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.num_experts) + expert_params_mapping_fused = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_up_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="gate_up_proj", + num_experts=1) + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "experts.gate_up_proj" in name or "experts.down_proj" in name: + fused_experts_params = True + expert_params_mapping = expert_params_mapping_fused + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name or "experts" in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(name) + break + else: + moe_loaded = self.load_moe_expert_weights( + name, + loaded_weight, + params_dict, + loaded_params, + expert_params_mapping, + fused=fused_experts_params) + + if not moe_loaded: + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Llama4ForCausalLM(LlamaForCausalLM): + + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + # Update temperature tuning config from generation config + gen_config = vllm_config.model_config.try_get_generation_config() + gen_config.update(vllm_config.model_config.override_generation_config) + vllm_config.model_config.hf_config.attn_temperature_tuning \ + = gen_config.get("attn_temperature_tuning", False) + LlamaForCausalLM.__init__(self, + vllm_config=vllm_config, + prefix=prefix, + layer_type=Llama4DecoderLayer) + + def _init_model(self, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: Type[Llama4DecoderLayer] = Llama4DecoderLayer): + return Llama4Model(vllm_config=vllm_config, + prefix=prefix, + layer_type=layer_type) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + weights = [ + self.permute_qk_weight_for_rotary(name, loaded_weight) + for name, loaded_weight in weights + ] + return loader.load_weights(weights) + + def permute_qk_weight_for_rotary( + self, + name: str, + loaded_weight: torch.Tensor, + ) -> Tuple[str, torch.Tensor]: + + def permute(w: torch.Tensor, n_heads: int): + attn_in = self.config.head_dim * n_heads + attn_out = self.config.hidden_size + + return w.view(n_heads, attn_in // n_heads // 2, 2, + attn_out).transpose(1, 2).reshape(attn_in, attn_out) + + modules = name.split(".") + + # rotary embeds should be sliced + if ("wk" in modules or "k_proj" in modules) \ + and modules[-1] == "weight": + loaded_weight = permute(loaded_weight, + self.config.num_key_value_heads) + elif ("wq" in modules or "q_proj" in modules) \ + and modules[-1] == "weight": + loaded_weight = permute(loaded_weight, + self.config.num_attention_heads) + + return name, loaded_weight diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py new file mode 100644 index 0000000..45a0bf7 --- /dev/null +++ b/vllm/model_executor/models/llava.py @@ -0,0 +1,931 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import abstractmethod +from collections.abc import Iterable, Mapping, Sequence +from functools import cached_property +from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict, + TypeVar, Union, cast) + +import torch +import torch.nn as nn +from packaging.version import Version +from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig, + PixtralVisionConfig, PretrainedConfig, + SiglipVisionConfig) +from transformers import __version__ as TRANSFORMERS_VERSION +from transformers.models.llava import LlavaProcessor +from transformers.models.pixtral import PixtralProcessor + +from vllm.config import VllmConfig +from vllm.inputs import InputProcessingContext +from vllm.jsontree import json_map_leaves +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalInputs, MultiModalKwargs) +from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, + ImageSize, MultiModalDataItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, ProcessingCache, + PromptReplacement, PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors + +from .clip import CLIPVisionModel +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel +from .siglip import SiglipVisionModel +from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, + maybe_prefix, merge_multimodal_embeddings) +from .vision import (get_vision_encoder_info, scatter_patch_features, + select_patch_features) + + +class LlavaImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + pixel_values: torch.Tensor + """ + Shape: `(batch_size * num_images, num_channels, height, width)` + + Note that `height` or `width` may be different per batch and image, + in which case the data is passed as a list instead of a batched tensor. + """ + + +class PixtralHFImagePixelInputs(TypedDict): + type: Literal["pixel_values_pixtral"] + pixel_values: Union[torch.Tensor, list[torch.Tensor]] + """ + Shape: `(batch_size * num_images, num_channels, height, width)` + + Note that `height` or `width` may be different per batch and image, + in which case the data is passed as a list instead of a batched tensor. + """ + + embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] + """ + A boolean mask indicating which image embeddings correspond + to patch tokens. + + Shape: `(batch_size * num_images, num_embeds)` + """ + + +class LlavaImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: torch.Tensor + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + """ + + +LlavaImageInputs = Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs, + LlavaImageEmbeddingInputs] + + +class LlavaMultiModalProjector(nn.Module): + + def __init__(self, + vision_hidden_size: int, + text_hidden_size: int, + projector_hidden_act: str, + multimodal_projector_bias: bool, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + + self.linear_1 = ColumnParallelLinear(vision_hidden_size, + text_hidden_size, + bias=multimodal_projector_bias, + quant_config=quant_config, + prefix=f"{prefix}.linear_1") + self.act = get_act_fn(projector_hidden_act) + self.linear_2 = RowParallelLinear(text_hidden_size, + text_hidden_size, + bias=multimodal_projector_bias, + quant_config=quant_config, + prefix=f"{prefix}.linear_2") + + def forward(self, image_features: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.linear_2(hidden_states) + return hidden_states + + +class LlavaLikeConfig(Protocol): + vision_config: Final[PretrainedConfig] + image_token_index: Final[int] + vision_feature_select_strategy: Final[str] + vision_feature_layer: Final[Union[int, list[int]]] + + +class LlavaLikeProcessor(Protocol): + image_token: Final[str] + + +class BaseLlavaProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self) -> LlavaLikeConfig: + return self.ctx.get_hf_config(LlavaConfig) + + def get_vision_encoder_info(self): + return get_vision_encoder_info(self.get_hf_config()) + + @abstractmethod + def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor: + raise NotImplementedError + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"image": self.get_max_image_tokens()} + + def _apply_feature_select_strategy( + self, + strategy: str, + encoder_num_image_tokens: int, + ) -> int: + if strategy == "default": + return encoder_num_image_tokens - 1 + if strategy == "full": + return encoder_num_image_tokens + + msg = f"Unexpected feature select strategy: {strategy!r}" + raise NotImplementedError(msg) + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + hf_config = self.get_hf_config() + vision_encoder_info = self.get_vision_encoder_info() + + return self._apply_feature_select_strategy( + hf_config.vision_feature_select_strategy, + vision_encoder_info.get_num_image_tokens( + image_width=image_width, + image_height=image_height, + ), + ) + + def get_image_size_with_most_features(self) -> ImageSize: + vision_encoder_info = self.get_vision_encoder_info() + width = height = vision_encoder_info.get_image_size() + return ImageSize(width=width, height=height) + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + return self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + ) + + +_I = TypeVar("_I", bound=BaseLlavaProcessingInfo) + + +class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + + processor = self.info.get_hf_processor() + image_token = processor.image_token + target_width, target_height = \ + self.info.get_image_size_with_most_features() + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text=image_token * num_images, + mm_data=mm_data, + ) + + +class LlavaProcessingInfo(BaseLlavaProcessingInfo): + + def get_hf_processor(self, **kwargs: object): + hf_processor = self.ctx.get_hf_processor(LlavaProcessor, **kwargs) + # In case patch_size is omitted from `processor_config.json` + # e.g. for E5-V: https://huggingface.co/royokong/e5-v + if hf_processor.patch_size is None: + patch_size = self.get_vision_encoder_info().get_patch_size() + hf_processor.patch_size = patch_size + return hf_processor + + +class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]): + + # Copied from BaseMultiModalProcessor + @abstractmethod + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + raise NotImplementedError + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_config = self.info.get_hf_config() + image_token_id = hf_config.image_token_index + + def get_replacement(item_idx: int): + images = mm_items.get_items( + "image", (ImageEmbeddingItems, ImageProcessorItems)) + + if isinstance(images, ImageEmbeddingItems): + num_image_tokens = images.get_feature_size(item_idx) + else: + image_size = images.get_image_size(item_idx) + num_image_tokens = self.info.get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + ) + + return [image_token_id] * num_image_tokens + + return [ + PromptReplacement( + modality="image", + target=[image_token_id], + replacement=get_replacement, + ), + ] + + +class LlavaMultiModalProcessor( + BaseLlavaMultiModalProcessor[LlavaProcessingInfo]): + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + ) + + +class PixtralHFProcessingInfo(BaseLlavaProcessingInfo): + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(PixtralProcessor, **kwargs) + + +class PixtralHFMultiModalProcessor( + BaseMultiModalProcessor[PixtralHFProcessingInfo]): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + pixel_values = processed_outputs.get("pixel_values") + if pixel_values is not None: + # Before/after https://github.com/huggingface/transformers/pull/35122 + if Version(TRANSFORMERS_VERSION) <= Version("4.48.3"): + images = mm_data["images"] + assert isinstance(images, list) + + # Original output: (1, num_images, C, H, W) + # New output: (num_images, C, H, W) + assert (isinstance(pixel_values, list) + and len(pixel_values) == 1) + assert (isinstance(pixel_values[0], list) + and len(pixel_values[0]) == len(images)) + + processed_outputs["pixel_values"] = pixel_values[0] + else: + # Avoid padding since we need the output for each image to be + # independent of other images for the cache to work correctly + image_sizes = processed_outputs["image_sizes"] + assert len(pixel_values) == len(image_sizes) + + processed_outputs["pixel_values"] = [ + p[:, :h, :w] + for p, (h, w) in zip(pixel_values, image_sizes) + ] + + hf_config = self.info.get_hf_config() + vision_config = hf_config.vision_config + assert isinstance(vision_config, PixtralVisionConfig) + encoder_info = PixtralHFEncoderInfo(vision_config) + + tile_sizes = [ + encoder_info.get_patch_grid_size( + image_width=pixel_value.shape[-1], + image_height=pixel_value.shape[-2], + ) for pixel_value in processed_outputs["pixel_values"] + ] + embed_is_patch = [ + torch.tensor(([True] * ncols + [False]) * nrows) + for ncols, nrows in tile_sizes + ] + processed_outputs["embed_is_patch"] = embed_is_patch + + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + embed_is_patch=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + hf_config = self.info.get_hf_config() + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + + image_break_id = vocab[processor.image_break_token] + image_token_id = hf_config.image_token_index + image_end_id = vocab[processor.image_end_token] + + vision_config = hf_config.vision_config + assert isinstance(vision_config, PixtralVisionConfig) + encoder_info = PixtralHFEncoderInfo(vision_config) + + def get_replacement(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) + + ncols, nrows = encoder_info.get_patch_grid_size( + image_width=image_size.width, + image_height=image_size.height, + ) + + tokens = ([image_token_id] * ncols + [image_break_id]) * nrows + tokens[-1] = image_end_id + + return tokens + + return [ + PromptReplacement( + modality="image", + target=[image_token_id], + replacement=get_replacement, + ), + ] + + +def _build_llava_or_pixtral_hf_info( + ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo: + hf_config = ctx.get_hf_config(LlavaConfig) + + if isinstance(hf_config.vision_config, PixtralVisionConfig): + return PixtralHFProcessingInfo(ctx) + + return LlavaProcessingInfo(ctx) + + +def _build_llava_or_pixtral_hf_processor( + info: _I, + dummy_inputs: BaseDummyInputsBuilder[_I], + *, + cache: Optional[ProcessingCache] = None, + enable_sanity_checks: bool = True, +) -> BaseMultiModalProcessor: + if isinstance(info, PixtralHFProcessingInfo): + return PixtralHFMultiModalProcessor( + info, + dummy_inputs, # type: ignore + cache=cache, + enable_sanity_checks=enable_sanity_checks, + ) + + if isinstance(info, LlavaProcessingInfo): + return LlavaMultiModalProcessor( + info, + dummy_inputs, # type: ignore + cache=cache, + enable_sanity_checks=enable_sanity_checks, + ) + + raise NotImplementedError(type(info)) + + +def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int: + """Determine the number of hidden layers to initialize up to in the + visual encoder. + + Args: + hf_config: Model config with vision feature layer(s). + """ + feature_layers = hf_config.vision_feature_layer + num_hidden_layers = hf_config.vision_config.num_hidden_layers + # If we have one feature layer, initialize up to that layer + if isinstance(feature_layers, int): + return _get_layer_index(feature_layers, num_hidden_layers) + # If we have multiple feature layers, initialize up to the deepest one + elif isinstance(feature_layers, (list, tuple)): + return max( + _get_layer_index(idx, num_hidden_layers) for idx in feature_layers) + raise TypeError(f"vision_layer_feature type: {type(feature_layers)}" + " is not supported") + + +def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int: + """Given a signed vision feature layer, get the number of hidden layers + needed to leverage it. + + Args: + feature_layer_index: Index of a required layer in the visual encoder. + num_hidden_layers: The total number of hidden layers in the visual + encoder. + """ + if feature_layer_index < 0: + return num_hidden_layers + feature_layer_index + 1 + return feature_layer_index + + +def init_vision_tower_for_llava( + hf_config: LlavaLikeConfig, + quant_config: Optional[QuantizationConfig], + *, + require_post_norm: Optional[bool] = None, + prefix: str = "", +) -> Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel]: + vision_config = hf_config.vision_config + + # Initialize the vision tower only up to the deepest required feature layer + num_hidden_layers = _get_num_hidden_layers(hf_config) + + if isinstance(vision_config, CLIPVisionConfig): + return CLIPVisionModel( + vision_config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers, + require_post_norm=require_post_norm, + prefix=prefix, + ) + elif isinstance(vision_config, SiglipVisionConfig): + return SiglipVisionModel( + vision_config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers, + require_post_norm=require_post_norm, + prefix=prefix, + ) + elif isinstance(vision_config, PixtralVisionConfig): + return PixtralHFVisionModel( + vision_config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers, + require_post_norm=require_post_norm, + prefix=prefix, + ) + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + +@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor, + info=_build_llava_or_pixtral_hf_info, + dummy_inputs=LlavaDummyInputsBuilder) +class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"] + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + + # NOTE: These are special cases for Pixtral-12B in the HF-format + # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json # noqa + if (config.text_config.architectures is None + and config.text_config.model_type == "mistral"): + config.text_config.architectures = ["MistralForCausalLM"] + if (config.projector_hidden_act is None + and config.vision_config.hidden_act == "gelu"): + config.projector_hidden_act = "gelu" + + # TODO: Optionally initializes this for supporting embeddings. + self.vision_tower = init_vision_tower_for_llava( + config, + quant_config, + require_post_norm=False, + prefix=maybe_prefix(prefix, "vision_tower")) + self.multi_modal_projector = LlavaMultiModalProjector( + vision_hidden_size=config.vision_config.hidden_size, + text_hidden_size=config.text_config.hidden_size, + projector_hidden_act=config.projector_hidden_act, + multimodal_projector_bias=config.multimodal_projector_bias, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "multi_modal_projector")) + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return get_sampler() + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + actual_dims = tuple(data.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("batch_size", *map(str, expected_dims)) + raise ValueError( + f"The expected shape of pixel values is {expected_expr}. " + f"You supplied {tuple(data.shape)}.") + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[LlavaImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + if self.config.vision_config.model_type == "pixtral": + embed_is_patch = kwargs.pop("embed_is_patch") + if not isinstance(embed_is_patch, (torch.Tensor, list)): + raise ValueError("Incorrect type of embed_is_patch. " + f"Got type: {type(embed_is_patch)}") + + embed_is_patch = flatten_bn(embed_is_patch) + + return PixtralHFImagePixelInputs( + type="pixel_values_pixtral", + pixel_values=flatten_bn(pixel_values), + embed_is_patch=embed_is_patch, + ) + + return LlavaImagePixelInputs( + type="pixel_values", + pixel_values=self._validate_pixel_values( + flatten_bn(pixel_values, concat=True)), + ) + + if image_embeds is not None: + if not isinstance(image_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + + if self.config.vision_config.model_type == "pixtral": + raise ValueError("Pixtral-HF does not support image_embeds.") + + return LlavaImageEmbeddingInputs( + type="image_embeds", + data=flatten_bn(image_embeds, concat=True), + ) + + raise AssertionError("This line should be unreachable.") + + def _select_image_features(self, image_features: torch.Tensor, *, + strategy: str) -> torch.Tensor: + # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa + if strategy == "default": + return image_features[:, 1:] + elif strategy == "full": + return image_features + + raise ValueError(f"Unexpected select feature strategy: {strategy}") + + def _image_pixels_to_features( + self, + vision_tower: Union[CLIPVisionModel, SiglipVisionModel, + PixtralHFVisionModel], + pixel_values: Union[torch.Tensor, list[torch.Tensor]], + ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + # NOTE: we skip the step to select the vision feature layer since + # this is already done inside the vision tower + image_features = vision_tower(pixel_values) + + def select_features(leaf: torch.Tensor): + return self._select_image_features( + leaf, + strategy=self.config.vision_feature_select_strategy, + ) + + return cast( + Union[torch.Tensor, tuple[torch.Tensor, ...]], + json_map_leaves(select_features, image_features), + ) + + def _process_image_pixels( + self, + inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs], + ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + assert self.vision_tower is not None + + pixel_values = inputs["pixel_values"] + + return self._image_pixels_to_features(self.vision_tower, pixel_values) + + def _process_image_input( + self, + image_input: LlavaImageInputs, + ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + if image_input["type"] == "image_embeds": + return image_input["data"] + + assert self.vision_tower is not None + image_features = self._process_image_pixels(image_input) + + if isinstance(image_features, torch.Tensor): + return self.multi_modal_projector(image_features) + + feature_sizes = [ + image_feature.shape[0] for image_feature in image_features + ] + + image_embeds = self.multi_modal_projector(torch.cat(image_features)) + image_embeds = torch.split(image_embeds, feature_sizes) + return image_embeds + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + + image_features = self._process_image_input(image_input) + + if image_input["type"] != "pixel_values_pixtral": + # The path is used for pixtral (V0 only) and llava (V0/V1) + return image_features + + return scatter_patch_features( + image_features, + image_input["embed_is_patch"], + ) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + select_patch_features(multimodal_embeddings), + self.config.image_token_index, + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + """Run forward pass for LLaVA-1.5. + + One key thing to understand is the `input_ids` already accounts for the + positions of the to-be-inserted image embeddings. + + Concretely, consider a text prompt: + `"USER: \\nWhat's the content of the image?\\nASSISTANT:"`. + + Tokenizer outputs: + `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879, + 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`. + + To reserve space in KV cache, we have to insert placeholder tokens + before they are inputted to the model, so the input processor prepends + additional image tokens (denoted as `32000`), resulting in: + `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618, + 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, + 29901]`. + + We insert 575 tokens so that including the original image token in the + input, there are a total of 576 (24 * 24) image tokens, which + corresponds to the number of image tokens inputted to the language + model, i.e. the number of image tokens outputted by the visual encoder. + + This way, the `positions` and `attn_metadata` are consistent + with the `input_ids`. + + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + pixel_values: The pixels in each input image. + + See also: + :class:`LlavaImageInputs` + """ + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.language_model.model(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) + + +class MantisProcessingInfo(LlavaProcessingInfo): + + def get_hf_processor(self, **kwargs: object): + hf_config = self.get_hf_config() + vision_info = self.get_vision_encoder_info() + + kwargs.setdefault("patch_size", vision_info.get_patch_size()) + + if Version(TRANSFORMERS_VERSION) < Version("4.48"): + # BUG: num_additional_image_tokens = 0 but treated as 1, + # so we set vision_feature_select_strategy to None to offset this + kwargs.setdefault("vision_feature_select_strategy", None) + else: + # FIXED: https://github.com/huggingface/transformers/pull/33424/files#diff-6a37acc21efcadaae622b079b2712a131131448ff64262bd219aa346aeec38faL150 + kwargs.setdefault( + "vision_feature_select_strategy", + hf_config.vision_feature_select_strategy, + ) + + return self.ctx.get_hf_processor(LlavaProcessor, **kwargs) + + +class MantisMultiModalProcessor(LlavaMultiModalProcessor): + + def apply( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + return_mm_hashes: bool = False, + ) -> MultiModalInputs: + hf_config = self.info.get_hf_config() + image_token_id = hf_config.image_token_index + + # Assume that it doesn't depend on the image size + num_image_tokens = self.info.get_num_image_tokens( + image_width=-1, + image_height=-1, + ) + + result = super().apply(prompt, mm_data, hf_processor_mm_kwargs, + return_mm_hashes) + + mm_items = self._to_mm_items(mm_data) + mm_item_counts = mm_items.get_all_counts() + mm_kwargs = result["mm_kwargs"] + mm_hashes = result["mm_hashes"] + + # We reimplement the functionality of MLlavaProcessor from + # https://github.com/TIGER-AI-Lab/Mantis.git + def get_replacement_mantis(item_idx: int): + return "".join([ + f"(image {item_idx+1}: ", # 7 tokens + "" * num_image_tokens, + ")", # 3 tokens + ]) + + mantis_mm_repls = self._bind_and_group_updates([ + PromptReplacement( + modality="image", + target=[image_token_id] * num_image_tokens, + replacement=get_replacement_mantis, + ) + ]) + + prompt_ids, prompt, _ = self._apply_prompt_updates( + result["prompt_token_ids"], + mantis_mm_repls, + mm_item_counts, + ) + + unbound_orig_repls = self._get_prompt_updates( + mm_items, + hf_processor_mm_kwargs, + mm_kwargs, + ) + orig_repls = self._bind_and_group_updates(unbound_orig_repls) + + mm_placeholders = self._find_mm_placeholders( + orig_repls, + prompt_ids, + mm_item_counts, + ) + self._validate_mm_placeholders(mm_placeholders, mm_item_counts) + + mm_placeholder_ranges = { + modality: [item.to_range() for item in placeholders] + for modality, placeholders in mm_placeholders.items() + } + + return MultiModalInputs( + type="multimodal", + prompt=prompt, + prompt_token_ids=prompt_ids, + mm_kwargs=mm_kwargs, + mm_hashes=mm_hashes, + mm_placeholders=mm_placeholder_ranges, + ) + + +# To use this model, please use +# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` +@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor, + info=MantisProcessingInfo, + dummy_inputs=LlavaDummyInputsBuilder) +class MantisForConditionalGeneration(LlavaForConditionalGeneration): + pass diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py new file mode 100644 index 0000000..4de13e5 --- /dev/null +++ b/vllm/model_executor/models/llava_next.py @@ -0,0 +1,595 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import abstractmethod +from functools import cached_property +from typing import (Final, Iterable, List, Literal, Mapping, Optional, + Protocol, Set, Tuple, TypedDict, TypeVar, Union) + +import torch +import torch.nn as nn +from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor +from transformers.models.llava_next.modeling_llava_next import ( + get_anyres_image_grid_shape, unpad_image) +from typing_extensions import NotRequired + +from vllm.config import VllmConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalFieldConfig +from vllm.multimodal.parse import ImageSize +from vllm.sequence import IntermediateTensors + +from .clip import CLIPVisionModel +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingInfo, + LlavaDummyInputsBuilder, LlavaLikeConfig, + LlavaMultiModalProjector, init_vision_tower_for_llava) +from .siglip import SiglipVisionModel +from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn, + init_vllm_registered_model, maybe_prefix) + + +class LlavaNextImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + pixel_values: Union[torch.Tensor, list[torch.Tensor]] + """ + Shape: + `(batch_size * num_images, 1 + num_patches, num_channels, height, width)` + + Note that `num_patches` may be different per batch and image, + in which case the data is passed as a list instead of a batched tensor. + """ + + image_sizes: NotRequired[torch.Tensor] + """ + Shape: `(batch_size * num_images, 2)` + + This should be in `(height, width)` format. + """ + + +class LlavaNextImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: torch.Tensor + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + """ + + +LlavaNextImageInputs = Union[LlavaNextImagePixelInputs, + LlavaNextImageEmbeddingInputs] + + +class LlavaNextLikeConfig(LlavaLikeConfig, Protocol): + image_grid_pinpoints: Final[list[list[int]]] + + +class LlavaNextProcessingInfo(BaseLlavaProcessingInfo): + + def get_hf_config(self) -> LlavaNextLikeConfig: + return self.ctx.get_hf_config(LlavaNextConfig) + + def get_hf_processor(self, **kwargs: object): + hf_processor = self.ctx.get_hf_processor(LlavaNextProcessor, **kwargs) + + # In case patch_size is omitted from `processor_config.json` + # e.g. for E5-V: https://huggingface.co/royokong/e5-v + if hf_processor.patch_size is None: + patch_size = self.get_vision_encoder_info().get_patch_size() + hf_processor.patch_size = patch_size + + return hf_processor + + # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L113 + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + hf_config = self.get_hf_config() + vision_encoder_info = self.get_vision_encoder_info() + + base_feature_size = self._apply_feature_select_strategy( + hf_config.vision_feature_select_strategy, + vision_encoder_info.get_num_image_tokens( + image_width=image_width, + image_height=image_height, + ), + ) + + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + image_size=(image_height, image_width), + grid_pinpoints=hf_config.image_grid_pinpoints, + patch_size=vision_encoder_info.get_image_size(), + ) + + ( + unpadded_feature_size, + newline_feature_size, + ) = self._get_num_unpadded_features( + original_height=image_height, + original_width=image_width, + npatches=vision_encoder_info.get_patch_grid_length(), + num_patch_height=num_patch_height, + num_patch_width=num_patch_width, + ) + + return unpadded_feature_size + newline_feature_size + base_feature_size + + # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86 + def _get_num_unpadded_features( + self, + *, + original_height: int, + original_width: int, + npatches: int, + num_patch_height: int, + num_patch_width: int, + ) -> tuple[int, int]: + current_height = npatches * num_patch_height + current_width = npatches * num_patch_width + + aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + if aspect_ratio > current_aspect_ratio: + new_height = (original_height * current_width) // original_width + padding = (current_height - new_height) // 2 + current_height = current_height - (2 * padding) + else: + new_width = (original_width * current_height) // original_height + padding = (current_width - new_width) // 2 + current_width = current_width - (2 * padding) + + unpadded_features = current_height * current_width + newline_features = current_height + + return (unpadded_features, newline_features) + + def get_image_size_with_most_features(self) -> ImageSize: + hf_config = self.get_hf_config() + + largest_feature_size, largest_feature_pinpoint = 0, None + for (height, width) in hf_config.image_grid_pinpoints: + feat_size = self.get_num_image_tokens(image_width=width, + image_height=height) + if feat_size > largest_feature_size: + largest_feature_size = feat_size + largest_feature_pinpoint = ImageSize(width=width, + height=height) + + if largest_feature_size == 0 or largest_feature_pinpoint is None: + raise ValueError("Cannot have a largest feature size of 0!") + + return largest_feature_pinpoint + + +_I = TypeVar("_I", bound=LlavaNextProcessingInfo) + + +class BaseLlavaNextMultiModalProcessor(BaseLlavaMultiModalProcessor[_I]): + + # Copied from BaseMultiModalProcessor + @abstractmethod + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + raise NotImplementedError + + +class LlavaNextMultiModalProcessor( + BaseLlavaNextMultiModalProcessor[LlavaNextProcessingInfo]): + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + image_sizes=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + ) + + +@MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor, + info=LlavaNextProcessingInfo, + dummy_inputs=LlavaDummyInputsBuilder) +class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + vision_feature_layer = config.vision_feature_layer + # Determine the layer up to which we will initialize the vision tower + if isinstance(vision_feature_layer, int): + vision_hidden_size = config.vision_config.hidden_size + self.feature_sample_layers = None + # Used for multimodal granite models to control encoder outputs + elif isinstance(vision_feature_layer, (list, tuple)): + vision_hidden_size = config.vision_config.hidden_size * len( + vision_feature_layer) + self.feature_sample_layers = vision_feature_layer + else: + raise TypeError( + f"vision_layer_feature type: {type(vision_feature_layer)}" + " is not supported") + + self.config = config + self.multimodal_config = multimodal_config + + # TODO: Optionally initializes this for supporting embeddings. + self.vision_tower = init_vision_tower_for_llava( + config, + quant_config, + require_post_norm=False, + prefix=maybe_prefix(prefix, "vision_tower")) + self.image_newline = nn.Parameter( + torch.empty(config.text_config.hidden_size)) + self.multi_modal_projector = LlavaMultiModalProjector( + vision_hidden_size=vision_hidden_size, + text_hidden_size=config.text_config.hidden_size, + projector_hidden_act=config.projector_hidden_act, + multimodal_projector_bias=config.multimodal_projector_bias) + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return get_sampler() + + def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: + expected_dims = (2, ) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape) + + if actual_dims != expected_dims: + expected_expr = str(expected_dims) + raise ValueError( + f"The expected shape of image sizes per image per batch " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _validate_pixel_values( + self, data: Union[torch.Tensor, List[torch.Tensor]] + ) -> Union[torch.Tensor, List[torch.Tensor]]: + + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("num_patches", *map(str, expected_dims)) + raise ValueError( + "The expected shape of pixel values per image per batch " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[LlavaNextImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_sizes = kwargs.pop("image_sizes", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + if not isinstance(image_sizes, (torch.Tensor, list)): + raise ValueError("Incorrect type of image sizes. " + f"Got type: {type(image_sizes)}") + + return LlavaNextImagePixelInputs( + type="pixel_values", + pixel_values=self._validate_pixel_values( + flatten_bn(pixel_values)), + image_sizes=self._validate_image_sizes( + flatten_bn(image_sizes, concat=True)), + ) + + if image_embeds is not None: + if not isinstance(image_embeds, torch.Tensor): + raise ValueError("Incorrect type of image embeds. " + f"Got type: {type(image_embeds)}") + + return LlavaNextImageEmbeddingInputs( + type="image_embeds", + data=flatten_bn(image_embeds), + ) + + raise AssertionError("This line should be unreachable.") + + def _select_image_features(self, image_features: torch.Tensor, *, + strategy: str) -> torch.Tensor: + # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa + if strategy == "default": + return image_features[:, 1:] + elif strategy == "full": + return image_features + + raise ValueError(f"Unexpected select feature strategy: {strategy}") + + def _image_pixels_to_features( + self, + vision_tower: Union[CLIPVisionModel, SiglipVisionModel], + pixel_values: torch.Tensor, + ) -> torch.Tensor: + + # NOTE: we skip the step to select the vision feature layer since + # this is already done inside the vision tower + image_features = vision_tower( + pixel_values, feature_sample_layers=self.feature_sample_layers) + + return self._select_image_features( + image_features, + strategy=self.config.vision_feature_select_strategy, + ) + + # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py + def _merge_image_patch_embeddings(self, image_size: torch.Tensor, + patch_embeddings: torch.Tensor, *, + strategy: str) -> torch.Tensor: + if strategy == "flat": + return patch_embeddings.flatten(0, 1) + + if strategy.startswith("spatial"): + height = width = self.config.vision_config.image_size \ + // self.config.vision_config.patch_size + + base_patch_embeds = patch_embeddings[0] + if height * width != base_patch_embeds.shape[0]: + raise ValueError( + "The number of patches is not consistent with the " + "image size.") + + if patch_embeddings.shape[0] > 1: + other_patch_embeds = patch_embeddings[1:] + + # Move to CPU to avoid floating-point errors + orig_height, orig_width = image_size.tolist() + + # image_aspect_ratio == "anyres" + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + (orig_height, orig_width), + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + num_patches = num_patch_height * num_patch_width + + # Image patches might be padded for batch processing + other_patch_embeds = other_patch_embeds[:num_patches] \ + .view(num_patch_height, num_patch_width, height, width, -1) + + if "unpad" in strategy: + other_patch_embeds = other_patch_embeds \ + .permute(4, 0, 2, 1, 3).contiguous() \ + .flatten(1, 2).flatten(2, 3) + other_patch_embeds = unpad_image(other_patch_embeds, + (orig_height, orig_width)) + other_patch_embeds = torch.cat(( + other_patch_embeds, + self.image_newline[:, None, None] \ + .expand(*other_patch_embeds.shape[:-1], 1) \ + .to(other_patch_embeds.device), + ), dim=-1) + other_patch_embeds = other_patch_embeds \ + .flatten(1, 2).transpose(0, 1) + else: + other_patch_embeds = other_patch_embeds \ + .permute(0, 2, 1, 3, 4).contiguous() \ + .flatten(0, 3) + + merged_patch_embeddings = torch.cat( + (base_patch_embeds, other_patch_embeds), dim=0) + else: + if "unpad" in strategy: + merged_patch_embeddings = torch.cat( + (base_patch_embeds, + self.image_newline[None] \ + .to(base_patch_embeds.device) + ), dim=0) + else: + merged_patch_embeddings = base_patch_embeds + + return merged_patch_embeddings + + raise ValueError(f"Unexpected patch merge strategy: {strategy}") + + def _process_image_pixels( + self, + inputs: LlavaNextImagePixelInputs, + ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + assert self.vision_tower is not None + + pixel_values = inputs["pixel_values"] + + if isinstance(pixel_values, torch.Tensor): + b, num_patches, c, h, w = pixel_values.shape + stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w) + stacked_image_features = self._image_pixels_to_features( + self.vision_tower, stacked_pixel_values) + stacked_patch_embeddings = self.multi_modal_projector( + stacked_image_features) + + return stacked_patch_embeddings.view( + b, num_patches, *stacked_patch_embeddings.shape[1:]) + + num_patches_per_batch = [v.shape[0] for v in pixel_values] + stacked_pixel_values = torch.cat(pixel_values) + stacked_image_features = self._image_pixels_to_features( + self.vision_tower, stacked_pixel_values) + + return torch.split(self.multi_modal_projector(stacked_image_features), + num_patches_per_batch) + + def _process_image_input( + self, + image_input: LlavaNextImageInputs, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + if image_input["type"] == "image_embeds": + return [image_input["data"]] + + patch_embeddings = self._process_image_pixels(image_input) + + image_sizes = image_input.get("image_sizes") + if image_sizes is None: + batch_size = len(image_input["data"]) + vision_config = self.config.vision_config + default_height = default_width = vision_config.image_size + image_sizes = torch.as_tensor([[default_height, default_width] + for _ in range(batch_size)]) + + return [ + self._merge_image_patch_embeddings(image_sizes[i], + patch_features_batch, + strategy="spatial_unpad") + for i, patch_features_batch in enumerate(patch_embeddings) + ] + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + + if multimodal_embeddings is None: + return self.language_model.get_input_embeddings(input_ids) + + inputs_embeds = embed_multimodal( + input_ids, + self.config.image_token_index, + self.language_model.model.get_input_embeddings, + multimodal_embeddings, + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + """Run forward pass for LlaVA-NeXT. + + One key thing to understand is the `input_ids` already accounts for the + positions of the to-be-inserted image embeddings. + + Concretely, consider a text prompt: + `"A chat between a curious human and an artificial intelligence + assistant. The assistant gives helpful, detailed, and polite answers to + the human's questions. + USER: \\nWhat is shown in this image? ASSISTANT:"`. + + Tokenizer outputs: + `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255, + 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568, + 6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901, + 29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799, + 9047, 13566, 29901]`. + + To reserve space in KV cache, we have to insert placeholder tokens + before they are inputted to the model, so the input processor prepends + additional image tokens (denoted as `32000`), resulting in: + `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255, + 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568, + 6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901, + 29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, + 319, 1799, 9047, 13566, 29901]`. + + Unlike in LLaVA-1.5, the number of image tokens inputted to the language + model depends on the original size of the input image. Including the + original image token in the input, the required number of image tokens + is given by :func:`get_llava_next_image_feature_size`. + + This way, the `positions` and `attn_metadata` are consistent + with the `input_ids`. + + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + pixel_values: The pixels in each grid patch for each input image. + image_sizes: The original `(height, width)` for each input image. + + See also: + :class:`LlavaNextImageInputs` + """ + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.language_model.model(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py new file mode 100644 index 0000000..780af72 --- /dev/null +++ b/vllm/model_executor/models/llava_next_video.py @@ -0,0 +1,498 @@ +# SPDX-License-Identifier: Apache-2.0 + +import math +from collections.abc import Iterable, Mapping, Sequence +from functools import cached_property +from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union + +import torch +import torch.nn as nn +from transformers import (BatchFeature, LlavaNextVideoConfig, + LlavaNextVideoProcessor) + +from vllm.config import VllmConfig +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.models.clip import CLIPVisionModel +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, + VideoEmbeddingItems, VideoProcessorItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors +from vllm.utils import is_list_of + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .llava import init_vision_tower_for_llava +from .siglip import SiglipVisionModel +from .utils import (AutoWeightsLoader, init_vllm_registered_model, + maybe_prefix, merge_multimodal_embeddings) +from .vision import get_vision_encoder_info + + +class LlavaNextVideoPixelInputs(TypedDict): + type: Literal["pixel_values_videos"] + data: Union[torch.Tensor, List[torch.Tensor]] + """ + Shape: `(batch_size, num_frames, num_channels, height, width)` + + Note that `num_frames` may be different for each batch, in which case + the data is passed as a list instead of a batched tensor. + + Note that it only supports one video input for one batch. + """ + + +class LlavaNextVideoProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(LlavaNextVideoConfig) + + def get_vision_encoder_info(self): + return get_vision_encoder_info(self.get_hf_config()) + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(LlavaNextVideoProcessor, **kwargs) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"video": 1} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + target_width, target_height = self.get_image_size_with_most_features() + + max_video_tokens = self.get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=self.get_num_frames_with_most_features( + seq_len, mm_counts), + ) + + return {"video": max_video_tokens} + + def get_image_size_with_most_features(self) -> ImageSize: + vision_encoder_info = self.get_vision_encoder_info() + width = height = vision_encoder_info.get_image_size() + return ImageSize(width=width, height=height) + + def _get_num_frame_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + hf_config = self.get_hf_config() + spatial_pool_stride = hf_config.spatial_pool_stride + + vision_encoder_info = self.get_vision_encoder_info() + patch_grid_length = vision_encoder_info.get_patch_grid_length() + pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride) + + return pooled_grid_length * pooled_grid_length + + def get_num_video_tokens( + self, + *, + image_width: int, + image_height: int, + num_frames: int, + ) -> int: + num_frame_tokens = self._get_num_frame_tokens( + image_width=image_width, + image_height=image_height, + ) + + return num_frame_tokens * num_frames + + def _get_max_video_frames(self, max_tokens: int) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + num_frames = 0 + + while True: + next_num_frames = num_frames + 1 + next_max_tokens = self.get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=next_num_frames, + ) + + if next_max_tokens > max_tokens: + break + + num_frames = next_num_frames + + return num_frames + + def get_num_frames_with_most_features( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + max_videos = mm_counts.get("video", 0) + + max_total_frames = self._get_max_video_frames(seq_len) + + return max(max_total_frames // max(max_videos, 1), 1) + + +class LlavaNextVideoDummyInputsBuilder( + BaseDummyInputsBuilder[LlavaNextVideoProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_videos = mm_counts.get("video", 0) + + processor = self.info.get_hf_processor() + video_token = processor.video_token + + target_width, target_height = \ + self.info.get_image_size_with_most_features() + target_num_frames = \ + self.info.get_num_frames_with_most_features(seq_len, mm_counts) + + mm_data = { + "video": + self._get_dummy_videos( + width=target_width, + height=target_height, + num_frames=target_num_frames, + num_videos=num_videos, + ) + } + + return ProcessorInputs( + prompt_text=video_token * num_videos, + mm_data=mm_data, + ) + + +class LlavaNextVideoMultiModalProcessor( + BaseMultiModalProcessor[LlavaNextVideoProcessingInfo]): + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values_videos=MultiModalFieldConfig.batched("video")) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_config = self.info.get_hf_config() + video_token_id = hf_config.video_token_index + + def get_replacement(item_idx: int): + videos = mm_items.get_items( + "video", (VideoEmbeddingItems, VideoProcessorItems)) + + if isinstance(videos, VideoEmbeddingItems): + num_video_tokens = videos.get_feature_size(item_idx) + else: + image_size = videos.get_frame_size(item_idx) + num_video_tokens = self.info.get_num_video_tokens( + image_width=image_size.width, + image_height=image_size.height, + num_frames=videos.get_num_frames(item_idx), + ) + + return [video_token_id] * num_video_tokens + + return [ + PromptReplacement( + modality="video", + target=[video_token_id], + replacement=get_replacement, + ), + ] + + +# adopted from transformers modeling_llava_next_video.py +class LlavaNextVideoPooler(nn.Module): + + def __init__(self, config: LlavaNextVideoConfig): + super().__init__() + + mode = config.spatial_pool_mode + stride = config.spatial_pool_stride + image_size = config.vision_config.image_size + patch_size = config.vision_config.patch_size + self.image_size = image_size // patch_size**2 + + if mode == "average": + self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride) + elif mode == "max": + self.pool = nn.MaxPool2d(kernel_size=stride, stride=stride) + else: + # TODO: Support Conv2d pooling layer, need to load weights + raise ValueError( + f"Unknown pooling mode: {mode}. Expected [`average`, `max`]") + + def forward(self, image_features: torch.Tensor): + ori_width = int( + math.sqrt(image_features.shape[1] * self.image_size // + self.image_size)) + ori_height = int(ori_width * self.image_size // self.image_size) + + batch_size, _, dim = image_features.shape + image_features_spatial = image_features \ + .view(batch_size, ori_height, ori_height, dim) \ + .permute(0, 3, 1, 2) + image_features_spatial = self.pool(image_features_spatial) + + return image_features_spatial.flatten(2).transpose(1, 2).contiguous() + + +class LlavaNextMultiModalProjector(nn.Module): + + def __init__(self, vision_hidden_size: int, text_hidden_size: int, + projector_hidden_act: str, multimodal_projector_bias: bool): + super().__init__() + + self.linear_1 = nn.Linear(vision_hidden_size, + text_hidden_size, + bias=multimodal_projector_bias) + self.act = get_act_fn(projector_hidden_act) + self.linear_2 = nn.Linear(text_hidden_size, + text_hidden_size, + bias=multimodal_projector_bias) + + def forward(self, image_features: torch.Tensor) -> torch.Tensor: + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +@MULTIMODAL_REGISTRY.register_processor( + LlavaNextVideoMultiModalProcessor, + info=LlavaNextVideoProcessingInfo, + dummy_inputs=LlavaNextVideoDummyInputsBuilder, +) +class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + + # Initialize the vision tower only up to the required feature layer + self.vision_tower = init_vision_tower_for_llava( + config, + quant_config, + require_post_norm=False, + prefix=maybe_prefix(prefix, "vision_tower")) + self.vision_resampler = LlavaNextVideoPooler(config) + self.multi_modal_projector = LlavaNextMultiModalProjector( + vision_hidden_size=config.vision_config.hidden_size, + text_hidden_size=config.text_config.hidden_size, + projector_hidden_act=config.projector_hidden_act, + multimodal_projector_bias=config.multimodal_projector_bias) + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return get_sampler() + + def _validate_video_pixel_values( + self, data: Union[torch.Tensor, List[torch.Tensor]] + ) -> Union[torch.Tensor, List[torch.Tensor]]: + + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape[2:]) + + if actual_dims != expected_dims: + expected_expr = ("num_frames", *map(str, expected_dims)) + raise ValueError( + "The expected shape of pixel values in each video frame " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_video_input( + self, **kwargs: object) -> Optional[LlavaNextVideoPixelInputs]: + """ + A legal video input should have the following dimensions: + { + "pixel_values_videos" : + List[b, Tensor(nb_frames, nb_channels, height, width)] + } + """ + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + + if pixel_values_videos is None: + return None + + if not isinstance(pixel_values_videos, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel_values_videos. " + f"Got type: {type(pixel_values_videos)}") + + return LlavaNextVideoPixelInputs( + type="pixel_values_videos", + data=pixel_values_videos, + ) + + def _select_image_features(self, image_features: torch.Tensor, *, + strategy: str) -> torch.Tensor: + if strategy == "default": + return image_features[:, 1:] + elif strategy == "full": + return image_features + + raise ValueError(f"Unexpected select feature strategy: {strategy}") + + def _video_pixels_to_features( + self, + vision_tower: Union[CLIPVisionModel, SiglipVisionModel], + pixel_values: torch.Tensor, + ) -> torch.Tensor: + + # NOTE: we skip the step to select the vision feature layer since + # this is already done inside the vision tower + image_features = vision_tower(pixel_values) + image_features = self._select_image_features( + image_features, + strategy=self.config.vision_feature_select_strategy, + ) + image_features = self.vision_resampler(image_features) + image_features = self.multi_modal_projector(image_features) + return image_features + + def _process_video_pixels(self, inputs: LlavaNextVideoPixelInputs): + assert self.vision_tower is not None + + video_pixels = inputs["data"] + + if isinstance(video_pixels, torch.Tensor): + # TODO: support multiple videos per input + b, num_videos, num_frames, c, h, w = video_pixels.shape + assert (num_videos == 1) + stacked_pixels = video_pixels.view(b * num_videos * num_frames, c, + h, w) + stacked_embeddings = self._video_pixels_to_features( + self.vision_tower, stacked_pixels) + embeds = stacked_embeddings.view(b, num_frames, + *stacked_embeddings.shape[1:]) + + elif is_list_of(video_pixels, torch.Tensor): + frames_per_videos = [v.shape[0] for v in video_pixels] + stacked_pixels = torch.cat(video_pixels, dim=0) + stacked_embeddings = self._video_pixels_to_features( + self.vision_tower, stacked_pixels) + embeds = torch.split(stacked_embeddings, frames_per_videos, dim=0) + else: + raise ValueError( + f"Unsupported type of video input {type(video_pixels)}") + + return [e.flatten(0, 1) for e in embeds] + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + video_input = self._parse_and_validate_video_input(**kwargs) + if video_input is None: + return None + vision_embeddings = self._process_video_pixels(video_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.config.video_token_index) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + """Run forward pass for LlaVA-NeXT-Video. + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + pixel_values_videos: Pixels in each frames for each input videos. + """ + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.language_model.model(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader( + self, + # This model doesn't support images for now + ignore_unexpected_prefixes=["image_newline"], + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py new file mode 100644 index 0000000..c7e13bb --- /dev/null +++ b/vllm/model_executor/models/llava_onevision.py @@ -0,0 +1,976 @@ +# SPDX-License-Identifier: Apache-2.0 + +import math +from collections.abc import Iterable, Mapping, Sequence +from functools import cached_property +from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple, + TypedDict, Union) + +import torch +import torch.nn as nn +from transformers import (BatchFeature, LlavaOnevisionConfig, + LlavaOnevisionProcessor) +from transformers.models.llava_onevision.modeling_llava_onevision import ( + get_anyres_image_grid_shape, unpad_image) +from typing_extensions import NotRequired + +from vllm.config import VllmConfig +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, + VideoEmbeddingItems, VideoProcessorItems) +from vllm.multimodal.processing import PromptReplacement, PromptUpdate +from vllm.multimodal.profiling import ProcessorInputs +from vllm.sequence import IntermediateTensors + +from .clip import CLIPVisionModel +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .llava import LlavaDummyInputsBuilder, init_vision_tower_for_llava +from .llava_next import (BaseLlavaNextMultiModalProcessor, LlavaNextLikeConfig, + LlavaNextProcessingInfo) +from .siglip import SiglipVisionModel +from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, + maybe_prefix, merge_multimodal_embeddings) + +# For profile run +_MAX_FRAMES_PER_VIDEO = 16 + + +class LlavaOnevisionVideoPixelInputs(TypedDict): + type: Literal["pixel_values_videos"] + pixel_values_videos: Union[torch.Tensor, list[torch.Tensor]] + """ + Shape: `(batch_size * num_videos, num_frames, num_channels, height, width)` + + Note that `num_videos` may be different for each batch, and 'num_frames' + may be different for each video, in which case the data is passed as a + list instead of a batched tensor. + """ + + +class LlavaOnevisionImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + pixel_values: Union[torch.Tensor, list[torch.Tensor]] + """ + Shape: + `(batch_size * num_images, 1 + num_patches, num_channels, height, width)` + + Note that `num_patches` may be different per batch and image, + in which case the data is passed as a list instead of a batched tensor. + """ + + image_sizes: NotRequired[torch.Tensor] + """ + Shape: `(batch_size * num_images, 2)` + + This should be in `(height, width)` format. + """ + + +class LlavaOnevisionImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: torch.Tensor + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + """ + + +LlavaOnevisionImageInputs = Union[LlavaOnevisionImagePixelInputs, + LlavaOnevisionImageEmbeddingInputs] + +LlavaOnevisionMultiInputs = Union[LlavaOnevisionImageInputs, + LlavaOnevisionVideoPixelInputs] + + +class LlavaOnevisionLikeConfig(LlavaNextLikeConfig, Protocol): + video_token_index: Final[int] + + +class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo): + + def get_hf_config(self) -> LlavaOnevisionLikeConfig: + return self.ctx.get_hf_config(LlavaOnevisionConfig) + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(LlavaOnevisionProcessor, **kwargs) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None, "video": None} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return { + "image": self.get_max_image_tokens(), + "video": self.get_max_video_tokens(seq_len, mm_counts), + } + + # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86 + # with additional logic afterwards taken from LlavaOnevisionProcessor + def _get_num_unpadded_features( + self, + *, + original_height: int, + original_width: int, + npatches: int, + num_patch_height: int, + num_patch_width: int, + ) -> tuple[int, int]: + current_height = npatches * num_patch_height + current_width = npatches * num_patch_width + + aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + if aspect_ratio > current_aspect_ratio: + new_height = (original_height * current_width) // original_width + padding = (current_height - new_height) // 2 + current_height = current_height - (2 * padding) + else: + new_width = (original_width * current_height) // original_height + padding = (current_width - new_width) // 2 + current_width = current_width - (2 * padding) + + unpadded_features = current_height * current_width + newline_features = current_height + + ratio = math.sqrt(current_height * current_width / (9 * npatches**2)) + if ratio > 1.1: + height_factor = int(current_height // ratio) + width_factor = int(current_width // ratio) + unpadded_features = height_factor * width_factor + newline_features = height_factor + + return (unpadded_features, newline_features) + + def get_image_size_with_most_features(self) -> ImageSize: + # NOTE: This hardcoded value is found via processor tests + return ImageSize(width=1153, height=944) + + def _get_num_frame_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + hf_config = self.get_hf_config() + spatial_pool_stride = getattr(hf_config, "spatial_pool_stride", 2) + + vision_encoder_info = self.get_vision_encoder_info() + patch_grid_length = vision_encoder_info.get_patch_grid_length() + pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride) + + return pooled_grid_length * pooled_grid_length + + def get_num_video_tokens( + self, + *, + image_width: int, + image_height: int, + num_frames: int, + ) -> int: + num_frame_tokens = self._get_num_frame_tokens( + image_width=image_width, + image_height=image_height, + ) + + return num_frame_tokens * num_frames + 1 # Newline token + + def _get_max_video_frames(self, max_tokens: int) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + num_frames = 0 + + while True: + next_num_frames = num_frames + 1 + next_max_tokens = self.get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=next_num_frames, + ) + + if next_max_tokens > max_tokens: + break + + num_frames = next_num_frames + + return num_frames + + def get_num_frames_with_most_features( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + max_images = mm_counts.get("image", 0) + max_videos = mm_counts.get("video", 0) + + max_image_tokens = self.get_max_image_tokens() * max_images + max_total_frames = self._get_max_video_frames(seq_len - + max_image_tokens) + max_frames_per_video = min(max_total_frames // max(max_videos, 1), + _MAX_FRAMES_PER_VIDEO) + + return max(max_frames_per_video, 1) + + def get_max_video_tokens( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + return self.get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=self.get_num_frames_with_most_features( + seq_len, mm_counts), + ) + + +class LlavaOnevisionDummyInputsBuilder( + LlavaDummyInputsBuilder[LlavaOnevisionProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + processor = self.info.get_hf_processor() + image_token = processor.image_token + video_token = processor.video_token + + target_width, target_height = \ + self.info.get_image_size_with_most_features() + target_num_frames = \ + self.info.get_num_frames_with_most_features(seq_len, + mm_counts) + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + "video": + self._get_dummy_videos( + width=target_width, + height=target_height, + num_frames=target_num_frames, + num_videos=num_videos, + ) + } + + return ProcessorInputs( + prompt_text=image_token * num_images + video_token * num_videos, + mm_data=mm_data, + ) + + +class LlavaOnevisionMultiModalProcessor( + BaseLlavaNextMultiModalProcessor[LlavaOnevisionProcessingInfo]): + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + image_sizes=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.batched("video"), + ) + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + mm_data = dict(mm_data) + videos = mm_data.pop("videos", []) + assert isinstance(videos, list) + + if not videos: + return super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + # LLaVA-OneVision processor doesn't support multiple videos + # with different sizes when converting back to tensors + # So, we process each component separately + # NOTE: No prompt replacement is applied in this case + processor = self.info.get_hf_processor() + image_token = processor.image_token + video_token = processor.video_token + + text_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data={}, + mm_kwargs=mm_kwargs, + ) + + images = mm_data.pop("images", []) + assert isinstance(images, list) + if images: + processor_outputs = super()._call_hf_processor( + prompt=image_token * len(images), + mm_data={"images": images}, + mm_kwargs=mm_kwargs, + ) + image_outputs = { + k: v + for k, v in processor_outputs.items() + if k in ("pixel_values", "image_sizes") + } + else: + image_outputs = {} + + pixel_values_videos = [] + for video in videos: + item_outputs = super()._call_hf_processor( + prompt=video_token, + mm_data={"videos": video}, + mm_kwargs=mm_kwargs, + ) + + pixel_values_videos.append(item_outputs["pixel_values_videos"][0]) + + video_outputs = {"pixel_values_videos": pixel_values_videos} + + combined_outputs = dict( + text_outputs, + **image_outputs, + **video_outputs, + ) + return BatchFeature(combined_outputs) + + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> bool: + base_result = super()._hf_processor_applies_updates( + prompt_text=prompt_text, + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + + return base_result and mm_items.get_count("video", strict=False) == 0 + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + image_repls = super()._get_prompt_updates( + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + out_mm_kwargs=out_mm_kwargs, + ) + + hf_config = self.info.get_hf_config() + video_token_id = hf_config.video_token_index + + def get_video_replacement(item_idx: int): + videos = mm_items.get_items( + "video", (VideoEmbeddingItems, VideoProcessorItems)) + + if isinstance(videos, VideoEmbeddingItems): + num_video_tokens = videos.get_feature_size(item_idx) + else: + image_size = videos.get_frame_size(item_idx) + num_video_tokens = self.info.get_num_video_tokens( + image_width=image_size.width, + image_height=image_size.height, + num_frames=videos.get_num_frames(item_idx), + ) + + return [video_token_id] * num_video_tokens + + return [ + *image_repls, + PromptReplacement( + modality="video", + target=[video_token_id], + replacement=get_video_replacement, + ), + ] + + +class LlavaOnevisionMultiModalProjector(nn.Module): + + def __init__(self, config: LlavaOnevisionConfig): + super().__init__() + + self.linear_1 = nn.Linear(config.vision_config.hidden_size, + config.text_config.hidden_size, + bias=config.multimodal_projector_bias) + self.act = get_act_fn(config.projector_hidden_act) + self.linear_2 = nn.Linear(config.text_config.hidden_size, + config.text_config.hidden_size, + bias=config.multimodal_projector_bias) + + def forward(self, image_features: torch.Tensor) -> torch.Tensor: + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +@MULTIMODAL_REGISTRY.register_processor( + LlavaOnevisionMultiModalProcessor, + info=LlavaOnevisionProcessingInfo, + dummy_inputs=LlavaOnevisionDummyInputsBuilder) +class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + + # Initialize the vision tower only up to the required feature layer + self.vision_tower = init_vision_tower_for_llava( + config, + quant_config, + require_post_norm=False, + prefix=maybe_prefix(prefix, "vision_tower")) + self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config) + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + self.image_newline = nn.Parameter( + torch.empty(config.text_config.hidden_size)) + + self.make_empty_intermediate_tensors = ( + self.language_model.model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return get_sampler() + + def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: + expected_dims = (2, ) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape) + + if actual_dims != expected_dims: + expected_expr = str(expected_dims) + raise ValueError( + f"The expected shape of image sizes per image per batch " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _validate_image_pixel_values( + self, data: Union[torch.Tensor, List[torch.Tensor]] + ) -> Union[torch.Tensor, List[torch.Tensor]]: + + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("num_patches", *map(str, expected_dims)) + raise ValueError( + "The expected shape of pixel values per image per batch " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[LlavaOnevisionImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_sizes = kwargs.pop("image_sizes", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + if not isinstance(image_sizes, (torch.Tensor, list)): + raise ValueError("Incorrect type of image sizes. " + f"Got type: {type(image_sizes)}") + + return LlavaOnevisionImagePixelInputs( + type="pixel_values", + pixel_values=self._validate_image_pixel_values( + flatten_bn(pixel_values)), + image_sizes=self._validate_image_sizes( + flatten_bn(image_sizes, concat=True)), + ) + + if image_embeds is not None: + if not isinstance(image_embeds, torch.Tensor): + raise ValueError("Incorrect type of image embeds. " + f"Got type: {type(image_embeds)}") + + return LlavaOnevisionImageEmbeddingInputs( + type="image_embeds", + data=flatten_bn(image_embeds), + ) + + raise AssertionError("This line should be unreachable.") + + def _validate_video_pixel_values( + self, data: Union[torch.Tensor, List[torch.Tensor]] + ) -> Union[torch.Tensor, List[torch.Tensor]]: + + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape[2:]) + + if actual_dims != expected_dims: + expected_expr = ("num_frames", *map(str, expected_dims)) + raise ValueError( + "The expected shape of pixel values in each video frame " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_video_input( + self, + **kwargs: object) -> Optional[LlavaOnevisionVideoPixelInputs]: + """ + A legal video input should have the following dimensions: + { + "pixel_values_videos" : + List[b, Tensor(nb_frames, nb_channels, height, width)] + } + """ + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + if pixel_values_videos is None: + return None + + if not isinstance(pixel_values_videos, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel_values_videos. " + f"Got type: {type(pixel_values_videos)}") + + return LlavaOnevisionVideoPixelInputs( + type="pixel_values_videos", + pixel_values_videos=flatten_bn(pixel_values_videos), + ) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("pixel_values", + "image_embeds") and "images" not in modalities: + modalities["images"] = self._parse_and_validate_image_input( + **kwargs) + if input_key in ("pixel_values_videos", + "video_embeds") and "videos" not in modalities: + modalities["videos"] = self._parse_and_validate_video_input( + **kwargs) + + return modalities + + def _select_image_features(self, image_features: torch.Tensor, *, + strategy: str) -> torch.Tensor: + if strategy == "default": + return image_features[:, 1:] + elif strategy == "full": + return image_features + + raise ValueError(f"Unexpected select feature strategy: {strategy}") + + def _image_pixels_to_features( + self, + vision_tower: Union[CLIPVisionModel, SiglipVisionModel], + pixel_values: torch.Tensor, + ) -> torch.Tensor: + + # NOTE: we skip the step to select the vision feature layer since + # this is already done inside the vision tower + image_features = vision_tower(pixel_values) + return self._select_image_features( + image_features, + strategy=self.config.vision_feature_select_strategy, + ) + + # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py + def _merge_image_patch_embeddings(self, + image_size: torch.Tensor, + patch_embeddings: torch.Tensor, + *, + image_newline=None, + vision_aspect_ratio="anyres_max_9", + strategy: str) -> torch.Tensor: + if strategy == "flat": + return patch_embeddings.flatten(0, 1) + + if strategy.startswith("spatial"): + height = width = self.config.vision_config.image_size \ + // self.config.vision_config.patch_size + + base_patch_embeds = patch_embeddings[0] + if height * width != base_patch_embeds.shape[0]: + raise ValueError( + "The number of patches is not consistent with the " + "image size.") + + if patch_embeddings.shape[0] > 1: + other_patch_embeds = patch_embeddings[1:] + + # Move to CPU to avoid floating-point errors + orig_height, orig_width = image_size.tolist() + + # image_aspect_ratio == "anyres" + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + (orig_height, orig_width), + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + num_patches = num_patch_height * num_patch_width + + # Image patches might be padded for batch processing + other_patch_embeds = other_patch_embeds[:num_patches] \ + .view(num_patch_height, num_patch_width, height, width, -1) + + if "unpad" in strategy: + other_patch_embeds = other_patch_embeds \ + .permute(4, 0, 2, 1, 3).contiguous() \ + .flatten(1, 2).flatten(2, 3) + other_patch_embeds = unpad_image(other_patch_embeds, + (orig_height, orig_width)) + max_num_patches = int( + vision_aspect_ratio.removeprefix("anyres_max_")) + channels, curr_height, curr_width = other_patch_embeds.shape + ratio = math.sqrt(curr_height * curr_width / + (max_num_patches * height**2)) + if ratio > 1.1: + other_patch_embeds = other_patch_embeds[None] + other_patch_embeds = nn.functional.interpolate( + other_patch_embeds, [ + int(curr_height // ratio), + int(curr_width // ratio) + ], + mode="bilinear")[0] + if image_newline is not None: + other_patch_embeds = torch.cat( + ( + other_patch_embeds, + image_newline[:, None, None] \ + .expand(*other_patch_embeds.shape[:-1], 1) \ + .to(other_patch_embeds.device), + ), + dim=-1) + other_patch_embeds = other_patch_embeds \ + .flatten(1, 2).transpose(0, 1) + else: + other_patch_embeds = other_patch_embeds \ + .permute(0, 2, 1, 3, 4).contiguous() \ + .flatten(0, 3) + + merged_patch_embeddings = torch.cat( + (base_patch_embeds, other_patch_embeds), dim=0) + else: + if "unpad" in strategy: + merged_patch_embeddings = torch.cat( + (base_patch_embeds, + self.image_newline[None] \ + .to(base_patch_embeds.device) + ), dim=0) + else: + merged_patch_embeddings = base_patch_embeds + + return merged_patch_embeddings + + raise ValueError(f"Unexpected patch merge strategy: {strategy}") + + def _process_image_pixels( + self, + inputs: LlavaOnevisionImagePixelInputs, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + assert self.vision_tower is not None + + pixel_values = inputs["pixel_values"] + + if isinstance(pixel_values, torch.Tensor): + b, num_patches, c, h, w = pixel_values.shape + stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w) + stacked_image_features = self._image_pixels_to_features( + self.vision_tower, stacked_pixel_values) + stacked_patch_embeddings = self.multi_modal_projector( + stacked_image_features) + + return stacked_patch_embeddings.view( + b, num_patches, *stacked_patch_embeddings.shape[1:]) + + num_patches_per_batch = [v.shape[0] for v in pixel_values] + stacked_pixel_values = torch.cat(pixel_values) + stacked_image_features = self._image_pixels_to_features( + self.vision_tower, stacked_pixel_values) + + return [ + self.multi_modal_projector(image_features) for image_features in + torch.split(stacked_image_features, num_patches_per_batch) + ] + + def _process_image_input( + self, + image_input: LlavaOnevisionImageInputs, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + if image_input["type"] == "image_embeds": + return [image_input["data"]] + + patch_embeddings = self._process_image_pixels(image_input) + + image_sizes = image_input.get("image_sizes") + if image_sizes is None: + batch_size = len(image_input["pixel_values"]) + vision_config = self.config.vision_config + default_height = default_width = vision_config.image_size + image_sizes = torch.as_tensor([[default_height, default_width] + for _ in range(batch_size)]) + + return [ + self._merge_image_patch_embeddings( + image_sizes[i], + patch_features_batch, + image_newline=self.image_newline, + strategy="spatial_unpad") + for i, patch_features_batch in enumerate(patch_embeddings) + ] + + def _video_pixels_to_features( + self, + vision_tower: Union[CLIPVisionModel, SiglipVisionModel], + pixel_values: torch.Tensor, + ) -> torch.Tensor: + + # NOTE: we skip the step to select the vision feature layer since + # this is already done inside the vision tower + video_features = vision_tower(pixel_values) + video_features = self._select_image_features( + video_features, + strategy=self.config.vision_feature_select_strategy, + ) + video_features = self.multi_modal_projector(video_features) + video_features = self.apply_pooling(video_features) + return video_features + + def _process_video_pixels(self, inputs: LlavaOnevisionVideoPixelInputs): + assert self.vision_tower is not None + + video_pixels = inputs["pixel_values_videos"] + + if isinstance(video_pixels, torch.Tensor): + total_videos, frames, c, h, w = video_pixels.shape + video_pixels_flat = video_pixels.view(total_videos * frames, c, h, + w) + + embeddings_flat = self._video_pixels_to_features( + self.vision_tower, video_pixels_flat) + + embeddings_flat = embeddings_flat.reshape( + total_videos, frames * embeddings_flat.shape[1], -1) + + image_newline = self.image_newline[None, None, :].expand( + total_videos, -1, -1) + return torch.cat((embeddings_flat, image_newline), dim=1) + + frames_per_video = [len(video) for video in video_pixels] + video_pixels_flat = torch.cat(video_pixels) + + embeddings_flat = self._video_pixels_to_features( + self.vision_tower, video_pixels_flat) + + image_newline = self.image_newline[None, None, :] + + return [ + torch.cat( + ( + embeds.reshape(1, num_frame * embeddings_flat.shape[1], + -1), + image_newline, + ), + dim=1, + ) for num_frame, embeds in zip( + frames_per_video, + torch.split(embeddings_flat, frames_per_video), + ) + ] + + def apply_pooling(self, image_features: torch.Tensor, stride: int = 2): + vision_config = self.config.vision_config + height = width = vision_config.image_size // vision_config.patch_size + batch_frames, _, dim = image_features.shape + image_features = image_features.view(batch_frames, height, width, -1) + image_features = image_features.permute(0, 3, 1, 2) + + # TODO support other pooling types config + height, width = image_features.shape[2:] + scaled_shape = [math.ceil(height / stride), math.ceil(width / stride)] + image_feature = nn.functional.interpolate(image_features, + size=scaled_shape, + mode='bilinear') + image_feature = image_feature.permute(0, 2, 3, 1) + image_feature = image_feature.view(batch_frames, -1, dim) + return image_feature + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: + return None + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in modalities: + if modality == "images": + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + multimodal_embeddings += tuple(vision_embeddings) + if modality == "videos": + video_input = modalities["videos"] + video_embeddings = self._process_video_pixels(video_input) + multimodal_embeddings += tuple(video_embeddings) + + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + [self.config.image_token_index, self.config.video_token_index]) + return inputs_embeds + + def get_input_embeddings_v0( + self, + input_ids: torch.Tensor, + image_input: Optional[LlavaOnevisionImagePixelInputs] = None, + video_input: Optional[LlavaOnevisionVideoPixelInputs] = None, + ) -> torch.Tensor: + inputs_embeds = self.get_input_embeddings(input_ids) + if image_input is not None: + image_embeds = self._process_image_input(image_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + image_embeds, + placeholder_token_id=self.config.image_token_index, + ) + + if video_input is not None: + video_embeds = self._process_video_pixels(video_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + video_embeds, + placeholder_token_id=self.config.video_token_index, + ) + + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + """Run forward pass for LlaVA-Onevision. + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + pixel_values_videos: Pixels in each frames for each input videos. + """ + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner from + # `get_multimodal_embeddings` and `get_input_embeddings`, this + # condition is only for v0 compatibility. + elif inputs_embeds is None: + image_input = self._parse_and_validate_image_input(**kwargs) + video_input = self._parse_and_validate_video_input(**kwargs) + + if image_input is None and video_input is None: + inputs_embeds = None + else: + inputs_embeds = self.get_input_embeddings_v0( + input_ids, + image_input=image_input, + video_input=video_input) + input_ids = None + + hidden_states = self.language_model.model(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py new file mode 100644 index 0000000..7a525ad --- /dev/null +++ b/vllm/model_executor/models/mamba.py @@ -0,0 +1,276 @@ +# SPDX-License-Identifier: Apache-2.0 +"""PyTorch MAMBA model.""" +from typing import Iterable, Optional, Set, Tuple + +import torch +from torch import nn +from transformers import MambaConfig + +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_pp_group +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import (HasInnerState, + IsAttentionFree, SupportsPP, + SupportsV0Only) +from vllm.model_executor.models.mamba_cache import (MambaCacheManager, + MambaCacheParams) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.utils import LayerBlockType + +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class MambaDecoderLayer(nn.Module): + + def __init__(self, + config: MambaConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + is_lora_enabled: Optional[bool] = False) -> None: + super().__init__() + self.config = config + self.is_falcon_mamba = config.model_type == "falcon_mamba" + self.is_lora_enabled = is_lora_enabled + mixer_rms_eps = config.mixer_rms_eps if self.is_falcon_mamba else None + self.mixer = MambaMixer(hidden_size=config.hidden_size, + ssm_state_size=config.state_size, + conv_kernel_size=config.conv_kernel, + intermediate_size=config.intermediate_size, + time_step_rank=config.time_step_rank, + use_conv_bias=config.use_conv_bias, + use_bias=config.use_bias, + use_rms_norm=self.is_falcon_mamba, + rms_norm_has_weight=not self.is_falcon_mamba, + rms_norm_eps=mixer_rms_eps, + activation=config.hidden_act, + is_lora_enabled=self.is_lora_enabled) + + self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + mamba_cache_params: MambaCacheParams, + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.norm(hidden_states) + else: + hidden_states, residual = self.norm(hidden_states, residual) + + hidden_states = self.mixer(hidden_states, mamba_cache_params) + return hidden_states, residual + + +class MambaModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + is_lora_enabled = bool(lora_config) + + self.config = config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embeddings = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: MambaDecoderLayer(config, + cache_config=cache_config, + quant_config=quant_config, + is_lora_enabled=is_lora_enabled), + prefix=f"{prefix}.layers") + + self.norm_f = RMSNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + mamba_cache_params: MambaCacheParams, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + mamba_cache_params=mamba_cache_params.at_layer_idx( + i - self.start_layer)) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm_f(hidden_states, residual) + + return hidden_states + + +class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP, + SupportsV0Only): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + self.scheduler_config = vllm_config.scheduler_config + assert not cache_config.enable_prefix_caching, \ + "Mamba does not support prefix caching" + + super().__init__() + self.config = config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.backbone = MambaModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "backbone")) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + if config.tie_word_embeddings: + self.lm_head = self.backbone.embeddings + else: + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + + # Used to track and store by the Mamba cache between steps. + self.mamba_cache: Optional[MambaCacheManager] = None + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = get_sampler() + + self.make_empty_intermediate_tensors = ( + self.backbone.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.backbone.get_input_embeddings(input_ids) + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs): + if self.mamba_cache is None: + num_mamba_layers = self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, LayerBlockType.mamba) + self.mamba_cache = MambaCacheManager( + self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, + *self._get_mamba_cache_shape()) + + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + + hidden_states = self.backbone(input_ids, positions, mamba_cache_params, + intermediate_tensors, inputs_embeds) + + return hidden_states + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def _get_mamba_cache_shape( + self) -> Tuple[Tuple[int, int], Tuple[int, int]]: + world_size = get_tensor_model_parallel_world_size() + conv_state_shape = ( + self.config.intermediate_size // world_size, + self.config.conv_kernel - 1, + ) + temporal_state_shape = ( + self.config.intermediate_size // world_size, + self.config.state_size, + ) + return conv_state_shape, temporal_state_shape + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "A_log" in name: + name = name.replace("A_log", "A") + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py new file mode 100644 index 0000000..da5cbdd --- /dev/null +++ b/vllm/model_executor/models/mamba2.py @@ -0,0 +1,318 @@ +# SPDX-License-Identifier: Apache-2.0 +"""PyTorch MAMBA2 model.""" +from typing import Iterable, Optional, Set, Tuple + +import torch +from torch import nn +from transformers import MambaConfig + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import VllmConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_pp_group +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba_mixer2 import ( + MambaMixer2, extra_groups_for_head_shards) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import (HasInnerState, + IsAttentionFree, + SupportsV0Only) +from vllm.model_executor.models.mamba_cache import (MambaCacheManager, + MambaCacheParams) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.utils import LayerBlockType + +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class Mamba2DecoderLayer(nn.Module): + + def __init__(self, + config: MambaConfig, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + self.config = config + self.mixer = MambaMixer2(hidden_size=config.hidden_size, + ssm_state_size=config.state_size, + conv_kernel_size=config.conv_kernel, + intermediate_size=getattr( + config, "intermediate_size", + config.expand * config.hidden_size), + use_conv_bias=config.use_conv_bias, + use_bias=config.use_bias, + n_groups=config.n_groups, + num_heads=config.num_heads, + head_dim=config.head_dim, + rms_norm_eps=config.layer_norm_epsilon, + activation=config.hidden_act, + chunk_size=config.chunk_size, + quant_config=quant_config) + + self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + mamba_cache_params: MambaCacheParams, + sequence_idx: Optional[torch.Tensor], + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.norm(hidden_states) + else: + hidden_states, residual = self.norm(hidden_states, residual) + + hidden_states = self.mixer(hidden_states, mamba_cache_params, + sequence_idx) + return hidden_states, residual + + +class Mamba2Model(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + is_lora_enabled = bool(lora_config) + assert not is_lora_enabled + + self.config = config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embeddings = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Mamba2DecoderLayer(config, + quant_config=quant_config), + prefix=f"{prefix}.layers") + + self.norm_f = RMSNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + mamba_cache_params: MambaCacheParams, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + # pass a sequence index tensor, that is required for + # proper continuous batching computation including + # chunked prefill + seq_idx = None + attn_metadata: AttentionMetadata = get_forward_context().attn_metadata + if attn_metadata.num_prefills > 0: + seq_idx = torch.zeros_like(input_ids, dtype=torch.int32) + for i, (srt, end) in enumerate( + zip( + attn_metadata.query_start_loc, + attn_metadata.query_start_loc[1:], + )): + seq_idx[srt:end] = i + seq_idx.unsqueeze_(0) + + for i in range(len(self.layers)): + layer = self.layers[i] + + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + mamba_cache_params=mamba_cache_params.at_layer_idx( + i - self.start_layer), + sequence_idx=seq_idx) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm_f(hidden_states, residual) + + return hidden_states + + +class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree, + SupportsV0Only): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config + assert not cache_config.enable_prefix_caching, \ + "Mamba does not support prefix caching" + + super().__init__() + self.config = config + self.vllm_config = vllm_config + self.scheduler_config = scheduler_config + self.model_config = vllm_config.model_config + self.backbone = Mamba2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "backbone")) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + if config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights(self.backbone.embeddings) + + # Used to track and store by the Mamba cache between steps. + self.mamba_cache: Optional[MambaCacheManager] = None + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = get_sampler() + + self.make_empty_intermediate_tensors = ( + self.backbone.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.backbone.get_input_embeddings(input_ids) + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs): + if self.mamba_cache is None: + num_mamba_layers = self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, LayerBlockType.mamba) + self.mamba_cache = MambaCacheManager( + self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, + *self._get_mamba_cache_shape()) + + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + + hidden_states = self.backbone(input_ids, positions, mamba_cache_params, + intermediate_tensors, inputs_embeds) + + return hidden_states + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def _get_mamba_cache_shape( + self) -> Tuple[Tuple[int, int], Tuple[int, int]]: + world_size = get_tensor_model_parallel_world_size() + + conv_state_shape, temporal_state_shape = None, None + + intermediate_size = getattr( + self.config, "intermediate_size", + self.config.expand * self.config.hidden_size) + + # if n_groups is not divisible by world_size, need to extend the shards + # to ensure all groups needed by a head is sharded along with it + n_groups = ( + self.config.n_groups + + extra_groups_for_head_shards(self.config.n_groups, world_size)) + + # - heads and n_groups are TP-ed + conv_dim = (intermediate_size + 2 * n_groups * self.config.state_size) + conv_state_shape = ( + divide(conv_dim, world_size), + self.config.conv_kernel - 1, + ) + + # These are not TP-ed as they depend on A, dt_bias, D + # - they are typically small + # e.g., (h_heads, d_head, d_state) = (128, 64, 128) + temporal_state_shape = ( + divide(self.config.num_heads, world_size), + self.config.head_dim, + self.config.state_size, + ) + return conv_state_shape, temporal_state_shape + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "A_log" in name: + name = name.replace("A_log", "A") + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/mamba_cache.py b/vllm/model_executor/models/mamba_cache.py new file mode 100644 index 0000000..2583972 --- /dev/null +++ b/vllm/model_executor/models/mamba_cache.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import Tuple + +import torch + +from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.config import VllmConfig +from vllm.model_executor.models.constant_size_cache import ConstantSizeCache + + +@dataclass +class MambaCacheParams: + conv_state: torch.Tensor = torch.Tensor() + ssm_state: torch.Tensor = torch.Tensor() + state_indices_tensor: torch.Tensor = torch.Tensor() + + def at_layer_idx(self, layer_idx): + return MambaCacheParams(self.conv_state[layer_idx], + self.ssm_state[layer_idx], + self.state_indices_tensor) + + +class MambaCacheManager(ConstantSizeCache): + + def __init__(self, vllm_config: VllmConfig, dtype: torch.dtype, + num_mamba_layers: int, conv_state_shape: Tuple[int, int], + temporal_state_shape: Tuple[int, int]): + + # Determine max batch size to set size of MambaCache + max_batch_size = vllm_config.scheduler_config.max_num_seqs + if not vllm_config.model_config.enforce_eager: + max_batch_size = vllm_config.pad_for_cudagraph(max_batch_size) + + # Initialize parent class + super().__init__(max_batch_size) + + conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) + + conv_state_shape, + dtype=dtype, + device="cuda") + temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) + + temporal_state_shape, + dtype=dtype, + device="cuda") + + self._mamba_cache = (conv_state, temporal_state) + + @property + def cache(self): + return self._mamba_cache + + def _copy_cache(self, from_index: int, to_index: int): + for cache_t in self.cache: + cache_t[:, to_index].copy_(cache_t[:, from_index], + non_blocking=True) + + def current_run_tensors(self, **kwargs) -> MambaCacheParams: + """ + Return the tensors for the current run's conv and ssm state. + """ + cache_tensors, state_indices_tensor = super().current_run_tensors( + **kwargs) + return MambaCacheParams(cache_tensors[0], cache_tensors[1], + state_indices_tensor) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + """ + Provide the CUDA graph capture runs with a buffer in adjusted size. + The buffer is used to maintain the Mamba Cache during the CUDA graph + replay runs. + """ + return self._mamba_cache, torch.as_tensor([PAD_SLOT_ID] * batch_size, + dtype=torch.int32, + device="cuda") diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py new file mode 100644 index 0000000..a19d7da --- /dev/null +++ b/vllm/model_executor/models/medusa.py @@ -0,0 +1,210 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Iterable, List, Optional, Set, Tuple + +import torch +import torch.nn as nn + +from vllm.config import VllmConfig +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata + + +class ResidualBlock(nn.Module): + + def __init__(self, config: VllmConfig, hidden_size: int, + num_layers: int) -> None: + super().__init__() + + self.layers = nn.ModuleList([ + nn.Linear(hidden_size, + hidden_size, + bias=getattr(config, "medusa_fc_bias", False)) + for _ in range(num_layers) + ]) + self.act = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for layer in self.layers: + x = x + self.act(layer(x)) + return x + + +class Medusa(nn.Module): + """This class implements the Medusa draft model from the paper: https://arxiv.org/abs/2401.10774 + Reference implementation: https://github.com/FasterDecoding/Medusa + + Differences from reference implementation: + 1. Currently this only supports generating proposals from top-1 tokens. + 2. We have an optional token_map which reduces draft vocab to most + frequently used tokens to give some additional speed-up by reducing + sampling overhead. This is disabled unless the checkpoint file has + explicit token_map tensor and config has an optional attribute + truncated_vocab_size < vocab_size. To use this technique, one has to find + the top-k most frequent tokens in target dataset and add that as a tensor + in the draft checkpoint (using key token_map). Also, the draft config + needs to have truncated_vocab_size (=k) as an attribute.""" + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + config = vllm_config.model_config.hf_config + super().__init__() + self.config = config + self.blocks = nn.ModuleList([ + ResidualBlock(config=config, + hidden_size=self.config.hidden_size, + num_layers=self.config.num_hidden_layers) + for _ in range(self.config.num_heads) + ]) + self.orig_vocab_size = config.vocab_size + self.truncated_vocab_size = config.truncated_vocab_size + self.unpadded_vocab_size = self.truncated_vocab_size + + if getattr(config, "original_lm_head", False): + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=self.truncated_vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + ) + self.lm_heads = [ + self.lm_head for _ in range(self.config.num_heads) + ] + else: + self.lm_heads = nn.ModuleList([ + ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=self.truncated_vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + ) for _ in range(self.config.num_heads) + ]) + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + self.truncated_vocab_size, + logit_scale) + + # Token map is a idx to token mapping to reduce the vocab size for + # the draft model. Using smaller vocab size for draft, containing + # only most frequent tokens reduces the speculation overhead. This + # doesn't affect the acceptance rate much and thus gives more speed + # -up. By default, this is disabled and is only used if the EAGLE + # checkpoint file has token_map tensor. + self.token_map = None + + def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]: + return [block(hidden_states) for block in self.blocks] + + def compute_logits( + self, hidden_states: List[torch.Tensor], + sampling_metadata: SamplingMetadata) -> List[torch.Tensor]: + logits_lst: List[torch.Tensor] = [] + + for hs, lm_head in zip(hidden_states, self.lm_heads): + _logits = self.logits_processor(lm_head, hs, sampling_metadata) + + if _logits is None: + # _logits should only be None on rank > 0, in which case + # it should remain true for every lm_head + assert len(logits_lst) == 0 + continue + + if self.token_map is None: + logits_lst.append(_logits) + else: + logits_lst.append(-torch.inf * torch.ones( + size=(*_logits.shape[:-1], self.orig_vocab_size), + device=_logits.device, + dtype=_logits.dtype)) + + logits_lst[-1][..., self.token_map] = _logits + + return logits_lst + + def sample( + self, + logits: List[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> List[SamplerOutput]: + logits = torch.stack(logits, dim=0).float() + logprobs = torch.log_softmax(logits, dim=-1) + token_ids = logits.argmax(-1) # support only top-1 for now + probs = torch.softmax(logits, dim=-1) + + token_id_list = [] + token_prob_list = [] + token_logprob_list = [] + + for idx, seq_group in enumerate(sampling_metadata.seq_groups): + token_id_list.append(token_ids[:, seq_group.sample_indices]) + token_prob_list.append(probs[:, seq_group.sample_indices]) + token_logprob_list.append(logprobs[:, seq_group.sample_indices]) + + outputs: List[Optional[SamplerOutput]] = [] + for idx in range(len(sampling_metadata.seq_groups)): + outputs.append( + SamplerOutput( + outputs=None, + sampled_token_probs=token_prob_list[idx].squeeze(1), + logprobs=token_logprob_list[idx].squeeze(1), + sampled_token_ids=token_id_list[idx].squeeze(1), + )) + + return outputs + + def generate_proposals( + self, + previous_hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> List[SamplerOutput]: + return self.sample( + logits=self.compute_logits( + hidden_states=self.forward(previous_hidden_states), + sampling_metadata=sampling_metadata, + ), + sampling_metadata=sampling_metadata, + ) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + + weights_map = {} + + for name, loaded_weight in weights: + name = name.replace("medusa_heads.", "") + + if name == "token_map": + if self.truncated_vocab_size < self.orig_vocab_size: + self.token_map = nn.Parameter(loaded_weight, + requires_grad=False) + elif name in params_dict: + weights_map[name] = loaded_weight + elif (getattr(self.config, "original_lm_head", False) + and name == "lm_heads.0.weight"): + weights_map["lm_head.weight"] = loaded_weight + + for name, loaded_weight in weights_map.items(): + if "lm_head" in name and self.token_map is not None and\ + loaded_weight.shape[0] > self.token_map.shape[0]: + + loaded_weight = loaded_weight[self.token_map] + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + if self.token_map is not None: + self.token_map.to(device=self.lm_heads[0].weight.device) + + assert (self.truncated_vocab_size + == self.orig_vocab_size) or (self.token_map is not None) + + return loaded_params diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py new file mode 100644 index 0000000..48c7ce3 --- /dev/null +++ b/vllm/model_executor/models/minicpm.py @@ -0,0 +1,617 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only MiniCPM model compatible with HuggingFace weights.""" +import math +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.activation import FatreluAndMul, SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class MiniCPMMoE(nn.Module): + """A tensor-parallel MoE implementation that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + tp_size: Optional[int] = None, + ): + super().__init__() + self.tp_size = tp_size or get_tensor_model_parallel_world_size() + self.num_total_experts = num_experts + self.top_k = top_k + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size // self.tp_size + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + self.gate = ReplicatedLinear(self.hidden_size, + self.num_total_experts, + bias=False, + params_dtype=self.params_dtype, + quant_config=None) + + self.ws = nn.Parameter( + torch.empty(self.num_total_experts, + 2 * self.intermediate_size, + self.hidden_size, + device=current_platform.device_type, + dtype=self.params_dtype)) + self.w2s = nn.Parameter( + torch.empty(self.num_total_experts, + self.hidden_size, + self.intermediate_size, + device=current_platform.device_type, + dtype=self.params_dtype)) + + set_weight_attrs(self.ws, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.w2s, { + "weight_loader": self.weight_loader, + }) + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, + weight_name: str, expert_id: int): + tp_rank = get_tensor_model_parallel_rank() + param_data = param.data + shard_size = self.intermediate_size + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + if weight_name.endswith("w1.weight"): + param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w3.weight"): + param_data[expert_id, + shard_size:2 * shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w2.weight"): + param_data[expert_id, :, :] = loaded_weight[:, shard] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = fused_moe(hidden_states, + self.ws, + self.w2s, + router_logits, + self.top_k, + renormalize=True, + inplace=True) + + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(num_tokens, hidden_size) + + +class MiniCPMMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + hidden_act_param: float, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + if hidden_act == "silu": + self.act_fn = SiluAndMul() + elif hidden_act == "fatrelu": + self.act_fn = FatreluAndMul(threshold=hidden_act_param) + else: + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu and fatrelu are supported for now.") + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class MiniCPMAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + # set rope as fp32 instead of bf16 + self.rotary_emb.cos_sin_cache = self.rotary_emb._compute_cos_sin_cache( + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + orig_dtype = q.dtype + q, k = q.float(), k.float() + q, k = self.rotary_emb(positions, q, k) + q, k = q.to(orig_dtype), k.to(orig_dtype) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class MiniCPMDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.cache_config = cache_config + self.quant_config = quant_config + self.hidden_size = config.hidden_size + self.rope_theta = getattr(config, "rope_theta", 10000) + self.rope_scaling = getattr(config, "rope_scaling", None) + self.max_position_embeddings = getattr(config, + "max_position_embeddings", 8192) + self.prefix = prefix + self._init_attn_block() + self._init_ffn_block() + self.hidden_scale = self.config.scale_depth / math.sqrt(self.config.num_hidden_layers) + + def _init_attn_block(self): + self.input_layernorm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) + self.self_attn = MiniCPMAttention( + hidden_size=self.hidden_size, + num_heads=self.config.num_attention_heads, + num_kv_heads=self.config.num_key_value_heads, + rope_theta=self.rope_theta, + rope_scaling=self.rope_scaling, + max_position_embeddings=self.max_position_embeddings, + cache_config=self.cache_config, + quant_config=self.quant_config, + prefix=f"{self.prefix}.self_attn", + ) + + def _init_ffn_block(self): + self.post_attention_layernorm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) + self.num_experts = getattr(self.config, "num_experts", 0) + if self.num_experts == 0: + self.mlp = MiniCPMMLP( + hidden_size=self.hidden_size, + intermediate_size=self.config.intermediate_size, + hidden_act=self.config.hidden_act, + hidden_act_param=getattr(self.config, "hidden_act_param", 0.), + quant_config=self.quant_config, + ) + else: + self.mlp = MiniCPMMoE( + num_experts=self.config.num_experts, + top_k=self.config.num_experts_per_tok, + hidden_size=self.config.hidden_size, + intermediate_size=self.config.intermediate_size) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(residual, hidden_states, self.hidden_scale) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + hidden_states, residual = self.post_attention_layernorm(residual, hidden_states, self.hidden_scale) + + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +@support_torch_compile +class MiniCPMModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.cache_config = cache_config + self.quant_config = quant_config + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + self.num_experts = getattr(self.config, "num_experts", 0) + self._init_layers(prefix, config, cache_config, quant_config) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], self.config.hidden_size)) + self.k = None + if self.layers[self.start_layer].self_attn.rotary_emb.__class__.__name__ == "Phi3LongRoPEScaledRotaryEmbedding": + self.k = self.layers[self.start_layer].self_attn.rotary_emb.original_max_position_embeddings + + def _init_layers( + self, + prefix: str, + config: PretrainedConfig, + cache_config: Optional[CacheConfig], + quant_config: Optional[QuantizationConfig], + ): + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: MiniCPMDecoderLayer( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers") + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + embedding = self.embed_tokens(input_ids) + return embedding * self.config.scale_emb + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + if self.k is not None: + long_offset = torch.any(positions > self.k) + else: + long_offset = None + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(residual, hidden_states, self.layers[self.start_layer].hidden_scale) + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + expert_params_mapping = [ + # (param_name, weight_name, expert_id) + ("ws" if weight_name in ["w1", "w3"] else "w2s", + f"experts.{expert_id}.{weight_name}.weight", expert_id) + for expert_id in range(self.num_experts) + for weight_name in ["w1", "w2", "w3"] + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, expert_id in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + weight_name, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.prefix = prefix + self.vllm_config = vllm_config + self.config = config + self.lora_config = lora_config + self.cache_config = cache_config + self.quant_config = quant_config + + self.model = self._init_model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + unpadded_vocab_size = config.vocab_size + if lora_config: + unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) + self.scale_width = self.config.hidden_size / self.config.dim_model_base + + self.logits_processor = LogitsProcessor(unpadded_vocab_size, + config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""): + return MiniCPMModel(vllm_config=vllm_config, prefix=prefix) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + hidden_states = hidden_states / self.scale_width + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + result = loader.load_weights(weights) + for _, layer in self.model.named_modules(): + if layer.__class__.__name__ == "MiniCPM3Attention": + dtype = layer.kv_a_proj_with_mqa.weight.dtype + if dtype not in [torch.half, torch.bfloat16, torch.int8]: + assert False, print(f"run on dtype:{dtype} is not supported.") + layer.q_a_proj.weight.data = torch.cat([layer.q_a_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=0) + del layer.kv_a_proj_with_mqa.weight + return result diff --git a/vllm/model_executor/models/minicpm3.py b/vllm/model_executor/models/minicpm3.py new file mode 100644 index 0000000..69f6bea --- /dev/null +++ b/vllm/model_executor/models/minicpm3.py @@ -0,0 +1,214 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2024 The ModelBest team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only MiniCPM3 model compatible with HuggingFace weights.""" +from typing import Any, Dict, Optional + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.models.minicpm import (MiniCPMDecoderLayer, + MiniCPMForCausalLM, + MiniCPMModel) + +from .utils import make_layers +import ixformer.inference.functions as ops + + +class MiniCPM3Attention(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: int, + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.num_heads = num_heads + + tp_size = get_tensor_model_parallel_world_size() + assert self.num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size + + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.q_a_proj = ReplicatedLinear(self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config) + self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear(q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config) + + self.kv_a_proj_with_mqa = ReplicatedLinear(self.hidden_size, + self.kv_lora_rank + + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config) + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config) + # O projection. + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config) + + self.rotary_emb = get_rope( + self.qk_rope_head_dim, + rotary_dim=self.qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.k = self.rotary_emb.original_max_position_embeddings + self.attn = Attention(self.num_local_heads, + self.qk_head_dim, + self.scaling, + num_kv_heads=self.num_local_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + long_offset: torch.Tensor, + ) -> torch.Tensor: + q_latent_kpe, _ = self.q_a_proj(hidden_states) + q, kv_a, k_pe = q_latent_kpe.split([self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim], dim=1) + + q = self.q_a_layernorm(q) + q, _ = self.q_b_proj(q) + q = q.view(-1, self.num_local_heads, self.qk_head_dim) + _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + kv_a = self.kv_a_layernorm(kv_a) + kv, _ = self.kv_b_proj(kv_a) + kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v_nope = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.empty_like(q) + v = torch.empty_like(q) + ops.mla_rope_phi(positions, q_pe, k_pe, k[...,self.qk_nope_head_dim:], self.rotary_emb.long_short_cos_sin_cache, long_offset, self.k) + ops.mla_copy_kv(k_nope, v_nope, k, v) + + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = attn_output.view( + -1, self.num_local_heads, + self.qk_head_dim)[..., :self.v_head_dim].reshape( + -1, self.num_local_heads * self.v_head_dim) + + output, _ = self.o_proj(attn_output) + return output + + +class MiniCPM3DecoderLayer(MiniCPMDecoderLayer): + + def _init_attn_block(self): + self.input_layernorm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) + self.self_attn = MiniCPM3Attention( + config=self.config, + hidden_size=self.hidden_size, + num_heads=self.config.num_attention_heads, + qk_nope_head_dim=self.config.qk_nope_head_dim, + qk_rope_head_dim=self.config.qk_rope_head_dim, + v_head_dim=self.config.v_head_dim, + q_lora_rank=self.config.q_lora_rank, + kv_lora_rank=self.config.kv_lora_rank, + rope_theta=self.rope_theta, + rope_scaling=self.rope_scaling, + max_position_embeddings=self.max_position_embeddings, + cache_config=self.cache_config, + quant_config=self.quant_config, + prefix=f"{self.prefix}.self_attn", + ) + + +class MiniCPM3Model(MiniCPMModel): + + def _init_layers( + self, + prefix: str, + config: PretrainedConfig, + cache_config: Optional[CacheConfig], + quant_config: Optional[QuantizationConfig], + ): + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: MiniCPM3DecoderLayer( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers") + + +class MiniCPM3ForCausalLM(MiniCPMForCausalLM): + packed_modules_mapping = { + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""): + return MiniCPM3Model(vllm_config=vllm_config, prefix=prefix) diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py new file mode 100644 index 0000000..c74e086 --- /dev/null +++ b/vllm/model_executor/models/minicpmo.py @@ -0,0 +1,790 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only MiniCPM-O model compatible with HuggingFace weights.""" +from collections.abc import Iterable, Mapping, Sequence +from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict, + Union) + +import torch +from torch import nn +from transformers import BatchFeature +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.whisper.modeling_whisper import ( + ACT2FN, WHISPER_ATTENTION_CLASSES, WhisperConfig, WhisperEncoder) + +from vllm.config import VllmConfig +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors +from vllm.multimodal.parse import (AudioItem, AudioProcessorItems, + DictEmbeddingItems, ModalityData, + ModalityDataItems, MultiModalDataItems, + MultiModalDataParser) +from vllm.multimodal.processing import PromptReplacement, PromptUpdate +from vllm.multimodal.profiling import ProcessorInputs + +from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6, + MiniCPMVDummyInputsBuilder, + MiniCPMVMultiModalDataParser, + MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo, + _minicpmv_field_config) +from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn, + maybe_prefix) +from .vision import scatter_patch_features + +CPU_DEVICE = torch.device("cpu") + + +class MiniCPMOAudioFeatureInputs(TypedDict): + type: Literal["audio_features"] + audio_features: Union[torch.Tensor, list[torch.Tensor]] + """ + Shape: `(batch_size * num_audios * num_slices, num_channels, length)` + Slice here means chunk. Audio that is too long will be split into slices, + which is the same as image. + Padding is used therefore `audio_features` is `torch.Tensor`. + """ + + audio_feature_lens: Union[torch.Tensor, list[torch.Tensor]] + """ + Shape: `(batch_size * num_audios, num_slices)` + + This should be feature length of each audio slice, + which equals to `audio_features.shape[-1]` + """ + + embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] + """ + A boolean mask indicating which audio embeddings correspond + to patch tokens. + + Shape: `(batch_size * num_audios, num_embeds)` + """ + + +class MiniCPMOAudioEmbeddingInputs(TypedDict): + type: Literal["audio_embeds"] + audio_embeds: Union[torch.Tensor, list[torch.Tensor]] + """ + Shape: `(batch_size * num_audios, num_slices, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + instead of a batched tensor. + Length of each slice may vary, so pass it as a list. + """ + + embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] + """ + A boolean mask indicating which audio embeddings correspond + to patch tokens. + + Shape: `(batch_size * num_audios, num_embeds)` + """ + + +MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs, + MiniCPMOAudioEmbeddingInputs] + + +def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]): + audio_features = hf_inputs.get("audio_features", torch.empty(0)) + num_audios = len(audio_features) + + return dict( + **_minicpmv_field_config(hf_inputs), + audio_features=MultiModalFieldConfig.batched("audio"), + audio_feature_lens=MultiModalFieldConfig.batched("audio"), + audio_embeds=MultiModalFieldConfig.batched("audio"), + audio_embed_is_patch=MultiModalFieldConfig.batched("audio"), + audio_token_id=MultiModalFieldConfig.shared("audio", num_audios), + ) + + +class MiniCPMOAudioEmbeddingItems(DictEmbeddingItems): + + def __init__( + self, + data: Mapping[str, torch.Tensor], + fields_factory: Callable[ + [Mapping[str, torch.Tensor]], + Mapping[str, MultiModalFieldConfig], + ], + ) -> None: + super().__init__( + data, + modality="image", + required_fields={"audio_embeds"}, + fields_factory=fields_factory, + ) + + +class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser): + + def _parse_audio_data( + self, + data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]], + ) -> ModalityDataItems[Any, Any]: + if isinstance(data, dict): + return MiniCPMOAudioEmbeddingItems( + data, + fields_factory=_minicpmo_field_config, + ) + + return super()._parse_audio_data(data) + + +class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): + audio_pattern = "()" + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {**super().get_supported_mm_limits(), "audio": None} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return { + **super().get_mm_max_tokens_per_item(seq_len, mm_counts), + "audio": + self.get_max_audio_tokens(), + } + + def get_audio_placeholder( + self, + audio_lens: int, + chunk_input: bool = True, + chunk_length: int = 1, + ) -> str: + hf_processor = self.get_hf_processor() + + return hf_processor.get_audio_placeholder( + audio_lens, + chunk_input=chunk_input, + chunk_length=chunk_length, + ) + + def get_default_audio_pool_step(self) -> int: + return 2 + + def get_default_audio_sampling_rate(self) -> int: + return 16000 + + def get_chunk_length(self) -> int: + return self.get_hf_config().audio_chunk_length + + def get_max_audio_tokens_per_chunk(self) -> int: + pool_step = self.get_default_audio_pool_step() + fbank_feat_in_chunk = 100 + cnn_feat_in_chunk = (fbank_feat_in_chunk - 1) // 2 + 1 + num_audio_tokens = (cnn_feat_in_chunk - pool_step) // pool_step + 1 + return num_audio_tokens + 2 # + + def get_max_audio_chunks_with_most_features(self) -> int: + return 30 + + def get_max_audio_tokens(self) -> int: + num_chunks = self.get_max_audio_chunks_with_most_features() + return self.get_max_audio_tokens_per_chunk() * num_chunks + + def get_audio_len_by_num_chunks(self, num_chunks: int) -> int: + sampling_rate = self.get_default_audio_sampling_rate() + # exclude + num_tokens_per_chunk = self.get_max_audio_tokens_per_chunk() - 2 + return int(num_chunks * sampling_rate / num_tokens_per_chunk) + 1 + + def get_num_frames_with_most_features( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + max_images = mm_counts.get("image", 0) + max_videos = mm_counts.get("video", 0) + max_audios = mm_counts.get("audio", 0) + + max_image_tokens = self.get_max_image_tokens() * max_images + max_audio_tokens = self.get_max_audio_tokens() * max_audios + max_total_frames = self.get_max_video_frames(seq_len - + max_image_tokens - + max_audio_tokens) + max_frames_per_video = min(max_total_frames // max(max_videos, 1), + _MAX_FRAMES_PER_VIDEO) + + return max(max_frames_per_video, 1) + + +class MiniCPMODummyInputsBuilder( + MiniCPMVDummyInputsBuilder[MiniCPMOProcessingInfo]): + + def get_dummy_processor_inputs( + self, seq_len: int, mm_counts: Mapping[str, + int]) -> ProcessorInputs: + num_audios = mm_counts.get("audio", 0) + audio_len = self.info.get_max_audio_chunks_with_most_features() * \ + self.info.get_default_audio_sampling_rate() + + processor_inputs = super().get_dummy_processor_inputs( + seq_len, mm_counts) + + audio_prompt_texts = self.info.audio_pattern * num_audios + audio_mm_data = { + "audio": + self._get_dummy_audios(length=audio_len, num_audios=num_audios) + } + + return ProcessorInputs( + prompt_text=processor_inputs.prompt_text + audio_prompt_texts, + mm_data={ + **processor_inputs.mm_data, + **audio_mm_data, + }, + ) + + +class MiniCPMOMultiModalProcessor( + MiniCPMVMultiModalProcessor[MiniCPMOProcessingInfo]): + + def _get_data_parser(self) -> MultiModalDataParser: + return MiniCPMOMultiModalDataParser( + target_sr=self.info.get_default_audio_sampling_rate()) + + def get_audio_prompt_texts( + self, + audio_lens: int, + chunk_input: bool = True, + chunk_length: int = 1, + ) -> str: + return self.info.get_audio_placeholder( + audio_lens, + chunk_input=chunk_input, + chunk_length=chunk_length, + ) + + def process_audios( + self, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> Mapping[str, NestedTensors]: + if (audios := mm_data.get("audios")) is None: + return {} + + parsed_audios = (self._get_data_parser().parse_mm_data({ + "audio": audios + }).get_items("audio", + (MiniCPMOAudioEmbeddingItems, AudioProcessorItems))) + + if isinstance(parsed_audios, MiniCPMOAudioEmbeddingItems): + audio_inputs = {} + + audio_lens = [ + self.info.get_audio_len_by_num_chunks( + sum(map(len, + parsed_audios.get(i)["audio_embeds"]))) + for i in range(len(parsed_audios)) + ] + else: + audio_inputs = self._base_call_hf_processor( + prompts=[self.info.audio_pattern] * len(parsed_audios), + mm_data={"audios": [[audio] for audio in parsed_audios]}, + mm_kwargs={ + **mm_kwargs, + "chunk_input": True, + }, + out_keys={"audio_features", "audio_feature_lens"}, + ) + + # Avoid padding since we need the output for each audio to be + # independent of other audios for the cache to work correctly + unpadded_audio_features = [ + feat[:, :feature_len] for feat, feature_len in zip( + audio_inputs["audio_features"], + audio_inputs["audio_feature_lens"], + ) + ] + audio_inputs["audio_features"] = unpadded_audio_features + + audio_lens = [ + parsed_audios.get_audio_length(i) + for i in range(len(parsed_audios)) + ] + + audio_repl_features = [ + self.get_audio_prompt_texts(audio_len) for audio_len in audio_lens + ] + + tokenizer = self.info.get_tokenizer() + audio_repls_feature_tokens = [ + tokenizer.encode(audio_repl, add_special_tokens=False) + for audio_repl in audio_repl_features + ] + + embed_is_patch = [ + self.get_embed_is_patch(audio_repl_tokens) + for audio_repl_tokens in audio_repls_feature_tokens + ] + audio_inputs["audio_embed_is_patch"] = embed_is_patch + + unk_token_id = tokenizer.get_vocab()[""] + audio_inputs["audio_token_id"] = torch.tensor(unk_token_id) + + return audio_inputs + + def process_mm_inputs( + self, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> Mapping[str, NestedTensors]: + return { + **super().process_mm_inputs(mm_data, mm_kwargs), + **self.process_audios(mm_data, mm_kwargs), + } + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + base_updates = super()._get_prompt_updates( + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + out_mm_kwargs=out_mm_kwargs, + ) + + audio_placeholder = self.info.audio_pattern + + def get_audio_replacement(item_idx: int): + audios = mm_items.get_items( + "audio", (MiniCPMOAudioEmbeddingItems, AudioProcessorItems)) + + if isinstance(audios, MiniCPMOAudioEmbeddingItems): + single_audio_embeds = audios.get(item_idx)["audio_embeds"] + audio_len = self.info.get_audio_len_by_num_chunks( + sum(map(len, single_audio_embeds))) + else: + audio_len = audios.get_audio_length(item_idx) + + return self.get_audio_prompt_texts(audio_len) + + return [ + *base_updates, + PromptReplacement(modality="audio", + target=audio_placeholder, + replacement=get_audio_replacement), + ] + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return _minicpmo_field_config(hf_inputs) + + +class MultiModalProjector(nn.Module): + + def __init__(self, in_dim: int, out_dim: int): + super().__init__() + self.linear1 = nn.Linear(in_features=in_dim, + out_features=out_dim, + bias=True) + self.relu = nn.ReLU() + self.linear2 = nn.Linear(in_features=out_dim, + out_features=out_dim, + bias=True) + + def forward(self, audio_features: torch.Tensor) -> torch.Tensor: + hidden_states = self.relu(self.linear1(audio_features)) + hidden_states = self.linear2(hidden_states) + return hidden_states + + +class MiniCPMWhisperEncoderLayer(nn.Module): + + def __init__(self, config: WhisperConfig, layer_idx: int): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = WHISPER_ATTENTION_CLASSES[ + config._attn_implementation]( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + layer_idx=layer_idx, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + ) -> torch.Tensor: + residual = hidden_states + past_key_values = None + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights, past_key_values = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_value=past_key_values, + ) + hidden_states = nn.functional.dropout(hidden_states, + p=self.dropout, + training=self.training) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, + p=self.activation_dropout, + training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, + p=self.dropout, + training=self.training) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16: + hidden_states = cast_overflow_tensors(hidden_states) + + outputs = (hidden_states, ) + + return outputs + + +class MiniCPMWhisperEncoder(WhisperEncoder): + + def __init__(self, config: WhisperConfig): + super().__init__(config) + self.layers = nn.ModuleList([ + MiniCPMWhisperEncoderLayer(config, layer_idx=i) + for i in range(config.encoder_layers) + ]) + + def forward( + self, + input_features: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> BaseModelOutputWithPast: + # Ignore copy + input_features = input_features.to(dtype=self.conv1.weight.dtype, + device=self.conv1.weight.device) + + inputs_embeds = nn.functional.gelu(self.conv1(input_features)) + inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) + + inputs_embeds = inputs_embeds.permute(0, 2, 1) + + embed_pos = self.embed_positions.weight + + embed_pos = embed_pos[:inputs_embeds.shape[1], :] + + hidden_states = inputs_embeds + embed_pos + hidden_states = nn.functional.dropout(hidden_states, + p=self.dropout, + training=self.training) + + encoder_states = () + + for idx, encoder_layer in enumerate(self.layers): + encoder_states = encoder_states + (hidden_states, ) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + # Ignore copy + if to_drop: + layer_outputs = (None, None) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + ) + + hidden_states = layer_outputs[0] + + hidden_states = self.layer_norm(hidden_states) + encoder_states = encoder_states + (hidden_states, ) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + hidden_states=encoder_states, + ) + + +@MULTIMODAL_REGISTRY.register_processor( + MiniCPMOMultiModalProcessor, + info=MiniCPMOProcessingInfo, + dummy_inputs=MiniCPMODummyInputsBuilder) +class MiniCPMO(MiniCPMV2_6): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + self.apm = self.init_audio_module(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "apm")) + + self.audio_token_id = None + + def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""): + # Do not use parameters temporarily + audio_config = self.config.audio_config + model = MiniCPMWhisperEncoder(audio_config) + audio_output_dim = int(audio_config.encoder_ffn_dim // 4) + self.audio_avg_pooler = \ + nn.AvgPool1d(self.config.audio_pool_step, + stride=self.config.audio_pool_step) + self.audio_projection_layer = \ + MultiModalProjector(in_dim=audio_output_dim,out_dim=self.embed_dim) + self.audio_encoder_layer = -1 + return model + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader(self, skip_prefixes=["tts"]) + return loader.load_weights(weights) + + def subsequent_chunk_mask( + self, + size: int, + chunk_size: int, + num_left_chunks: int = -1, + device: torch.device = CPU_DEVICE, + num_lookhead: int = 0, + ) -> torch.Tensor: + ret = torch.zeros(size, size, device=device, dtype=torch.bool) + for i in range(size): + if num_left_chunks < 0: + start = 0 + else: + start = max((i // chunk_size - num_left_chunks) * chunk_size, + 0) + ending = min((i // chunk_size + 1) * chunk_size + num_lookhead, + size) + ret[i, start:ending] = True + return ret + + def _get_feat_extract_output_lengths(self, + input_lengths: torch.LongTensor): + input_lengths_after_cnn = (input_lengths - 1) // 2 + 1 + input_lengths_after_pooling = ( + input_lengths_after_cnn - + self.config.audio_pool_step) // self.config.audio_pool_step + 1 + input_lengths_after_pooling = input_lengths_after_pooling.to( + dtype=torch.int32) + + return input_lengths_after_cnn, input_lengths_after_pooling + + def get_audio_hidden_states( + self, data: MiniCPMOAudioFeatureInputs) -> list[torch.Tensor]: + chunk_length = self.config.audio_chunk_length + + # (bs, 80, frames) or [], multi audios need filled in advance + wavforms_raw = data["audio_features"] + if isinstance(wavforms_raw, list): + B = len(wavforms_raw) + C = wavforms_raw[0].shape[-2] + L = max(item.shape[-1] for item in wavforms_raw) + device = wavforms_raw[0].device + dtype = wavforms_raw[0].dtype + + wavforms = torch.zeros((B, C, L), dtype=dtype, device=device) + for i, wavforms_item in enumerate(wavforms_raw): + L_item = wavforms_item.shape[-1] + wavforms[i, ..., :L_item] = wavforms_item + else: + wavforms = wavforms_raw + + # list, [[x1, x2], [y1], [z1]] + audio_feature_lens_raw = data["audio_feature_lens"] + if isinstance(audio_feature_lens_raw, torch.Tensor): + audio_feature_lens_raw = audio_feature_lens_raw.unbind(0) + + audio_feature_lens = torch.hstack(audio_feature_lens_raw) + batch_size, _, max_mel_seq_len = wavforms.shape + max_seq_len = (max_mel_seq_len - 1) // 2 + 1 + + # Create a sequence tensor of shape (batch_size, max_seq_len) + seq_range = (torch.arange( + 0, + max_seq_len, + dtype=audio_feature_lens.dtype, + device=audio_feature_lens.device).unsqueeze(0).expand( + batch_size, max_seq_len)) + lengths_expand = audio_feature_lens.unsqueeze(1).expand( + batch_size, max_seq_len) + # Create mask + padding_mask = seq_range >= lengths_expand # 1 for padded values + + audio_attention_mask_ = padding_mask.view( + batch_size, 1, 1, max_seq_len).expand(batch_size, 1, max_seq_len, + max_seq_len) + audio_attention_mask = audio_attention_mask_.to( + dtype=self.apm.conv1.weight.dtype, + device=self.apm.conv1.weight.device) + + if chunk_length > 0: + chunk_num_frame = int(chunk_length * 50) + chunk_mask = self.subsequent_chunk_mask( + size=max_seq_len, + chunk_size=chunk_num_frame, + num_left_chunks=-1, + device=audio_attention_mask_.device, + ) + audio_attention_mask_ = torch.logical_or( + audio_attention_mask_, torch.logical_not(chunk_mask)) + + audio_attention_mask[audio_attention_mask_] = float("-inf") + audio_states = self.apm( + wavforms, attention_mask=audio_attention_mask).hidden_states[ + self.audio_encoder_layer] + audio_embeds = self.audio_projection_layer(audio_states) + + audio_embeds = audio_embeds.transpose(1, 2) + audio_embeds = self.audio_avg_pooler(audio_embeds) + audio_embeds = audio_embeds.transpose(1, 2) + + _, feature_lens_after_pooling = \ + self._get_feat_extract_output_lengths(audio_feature_lens) + + num_audio_tokens = feature_lens_after_pooling + + final_audio_embeds = list[torch.Tensor]() + idx = 0 + for i in range(len(audio_feature_lens_raw)): + target_audio_embeds_lst = list[torch.Tensor]() + for _ in range(len(audio_feature_lens_raw[i])): + target_audio_embeds_lst.append( + audio_embeds[idx, :num_audio_tokens[idx], :]) + idx += 1 + + final_audio_embeds.append(torch.cat(target_audio_embeds_lst)) + + return final_audio_embeds + + def _parse_and_validate_audio_input( + self, **kwargs: object) -> Optional[MiniCPMOAudioInputs]: + audio_features = kwargs.pop("audio_features", None) + audio_embeds = kwargs.pop("audio_embeds", None) + + if audio_features is None and audio_embeds is None: + return None + + audio_token_id = kwargs.pop("audio_token_id") + if audio_token_id is not None: + assert isinstance(audio_token_id, torch.Tensor) + self.mm_token_ids.add(audio_token_id.flatten().unique().item()) + + audio_embed_is_patch = kwargs.pop("audio_embed_is_patch") + if not isinstance(audio_embed_is_patch, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio_embed_is_patch. " + f"Got type: {type(audio_embed_is_patch)}") + + audio_embed_is_patch = flatten_bn(audio_embed_is_patch) + + if audio_embeds is not None: + if not isinstance(audio_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio_embeds. " + f"Got type: {type(audio_embeds)}") + + audio_embeds_flat = flatten_bn(audio_embeds) + + return MiniCPMOAudioEmbeddingInputs( + type="audio_embeds", + audio_embeds=audio_embeds_flat, + embed_is_patch=audio_embed_is_patch, + ) + + if not isinstance(audio_features, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio_features. " + f"Got type: {type(audio_features)}") + + audio_feature_lens = kwargs.pop("audio_feature_lens") + if not isinstance(audio_feature_lens, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio_feature_lens. " + f"Got type: {type(audio_feature_lens)}") + + audio_features_flat = flatten_bn(audio_features) + audio_feature_lens_flat = flatten_bn(audio_feature_lens) + + return MiniCPMOAudioFeatureInputs( + type="audio_features", + audio_features=audio_features_flat, + audio_feature_lens=audio_feature_lens_flat, + embed_is_patch=audio_embed_is_patch, + ) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = super()._parse_and_validate_multimodal_inputs(**kwargs) + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("audio_features", + "audio_embeds") and "audios" not in modalities: + modalities["audios"] = self._parse_and_validate_audio_input( + **kwargs) + + return modalities + + def _process_audio_input( + self, + audio_input: MiniCPMOAudioInputs, + ) -> Union[torch.Tensor, list[torch.Tensor]]: + if audio_input["type"] == "audio_embeds": + return audio_input["audio_embeds"] + + return self.get_audio_hidden_states(audio_input) + + def _process_multimodal_inputs(self, modalities: dict): + multimodal_embeddings = super()._process_multimodal_inputs(modalities) + + for modality in modalities: + if modality == "audios": + audio_input = modalities["audios"] + audio_features = self._process_audio_input(audio_input) + multimodal_embeddings += tuple( + scatter_patch_features( + audio_features, + audio_input["embed_is_patch"], + )) + + return multimodal_embeddings diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py new file mode 100644 index 0000000..5fab9df --- /dev/null +++ b/vllm/model_executor/models/minicpmv.py @@ -0,0 +1,1369 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only MiniCPM-V model compatible with HuggingFace weights.""" +import math +from collections import defaultdict +from collections.abc import Iterable, Mapping, Sequence +from functools import cached_property, partial +from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict, + Union) + +import numpy as np +import torch +import torch.types +from torch import nn +from transformers import BatchFeature, PretrainedConfig +from typing_extensions import TypeVar + +from vllm.config import VllmConfig +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2, + get_2d_sincos_pos_embed) +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.model_loader.utils import set_default_torch_dtype +from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.model_executor.models.minicpm import MiniCPMForCausalLM +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors +from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, + ImageProcessorItems, ImageSize, + ModalityData, ModalityDataItems, + MultiModalDataItems, MultiModalDataParser, + VideoItem, VideoProcessorItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors +from vllm.utils import flatten_2d_lists + +from .idefics2_vision_model import Idefics2VisionTransformer +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsPP) +from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, + merge_multimodal_embeddings) +from .vision import scatter_patch_features, select_patch_features + +# For profile run +_MAX_FRAMES_PER_VIDEO = 16 + + +class MiniCPMVImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + pixel_values: list[torch.Tensor] + """ + Shape: `(batch_size * num_images * num_slices, num_channels, height, width)` + + Note that the image size may vary, so we pass it as a list + instead of a batched tensor. + """ + + tgt_sizes: torch.Tensor + """ + Shape: `(batch_size * num_images * num_slices, 2)` + + This should be in `(height, width)` format. + """ + + embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] + """ + A boolean mask indicating which image embeddings correspond + to patch tokens. + + Shape: `(batch_size * num_images, num_embeds)` + """ + + num_slices: torch.Tensor + """Shape: `(batch_size * num_images)`""" + + +class MiniCPMVImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + image_embeds: Union[torch.Tensor, list[torch.Tensor]] + """ + Shape: `(batch_size * num_images, num_slices, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + instead of a batched tensor. + """ + + embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] + """ + A boolean mask indicating which image embeddings correspond + to patch tokens. + + Shape: `(batch_size * num_images, num_embeds)` + """ + + +MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs, + MiniCPMVImageEmbeddingInputs] + +DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) + + +class Resampler2_5(BaseResampler): + + def __init__(self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + max_size: Tuple[int, int] = (70, 70), + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__(num_queries, + embed_dim, + num_heads, + kv_dim, + norm_layer, + quant_config=quant_config, + prefix=prefix) + + self.max_size = max_size + self._set_2d_pos_cache(self.max_size) + + def _set_2d_pos_cache(self, + max_size: Tuple[int, int], + device: torch.types.Device = "cpu") -> None: + pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim, + max_size, + version=(2, 5)) + pos_embed = torch.from_numpy(pos_embed_arr).float().to(device) + self.register_buffer("pos_embed", pos_embed, persistent=False) + + def _adjust_pos_cache(self, tgt_sizes: torch.Tensor, + device: torch.types.Device) -> None: + max_h = tgt_sizes[:, 0].max().item() + max_w = tgt_sizes[:, 1].max().item() + assert isinstance(max_h, int) and isinstance(max_w, int) + + if max_h > self.max_size[0] or max_w > self.max_size[1]: + self.max_size = ( + max(max_h, self.max_size[0]), + max(max_w, self.max_size[1]), + ) + self._set_2d_pos_cache(self.max_size, device) + + def forward(self, x: torch.Tensor, + tgt_sizes: torch.Tensor) -> torch.Tensor: + assert x.shape[0] == tgt_sizes.shape[0] + bs = x.shape[0] + + device = x.device + dtype = x.dtype + + patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] + + self._adjust_pos_cache(tgt_sizes, device=device) + + max_patch_len = patch_len.max().item() + assert isinstance(max_patch_len, int) + + key_padding_mask = torch.zeros((bs, max_patch_len), + dtype=torch.bool, + device=device) + + pos_embed = [] + for i in range(bs): + tgt_h, tgt_w = tgt_sizes[i].tolist() + pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape( + (tgt_h * tgt_w, -1)).to(dtype)) # patches * D + key_padding_mask[i, patch_len[i]:] = True + pos_embed = torch.nn.utils.rnn.pad_sequence(pos_embed, + batch_first=True, + padding_value=0.0).permute( + 1, 0, + 2) # BLD => L * B * D + x, _ = self.kv_proj(x) # B * L * D + x = self.ln_kv(x).permute(1, 0, 2) # L * B * D + + q = self.ln_q(self.query) # Q * D + + out = self.attn( + self._repeat(q, bs), # Q * B * D + x + pos_embed, # L * B * D + L * B * D + x, + key_padding_mask=key_padding_mask, + )[0] + # out: Q * B * D + x = out.permute(1, 0, 2) # B * Q * D + + x = self.ln_post(x) + x = x @ self.proj + return x + + +def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]: + version_float = getattr(config, "version", None) + + # The old configs do not include version number + # TODO: Remove this after the HF repos are updated + if version_float is None: + if config.hidden_size == 2304 and config.query_num == 64: + return (2, 0) + return (2, 5) + version_str = str(version_float) + return tuple(int(x) for x in version_str.split(".")) + + +def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]): + pixel_values = hf_inputs.get("pixel_values", torch.empty(0)) + num_images = len(pixel_values) + + video_pixel_values = hf_inputs.get("video_pixel_values", torch.empty(0)) + num_videos = len(video_pixel_values) + + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + image_sizes=MultiModalFieldConfig.batched("image"), + tgt_sizes=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + embed_is_patch=MultiModalFieldConfig.batched("image"), + video_pixel_values=MultiModalFieldConfig.batched("video"), + video_image_sizes=MultiModalFieldConfig.batched("video"), + video_tgt_sizes=MultiModalFieldConfig.batched("video"), + video_embeds=MultiModalFieldConfig.batched("video"), + video_embed_is_patch=MultiModalFieldConfig.batched("video"), + image_token_id=MultiModalFieldConfig.shared("image", num_images), + video_token_id=MultiModalFieldConfig.shared("video", num_videos), + ) + + +class MiniCPMVImageEmbeddingItems(DictEmbeddingItems): + + def __init__( + self, + data: Mapping[str, torch.Tensor], + fields_factory: Callable[ + [Mapping[str, torch.Tensor]], + Mapping[str, MultiModalFieldConfig], + ], + ) -> None: + super().__init__( + data, + modality="image", + required_fields={"image_embeds", "image_sizes"}, + fields_factory=fields_factory, + ) + + def get_image_size(self, index: int) -> ImageSize: + image_size = self.get(index)["image_sizes"].tolist() + return ImageSize(width=image_size[0], height=image_size[1]) + + +class MiniCPMVVideoEmbeddingItems(DictEmbeddingItems): + + def __init__( + self, + data: Mapping[str, torch.Tensor], + fields_factory: Callable[ + [Mapping[str, torch.Tensor]], + Mapping[str, MultiModalFieldConfig], + ], + ) -> None: + super().__init__( + data, + modality="video", + required_fields={"video_embeds", "video_image_sizes"}, + fields_factory=fields_factory, + ) + + def get_frame_size(self, index: int) -> ImageSize: + frame_size = self.get(index)["video_image_sizes"].tolist() + return ImageSize(width=frame_size[0], height=frame_size[1]) + + def get_num_frames(self, index: int) -> int: + return len(self.get(index)["video_image_sizes"]) + + +class MiniCPMVMultiModalDataParser(MultiModalDataParser): + + def _parse_image_data( + self, + data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], + ) -> ModalityDataItems[Any, Any]: + if isinstance(data, dict): + return MiniCPMVImageEmbeddingItems( + data, + fields_factory=_minicpmv_field_config, + ) + + return super()._parse_image_data(data) + + def _parse_video_data( + self, + data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]], + ) -> ModalityDataItems[Any, Any]: + if isinstance(data, dict): + return MiniCPMVVideoEmbeddingItems( + data, + fields_factory=_minicpmv_field_config, + ) + + return super()._parse_video_data(data) + + +class MiniCPMVProcessingInfo(BaseProcessingInfo): + image_pattern = "(./)" + video_pattern = "()" + + def get_hf_config(self): + return self.ctx.get_hf_config() + + def get_hf_processor(self, **kwargs: object): + hf_processor = self.ctx.get_hf_processor(**kwargs) + + # NumPy arrays are considered as Iterable but not Sequence in + # https://github.com/huggingface/transformers/blob/main/src/transformers/image_transforms.py#L428 + image_processor = hf_processor.image_processor # type: ignore + for attr in ("mean", "std"): + val = getattr(image_processor, attr) + if isinstance(val, np.ndarray): + setattr(image_processor, attr, val.tolist()) + + return hf_processor + + def get_image_processor(self): + hf_processor = self.get_hf_processor() + image_processor = hf_processor.image_processor # type: ignore + return image_processor + + def get_model_version(self): + return get_version_by_config(self.get_hf_config()) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + mm_limits = {"image": None} + if self.get_model_version() == (2, 6): + mm_limits["video"] = None + + return mm_limits + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + mm_max_tokens = {"image": self.get_max_image_tokens()} + if self.get_model_version() == (2, 6): + mm_max_tokens["video"] = self.get_max_video_tokens( + seq_len, mm_counts) + + return mm_max_tokens + + def get_slice_image_placeholder( + self, + image_size: ImageSize, + # For MiniCPM V/O 2.6 + image_idx: int = 0, + max_slice_nums: Optional[int] = None, + use_image_id: bool = True, + ) -> str: + image_processor = self.get_image_processor() + version = self.get_model_version() + + if version == (2, 0) or version == (2, 5): + return image_processor.get_slice_image_placeholder(image_size) + + return image_processor.get_slice_image_placeholder( + image_size, + image_idx=image_idx, + max_slice_nums=max_slice_nums, + use_image_id=use_image_id, + ) + + def get_num_image_tokens( + self, + image_size: ImageSize, + max_slice_nums: Optional[int] = None, + use_image_id: bool = True, + ) -> int: + tokenizer = self.get_tokenizer() + image_placeholders = self.get_slice_image_placeholder( + image_size, + max_slice_nums=max_slice_nums, + use_image_id=use_image_id, + ) + image_token_ids = tokenizer.encode(image_placeholders, + add_special_tokens=False) + + return len(image_token_ids) + + def get_max_image_tokens(self) -> int: + image_size = self.get_image_size_with_most_features() + return self.get_num_image_tokens(image_size) + + def get_image_max_slice_num(self) -> int: + return getattr(self.get_hf_config(), "max_slice_num", 9) + + def get_image_size_with_most_features(self) -> ImageSize: + image_size = getattr(self.get_hf_config(), "image_size", 448) + max_slice_num = self.get_image_max_slice_num() + return ImageSize(width=image_size, height=image_size * max_slice_num) + + def get_max_video_frame_tokens(self) -> int: + frame_size = self.get_video_frame_size_with_most_features() + + return self.get_num_image_tokens( + frame_size, + max_slice_nums=self.get_video_max_slice_num(), + use_image_id=False, + ) + + def get_max_video_tokens( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + num_frames = self.get_num_frames_with_most_features(seq_len, mm_counts) + num_video_tokens_total = self.get_max_video_frame_tokens() * num_frames + return num_video_tokens_total + + def get_video_max_slice_num(self) -> int: + return 1 + + def get_video_frame_size_with_most_features(self) -> ImageSize: + image_size = getattr(self.get_hf_config(), "image_size", 448) + max_slice_num = self.get_video_max_slice_num() + return ImageSize(width=image_size, height=image_size * max_slice_num) + + def get_max_video_frames(self, max_tokens: int) -> int: + num_frame_tokens = self.get_max_video_frame_tokens() + num_frames = max_tokens // num_frame_tokens + return num_frames + + def get_num_frames_with_most_features( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + max_images = mm_counts.get("image", 0) + max_videos = mm_counts.get("video", 0) + + max_image_tokens = self.get_max_image_tokens() * max_images + max_total_frames = self.get_max_video_frames(seq_len - + max_image_tokens) + max_frames_per_video = min(max_total_frames // max(max_videos, 1), + _MAX_FRAMES_PER_VIDEO) + + return max(max_frames_per_video, 1) + + +_I = TypeVar("_I", + bound=MiniCPMVProcessingInfo, + default=MiniCPMVProcessingInfo) + + +class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + image_width, image_height = \ + self.info.get_image_size_with_most_features() + video_width, video_height = \ + self.info.get_video_frame_size_with_most_features() + num_video_frames = \ + self.info.get_num_frames_with_most_features(seq_len, mm_counts) + + mm_data = { + "image": + self._get_dummy_images(width=image_width, + height=image_height, + num_images=num_images), + "video": [ + self._get_dummy_images(width=video_width, + height=video_height, + num_images=num_video_frames) + ] * num_videos, + } + + image_prompt_texts = self.info.image_pattern * num_images + video_prompt_texts = self.info.video_pattern * num_videos + + return ProcessorInputs(prompt_text=image_prompt_texts + + video_prompt_texts, + mm_data=mm_data) + + +class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): + + def _get_data_parser(self) -> MultiModalDataParser: + return MiniCPMVMultiModalDataParser() + + def get_image_prompt_texts(self, + image_size: ImageSize, + image_idx: int = 0) -> str: + return self.info.get_slice_image_placeholder( + image_size, + image_idx=image_idx, + ) + + def get_video_prompt_texts(self, image_size: ImageSize, + num_frames: int) -> str: + return self.info.get_slice_image_placeholder( + image_size=image_size, + image_idx=0, + max_slice_nums=self.info.get_video_max_slice_num(), + use_image_id=False, + ) * num_frames + + def get_embed_is_patch( + self, + input_ids: list[int], + ) -> torch.Tensor: + tokenizer = self.info.get_tokenizer() + unk_token_id = tokenizer.get_vocab()[""] + return torch.tensor(input_ids) == unk_token_id + + def process_images( + self, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> Mapping[str, NestedTensors]: + if (images := mm_data.get("images")) is None: + return {} + + parsed_images = (self._get_data_parser().parse_mm_data({ + "image": images + }).get_items("image", + (MiniCPMVImageEmbeddingItems, ImageProcessorItems))) + + if isinstance(parsed_images, MiniCPMVImageEmbeddingItems): + image_inputs = {} + else: + image_inputs = self._base_call_hf_processor( + prompts=[self.info.image_pattern] * len(parsed_images), + mm_data={"images": [[image] for image in parsed_images]}, + mm_kwargs=mm_kwargs, + out_keys={"pixel_values", "image_sizes", "tgt_sizes"}, + ) + + image_sizes = [ + parsed_images.get_image_size(i) for i in range(len(parsed_images)) + ] + image_repl_features = [ + self.get_image_prompt_texts(size, idx) + for idx, size in enumerate(image_sizes) + ] + + tokenizer = self.info.get_tokenizer() + image_repls_feature_tokens = [ + tokenizer.encode(image_repl, add_special_tokens=False) + for image_repl in image_repl_features + ] + + embed_is_patch = [ + self.get_embed_is_patch(image_repl_tokens) + for image_repl_tokens in image_repls_feature_tokens + ] + image_inputs["embed_is_patch"] = embed_is_patch + + unk_token_id = tokenizer.get_vocab()[""] + image_inputs["image_token_id"] = torch.tensor(unk_token_id) + + return image_inputs + + def process_videos( + self, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> Mapping[str, NestedTensors]: + if (videos := mm_data.get("videos")) is None: + return {} + + parsed_videos = (self._get_data_parser().parse_mm_data({ + "video": videos + }).get_items("video", + (MiniCPMVVideoEmbeddingItems, VideoProcessorItems))) + + if isinstance(parsed_videos, MiniCPMVVideoEmbeddingItems): + video_inputs = {} + else: + video_inputs = self._base_call_hf_processor( + prompts=[ + self.info.image_pattern * len(video) + for video in parsed_videos + ], + mm_data={"images": list(parsed_videos)}, + mm_kwargs={ + **mm_kwargs, + "max_slice_nums": + self.info.get_video_max_slice_num(), + }, + out_keys={"pixel_values", "image_sizes", "tgt_sizes"}, + ) + + frame_sizes = [ + parsed_videos.get_frame_size(i) for i in range(len(parsed_videos)) + ] + num_frames = [ + parsed_videos.get_num_frames(i) for i in range(len(parsed_videos)) + ] + video_repl_features = [ + self.get_video_prompt_texts(size, nframes) + for size, nframes in zip(frame_sizes, num_frames) + ] + + tokenizer = self.info.get_tokenizer() + video_repls_feature_tokens = [ + tokenizer.encode(video_repl, add_special_tokens=False) + for video_repl in video_repl_features + ] + + embed_is_patch = [ + self.get_embed_is_patch(video_repl_tokens) + for video_repl_tokens in video_repls_feature_tokens + ] + video_inputs["embed_is_patch"] = embed_is_patch + + video_inputs = {f"video_{k}": v for k, v in video_inputs.items()} + + unk_token_id = tokenizer.get_vocab()[""] + video_inputs["video_token_id"] = torch.tensor(unk_token_id) + + return video_inputs + + def process_mm_inputs( + self, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> Mapping[str, NestedTensors]: + return { + **self.process_images(mm_data, mm_kwargs), + **self.process_videos(mm_data, mm_kwargs), + } + + def _base_call_hf_processor( + self, + prompts: list[str], + mm_data: Mapping[str, Sequence[object]], + mm_kwargs: Mapping[str, object], + *, + out_keys: set[str], + ) -> dict[str, NestedTensors]: + # This processor supports zipping prompt and mm_data together + if self.info.get_model_version() == (2, 6): + inputs = super()._call_hf_processor( + prompt=prompts, # type: ignore + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + else: + inputs = defaultdict[str, list[torch.Tensor]](list) + + for i, prompt in enumerate(prompts): + inputs_one = super()._call_hf_processor( + prompt=prompt, + mm_data={ + k: v[i] + for k, v in mm_data.items() + }, + mm_kwargs=mm_kwargs, + ) + + for k, v in inputs_one.items(): + assert len(v) == 1, (k, len(v)) + inputs[k].append(v[0]) + + return {k: inputs[k] for k in out_keys} + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + tokenizer = self.info.get_tokenizer() + + input_ids = torch.tensor([tokenizer.encode(prompt)]) + mm_inputs = self.process_mm_inputs(mm_data, mm_kwargs) + + return BatchFeature({ + "input_ids": input_ids, + **mm_inputs, + }) + + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> bool: + return False + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + placeholder = { + "image": self.info.image_pattern, + "video": self.info.video_pattern, + } + + def get_image_replacement(item_idx: int): + images = mm_items.get_items( + "image", (MiniCPMVImageEmbeddingItems, ImageProcessorItems)) + + image_size = images.get_image_size(item_idx) + + return self.get_image_prompt_texts(image_size, item_idx) + + def get_video_replacement(item_idx: int): + videos = mm_items.get_items( + "video", (MiniCPMVVideoEmbeddingItems, VideoProcessorItems)) + + frame_size = videos.get_frame_size(item_idx) + num_frames = videos.get_num_frames(item_idx) + + return self.get_video_prompt_texts(frame_size, num_frames) + + get_replacement = { + "image": get_image_replacement, + "video": get_video_replacement, + } + + return [ + PromptReplacement(modality=modality, + target=placeholder[modality], + replacement=get_replacement[modality]) + for modality in ("image", "video") + ] + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return _minicpmv_field_config(hf_inputs) + + +class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): + """ + The abstract class of MiniCPMV can only be inherited, but cannot be + instantiated. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + multimodal_config = vllm_config.model_config.multimodal_config + quant_config = vllm_config.quant_config + super().__init__() + # All MiniCPM-V models disable `tie_word_embeddings` but + # `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot + # check `tie_word_embeddings` until vLLM integrate MiniCPM-V model + # and config class + self.config = config + self.multimodal_config = multimodal_config + + self.version = get_version_by_config(self.config) + self.llm = self.init_llm(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "llm")) + self.vpm = self.init_vision_module(config, + quant_config, + prefix=maybe_prefix(prefix, "vpm")) + self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else + self.vpm.embeddings.embed_dim) + self.embed_dim = self.config.hidden_size + + self.resampler = self.init_resampler(self.embed_dim, + self.vision_dim, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "resampler")) + + self.mm_token_ids = set[int]() + self.make_empty_intermediate_tensors = ( + self.llm.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.llm, "sampler"): + return self.llm.sampler + + return get_sampler() + + def _parse_and_validate_vision_input( + self, + modality: str, + **kwargs: object, + ) -> Optional[MiniCPMVImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + image_token_id = kwargs.pop("image_token_id") + if image_token_id is not None: + assert isinstance(image_token_id, torch.Tensor) + self.mm_token_ids.add(image_token_id.flatten().unique().item()) + + embed_is_patch = kwargs.pop("embed_is_patch") + if not isinstance(embed_is_patch, (torch.Tensor, list)): + raise ValueError( + f"Incorrect type of embed_is_patch for {modality=}. " + f"Got type: {type(embed_is_patch)}") + + embed_is_patch = flatten_bn(embed_is_patch) + + if image_embeds is not None: + if not isinstance(image_embeds, (torch.Tensor, list)): + raise ValueError( + f"Incorrect type of image_embeds for {modality=}. " + f"Got type: {type(image_embeds)}") + + image_embeds_flat = flatten_bn(image_embeds) + + return MiniCPMVImageEmbeddingInputs( + type="image_embeds", + image_embeds=image_embeds_flat, + embed_is_patch=embed_is_patch, + ) + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError( + f"Incorrect type of pixel_values for {modality=}. " + f"Got type: {type(pixel_values)}") + + tgt_sizes = kwargs.pop("tgt_sizes") + if not isinstance(tgt_sizes, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of tgt_sizes for {modality=}. " + f"Got type: {type(tgt_sizes)}") + + num_slices = [[len(p) for p in ps] for ps in pixel_values] + num_slices_flat = flatten_bn(torch.tensor(num_slices)) + + pixel_values_flat = flatten_bn(flatten_2d_lists(pixel_values)) + tgt_sizes_flat = flatten_bn(flatten_2d_lists(tgt_sizes), concat=True) + + if len(pixel_values_flat) != len(tgt_sizes_flat): + raise ValueError("Inconsistent flattened lengths, found: " + f"{len(pixel_values_flat)} vs. " + f"{len(tgt_sizes_flat)}") + + return MiniCPMVImagePixelInputs( + type="pixel_values", + pixel_values=pixel_values_flat, + tgt_sizes=tgt_sizes_flat, + embed_is_patch=embed_is_patch, + num_slices=num_slices_flat, + ) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("pixel_values", + "image_embeds") and "images" not in modalities: + modalities["images"] = self._parse_and_validate_vision_input( + "images", **kwargs) + if input_key in ("video_pixel_values", + "video_embeds") and "videos" not in modalities: + + def _image_key(video_key: str): + if video_key == "video_token_id": + return "image_token_id" + + return video_key.removeprefix("video_") + + modalities["videos"] = self._parse_and_validate_vision_input( + "videos", **{ + _image_key(k): v + for k, v in kwargs.items() + }) + + return modalities + + def _process_vision_input( + self, + image_input: MiniCPMVImageInputs, + ) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]: + if image_input["type"] == "image_embeds": + return image_input["image_embeds"] + + image_features_flat = self.get_vision_hidden_states(image_input) + + num_slices = image_input["num_slices"] + return [ + e.flatten(0, 1) + for e in image_features_flat.split(num_slices.tolist()) + ] + + def _process_multimodal_inputs(self, modalities: dict): + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in modalities: + if modality == "images": + image_input = modalities["images"] + image_features = self._process_vision_input(image_input) + multimodal_embeddings += tuple( + scatter_patch_features( + image_features, + image_input["embed_is_patch"], + )) + if modality == "videos": + video_input = modalities["videos"] + video_features = self._process_vision_input(video_input) + multimodal_embeddings += tuple( + scatter_patch_features( + video_features, + video_input["embed_is_patch"], + )) + + return multimodal_embeddings + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: + return None + + return self._process_multimodal_inputs(modalities) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.llm.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + assert len(self.mm_token_ids) > 0 + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + select_patch_features(multimodal_embeddings), + list(self.mm_token_ids), + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: Any, + ) -> torch.Tensor: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner from + # `get_multimodal_embeddings` and `get_input_embeddings`, this + # condition is only for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.llm.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.llm.compute_logits(hidden_states, sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field(language_model="llm", + connector="resampler", + tower_model="vpm") + + def init_llm( + self, + vllm_config: VllmConfig, + prefix: str = "", + ) -> nn.Module: + raise NotImplementedError + + def init_vision_module( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig], + prefix: str = "", + ) -> nn.Module: + raise NotImplementedError + + def init_resampler(self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> nn.Module: + raise NotImplementedError + + def get_vision_hidden_states( + self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: + raise NotImplementedError + + +class MiniCPMV2_0(MiniCPMVBaseModel): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + assert self.version == (2, 0) + + def init_llm( + self, + vllm_config: VllmConfig, + prefix: str = "", + ) -> nn.Module: + return MiniCPMForCausalLM(vllm_config=vllm_config, prefix=prefix) + + def init_vision_module( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig], + prefix: str = "", + ) -> nn.Module: + # TODO: refactor vision model through timm wrapper from transformers + try: + import timm + except ImportError: + raise ImportError("Please install timm==0.9.10") from ImportError + + with set_default_torch_dtype(torch.float16): + model = timm.create_model( + "vit_so400m_patch14_siglip_384.webli", + pretrained=False, + num_classes=0, + dynamic_img_size=True, + dynamic_img_pad=True, + ) + + model = model.to(dtype=torch.get_default_dtype()) + + if (isinstance(model, timm.models.VisionTransformer) + and model.attn_pool is not None): + model.attn_pool = torch.nn.Identity() + + if self.config.drop_vision_last_layer: + model.blocks = model.blocks[:-1] + + return model + + def init_resampler(self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> nn.Module: + with set_default_torch_dtype(torch.float16): + resampler = Resampler2(embed_dim=embed_dim, + num_heads=embed_dim // 128, + grid_size=int( + math.sqrt(self.config.query_num)), + kv_dim=vision_dim, + adaptive=False, + do_post_projection=True, + quant_config=quant_config, + prefix=prefix) + + return resampler.to(device=current_platform.device_type, + dtype=torch.get_default_dtype()) + + def get_vision_hidden_states( + self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: + pixel_values = data["pixel_values"] + + P_h, P_w = self.vpm.patch_embed.patch_size + dtype: torch.dtype = self.vpm.pos_embed.data.dtype + num_prefix_tokens = getattr(self.vpm, "num_prefix_tokens", 0) + + res = list[torch.Tensor]() + for pixel_value in pixel_values: + H, W = pixel_value[0].shape[-2:] + tgt_size = (math.ceil(H / P_h), math.ceil(W / P_w)) + vision_embedding = self.vpm.forward_features( + pixel_value.unsqueeze(0).type(dtype)) + + if num_prefix_tokens > 0: + vision_embedding = vision_embedding[:, num_prefix_tokens:] + res.append(self.resampler(vision_embedding, tgt_size)) + + return torch.vstack(res) + + +class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + assert self.version == (2, 5) + + def init_llm( + self, + vllm_config: VllmConfig, + prefix: str = "", + ) -> nn.Module: + return LlamaForCausalLM(vllm_config=vllm_config, prefix=prefix) + + def init_vision_module( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig], + prefix: str = "", + ) -> nn.Module: + model = Idefics2VisionTransformer(config.vision_config, + quant_config=quant_config, + prefix=prefix) + if self.config.drop_vision_last_layer: + model.encoder.layers = model.encoder.layers[:-1] + return model + + def init_resampler(self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> nn.Module: + with set_default_torch_dtype(torch.float16): + resampler = Resampler2_5(num_queries=self.config.query_num, + embed_dim=embed_dim, + num_heads=embed_dim // 128, + kv_dim=vision_dim, + quant_config=quant_config, + prefix=prefix) + + return resampler.to(device=current_platform.device_type, + dtype=torch.get_default_dtype()) + + def get_vision_hidden_states( + self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: + pixel_values = data["pixel_values"] + tgt_sizes = data["tgt_sizes"] + + B = len(pixel_values) + P = pixel_values[0].shape[-2] + L = max(item.shape[-1] for item in pixel_values) + device = pixel_values[0].device + dtype = pixel_values[0].dtype + + all_pixel_values = torch.zeros((B, 3, P, L), + dtype=dtype, + device=device) + for i, pixel_values_item in enumerate(pixel_values): + L_item = pixel_values_item.shape[-1] + all_pixel_values[i, ..., :L_item] = pixel_values_item + + num_patches = tgt_sizes.prod(-1) + max_patches = num_patches.max().item() + assert isinstance(max_patches, int) + + patch_attn_mask = torch.zeros((B, max_patches), + dtype=torch.bool, + device=device) + for i, num_patches_item in enumerate(num_patches): + patch_attn_mask[i, :num_patches_item] = True + + vision_embedding = self.vpm( + all_pixel_values, + patch_attention_mask=patch_attn_mask.unsqueeze(1), + tgt_sizes=None, + ) + + return self.resampler(vision_embedding, tgt_sizes) + + +class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + assert self.version == (2, 6) + + def init_llm( + self, + vllm_config: VllmConfig, + prefix: str = "", + ) -> nn.Module: + return Qwen2ForCausalLM(vllm_config=vllm_config, prefix=prefix) + + def init_vision_module( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig], + prefix: str = "", + ) -> nn.Module: + model = Idefics2VisionTransformer(config.vision_config, + quant_config=quant_config, + prefix=prefix) + if self.config.drop_vision_last_layer: + model.encoder.layers = model.encoder.layers[:-1] + return model + + def init_resampler(self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> nn.Module: + with set_default_torch_dtype(torch.float16): + # The resampler in 2.6 remains consistent with the one in 2.5. + resampler = Resampler2_5(num_queries=self.config.query_num, + embed_dim=embed_dim, + num_heads=embed_dim // 128, + kv_dim=vision_dim, + quant_config=quant_config, + prefix=prefix) + + return resampler.to(device=current_platform.device_type, + dtype=torch.get_default_dtype()) + + def get_vision_hidden_states( + self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: + pixel_values = data["pixel_values"] + tgt_sizes = data["tgt_sizes"] + + B = len(pixel_values) + P = pixel_values[0].shape[-2] + L = max(item.shape[-1] for item in pixel_values) + device = pixel_values[0].device + dtype = pixel_values[0].dtype + + all_pixel_values = torch.zeros((B, 3, P, L), + dtype=dtype, + device=device) + for i, pixel_values_item in enumerate(pixel_values): + L_item = pixel_values_item.shape[-1] + all_pixel_values[i, ..., :L_item] = pixel_values_item + + num_patches = tgt_sizes.prod(-1) + max_patches = num_patches.max().item() + assert isinstance(max_patches, int) + + patch_attn_mask = torch.zeros((B, max_patches), + dtype=torch.bool, + device=device) + for i, num_patches_item in enumerate(num_patches): + patch_attn_mask[i, :num_patches_item] = True + + vision_embedding = self.vpm( + all_pixel_values, + patch_attention_mask=patch_attn_mask.unsqueeze(1), + tgt_sizes=tgt_sizes, + ) + + return self.resampler(vision_embedding, tgt_sizes) + + +_SUPPORT_VERSION = { + (2, 0): MiniCPMV2_0, + (2, 5): MiniCPMV2_5, + (2, 6): MiniCPMV2_6, +} + + +@MULTIMODAL_REGISTRY.register_processor( + MiniCPMVMultiModalProcessor, + info=MiniCPMVProcessingInfo, + dummy_inputs=MiniCPMVDummyInputsBuilder) +class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA): + """ + Different versions of MiniCPMV use different visual encoders and LLMs, + which is not conducive to the current integration logic of LoRA and + bitsandbytes in vLLM. Therefore, it is necessary to separate them. + """ + + def __new__(cls, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + if not hasattr(config, "version"): + if config.hidden_size == 2304 and config.query_num == 64: + version = (2, 0) + else: + version = (2, 5) + else: + version = str(config.version).split(".") + version = tuple([int(x) for x in version]) + # Dispatch class based on version + instance_cls = _SUPPORT_VERSION.get(version) + if instance_cls is None: + raise ValueError( + "Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6") + + # quant_config references base class members, + # so update values before init is called + cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping) + cls.embedding_modules.update(instance_cls.embedding_modules) + cls.embedding_padding_modules += instance_cls.embedding_padding_modules + return instance_cls(vllm_config=vllm_config, prefix=prefix) diff --git a/vllm/model_executor/models/minimax_cache.py b/vllm/model_executor/models/minimax_cache.py new file mode 100644 index 0000000..c95cbb4 --- /dev/null +++ b/vllm/model_executor/models/minimax_cache.py @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass + +import torch + +from vllm.model_executor.models.constant_size_cache import ConstantSizeCache + + +@dataclass +class MinimaxCacheParams: + minimax_cache: torch.Tensor = torch.Tensor() + state_indices_tensor: torch.Tensor = torch.Tensor() + + def at_layer_idx(self, layer_idx): + return MinimaxCacheParams(self.minimax_cache[layer_idx, ...], + self.state_indices_tensor) + + +class MinimaxCacheManager(ConstantSizeCache): + + def __init__(self, dtype, cache_shape): + super().__init__(cache_shape[1]) # max_batch_size is cache_shape[1] + self._minimax_cache = torch.empty(size=cache_shape, + dtype=dtype, + device="cuda") + + @property + def cache(self): + return self._minimax_cache + + def _copy_cache(self, from_index: int, to_index: int): + assert len(self.cache) > 0 + for cache_t in self.cache: + cache_t[:, to_index].copy_(cache_t[:, from_index], + non_blocking=True) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py new file mode 100644 index 0000000..7562aa6 --- /dev/null +++ b/vllm/model_executor/models/minimax_text_01.py @@ -0,0 +1,1273 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Inference-only MiniMaxText01 model.""" +import copy +import math +import re +from typing import Dict, Iterable, List, Optional, Tuple, Union + +import torch +import torch.distributed +import torch.nn.functional as F +from einops import rearrange +from torch import nn +from transformers.configuration_utils import PretrainedConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed.communication_op import tensor_model_parallel_all_reduce +from vllm.distributed.parallel_state import ( + get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.forward_context import get_forward_context +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.lightning_attn import ( + lightning_attention, linear_decode_forward_triton) +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.utils import maybe_prefix +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import HasInnerState, IsHybrid, SupportsV0Only +from .minimax_cache import MinimaxCacheManager, MinimaxCacheParams +from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers + + +def replace_weight_name(name: str, + key: str = None, + to: str = None, + count: int = None, + prefix: str = None) -> str: + name = name.replace(key, to) if count is None else \ + name.replace(key, to, count) + return name + + +def weight_loader_with_alias(alias: str): + + def wrapper(func: callable): + + def inner_func(param: torch.Tensor, + loaded_weight: torch.Tensor, + *args, + prefix: str = None, + **kwargs): + value = func(param, loaded_weight, *args, **kwargs) + return value + + return inner_func + + return wrapper + + +class MiniMaxText01RMSNormTP(CustomOp): + name = "MiniMaxText01RMSNormTP" + + def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: + super().__init__() + self.tp_world = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.weight = nn.Parameter(torch.ones(int(hidden_size / + self.tp_world))) + + self.weight.weight_loader = self.weight_loader + self.variance_epsilon = eps + return + + @staticmethod + def weight_loader( + param: nn.Parameter, + loaded_weight: torch.Tensor, + ) -> None: + tp_world = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + shard_size = loaded_weight.shape[0] // tp_world + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + param.data.copy_(loaded_weight[shard]) + return + + def _forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + orig_dtype = x.dtype + x = x.to(torch.float32) + variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32) + if self.tp_world > 1: + variance = tensor_model_parallel_all_reduce( + variance) / self.tp_world + x = x * torch.rsqrt(variance + self.variance_epsilon) + x = x.to(orig_dtype) * self.weight + return x + + def forward( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + assert residual is None, "RMSNorm does not support residual connection." + return self._forward(x) + + +class MiniMaxText01RotaryEmbedding(CustomOp): + name = "MiniMaxText01RotaryEmbedding" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position: int, + base: int, + is_neox_style: bool, + cache_dtype: torch.dtype, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position + self.base = base + self.is_neox_style = is_neox_style + self.cache_dtype = cache_dtype + cache = self._compute_cos_sin_cache().to(cache_dtype) + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq( + self, + base: Union[int, float], + ) -> torch.Tensor: + """Compute the inverse frequency.""" + inv_freq = 1.0 / (base**(torch.arange( + 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + from vllm import _custom_ops as ops + self.cos_sin_cache = self.cos_sin_cache.to(positions.device) + query_cast = query.to(self.cache_dtype) + key_cast = key.to(self.cache_dtype) + ops.rotary_embedding(positions, query_cast, key_cast, self.head_size, + self.cos_sin_cache, self.is_neox_style) + query = query_cast.to(query.dtype) + key = key_cast.to(key.dtype) + return query, key + + +class MiniMaxText01MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + quant_config: Optional[QuantizationConfig] = None, + layer_idx: int = None, + prefix: str = "mlp", + ) -> None: + super().__init__() + self.layer_idx = layer_idx + + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + self.act_fn = SiluAndMul() + return + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class MiniMaxText01MoE(nn.Module): + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + layer_idx: int = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "moe", + ) -> None: + super().__init__() + + self.layer_idx = layer_idx + self.tp_size = get_tensor_model_parallel_world_size() + self.num_total_experts = num_experts + self.top_k = top_k + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size // self.tp_size + self.quant_config = quant_config + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + self.gate = ReplicatedLinear( + self.hidden_size, + self.num_total_experts, + bias=False, + params_dtype=torch.float32, + quant_config=None, + prefix=f"{prefix}.gate", + ) + self.gate.weight.weight_loader = MiniMaxText01MoE.gate_weight_loader + + self.experts = FusedMoE( + num_experts=self.num_total_experts, + top_k=self.top_k, + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size * self.tp_size, + params_dtype=self.params_dtype, + reduce_results=True, + renormalize=True, + quant_config=self.quant_config, + tp_size=self.tp_size, + prefix=f"{prefix}.experts", + ) + return + + @staticmethod + def gate_weight_loader(param: nn.Parameter, + loaded_weight: torch.Tensor) -> None: + assert param.size() == loaded_weight.size() + param.data.copy_(loaded_weight.to(torch.float32)) + return + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + router_logits_fp32, _ = self.gate(hidden_states.to(torch.float32)) + final_hidden_states = self.experts( + hidden_states, router_logits_fp32.to(hidden_states.dtype)) + final_hidden = final_hidden_states.view(num_tokens, hidden_size) + return final_hidden + + +class MiniMaxText01LinearKernel: + + @staticmethod + def jit_linear_forward_prefix(q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kv_caches: torch.Tensor, + slope_rate: torch.Tensor, + block_size: int, + layer_idx: int = None, + **kwargs) -> torch.Tensor: + + slope_rate = slope_rate.to(torch.float32) + should_pad_dim = q.dim() == 3 + if should_pad_dim: + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + b, h, n, d = q.shape + e = d + kv_history = kv_caches.reshape(1, h, d, e).contiguous() + output, kv_history = lightning_attention(q, + k, + v, + slope_rate, + block_size=block_size, + kv_history=kv_history) + kv_caches.copy_(kv_history[:, :, -1, :, :].reshape(h, d, e)) + assert output.shape[0] == 1, "batch size must be 1" + return rearrange(output.squeeze(0), "h n d -> n (h d)") + + +class MiniMaxText01LinearAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + hidden_inner_size: int, + num_heads: int, + head_dim: int, + max_position: int, + block_size: int, + num_hidden_layer: int, + quant_config: Optional[QuantizationConfig] = None, + layer_idx: int = 0, + linear_layer_idx: int = 0, + prefix: str = "linear_attn", + ) -> None: + super().__init__() + + self.layer_idx = layer_idx + self.BLOCK = block_size + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = head_dim + self.total_num_heads = num_heads + self.hidden_inner_size = hidden_inner_size + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + assert self.total_num_heads % self.tp_size == 0 + self.tp_heads = self.total_num_heads // self.tp_size + self.qkv_size = self.num_heads * self.head_dim + self.tp_hidden = self.head_dim * self.tp_heads + + self.qkv_proj = ColumnParallelLinear( + hidden_size, + self.hidden_inner_size * 3, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.output_gate = ColumnParallelLinear( + hidden_size, + self.hidden_inner_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.output_gate", + ) + self.out_proj = RowParallelLinear( + self.hidden_inner_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + self.norm = MiniMaxText01RMSNormTP( + self.hidden_inner_size, + eps=1e-5, + ) + + slope_rate = MiniMaxText01LinearAttention._build_slope_tensor( + self.num_heads) + if num_hidden_layer <= 1: + self.slope_rate = slope_rate * (1 + 1e-5) + else: + self.slope_rate = slope_rate * (1 - layer_idx / + (num_hidden_layer - 1) + 1e-5) + self.tp_slope = self.slope_rate[self.tp_rank * + self.tp_heads:(self.tp_rank + 1) * + self.tp_heads].contiguous() + + @staticmethod + def weight_direct_load(param: torch.Tensor, + loaded_weight: torch.Tensor) -> None: + assert param.size() == loaded_weight.size() + param.data.copy_(loaded_weight) + return + + @staticmethod + def _build_slope_tensor(n_attention_heads: int): + + def get_slopes(n): + + def get_slopes_power_of_2(n): + start = 2**(-(2**-(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2**math.floor(math.log2(n)) + return (get_slopes_power_of_2(closest_power_of_2) + get_slopes( + 2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) + + slopes = torch.tensor(get_slopes(n_attention_heads), + dtype=torch.float32).reshape( + n_attention_heads, 1, 1) + return slopes + + def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor, + attn_metadata): + hidden = [] + for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)): + _start = attn_metadata.query_start_loc[_prefill_idx] + _end = attn_metadata.query_start_loc[_prefill_idx + 1] + slot_id = state_indices_tensor[_prefill_idx] + qs = q[_start:_end].transpose(0, 1).contiguous() + ks = k[_start:_end].transpose(0, 1).contiguous() + vs = v[_start:_end].transpose(0, 1).contiguous() + slot_id = state_indices_tensor[_prefill_idx] + slice_layer_cache = kv_cache[slot_id, ...] + + out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix( + qs, + ks, + vs, + slice_layer_cache, + self.tp_slope, + self.BLOCK, + layer_idx=self.layer_idx) + hidden.append(out_slice.contiguous()) + if attn_metadata.num_decode_tokens > 0: + hidden.append( + self._decode_infer(q, k, v, kv_cache, state_indices_tensor, + attn_metadata)) + hidden = torch.concat(hidden, dim=0).contiguous() + return hidden + + def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, + attn_metadata): + q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() + k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() + v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() + slot_id = state_indices_tensor[getattr(attn_metadata, "num_prefills", 0 + ):] + hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope, + slot_id, 32) + return hidden + + def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, + kv_caches: MinimaxCacheParams, **kwargs) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + qkv32 = qkv.to(torch.float32) + qkvact = torch.nn.functional.silu(qkv32) + qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1)) + q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1) + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + kv_cache = kv_caches.minimax_cache + state_indices_tensor = kv_caches.state_indices_tensor + + decode_only = getattr(attn_metadata, "num_prefills", 0) == 0 + if not decode_only: + hidden = self._prefill_and_mix_infer(q, k, v, kv_cache, + state_indices_tensor, + attn_metadata) + else: + hidden = self._decode_infer(q, k, v, kv_cache, + state_indices_tensor, attn_metadata) + + hidden = self.norm._forward(hidden) + gate, _ = self.output_gate(hidden_states) + hidden = F.sigmoid(gate) * hidden + hidden = hidden.to(hidden_states.dtype) + hidden, _ = self.out_proj(hidden) + return hidden + + +class MiniMaxText01Attention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + head_dim: int, + num_kv_heads: int, + rotary_dim: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + sliding_window: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + layer_idx: int = None, + cache_config: Optional[CacheConfig] = None, + prefix: str = "mha", + ) -> None: + super().__init__() + self.layer_idx = layer_idx + + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.sliding_window = sliding_window + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + return + + def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, + **kwargs) -> torch.Tensor: + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = attn_metadata.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class MiniMaxText01DecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + expert_num: int = 1, + layer_id: int = None, + linear_layer_id: Optional[int] = None, + prefix: str = "decoder", + ) -> None: + self._ilayer = layer_id + self._irank = get_tensor_model_parallel_rank() + super().__init__() + + self.hidden_size = config.hidden_size + self.expert_num = expert_num + + rope_theta = getattr(config, "rope_theta", 10000) + + head_dim = getattr(config, "head_dim", + config.hidden_size // config.num_attention_heads) + if hasattr(config, "max_model_len") and isinstance( + config.max_model_len, int): + max_position_embeddings = min(config.max_position_embeddings, + config.max_model_len) + if config.attention_type == 0: + use_headxdim = True + hidden_inner = (head_dim * config.num_attention_heads + if use_headxdim else config.hidden_size) + self.self_attn = MiniMaxText01LinearAttention( + hidden_size=self.hidden_size, + hidden_inner_size=hidden_inner, + num_heads=config.num_attention_heads, + head_dim=head_dim, + max_position=max_position_embeddings, + block_size=config.block if hasattr(config, "block") else 256, + num_hidden_layer=config.num_hidden_layers, + quant_config=quant_config, + layer_idx=self._ilayer, + linear_layer_idx=linear_layer_id, + prefix=prefix) + elif config.attention_type == 1: + self.self_attn = MiniMaxText01Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + head_dim=head_dim, + rotary_dim=config.rotary_dim + if hasattr(config, "rotary_dim") else head_dim, + num_kv_heads=config.num_key_value_heads, + max_position=max_position_embeddings, + rope_theta=rope_theta, + sliding_window=config.sliding_window, + quant_config=quant_config, + layer_idx=self._ilayer, + cache_config=cache_config, + prefix=prefix) + else: + raise ValueError( + f"Unsupported attention type: {self.config.attention_type}") + + if expert_num == 1: + self.mlp = MiniMaxText01MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + layer_idx=self._ilayer, + prefix=prefix) + else: + self.block_sparse_moe = MiniMaxText01MoE( + num_experts=expert_num, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + layer_idx=self._ilayer, + quant_config=quant_config, + prefix=prefix) + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + if config.attention_type == 0: + self.layernorm_attention_alpha = getattr( + config, 'layernorm_linear_attention_alpha', 1) + self.layernorm_attention_beta = getattr( + config, 'layernorm_linear_attention_beta', 1) + else: + self.layernorm_attention_alpha = getattr( + config, 'layernorm_full_attention_alpha', 1) + self.layernorm_attention_beta = getattr( + config, 'layernorm_full_attention_beta', 1) + self.layernorm_mlp_alpha = getattr(config, 'layernorm_mlp_alpha', 1) + self.layernorm_mlp_beta = getattr(config, 'layernorm_mlp_beta', 1) + self.postnorm = getattr(config, 'postnorm', False) + self.shared_moe = False + + shared_intermediate = getattr(config, 'shared_intermediate_size', 0) + if shared_intermediate > 0: + self.shared_moe = True + self.shared_mlp = MiniMaxText01MLP( + hidden_size=self.hidden_size, + intermediate_size=shared_intermediate, + quant_config=quant_config, + layer_idx=self._ilayer, + prefix=prefix) + self.coefficient = ReplicatedLinear( + self.hidden_size, + 1, + bias=False, + quant_config=quant_config, + params_dtype=torch.float32, + ) + self.coefficient.weight.weight_loader = ( + self.shared_moe_coefficient_loader) + self.shared_moe_mode = getattr(config, 'shared_moe_mode', + 'softmax') + return + + def forward(self, + hidden_states: torch.Tensor, + positions: torch.Tensor, + kv_caches: Union[List[Dict], Optional[torch.Tensor]], + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + is_warmup: bool = False, + **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: + + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + layernorm_input = hidden_states + layernorm_output = self.input_layernorm(layernorm_input) + residual = layernorm_output if self.postnorm else layernorm_input + self_attention_output = self.self_attn( + hidden_states=layernorm_output, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + ) + + residual = residual * self.layernorm_attention_alpha + self_attention_output = (self_attention_output * + self.layernorm_attention_beta) + + layernorm_input = residual + self_attention_output + layernorm_output = self.post_attention_layernorm(layernorm_input) + residual = layernorm_output if self.postnorm else layernorm_input + + if self.expert_num == 1: + hidden_states = self.mlp(layernorm_output) + else: + moe_hidden_states = self.block_sparse_moe( + copy.deepcopy(layernorm_output)) + if self.shared_moe: + before_moe_dtype = layernorm_output.dtype + moe_hidden_fp32 = moe_hidden_states.to(torch.float32) + output_mlp = self.shared_mlp(layernorm_output).to( + torch.float32) + + coef, _ = self.coefficient(layernorm_output.to(torch.float32)) + + if self.shared_moe_mode == 'softmax': + coef = torch.nn.functional.softmax(coef, dim=-1) + hidden_states = moe_hidden_fp32 * ( + 1 - coef) + output_mlp * coef + elif self.shared_moe_mode == 'sigmoid': + coef = torch.nn.functional.sigmoid(coef) + hidden_states = moe_hidden_fp32 * ( + 1 - coef) + output_mlp * coef + + hidden_states = hidden_states.to(before_moe_dtype) + else: + hidden_states = moe_hidden_states + + residual = residual * self.layernorm_mlp_alpha + hidden_states = hidden_states * self.layernorm_mlp_beta + + hidden_states = residual + hidden_states + + return hidden_states, None + + @staticmethod + def shared_moe_coefficient_loader(param: torch.Tensor, + loaded_weight: torch.Tensor) -> None: + assert param.size() == loaded_weight.size() + + param.data.copy_(loaded_weight.to(torch.float32)) + return + + +class MiniMaxText01Model(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + scheduler_config=None, + prefix: str = "", + ) -> None: + super().__init__() + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.decoder_attention_types = getattr( + config, "attn_type_list", False) or getattr( + config, "decoder_attention_types", False) + if not self.decoder_attention_types: + self.decoder_attention_types = [1] * config.num_hidden_layers + self.num_layers = config.num_hidden_layers + + self._layer_barrier = False + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=self.vocab_size, + ) + else: + self.embed_tokens = PPMissingLayer() + + def layer_fn(prefix): + layer_idx = int(prefix.split('.')[-1]) + layer_config = config + layer_config.attention_type = self.decoder_attention_types[ + layer_idx] + layer_config.layer_idx = layer_idx + + decoder_kwargs = { + "quant_config": quant_config, + "layer_id": layer_idx, + "cache_config": cache_config + } + + if layer_config.attention_type == 0: + decoder_kwargs["linear_layer_id"] = sum( + 1 for i in range(layer_idx) + if self.decoder_attention_types[i] == 0) + else: + decoder_kwargs["linear_layer_id"] = None + + if hasattr(config, "num_local_experts") and isinstance( + config.num_local_experts, list): + decoder_kwargs["expert_num"] = config.num_local_experts[ + layer_idx] + elif hasattr(config, "num_local_experts") and isinstance( + config.num_local_experts, int): + decoder_kwargs["expert_num"] = config.num_local_experts + else: + decoder_kwargs["expert_num"] = 1 + + return MiniMaxText01DecoderLayer(layer_config, + **decoder_kwargs, + prefix=prefix) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, layer_fn, prefix=f"{prefix}.layers") + + linear_layer_nums = sum(1 for i in range(config.num_hidden_layers) + if self.decoder_attention_types[i] == 0) + max_slots_number = scheduler_config.max_num_seqs + self.cache_shape = (linear_layer_nums, max_slots_number, + config.num_attention_heads // + get_tensor_model_parallel_world_size(), + config.head_dim, config.head_dim) + _dummy = torch.zeros(1) + self._dtype = _dummy.dtype + del _dummy + + self.minimax_cache = MinimaxCacheManager(dtype=self._dtype, + cache_shape=self.cache_shape) + + rope_theta = getattr(config, "rope_theta", 10000) + head_dim = getattr(config, "head_dim", + config.hidden_size // config.num_attention_heads) + if hasattr(config, "max_model_len") and isinstance( + config.max_model_len, int): + max_position_embeddings = min(config.max_position_embeddings, + config.max_model_len) + self.rotary_emb = MiniMaxText01RotaryEmbedding( + head_dim, + rotary_dim=config.rotary_dim + if hasattr(config, "rotary_dim") else head_dim, + max_position=max_position_embeddings, + base=int(rope_theta), + is_neox_style=True, + cache_dtype=torch.float32, + ) + + norm_kwargs = {} + if hasattr(config, "rms_norm_eps"): + norm_kwargs["eps"] = config.rms_norm_eps + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, **norm_kwargs) + else: + self.norm = PPMissingLayer() + self.embed_scale = 1.0 + return + + def _clear_prefill_cache(self, attn_metadata, + minimax_cache_tensors: torch.Tensor, **kwargs): + seq_to_slot_maps = {} + seq_id_map = sum(list(kwargs["request_ids_to_seq_ids"].values()), []) + for _, seq_to_slot_map in ( + self.minimax_cache.cache_indices_mapping.items()): + seq_to_slot_maps.update(seq_to_slot_map) + + slots_to_clear = [] + for _prefill_id in range(getattr(attn_metadata, "num_prefills", 0)): + seq_id = seq_id_map[_prefill_id] + if attn_metadata.context_lens_tensor[ + _prefill_id] == 0 and seq_id in seq_to_slot_maps: + slots_to_clear.append(seq_to_slot_maps[seq_id]) + + if slots_to_clear: + slots_tensor = torch.tensor(slots_to_clear, + device=minimax_cache_tensors.device, + dtype=torch.long) + minimax_cache_tensors[:, slots_tensor, ...] = 0 + + def forward(self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + intermediate_tensors=None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs) -> torch.Tensor: + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + return None + if "request_ids_to_seq_ids" not in kwargs: + kwargs["request_ids_to_seq_ids"] = {} + if "finished_requests_ids" not in kwargs: + kwargs["finished_requests_ids"] = [] + ( + minimax_cache_tensors, + state_indices_tensor, + ) = self.minimax_cache.current_run_tensors(**kwargs) + if getattr(attn_metadata, "num_prefills", 0) > 0: + self._clear_prefill_cache(attn_metadata, minimax_cache_tensors, + **kwargs) + + minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors, + state_indices_tensor) + if get_pp_group().is_first_rank: + if inputs_embeds is None: + hidden_states = self.embed_scale * self.embed_tokens(input_ids) + else: + hidden_states = inputs_embeds + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + kv_cache_index = 0 + minimax_cache_index = 0 + attn_metadata.rotary_emb = self.rotary_emb + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + _caches = None + if isinstance(layer.self_attn, MiniMaxText01Attention): + _caches = kv_caches[kv_cache_index] + kv_cache_index += 1 + if isinstance(layer.self_attn, MiniMaxText01LinearAttention): + current_state_layer = minimax_cache_index + _caches = minimax_cache_params.at_layer_idx( + current_state_layer) + minimax_cache_index += 1 + hidden_states, residual = layer( + hidden_states=hidden_states, + positions=positions, + kv_caches=_caches, + attn_metadata=attn_metadata, + residual=residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + if residual is not None: + hidden_states, _ = self.norm(hidden_states, residual) + else: + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid, + SupportsV0Only): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config + self.lora_config = lora_config + + if not hasattr(config, "sliding_window"): + config.sliding_window = None + + self.CONCAT_FFN = True + + self.unpadded_vocab_size = self.config.vocab_size + if hasattr(vllm_config.model_config, "max_model_len"): + self.config.max_model_len = vllm_config.model_config.max_model_len + self.model = MiniMaxText01Model( + self.config, + quant_config, + cache_config=vllm_config.cache_config, + scheduler_config=vllm_config.scheduler_config, + prefix=maybe_prefix(prefix, "model")) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + self.config.hidden_size, + org_num_embeddings=self.config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + ) + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + self.config.vocab_size) + + self.sampler = Sampler() + else: + self.lm_head = PPMissingLayer() + + flash_layer_count = sum(1 for attn_type in self.config.attn_type_list + if attn_type == 1) + self.kv_cache = [torch.tensor([]) for _ in range(flash_layer_count)] + return + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.model.minimax_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs( + batch_size) + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, self.kv_cache, + intermediate_tensors, inputs_embeds, + **kwargs) + + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ): + + next_tokens = self.sampler(logits, sampling_metadata) + + return next_tokens + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> None: + params_dict = dict(self.named_parameters()) + + def which_layer(name: str) -> int: + if "layers" in name: + after_layer = name.split("layers")[-1] + return int(after_layer.split(".")[1]) + return None + + def is_linear_attn_layer(layer_idx: int) -> bool: + if layer_idx is None or not hasattr(self.config, "attn_type_list"): + return False + return self.config.attn_type_list[layer_idx] == 0 + + def is_moe_weight(name: str) -> bool: + return "block_sparse_moe" in name and not name.endswith(".bias") + + def get_expert_id(param_name): + pattern = r'model\.layers\.\d+\.block_sparse_moe\.experts\.(\d+)\.' + match = re.search(pattern, param_name) + if match: + return match.group(1) + return None + + def load_sparse_moe_weight(name: str, loaded_weight: torch.Tensor, + self) -> None: + if isinstance(self.config.num_local_experts, list): + expert_params_mapping = [ + ("w13_weight" + if weight_name in ["w1", "w3"] else "w2_weight", + f"experts.{expert_id}.{weight_name}.weight", expert_id) + for expert_id in range(max(self.config.num_local_experts)) + for weight_name in ["w1", "w2", "w3"] + ] + else: + expert_params_mapping = [ + ("w13_scale" if weight_name in ["w1", "w3"] else + "w2_scale", f"{expert_id}.{weight_name}.weight_scale", + expert_id, weight_name) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"] + ] + [("w13_weight" if weight_name in ["w1", "w3"] else + "w2_weight", f"{expert_id}.{weight_name}.weight", + expert_id, weight_name) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"]] + for (param_name, weight_name, expert_id, + shard_id) in expert_params_mapping: + name_expert_id = get_expert_id(name) + if name_expert_id is not None and int(name_expert_id) != int( + expert_id): + continue + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + return + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader = weight_loader_with_alias(name)(weight_loader) + weight_loader(param, + loaded_weight, + weight_name, + expert_id=expert_id, + shard_id=shard_id) + break + else: + if is_pp_missing_parameter(name, self): + return + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader = weight_loader_with_alias(name)(weight_loader) + weight_loader(param, loaded_weight) + return + + def is_shared_mlp_weight(name: str) -> bool: + return "shared_mlp" in name and not name.endswith(".bias") + + def load_shared_mlp_weight(name: str, loaded_weight: torch.Tensor, + self) -> None: + if not self.CONCAT_FFN: + if "gate_proj" in name: + name = name.replace("gate_proj", "w1", 1) + elif "up_proj" in name: + name = name.replace("up_proj", "w3", 1) + elif "down_proj" in name: + name = name.replace("down_proj", "w2", 1) + else: + if "gate_proj" in name: + name = name.replace("gate_proj", "gate_up_proj", 1) + loaded_shard_id = 0 + elif "up_proj" in name: + name = name.replace("up_proj", "gate_up_proj", 1) + loaded_shard_id = 1 + if is_pp_missing_parameter(name, self): + return + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader = weight_loader_with_alias(name)(weight_loader) + if not self.CONCAT_FFN: + weight_loader(param, loaded_weight) + else: + if "gate_up_proj" in name: + weight_loader(param, loaded_weight, loaded_shard_id) + elif "down_proj" in name: + weight_loader(param, loaded_weight) + else: + raise AssertionError( + "MLP weight not in [gate_up_proj, down_proj]") + return + + def is_mha_weight(name: str) -> bool: + return "self_attn" in name and not name.endswith(".bias") + + def load_linear_attn_weight(name: str, loaded_weight: torch.Tensor, + self) -> None: + if is_pp_missing_parameter(name, self): + return + param = params_dict[name] + + weight_loader = getattr( + param, "weight_loader", + MiniMaxText01LinearAttention.weight_direct_load) + weight_loader = weight_loader_with_alias(name)(weight_loader) + weight_loader(param, loaded_weight) + return + + def load_flash_attn_weight(name: str, loaded_weight: torch.Tensor, + self) -> None: + + flash_mha_params_mapping = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + for (param_name, weight_name, + shard_id) in flash_mha_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + return + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader = weight_loader_with_alias(name)(weight_loader) + weight_loader(param, loaded_weight, shard_id) + break + else: + if is_pp_missing_parameter(name, self): + return + param = params_dict[name] + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader = weight_loader_with_alias(name)(weight_loader) + weight_loader(param, loaded_weight) + return + + def is_layer_norm_weight(name: str) -> bool: + return "norm" in name and not name.endswith( + ".bias") and name in params_dict + + def load_layer_norm_weight(name: str, loaded_weight: torch.Tensor, + self) -> None: + if is_pp_missing_parameter(name, self): + return + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader = weight_loader_with_alias(name)(weight_loader) + weight_loader(param, loaded_weight) + return + + def load_basic_weight(name: str, loaded_weight: torch.Tensor, + self) -> None: + if is_pp_missing_parameter(name, self): + return + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader = weight_loader_with_alias(name)(weight_loader) + weight_loader(param, loaded_weight) + return + + for name, loaded_weight in weights: + weight_at_layer = which_layer(name) + if weight_at_layer and weight_at_layer >= len( + self.config.attn_type_list): + continue + + if is_layer_norm_weight(name): + load_layer_norm_weight(name, loaded_weight, self) + continue + if is_mha_weight(name): + if is_linear_attn_layer(weight_at_layer): + load_linear_attn_weight(name, loaded_weight, self) + else: + load_flash_attn_weight(name, loaded_weight, self) + continue + if is_moe_weight(name): + load_sparse_moe_weight(name, loaded_weight, self) + continue + if is_shared_mlp_weight(name): + load_shared_mlp_weight(name, loaded_weight, self) + continue + + if "rotary_emb.inv_freq" in name: + continue + + load_basic_weight(name, loaded_weight, self) + return diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py new file mode 100644 index 0000000..872769d --- /dev/null +++ b/vllm/model_executor/models/mistral3.py @@ -0,0 +1,659 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import abstractmethod +from collections.abc import Iterable, Mapping, Sequence +from functools import cached_property +from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict, + TypeVar, Union) + +import torch +import torch.nn as nn +from transformers import (BatchFeature, Mistral3Config, PixtralVisionConfig, + PretrainedConfig) +from transformers.models.pixtral import PixtralProcessor + +from vllm.config import VllmConfig +from vllm.inputs import InputProcessingContext +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, + MultiModalDataItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, ProcessingCache, + PromptReplacement, PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel +from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, + maybe_prefix, merge_multimodal_embeddings) +from .vision import (get_vision_encoder_info, scatter_patch_features, + select_patch_features) + + +class Mistral3ImagePixelInputs(TypedDict): + type: Literal["pixel_values_pixtral"] + pixel_values: Union[torch.Tensor, list[torch.Tensor]] + """ + Shape: `(batch_size * num_images, num_channels, height, width)` + + Note that `height` or `width` may be different per batch and image, + in which case the data is passed as a list instead of a batched tensor. + """ + + embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] + """ + A boolean mask indicating which image embeddings correspond + to patch tokens. + + Shape: `(batch_size, num_images, num_embeds)` + """ + + +class Mistral3PatchMerger(nn.Module): + """ + Learned merging of spatial_merge_size ** 2 patches + """ + + def __init__(self, vision_hidden_size: int, spatial_merge_size: int, + patch_size: int): + super().__init__() + + self.vision_hidden_size = vision_hidden_size + self.spatial_merge_size = spatial_merge_size + self.patch_size = patch_size + self.merging_layer = nn.Linear(vision_hidden_size * + self.spatial_merge_size**2, + vision_hidden_size, + bias=False) + + def forward(self, image_features: torch.Tensor, + image_sizes: torch.Tensor) -> torch.Tensor: + image_sizes = [(image_size[0] // self.patch_size, + image_size[1] // self.patch_size) + for image_size in image_sizes] + + tokens_per_image = [h * w for h, w in image_sizes] + d = image_features.shape[-1] + + permuted_tensor = [] + for image_index, image_tokens in enumerate( + image_features.split(tokens_per_image)): + # Reshape image_tokens into a 2D grid + h, w = image_sizes[image_index] + image_grid = image_tokens.view(h, w, d).permute(2, 0, + 1).unsqueeze(0) + grid = torch.nn.functional.unfold( + image_grid, + kernel_size=self.spatial_merge_size, + stride=self.spatial_merge_size) + grid = grid.view(d * self.spatial_merge_size**2, -1).t() + permuted_tensor.append(grid) + + image_features = torch.cat(permuted_tensor, dim=0) + image_features = self.merging_layer(image_features) + return image_features + + +class Mistral3MultiModalProjector(nn.Module): + + def __init__(self, + vision_hidden_size: int, + text_hidden_size: int, + spatial_merge_size: int, + patch_size: int, + projector_hidden_act: str, + multimodal_projector_bias: bool, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + + self.norm = RMSNorm(vision_hidden_size, eps=1e-5) + self.patch_merger = Mistral3PatchMerger( + vision_hidden_size=vision_hidden_size, + spatial_merge_size=spatial_merge_size, + patch_size=patch_size) + + self.linear_1 = ColumnParallelLinear(vision_hidden_size, + text_hidden_size, + bias=multimodal_projector_bias, + quant_config=quant_config, + prefix=f"{prefix}.linear_1") + self.act = get_act_fn(projector_hidden_act) + self.linear_2 = RowParallelLinear(text_hidden_size, + text_hidden_size, + bias=multimodal_projector_bias, + quant_config=quant_config, + prefix=f"{prefix}.linear_2") + + def forward(self, image_features: torch.Tensor, + image_sizes: torch.Tensor) -> torch.Tensor: + image_features = self.norm(image_features) + image_features = self.patch_merger(image_features, image_sizes) + hidden_states, _ = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.linear_2(hidden_states) + return hidden_states + + +class LlavaLikeConfig(Protocol): + vision_config: Final[PretrainedConfig] + image_token_index: Final[int] + vision_feature_select_strategy: Final[str] + vision_feature_layer: Final[Union[int, list[int]]] + + +class LlavaLikeProcessor(Protocol): + image_token: Final[str] + + +class BaseLlavaProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self) -> LlavaLikeConfig: + return self.ctx.get_hf_config(Mistral3Config) + + def get_vision_encoder_info(self): + return get_vision_encoder_info(self.get_hf_config()) + + @abstractmethod + def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor: + raise NotImplementedError + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"image": self.get_max_image_tokens()} + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + vision_encoder_info = self.get_vision_encoder_info() + return vision_encoder_info.get_num_image_tokens( + image_width=image_width, + image_height=image_height, + ) + + def get_image_size_with_most_features(self) -> ImageSize: + vision_encoder_info = self.get_vision_encoder_info() + width = height = vision_encoder_info.get_image_size() + return ImageSize(width=width, height=height) + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + return self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + ) + + +_I = TypeVar("_I", bound=BaseLlavaProcessingInfo) + + +class Mistral3DummyInputsBuilder(BaseDummyInputsBuilder[_I]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + + processor = self.info.get_hf_processor() + image_token = processor.image_token + target_width, target_height = \ + self.info.get_image_size_with_most_features() + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text=image_token * num_images, + mm_data=mm_data, + ) + + +class Mistral3ProcessingInfo(BaseLlavaProcessingInfo): + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(PixtralProcessor, **kwargs) + + +class Mistral3MultiModalProcessor( + BaseMultiModalProcessor[Mistral3ProcessingInfo]): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + pixel_values = processed_outputs.get("pixel_values") + if pixel_values is not None: + + # Avoid padding since we need the output for each image to be + # independent of other images for the cache to work correctly + image_sizes = processed_outputs["image_sizes"] + assert len(pixel_values) == len(image_sizes) + + processed_outputs["pixel_values"] = [ + p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes) + ] + + hf_config = self.info.get_hf_config() + vision_config = hf_config.vision_config + assert isinstance(vision_config, PixtralVisionConfig) + encoder_info = PixtralHFEncoderInfo(vision_config) + + tile_sizes = [ + encoder_info.get_patch_grid_size( + image_width=pixel_value.shape[-1], + image_height=pixel_value.shape[-2], + ) for pixel_value in processed_outputs["pixel_values"] + ] + embed_is_patch = [ + torch.tensor(([True] * ncols + [False]) * nrows) + for ncols, nrows in tile_sizes + ] + processed_outputs["embed_is_patch"] = embed_is_patch + + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + embed_is_patch=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + hf_config = self.info.get_hf_config() + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + + image_break_id = vocab[processor.image_break_token] + image_token_id = hf_config.image_token_index + image_end_id = vocab[processor.image_end_token] + + vision_config = hf_config.vision_config + assert isinstance(vision_config, PixtralVisionConfig) + encoder_info = PixtralHFEncoderInfo(vision_config) + + def get_replacement(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) + + ncols, nrows = encoder_info.get_patch_grid_size( + image_width=image_size.width, + image_height=image_size.height, + ) + + tokens = ([image_token_id] * ncols + [image_break_id]) * nrows + tokens[-1] = image_end_id + + return tokens + + return [ + PromptReplacement( + modality="image", + target=[image_token_id], + replacement=get_replacement, + ), + ] + + +def _build_mistral3_info( + ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo: + hf_config = ctx.get_hf_config(Mistral3Config) + assert isinstance(hf_config.vision_config, PixtralVisionConfig) + return Mistral3ProcessingInfo(ctx) + + +def _build_mistral3_processor( + info: _I, + dummy_inputs: BaseDummyInputsBuilder[_I], + *, + cache: Optional[ProcessingCache] = None, + enable_sanity_checks: bool = True, +) -> BaseMultiModalProcessor: + assert isinstance(info, Mistral3ProcessingInfo) + return Mistral3MultiModalProcessor( + info, + dummy_inputs, # type: ignore + cache=cache, + enable_sanity_checks=enable_sanity_checks, + ) + + +def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int: + """Determine the number of hidden layers to initialize up to in the + visual encoder. + + Args: + hf_config: Model config with vision feature layer(s). + """ + feature_layers = hf_config.vision_feature_layer + num_hidden_layers = hf_config.vision_config.num_hidden_layers + # If we have one feature layer, initialize up to that layer + if isinstance(feature_layers, int): + return _get_layer_index(feature_layers, num_hidden_layers) + # If we have multiple feature layers, initialize up to the deepest one + elif isinstance(feature_layers, (list, tuple)): + return max( + _get_layer_index(idx, num_hidden_layers) for idx in feature_layers) + raise TypeError(f"vision_layer_feature type: {type(feature_layers)}" + " is not supported") + + +def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int: + """Given a signed vision feature layer, get the number of hidden layers + needed to leverage it. + + Args: + feature_layer_index: Index of a required layer in the visual encoder. + num_hidden_layers: The total number of hidden layers in the visual + encoder. + """ + if feature_layer_index < 0: + return num_hidden_layers + feature_layer_index + 1 + return feature_layer_index + + +def init_vision_tower_for_llava( + hf_config: LlavaLikeConfig, + quant_config: Optional[QuantizationConfig], + *, + require_post_norm: Optional[bool] = None, + prefix: str = "", +) -> PixtralHFVisionModel: + vision_config = hf_config.vision_config + + # Initialize the vision tower only up to the deepest required feature layer + num_hidden_layers = _get_num_hidden_layers(hf_config) + + assert isinstance(vision_config, PixtralVisionConfig) + + return PixtralHFVisionModel( + vision_config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers, + require_post_norm=require_post_norm, + prefix=prefix, + ) + + +# TODO(mgoin): Support V1, there are issues with image batching/chunking +# that need to be resolved first. +@MULTIMODAL_REGISTRY.register_processor( + _build_mistral3_processor, + info=_build_mistral3_info, + dummy_inputs=Mistral3DummyInputsBuilder) +class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): + + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"] + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + + # NOTE: These are special cases for Pixtral-12B in the HF-format + # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json # noqa + if (config.text_config.architectures is None + and config.text_config.model_type == "mistral"): + config.text_config.architectures = ["MistralForCausalLM"] + if (config.projector_hidden_act is None + and config.vision_config.hidden_act == "gelu"): + config.projector_hidden_act = "gelu" + + # TODO: Optionally initializes this for supporting embeddings. + self.vision_tower = init_vision_tower_for_llava( + config, + quant_config, + require_post_norm=False, + prefix=maybe_prefix(prefix, "vision_tower")) + self.multi_modal_projector = Mistral3MultiModalProjector( + vision_hidden_size=config.vision_config.hidden_size, + text_hidden_size=config.text_config.hidden_size, + projector_hidden_act=config.projector_hidden_act, + spatial_merge_size=config.spatial_merge_size, + patch_size=config.vision_config.patch_size, + multimodal_projector_bias=config.multimodal_projector_bias, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "multi_modal_projector")) + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return get_sampler() + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + actual_dims = tuple(data.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("batch_size", *map(str, expected_dims)) + raise ValueError( + f"The expected shape of pixel values is {expected_expr}. " + f"You supplied {tuple(data.shape)}.") + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Mistral3ImagePixelInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + assert pixel_values is not None + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + assert self.config.vision_config.model_type == "pixtral" + embed_is_patch = kwargs.pop("embed_is_patch") + if not isinstance(embed_is_patch, (torch.Tensor, list)): + raise ValueError("Incorrect type of embed_is_patch. " + f"Got type: {type(embed_is_patch)}") + + return Mistral3ImagePixelInputs( + type="pixel_values_pixtral", + pixel_values=flatten_bn(pixel_values), + embed_is_patch=flatten_bn(embed_is_patch), + ) + + def _process_image_input( + self, + image_input: Mistral3ImagePixelInputs, + ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + if image_input["type"] == "image_embeds": + return image_input["data"] + + image_sizes = [(img.shape[-2], img.shape[-1]) + for img in image_input["pixel_values"]] + + image_features = self.vision_tower(image_input["pixel_values"]) + + if isinstance(image_features, torch.Tensor): + return self.multi_modal_projector(image_features, image_sizes) + + feature_sizes = [ + image_feature.shape[0] // self.config.spatial_merge_size**2 + for image_feature in image_features + ] + + image_embeds = self.multi_modal_projector(torch.cat(image_features), + image_sizes) + if len(feature_sizes) > 1: + image_embeds = torch.split(image_embeds, feature_sizes) + else: + image_embeds = (image_embeds, ) + return image_embeds + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + + vision_embeddings = self._process_image_input(image_input) + + return scatter_patch_features( + vision_embeddings, + image_input["embed_is_patch"], + ) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + select_patch_features(multimodal_embeddings), + self.config.image_token_index, + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + """Run forward pass for Mistral3. + + One key thing to understand is the `input_ids` already accounts for the + positions of the to-be-inserted image embeddings. + + Concretely, consider a text prompt: + `"USER: \\nWhat's the content of the image?\\nASSISTANT:"`. + + Tokenizer outputs: + `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879, + 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`. + + To reserve space in KV cache, we have to insert placeholder tokens + before they are inputted to the model, so the input processor prepends + additional image tokens (denoted as `32000`), resulting in: + `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618, + 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, + 29901]`. + + We insert 575 tokens so that including the original image token in the + input, there are a total of 576 (24 * 24) image tokens, which + corresponds to the number of image tokens inputted to the language + model, i.e. the number of image tokens outputted by the visual encoder. + + This way, the `positions` and `attn_metadata` are consistent + with the `input_ids`. + + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + pixel_values: The pixels in each input image. + + See also: + :class:`Mistral3ImagePixelInputs` + """ + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.language_model.model(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py new file mode 100644 index 0000000..0d5cc63 --- /dev/null +++ b/vllm/model_executor/models/mixtral.py @@ -0,0 +1,492 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Mixtral model.""" +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import MixtralConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class MixtralMoE(nn.Module): + """A tensor-parallel MoE implementation for Mixtral that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__(self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + dp_size: Optional[int] = None, + prefix: str = ""): + super().__init__() + self.hidden_size = hidden_size + + # Gate always runs at half / full precision for now. + + self.gate = ReplicatedLinear(hidden_size, + num_experts, + bias=False, + params_dtype=params_dtype, + quant_config=None, + prefix=f"{prefix}.gate") + + self.experts = FusedMoE(num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=tp_size, + dp_size=dp_size, + prefix=f"{prefix}.experts") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states, router_logits) + return final_hidden_states.view(orig_shape) + + +class MixtralAttention(nn.Module): + + def __init__( + self, + config: MixtralConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MixtralConfig has an optional head_dim argument + self.head_dim = getattr(config, "head_dim") + if self.head_dim is None: + self.head_dim = (self.hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=int(self.rope_theta), + is_neox_style=True, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class MixtralDecoderLayer(nn.Module): + + def __init__( + self, + config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 10000) + self.self_attn = MixtralAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn") + self.block_sparse_moe = MixtralMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.block_sparse_moe") + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.block_sparse_moe(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class MixtralModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: MixtralDecoderLayer( + config, cache_config, quant_config=quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers") + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer(positions, hidden_states, residual) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + fall_back_to_pt_during_load = False + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config + self.lora_config = lora_config + self.quant_config = quant_config + + self.model = MixtralModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_local_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if name.endswith("scale"): + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py new file mode 100644 index 0000000..5be91f4 --- /dev/null +++ b/vllm/model_executor/models/mixtral_quant.py @@ -0,0 +1,451 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Mixtral model.""" +from typing import Iterable, Optional, Set, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from transformers import MixtralConfig + +from vllm.attention import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class MixtralMLP(nn.Module): + + def __init__( + self, + num_experts: int, + hidden_size: int, + intermediate_size: int, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.num_experts = num_experts + self.ffn_dim = intermediate_size + self.hidden_dim = hidden_size + + self.w1 = ReplicatedLinear(self.hidden_dim, + self.ffn_dim, + bias=False, + quant_config=quant_config) + self.w2 = ReplicatedLinear(self.ffn_dim, + self.hidden_dim, + bias=False, + quant_config=quant_config) + self.w3 = ReplicatedLinear(self.hidden_dim, + self.ffn_dim, + bias=False, + quant_config=quant_config) + + # TODO: Use vllm's SiluAndMul + self.act_fn = nn.SiLU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + w1_out, _ = self.w1(hidden_states) + w1_out = self.act_fn(w1_out) + w3_out, _ = self.w3(hidden_states) + current_hidden_states = w1_out * w3_out + current_hidden_states, _ = self.w2(current_hidden_states) + return current_hidden_states + + +class MixtralMoE(nn.Module): + + def __init__( + self, + config: MixtralConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + self.num_total_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + if self.tp_size > self.num_total_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {self.num_total_experts}.") + # Split experts equally between ranks + self.expert_indicies = np.array_split(range( + self.num_total_experts), self.tp_size)[self.rank].tolist() + if not self.expert_indicies: + raise ValueError( + f"Rank {self.rank} has no experts assigned to it.") + + self.experts = nn.ModuleList([ + MixtralMLP(self.num_total_experts, + config.hidden_size, + config.intermediate_size, + quant_config=quant_config) + if idx in self.expert_indicies else None + for idx in range(self.num_total_experts) + ]) + self.gate = ReplicatedLinear(config.hidden_size, + self.num_total_experts, + bias=False, + quant_config=None) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, + self.top_k, + dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + + final_hidden_states = None + for expert_idx in self.expert_indicies: + expert_layer = self.experts[expert_idx] + expert_mask = (selected_experts == expert_idx) + expert_weights = (routing_weights * expert_mask).sum(dim=-1, + keepdim=True) + + current_hidden_states = expert_layer(hidden_states).mul_( + expert_weights) + if final_hidden_states is None: + final_hidden_states = current_hidden_states + else: + final_hidden_states.add_(current_hidden_states) + + return tensor_model_parallel_all_reduce(final_hidden_states).view( + num_tokens, hidden_dim) + + +class MixtralAttention(nn.Module): + + def __init__( + self, + config: MixtralConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MixtralConfig has an optional head_dim argument + self.head_dim = getattr(config, "head_dim", + self.hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=int(self.rope_theta), + is_neox_style=True, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class MixtralDecoderLayer(nn.Module): + + def __init__( + self, + config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 10000) + self.self_attn = MixtralAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.block_sparse_moe = MixtralMoE(config=config, + quant_config=quant_config) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.block_sparse_moe(hidden_states) + return hidden_states, residual + + +class MixtralModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: MixtralDecoderLayer( + config, cache_config, quant_config=quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers") + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer(positions, hidden_states, residual) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class MixtralForCausalLM(nn.Module, SupportsPP): + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = MixtralModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip experts that are not assigned to this worker. + if ("block_sparse_moe.experts." in name + and name not in params_dict): + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py new file mode 100644 index 0000000..1dc8adb --- /dev/null +++ b/vllm/model_executor/models/mllama.py @@ -0,0 +1,1618 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Mllama model.""" +import math +from collections.abc import Iterable, Mapping, Sequence +from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union + +import numpy as np +import torch +import torch.nn.functional as F +import transformers.models.mllama.configuration_mllama as config_mllama +from PIL.Image import Image +from torch import nn +from transformers import BatchFeature, MllamaConfig +from transformers.modeling_outputs import (BaseModelOutput, + CausalLMOutputWithPast) +from transformers.models.mllama.image_processing_mllama import ( + get_optimal_tiled_canvas) +from transformers.models.mllama.processing_mllama import ( + MllamaProcessor, get_cross_attention_token_mask) + +import vllm.distributed.parallel_state as ps +from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.attention.ops.paged_attn import PagedAttention +from vllm.attention.selector import _Backend +from vllm.config import VllmConfig +from vllm.distributed import get_pp_group, get_tp_group +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVCrossParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalEncDecInputs, + MultiModalFieldConfig, MultiModalKwargs) +from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, + MultiModalDataDict, MultiModalDataItems) +from vllm.multimodal.processing import (BaseProcessingInfo, + EncDecMultiModalProcessor, + PromptReplacement, PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs + +from .clip import CLIPMLP +from .interfaces import SupportsMultiModal, SupportsV0Only +from .llama import LlamaDecoderLayer, LlamaMLP +from .utils import maybe_prefix +from vllm import _custom_ops as ops + +logger = init_logger(__name__) + + +class MllamaImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: """ + """(batch_size, max_num_image, max_num_chunk, num_channel, height, width)""" + aspect_ratio_ids: torch.Tensor + """Shape: `(batch_size, max_num_image)`""" + aspect_ratio_mask: torch.Tensor + """Shape: `(batch_size, max_num_image, max_num_tiles)`""" + + +# TODO: support LlamaImageEmbeddingInputs + + +def calc_token_per_chunk(image_size: int) -> int: + assert image_size % 14 == 0, "chunk size should be multiple of 14" + token_per_chunk = (image_size // 14)**2 + 1 + return token_per_chunk + + +class MllamaProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self) -> MllamaConfig: + return self.ctx.get_hf_config(MllamaConfig) + + def get_hf_processor(self, **kwargs: object) -> MllamaProcessor: + return self.ctx.get_hf_processor(MllamaProcessor, **kwargs) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_token_per_chunk_from_config(self) -> int: + image_size = self.get_hf_config().vision_config.image_size + return calc_token_per_chunk(image_size) + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + vision_config = self.get_hf_config().vision_config + token_per_chunk = self.get_token_per_chunk_from_config() + mm_max_tokens = vision_config.max_num_tiles * token_per_chunk + return {"image": mm_max_tokens} + + def get_num_tiles_per_image(self, image_height: int, + image_width: int) -> int: + vision_config = self.get_hf_config().vision_config + max_num_tiles = vision_config.max_num_tiles + image_size = vision_config.image_size + tiled_height, tiled_width = get_optimal_tiled_canvas( + image_height, + image_width, + max_num_tiles, + tile_size=image_size, + ) + num_tiles_height = tiled_height // image_size + num_tiles_width = tiled_width // image_size + return num_tiles_height * num_tiles_width + + def get_image_size_with_most_features(self) -> ImageSize: + vision_config = self.get_hf_config().vision_config + image_size = vision_config.image_size + max_num_tiles = vision_config.max_num_tiles + # Result in the max possible feature size (h:w = 16:1) + return ImageSize(height=max_num_tiles * image_size, width=image_size) + + +class MllamaDummyInputsBuilder(BaseDummyInputsBuilder[MllamaProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + + target_width, target_height = \ + self.info.get_image_size_with_most_features() + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + hf_processor = self.info.get_hf_processor() + image_token: str = hf_processor.image_token + + return ProcessorInputs( + prompt_text=image_token * num_images, + mm_data=mm_data, + ) + + +class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] + ): + + def apply( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + return_mm_hashes: bool = False, + ) -> MultiModalEncDecInputs: + mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs, + return_mm_hashes) + + image_token_id = self.info.get_hf_config().image_token_index + # Check that the number of image tokens in the decoder prompt matches + # the number of images provided in mm_data + num_image_tokens = mm_inputs['prompt_token_ids'].count(image_token_id) + image_data = mm_data.get("image", []) + num_images = 1 if isinstance(image_data, Image) else len(image_data) + if num_image_tokens != num_images: + raise ValueError( + f"The number of image tokens ({num_image_tokens}) must be" + f" the same as the number of images ({num_images})") + + # Given prompt: P0 P1 P3 P4 D5 D6...., (P-prefill, D-decode) # noqa: E501 + # P0 & P1 do cross attention with placeholder of + # P3 P4 D5 D6 do cross attention with placeholder of and + # Example input to encoder and decoder: + # { + # 'encoder': { + # 'type': 'token', + # 'prompt_token_ids': [128256, 128256, ..., 128256], + # 'prompt': '<|image|><|image|>...<|image|>', + # 'multi_modal_data': {'image': }, # noqa: E501 + # }, + # 'decoder': { + # 'type': 'token', + # 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501 + # 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501 + # 'multi_modal_data': {'image': }, # noqa: E501 + # }, + # } + + if mm_data: + # Since only the last group of consecutive images + # are attended by the decoded tokens, we only need to + # get the number of tokens for those images. + token_per_chunk = self.info.get_token_per_chunk_from_config() + num_decode_images = self._get_num_image_in_last_group( + mm_inputs["prompt_token_ids"]) + num_encode_images = num_images - num_decode_images + + # Set encoder prompt length based on the number of tiles. + # This tells the block manager to allocate correct number + # of slots for encoder tokens. + num_tiles = mm_inputs["mm_kwargs"]["num_tiles"] + decode_tiles = num_tiles[num_encode_images:num_images].sum().item() + num_tokens = decode_tiles * token_per_chunk + mm_inputs["encoder_prompt_token_ids"] = [image_token_id + ] * num_tokens + mm_inputs["encoder_prompt"] = "<|image|>" * num_tokens + + return mm_inputs + + def _get_num_image_in_last_group(self, prompt_token_ids: List[int]) -> int: + num_images = 0 + for token_id in prompt_token_ids[::-1]: + if token_id == self.info.get_hf_config().image_token_index: + num_images += 1 + elif num_images > 0: + break + return num_images + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + tokenizer = self.info.get_tokenizer() + if mm_data: + num_tiles = [ + self.info.get_num_tiles_per_image(img.height, img.width) + for img in mm_data["images"] + ] + processed_outputs = super()._call_hf_processor( + prompt, mm_data, mm_kwargs) + processed_outputs["num_tiles"] = torch.tensor(num_tiles) + for k in ('pixel_values', 'aspect_ratio_ids', "aspect_ratio_mask"): + processed_outputs[k] = processed_outputs[k].squeeze(0) + + processed_token_ids = processed_outputs.pop("input_ids") + start_idx, end_idx = 0, processed_token_ids.size(1) + processed_prompt_text = tokenizer.decode(processed_token_ids[0]) + + hf_processor = self.info.get_hf_processor() + bos_token = hf_processor.bos_token + # Remove the bos_token from the start of prompt, + # because we all know there would be image_token. + if processed_prompt_text.startswith(bos_token): + start_idx += 1 + # Remove the bos_token from the end of prompt, + # because text is empty in this case. + if processed_prompt_text.endswith(bos_token): + end_idx -= 1 + processed_outputs[ + "input_ids"] = processed_token_ids[:, start_idx:end_idx] + else: + processed_outputs = tokenizer(prompt, + add_special_tokens=False, + return_tensors="pt") + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + aspect_ratio_ids=MultiModalFieldConfig.batched("image"), + aspect_ratio_mask=MultiModalFieldConfig.batched("image"), + num_tiles=MultiModalFieldConfig.batched("image"), + ) + + def create_encoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + ) -> Union[str, list[int]]: + data = mm_data.get("image", []) + num_images = 1 if isinstance(data, Image) else len(data) + image_token_id = self.info.get_hf_config().image_token_index + return [image_token_id] * num_images + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + token_per_chunk = self.info.get_token_per_chunk_from_config() + image_token_id = self.info.get_hf_config().image_token_index + + def get_replacement_mllama(item_idx): + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) + num_tile = self.info.get_num_tiles_per_image( + image_height=image_size.height, + image_width=image_size.width, + ) + num_tokens = num_tile * token_per_chunk + return [image_token_id] * num_tokens + + return [ + PromptReplacement( + modality="image", + target=[image_token_id], + replacement=get_replacement_mllama, + ) + ] + + +def _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask: torch.Tensor, + num_patches: int, + target_length: int, + dtype: torch.dtype, +) -> torch.Tensor: + # Expand aspect ratio mask to target_length + batch_size, max_num_tiles = aspect_ratio_mask.shape + attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, + 1).to(dtype) + attention_mask = attention_mask.repeat(1, 1, target_length, 1) + + # Mask padding patches + pad_patches = target_length - num_patches + attention_mask[:, :, -pad_patches:] = 0 + + # Invert the mask (0 -> 1, 1 -> 0) + attention_mask = 1 - attention_mask + + # Reshape to 2D and create 4D attention mask + # (batch_size, 1, max_num_tiles*target_length, max_num_tiles*target_length) + attention_mask = attention_mask.reshape(batch_size, + max_num_tiles * target_length, 1) + attention_mask = attention_mask @ attention_mask.transpose( + -1, -2) * torch.finfo(dtype).min + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + + +class ColumnParallelConv2dPatch(torch.nn.Module): + """Conv2D Patching layer with model parallelism. + Column parallel over unfolded input. + Arguments: + in_channels: Input channels. + out_channels: Output channels. + kernel_size: Size of convolution kernel. + stride (default 1): Stride for convolution. + bias (default False): Use bias in Conv2d. + Input: (bsz, in_channels, width, height) + Output: (bsz, num_tokens, out_channels) + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]], + bias: bool = False, + ) -> None: + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride) + self._linear = ColumnParallelLinear( + in_channels * kernel_size[0] * kernel_size[1], + out_channels, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self._unfold(x) + x = x.permute(0, 2, 1) + x, _ = self._linear(x) + return x + + +class MllamaPrecomputedAspectRatioEmbedding(nn.Module): + + def __init__(self, + config: config_mllama.MllamaVisionConfig, + is_gated: bool = True): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.hidden_size = config.hidden_size + self.max_aspect_ratio_id = config.max_aspect_ratio_id + self.is_gated = is_gated + + self.embedding = nn.Embedding(self.max_aspect_ratio_id + 1, + self.max_num_tiles * self.hidden_size) + if is_gated: + self.gate = nn.Parameter(torch.zeros(1)) + + def forward(self, hidden_state: torch.Tensor, + aspect_ratio_ids: torch.Tensor) -> torch.Tensor: + embeddings = self.embedding(aspect_ratio_ids) + embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, + self.hidden_size) + + if self.is_gated: + embeddings = embeddings * self.gate.tanh() + + hidden_state = hidden_state + embeddings + return hidden_state + + +class MllamaPrecomputedPositionEmbedding(nn.Module): + + def __init__(self, config: config_mllama.MllamaVisionConfig): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.max_aspect_ratio_id = config.max_aspect_ratio_id + self.num_patches = (config.image_size // config.patch_size)**2 + 1 + self.hidden_size = config.hidden_size + self.scale = config.hidden_size**-0.5 + + self.gate = nn.Parameter(torch.zeros(1)) + + # position embedding + position_embedding = torch.randn(self.num_patches, self.hidden_size) + self.embedding = nn.Parameter(self.scale * position_embedding) + + # tile position embedding + self.tile_embedding = nn.Embedding( + self.max_aspect_ratio_id + 1, + self.max_num_tiles * self.num_patches * self.hidden_size) + + def forward(self, hidden_state: torch.Tensor, + aspect_ratio_ids: torch.Tensor) -> torch.Tensor: + # position embeddings + gated_position_embedding = (1 - self.gate.tanh()) * self.embedding + hidden_state = hidden_state + gated_position_embedding.view( + 1, 1, self.num_patches, self.hidden_size) + + # precomputed tile position embeddings + tile_position_embedding = self.tile_embedding(aspect_ratio_ids) + batch_size = hidden_state.shape[0] + tile_position_embedding = tile_position_embedding.reshape( + batch_size, self.max_num_tiles, self.num_patches, self.hidden_size) + gated_tile_position_embedding = self.gate.tanh( + ) * tile_position_embedding + hidden_state = hidden_state + gated_tile_position_embedding + + return hidden_state + + +# TODO: support other attention backends for attention in vision model +class MllamaVisionSdpaAttention(nn.Module): + + def __init__(self, + config: config_mllama.MllamaVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + + tensor_parallel_size = get_tp_group().world_size + self.embed_dim = config.hidden_size + self.num_heads = config.attention_heads + self.head_dim = config.hidden_size // config.attention_heads + self.num_local_heads = self.num_heads // tensor_parallel_size + self.q_size = self.num_local_heads * self.head_dim + self.kv_size = self.num_local_heads * self.head_dim + self.scale = 1 / math.sqrt(self.head_dim) + + self.qkv_proj = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.num_heads * self.head_dim, + self.embed_dim, + bias=False, + input_is_parallel=True, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_state) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q.view(q.shape[0], q.shape[1], self.num_local_heads, + self.head_dim).transpose(1, 2) + k = k.view(k.shape[0], k.shape[1], self.num_local_heads, + self.head_dim).transpose(1, 2) + v = v.view(v.shape[0], v.shape[1], self.num_local_heads, + self.head_dim).transpose(1, 2) + + # TODO: remove padding in image encoder + attn_output = torch.empty_like(q) + for i in range(q.shape[0]): + attn = q[i] @ (k[i] * self.scale).permute(0,2,1) + if attention_mask is not None: + attn = attn + attention_mask[i] + attn = torch.softmax(attn, dim=-1) + + output = attn @ v[i] + attn_output[i] = output + # attn_output = F.scaled_dot_product_attention(q, + # k, + # v, + # attn_mask=attention_mask, + # dropout_p=0.0) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(attn_output.shape[0], + attn_output.shape[1], -1) + output, _ = self.o_proj(attn_output) + return output + + +class MllamaVisionEncoderLayer(nn.Module): + + def __init__( + self, + config: config_mllama.MllamaVisionConfig, + quant_config: Optional[QuantizationConfig], + prefix: str = "", + is_gated: bool = False, + ) -> None: + super().__init__() + + self.hidden_size = config.hidden_size + self.num_attention_heads = config.attention_heads + self.is_gated = is_gated + self.intermediate_size = config.intermediate_size + + self.self_attn = MllamaVisionSdpaAttention( + config, quant_config=quant_config, prefix=f"{prefix}.self_attn") + self.mlp = CLIPMLP(config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + self.input_layernorm = nn.LayerNorm(self.hidden_size, + eps=config.norm_eps) + self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, + eps=config.norm_eps) + + # there used to be an if else here, no code path + if is_gated: + self.gate_attn = nn.Parameter(torch.ones(1) * math.pi / 4) + self.gate_ffn = nn.Parameter(torch.ones(1) * math.pi / 4) + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ): + # Self Attention + residual = hidden_state + hidden_state = self.input_layernorm(hidden_state) + hidden_state = self.self_attn(hidden_state, + attention_mask=attention_mask) + gate_attn = 1 if not self.is_gated else self.gate_attn.tanh() + hidden_state = residual + gate_attn * hidden_state + + # Feed forward + residual = hidden_state + hidden_state = self.post_attention_layernorm(hidden_state) + hidden_state = self.mlp(hidden_state) + gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh() + hidden_state = residual + gate_ffn * hidden_state + + return hidden_state + + +class MllamaVisionEncoder(nn.Module): + + def __init__( + self, + config: config_mllama.MllamaVisionConfig, + quant_config: Optional[QuantizationConfig], + num_layers: int = 32, + is_gated: bool = False, + output_hidden_states=None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + MllamaVisionEncoderLayer(config, + quant_config=quant_config, + is_gated=is_gated, + prefix=f"{prefix}.layers.{layer_idx}") + for layer_idx in range(num_layers) + ]) + self.output_hidden_states = output_hidden_states or [] + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> Union[Tuple, BaseModelOutput]: + encoder_states = () + + for i, encoder_layer in enumerate(self.layers): + if i in self.output_hidden_states: + encoder_states = encoder_states + (hidden_states, ) + hidden_states = encoder_layer( + hidden_states, + attention_mask, + ) + + if len(self.layers) - 1 in self.output_hidden_states: + encoder_states = encoder_states + (hidden_states, ) + + return hidden_states, encoder_states + + +class MllamaVisionModel(nn.Module): + + def __init__( + self, + config: config_mllama.MllamaVisionConfig, + quant_config: Optional[QuantizationConfig], + prefix: str = "", + ) -> None: + super().__init__() + + self.image_size = config.image_size + self.patch_size = config.patch_size + self.max_num_tiles = config.max_num_tiles + self.hidden_size = config.hidden_size + self.in_channels = config.num_channels + self.intermediate_layers_indices = config.intermediate_layers_indices + + self.num_patches = (self.image_size // self.patch_size)**2 + 1 + self.scale = config.hidden_size**-0.5 + + self.patch_embedding = ColumnParallelConv2dPatch( + in_channels=config.num_channels, + out_channels=self.hidden_size, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.class_embedding = nn.Parameter(self.scale * + torch.randn(self.hidden_size)) + self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding( + config) + + self.pre_tile_positional_embedding = \ + MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True) + self.post_tile_positional_embedding = \ + MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True) + + # layer norms + self.layernorm_pre = nn.LayerNorm(self.hidden_size) + self.layernorm_post = nn.LayerNorm(self.hidden_size) + + # encoders + self.transformer = MllamaVisionEncoder( + config, + quant_config, + config.num_hidden_layers, + is_gated=False, + output_hidden_states=config.intermediate_layers_indices, + prefix=f"{prefix}.transformer", + ) + self.global_transformer = MllamaVisionEncoder( + config, + quant_config, + config.num_global_layers, + is_gated=True, + prefix=f"{prefix}.global_transformer", + ) + + def apply_class_embedding(self, + hidden_state: torch.Tensor) -> torch.Tensor: + batch_size, _, hidden_size = hidden_state.shape + class_embedding = self.class_embedding.expand(batch_size, 1, + hidden_size) + hidden_state = torch.cat([class_embedding, hidden_state], dim=1) + return hidden_state + + def forward(self, pixel_values: torch.Tensor, + aspect_ratio_ids: torch.Tensor, + aspect_ratio_mask: torch.Tensor) -> torch.Tensor: + batch_size, num_concurrent_media, num_tiles, num_channels, \ + height, width = pixel_values.shape + + pixel_values = pixel_values.reshape( + batch_size * num_concurrent_media * num_tiles, num_channels, + height, width) + aspect_ratio_ids = aspect_ratio_ids.reshape( + batch_size * num_concurrent_media, -1) + + # patch embedding + patch_embeds = self.patch_embedding( + pixel_values.to(self.layernorm_pre.weight.dtype)) + hidden_state = patch_embeds + hidden_state = ps.get_tp_group().all_gather(hidden_state) + + # tile embeddings + _, num_patches, dim = hidden_state.shape + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, -1, dim) + hidden_state = self.pre_tile_positional_embedding( + hidden_state, aspect_ratio_ids) + + # apply cls token + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media * num_tiles, num_patches, dim) + hidden_state = self.apply_class_embedding(hidden_state) + num_patches += 1 + + # apply position embeddings + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, num_patches, dim) + hidden_state = self.gated_positional_embedding(hidden_state, + aspect_ratio_ids) + + # apply encoder + hidden_state = self.layernorm_pre(hidden_state) + + # Compute the number of tokens to pad + num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8 + # Compute padding tuple for pad function + padding = ( + 0, 0, 0, num_padding_patches + ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2) + # Pad the tensor + hidden_state = F.pad(hidden_state, padding, mode="constant", value=0) + slice_index = -num_padding_patches if num_padding_patches > 0 else None + + attention_mask = aspect_ratio_mask.reshape( + batch_size * num_concurrent_media, -1) + attention_mask = _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask=attention_mask, + num_patches=self.num_patches, + target_length=hidden_state.shape[2], + dtype=self.layernorm_pre.weight.dtype, + ) + + hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, + dim) + output = self.transformer( + hidden_state, + attention_mask=attention_mask, + ) + hidden_state, intermediate_hidden_states = output[0], output[1] + intermediate_hidden_states = torch.stack(intermediate_hidden_states, + dim=-1) + + # apply global encoder + hidden_state = self.layernorm_post(hidden_state) + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim) + hidden_state = self.post_tile_positional_embedding( + hidden_state, aspect_ratio_ids) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles * (num_patches + num_padding_patches), dim) + hidden_state = self.global_transformer( + hidden_state, attention_mask=attention_mask)[0] + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim) + hidden_state = hidden_state[:, :, :slice_index] + + # adding intermediate layer outputs + hidden_state = hidden_state.reshape(batch_size, num_concurrent_media, + num_tiles, num_patches, dim) + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size * num_concurrent_media, num_tiles, + num_patches + num_padding_patches, -1) + intermediate_hidden_states = intermediate_hidden_states[:, :, : + slice_index] + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size, num_concurrent_media, num_tiles, num_patches, -1) + hidden_state = torch.cat([hidden_state, intermediate_hidden_states], + dim=-1) + return hidden_state + + +class MllamaTextRMSNorm(nn.Module): + + def __init__(self, hidden_size, eps=1e-6): + """ + MllamaTextRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class MllamaTextCrossAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: Optional[config_mllama.MllamaTextConfig] = None, + layer_idx: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.pipeline_parallel_rank = get_pp_group().rank_in_group + self.tensor_parallel_size = get_tp_group().world_size + self.num_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + + self.num_local_heads = self.num_heads // self.tensor_parallel_size + self.num_local_key_value_heads = \ + self.num_key_value_heads // self.tensor_parallel_size + self.hidden_size = config.hidden_size + self.head_dim = config.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + + self.layer_idx = layer_idx + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.q_local_size = self.num_local_heads * self.head_dim + self.kv_local_size = self.num_local_key_value_heads * self.head_dim + + self.qkv_proj = QKVCrossParallelLinear( + self.hidden_size, + self.head_dim, + self.num_heads, + self.num_key_value_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + self.num_heads * self.head_dim, + self.hidden_size, + bias=False, + input_is_parallel=True, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + # vllm.model_executor.layers.layernorm.RMSNorm has precision issue, + # use huggingface's instead + self.q_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.scaling = self.head_dim**-0.5 + + self.attn = Attention( + self.num_local_heads, + self.head_dim, + self.scaling, + self.num_local_key_value_heads, + prefix=f"{prefix}.attn", + attn_type=AttentionType.ENCODER_DECODER, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + kv_range_for_decode: Optional[List[Tuple[int, int]]], + cross_attention_states: Optional[torch.Tensor], + ) -> torch.Tensor: + q, k, v = self.qkv_proj(hidden_states, cross_attention_states) + if cross_attention_states is not None: + k = k.view(-1, self.num_local_key_value_heads, self.head_dim) + v = v.view(-1, self.num_local_key_value_heads, self.head_dim) + k = self.k_norm(k) + + q = q.view(-1, self.num_local_heads, self.head_dim) + q = self.q_norm(q) + + if attention_mask is not None: + output = self._attention_with_mask(q, k, v, attention_mask, + kv_range_for_decode) + else: + output = self.attn( + q.view(-1, self.num_local_heads * self.head_dim), k, v) + out, _ = self.o_proj(output) + return out + + def _attention_with_mask( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attention_mask: torch.Tensor, + kv_range_for_decode: List[Tuple[int, int]], + ) -> torch.Tensor: + kv_cache = self.attn.kv_cache[self.pipeline_parallel_rank] + attn_metadata: AttentionMetadata = get_forward_context().attn_metadata + # Skip writing kv-cache for the initial profiling run. + # TODO (NickLucche) replace with custom attn bias and use standard attn + if len(kv_cache.shape) > 1: + i = torch.ones(1, dtype=torch.float32) + if self.attn.backend in (_Backend.FLASH_ATTN, + _Backend.FLASH_ATTN_VLLM_V1): + cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) + cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode]) + ops.reshape_and_cache_flash( + cached_k, + cached_v, + kv_cache[0], + kv_cache[1], + attn_metadata. + cross_slot_mapping, # type: ignore[union-attr] + "auto", + i, + i, + ) + elif self.attn.backend in (_Backend.XFORMERS, _Backend.ROCM_FLASH, + _Backend.TORCH_SDPA): + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_local_key_value_heads, self.head_dim) + cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) + cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode]) + PagedAttention.write_to_paged_cache( + cached_k, cached_v, key_cache, value_cache, + attn_metadata.cross_slot_mapping, "auto", i, i) + else: + raise ValueError( + f"Unsupported Attention backend {self.attn.backend} " + "enum found. Expected the Attention backend to be " + "FLASH_ATTN, FLASH_ATTN_VLLM_V1, " + "XFORMERS or TORCH_SDPA.") + + # We have to call torch.sdpa for prefill when using a + # custom cross-attention mask. Because the mask is not a + # standard causal mask, neither a block diagonal mask which + # can be optimized by xformers.BlockDiagonalMask. + # The mask is specially calculated for supporting multi + # images and interleaved images. + q_len = q.shape[0] + kv_len = k.shape[0] + q = q.transpose(0, 1).view(self.num_local_key_value_heads, + self.num_key_value_groups, q_len, + self.head_dim).contiguous() + k = k.transpose(0, + 1)[:, + None, :, :].expand(self.num_local_key_value_heads, + self.num_key_value_groups, + kv_len, + self.head_dim).contiguous() + v = v.transpose(0, + 1)[:, + None, :, :].expand(self.num_local_key_value_heads, + self.num_key_value_groups, + kv_len, + self.head_dim).contiguous() + attention_mask = attention_mask.view(1, 1, q_len, kv_len) + output = F.scaled_dot_product_attention(q, + k, + v, + attn_mask=attention_mask, + is_causal=False) + output = output.permute(2, 0, 1, 3).reshape( + q_len, self.num_local_heads * self.head_dim) + return output + + +class MllamaCrossAttentionDecoderLayer(torch.nn.Module): + """Cross-attention transformer block with tanh-gated attention + and feedforward.""" + + def __init__( + self, + config: config_mllama.MllamaTextConfig, + layer_idx: int, + quant_config: Optional[QuantizationConfig], + prefix: str = "", + ) -> None: + super().__init__() + + self.layer_idx = layer_idx + self.cross_attn = MllamaTextCrossAttention( + config=config, + layer_idx=layer_idx, + quant_config=quant_config, + prefix=f"{prefix}.cross_attn", + ) + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.cross_attn_attn_gate = torch.nn.Parameter(torch.zeros(1)) + + self.mlp = LlamaMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.cross_attn_mlp_gate = torch.nn.Parameter(torch.zeros(1)) + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: torch.Tensor, + cross_attention_mask: torch.Tensor, + kv_range_for_decode: Optional[List[Tuple[int, int]]], + full_text_row_masked_out_mask: torch.Tensor, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.cross_attn( + hidden_states=hidden_states, + attention_mask=cross_attention_mask, + kv_range_for_decode=kv_range_for_decode, + cross_attention_states=cross_attention_states, + ) + hidden_states = full_text_row_masked_out_mask * hidden_states + hidden_states = residual + self.cross_attn_attn_gate.tanh( + ) * hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = full_text_row_masked_out_mask * hidden_states + hidden_states = residual + self.cross_attn_mlp_gate.tanh( + ) * hidden_states + return hidden_states + + +class MllamaTextModel(nn.Module): + config_class = config_mllama.MllamaTextConfig + base_model_prefix = "model" + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config.text_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding(config.vocab_size + 8, + config.hidden_size) + self.cross_attention_layers = config.cross_attention_layers + + layers = [] + for layer_idx in range(config.num_hidden_layers): + if layer_idx in self.cross_attention_layers: + layers.append( + MllamaCrossAttentionDecoderLayer( + config, + layer_idx, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}", + )) + else: + # TODO: force LlamaDecoderLayer to config.attention_bias=False + layers.append( + LlamaDecoderLayer( + config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}", + )) + + self.layers = nn.ModuleList(layers) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.LongTensor, + positions: Optional[torch.LongTensor], + cross_attention_states: Optional[torch.LongTensor], + cross_attention_mask: Optional[torch.LongTensor], + kv_range_for_decode: Optional[List[Tuple[int, int]]], + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, + torch.Tensor]], + skip_cross_attention: bool, + ) -> torch.Tensor: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + + for idx, decoder_layer in enumerate(self.layers): + if idx in self.cross_attention_layers: + if not skip_cross_attention: + hidden_states = decoder_layer( + hidden_states=hidden_states, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + kv_range_for_decode=kv_range_for_decode, + full_text_row_masked_out_mask= + full_text_row_masked_out_mask, + ) + else: + hidden_states, residual = decoder_layer( + positions=positions, + hidden_states=hidden_states, + residual=None, + ) + hidden_states = hidden_states + residual + hidden_states = self.norm(hidden_states) + return hidden_states + + +class MllamaForCausalLM(nn.Module): + config_class = config_mllama.MllamaTextConfig + base_model_prefix = "language_model" + _no_split_modules = [ + "MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer" + ] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config.text_config + quant_config = vllm_config.quant_config + + self.vocab_size = config.vocab_size + self.model = MllamaTextModel(vllm_config=vllm_config, + prefix=f"{prefix}.model") + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + quant_config=quant_config, + prefix=f"{prefix}.lm_head", + ) + + def forward( + self, + input_ids: torch.LongTensor, + positions: Optional[torch.LongTensor], + cross_attention_states: Optional[torch.LongTensor], + cross_attention_mask: Optional[torch.LongTensor], + kv_range_for_decode: Optional[List[Tuple[int, int]]], + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, + torch.Tensor]], + skip_cross_attention: bool, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + kv_range_for_decode=kv_range_for_decode, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + skip_cross_attention=skip_cross_attention, + ) + return hidden_states + + +@MULTIMODAL_REGISTRY.register_processor(MllamaMultiModalProcessor, + info=MllamaProcessingInfo, + dummy_inputs=MllamaDummyInputsBuilder) +class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsV0Only): + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"] + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config: MllamaConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.quant_config = quant_config + self.vocab_size = config.text_config.vocab_size + self.hidden_size = config.text_config.hidden_size + self.max_num_tiles = config.vision_config.max_num_tiles + self.vision_output_dim = config.vision_config.vision_output_dim + self.pad_token_id = \ + config.pad_token_id if config.pad_token_id is not None else -1 + self.image_size = config.vision_config.image_size + self.image_token_id = config.image_token_index + + self.vision_model = MllamaVisionModel(config.vision_config, + quant_config, + prefix=maybe_prefix( + prefix, "vision_model")) + self.language_model = MllamaForCausalLM( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + self.multi_modal_projector = ColumnParallelLinear( + config.vision_config.vision_output_dim, + config.text_config.hidden_size, + bias=True, + quant_config=quant_config, + gather_output=True, + prefix=maybe_prefix(prefix, "multi_modal_projector"), + ) + self.logits_processor = LogitsProcessor(config.output_hidden_states, + config.text_config.vocab_size) + self.sampler = get_sampler() + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.language_model.lm_head, + hidden_states, sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def unpack_data(self, + image_data: Union[List[torch.Tensor], torch.Tensor], + padding_value=0) -> torch.Tensor: + if isinstance(image_data, torch.Tensor): + # torch.Tensor + return image_data + else: + assert isinstance( + image_data[0], + torch.Tensor), "Image data is not properly batched." + # List[torch.Tensor] + bsz = len(image_data) + max_length = max(t.size(0) for t in image_data) + trailing_dims = image_data[0].shape[1:] + for data in image_data: + cur_trailing_dims = data.shape[1:] + assert cur_trailing_dims == trailing_dims + output_tensor = torch.full((bsz, max_length, *trailing_dims), + padding_value, + dtype=image_data[0].dtype, + device=image_data[0].device) + for i, t in enumerate(image_data): + output_tensor[i, :t.size(0)] = t + return output_tensor + + def _parse_and_validate_image_input(self, **kwargs: object): + # tensor with the same shape will be batched together by + # MultiModalKwargs.batch, so pixel_values here can be: + # - List[torch.Tensor]: + # with shape (num_image, num_tiles, 3, image_res, image_res) + # - torch.Tensor: + # with shape (bs, num_image, num_tiles, 3, image_res, image_res) + pixel_values: Optional[Union[List[List[torch.Tensor]], + List[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "pixel_values", None) + image_embeds: Optional[Union[List[List[torch.Tensor]], + List[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "image_embeds", None) + aspect_ratio_ids: Optional[Union[List[List[torch.Tensor]], + List[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "aspect_ratio_ids", None) + aspect_ratio_mask: Optional[Union[List[List[torch.Tensor]], + List[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "aspect_ratio_mask", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None and image_embeds is not None: + raise ValueError( + "Both pixel values and image embeds are provided.") + + if pixel_values is not None: + assert aspect_ratio_ids is not None + assert aspect_ratio_mask is not None + + return MllamaImagePixelInputs( + type="pixel_values", + data=self.unpack_data(pixel_values), + aspect_ratio_ids=self.unpack_data(aspect_ratio_ids), + aspect_ratio_mask=self.unpack_data(aspect_ratio_mask)) + + if image_embeds is not None: + raise NotImplementedError + + raise AssertionError("This line should be unreachable.") + + def flat_encoder_result(self, cross_attention_states: torch.Tensor, + attn_metadata: AttentionMetadata, + actual_encoder_seq_lens: List[int]): + + cross_attention_states_flat = torch.zeros( + sum(actual_encoder_seq_lens), + cross_attention_states.shape[-1], + device=cross_attention_states.device, + dtype=cross_attention_states.dtype) + start_pos = 0 + for seq_len, vision_token_in_batch in zip(actual_encoder_seq_lens, + cross_attention_states): + end_pos = start_pos + seq_len + cross_attention_states_flat[ + start_pos:end_pos] = vision_token_in_batch[:seq_len] + start_pos = end_pos + cross_attention_states = cross_attention_states_flat + return cross_attention_states + + def get_cross_attention_states( + self, + image_inputs: MllamaImagePixelInputs, + attn_metadata: AttentionMetadata, + actual_encoder_seq_lens: List[int], + ) -> Tuple[torch.Tensor]: + # NOTE: llama's reference implementation runs vision model on CPU + pixel_values = image_inputs['data'] + aspect_ratio_ids = image_inputs['aspect_ratio_ids'] + aspect_ratio_mask = image_inputs['aspect_ratio_mask'] + cross_attention_states = self.vision_model(pixel_values, + aspect_ratio_ids, + aspect_ratio_mask) + cross_attention_states, _ = self.multi_modal_projector( + cross_attention_states) + + bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape) + cross_attention_states = cross_attention_states.view( + bsz, -1, image_token_dim) + + cross_attention_states = self.flat_encoder_result( + cross_attention_states, attn_metadata, actual_encoder_seq_lens) + + return cross_attention_states + + def get_cross_attention_mask( + self, + input_ids: torch.Tensor, + attn_metadata: AttentionMetadata, + num_tiles: List[List[int]], + num_tokens_per_tile: int, + dtype: torch.dtype, + ) -> Tuple[torch.Tensor, torch.Tensor]: + token_ids = input_ids.tolist() + start = 0 + batch_token_ids = [] + for seq_len in attn_metadata.seq_lens: + batch_token_ids.append(token_ids[start:start + seq_len]) + start += seq_len + sparse_mask = [ + get_cross_attention_token_mask(t, self.image_token_id) + for t in batch_token_ids + ] + + # Skip generating cross-attention mask if all samples + # are text-only or have only 1 leading image. + if skip_attention_mask(sparse_mask): + return None, None + + dense_mask, tile_range_for_decode = \ + convert_sparse_cross_attention_mask_to_dense( + sparse_mask, num_tiles, attn_metadata.seq_lens) + cross_attention_mask = \ + convert_dense_cross_attention_mask_to_tensor( + dense_mask, num_tokens_per_tile, input_ids.device, dtype) + kv_range_for_decode = [[ + t[0] * num_tokens_per_tile, t[1] * num_tokens_per_tile + ] for t in tile_range_for_decode] + + return cross_attention_mask, kv_range_for_decode + + def get_full_text_row_masked_out_mask( + self, + attn_metadata: AttentionMetadata, + device: torch.device, + ) -> torch.Tensor: + full_text_row_masked_out_mask = torch.ones( + (attn_metadata.num_prefill_tokens, 1), dtype=torch.bool) + start_pos = 0 + for seq_len, encoder_seq_len in zip(attn_metadata.seq_lens, + attn_metadata.encoder_seq_lens): + if encoder_seq_len == 0: + full_text_row_masked_out_mask[start_pos:start_pos + + seq_len] = False + start_pos += seq_len + full_text_row_masked_out_mask = full_text_row_masked_out_mask.to( + device) + return full_text_row_masked_out_mask + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + **kwargs: object, + ) -> Union[Tuple, CausalLMOutputWithPast]: + attn_metadata = get_forward_context().attn_metadata + if attn_metadata.num_prefill_tokens > 0 and \ + attn_metadata.num_decode_tokens > 0: + raise ValueError("Chunk prefill not supported") + image_inputs = self._parse_and_validate_image_input(**kwargs) + cross_attention_states = None + cross_attention_mask = None + kv_range_for_decode = None + + # For 1) text-only prefill and decode, 2) image-present decode. + if image_inputs is None: + full_text_row_masked_out_mask = ( + attn_metadata.encoder_seq_lens_tensor + != 0).reshape(-1, 1).to(input_ids.device) + skip_cross_attention = attn_metadata.max_encoder_seq_len == 0 + + # For image-present prefill. + else: + skip_cross_attention = False + + # Get the actual number of encoder tokens for each sample. + # Because attn_metadata.encoder_seq_lens only counts the last + # group of images for each sample, which is used to cheat the + # block manager to allocate blocks for those images only. + # See MllamaMultiModalProcessor for more details. + num_tiles_tensor = kwargs.pop("num_tiles") + num_tiles = [t.tolist() for t in num_tiles_tensor] + num_tokens_per_tile = calc_token_per_chunk(self.image_size) + actual_encoder_seq_lens = [ + sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles + ] + for actual_len, last_group_len in zip( + actual_encoder_seq_lens, attn_metadata.encoder_seq_lens): + assert actual_len >= last_group_len + + cross_attention_states = self.get_cross_attention_states( + image_inputs, attn_metadata, actual_encoder_seq_lens) + + full_text_row_masked_out_mask = \ + self.get_full_text_row_masked_out_mask( + attn_metadata, input_ids.device) + + cross_attention_mask, kv_range_for_decode = \ + self.get_cross_attention_mask( + input_ids, attn_metadata, num_tiles, + num_tokens_per_tile, cross_attention_states.dtype) + + outputs = self.language_model( + input_ids=input_ids, + positions=positions, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + kv_range_for_decode=kv_range_for_decode, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + skip_cross_attention=skip_cross_attention, + ) + + return outputs + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + updated_params: Set[str] = set() + for name, loaded_weight in weights: + if 'patch_embedding.weight' in name: + name = name.replace('patch_embedding.weight', + 'patch_embedding._linear.weight') + loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1) + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + updated_params.add(scale_name) + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + updated_params.add(name) + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + orig_name = name + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + logger.debug("Missing name %s, orig name %s", name, + orig_name) + continue + + param = params_dict.pop(name) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + updated_params.add(name) + return updated_params + + +def skip_attention_mask(sparse_mask: List[List[int]]) -> bool: + for mask in sparse_mask: + # Skip text-only samples. + if len(mask) == 0: + continue + # If the sample contains more than 1 images, + # we can't skip mask. + if len(mask) != 1: + return False + # If the sample contains only 1 image, + # but the image is not the leading one, + # we can't skip mask. + if mask[0][0] != 0 or mask[0][1] != -1: + return False + return True + + +def convert_sparse_cross_attention_mask_to_dense( + sparse_mask: List[List[List[int]]], + num_tiles: List[List[int]], + lengths: List[int], +) -> Tuple[np.ndarray, List[Tuple[int, int]]]: + total_length = sum(lengths) + total_tiles = sum([sum(tiles) for tiles in num_tiles]) + dense_mask = np.zeros(shape=(total_length, total_tiles), dtype=np.int64) + # A list of ranges, range[i] = [start, end] means that the i-th image will + # use tiles[start, end] for cross-attention decoding. + tile_range_for_decode = [] + + seq_start = 0 + tile_start = 0 + + # sparse_mask has an [] entry for each sequence that does not have images, + # but num_tiles does not have these entries... + num_tiles_idx = 0 + for masks, length in zip(sparse_mask, lengths): + if len(masks) == 0: + # Text only + continue + + tiles = num_tiles[num_tiles_idx] + num_tiles_idx += 1 + ts, td = -1, 0 + for mask, tile in zip(masks, tiles): + if len(mask) != 2: + continue + start, end = mask + end = min(end, length) + if end == -1: + end = length + if end == length: + if ts == -1: + ts = tile_start + td += tile + dense_mask[seq_start + start:seq_start + end, + tile_start:tile_start + tile] = 1 + tile_start += tile + assert ts != -1 + assert td != 0 + tile_range_for_decode.append((ts, ts + td)) + seq_start += length + assert num_tiles_idx == len(num_tiles) + + return dense_mask, tile_range_for_decode + + +def convert_dense_cross_attention_mask_to_tensor( + cross_attention_token_mask: np.ndarray, + num_tokens_per_tile: int, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + mask = torch.tensor(cross_attention_token_mask, dtype=dtype, device=device) + mask = mask.repeat_interleave(num_tokens_per_tile, dim=1) + + mask = 1.0 - mask + mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(dtype).min) + + ninf = torch.finfo(dtype).min + full_text_mask = ((mask != ninf).any(dim=-1).type_as(mask)[..., None]) + mask *= full_text_mask + # (num_prompt_tokens, num_encoder_tokens) + return mask diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py new file mode 100644 index 0000000..012178c --- /dev/null +++ b/vllm/model_executor/models/mllama4.py @@ -0,0 +1,886 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team. +# All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from collections.abc import Iterable, Mapping +from itertools import tee +from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union + +import torch +from torch import nn +from transformers import BatchFeature, Llama4Config, Llama4VisionConfig +from transformers.image_utils import SizeDict +from transformers.modeling_outputs import BaseModelOutput +from transformers.models.llama4 import Llama4Processor +from transformers.models.llama4.image_processing_llama4_fast import ( + find_supported_resolutions, get_best_fit) + +from vllm.attention.layer import MultiHeadAttention +from vllm.config import VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.inputs import InputProcessingContext +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, + NestedTensors) +from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, + MultiModalDataItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal +from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, + maybe_prefix, merge_multimodal_embeddings) +from .vision import scatter_patch_features, select_patch_features + +logger = init_logger(__name__) + + +class Llama4ImagePatchInputs(TypedDict): + type: Literal["pixel_values"] + flat_data: torch.Tensor + """ + Shape: + `(batch_size * num_chunks, num_channels, image size, image size)` + """ + patches_per_image: torch.Tensor + """ + The number of total patches for each image in the batch. + + This is used to split the embeddings which has the first two dimensions + flattened just like `flat_data`. + """ + embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] + """ + A boolean mask indicating which image embeddings correspond + to patch tokens. + """ + aspect_ratios: Union[torch.Tensor, list[torch.Tensor]] + """ + A list of aspect ratios corresponding to the number of tiles + in each dimension that each image in the batch corresponds to. + + Shape: + `(batch_size, ratio)` where ratio is a pair `(ratio_h, ratio_w)` + """ + + +class Llama4VisionMLP(nn.Module): + + def __init__(self, + input_size: int, + intermediate_size: int, + output_size: int, + bias: bool, + output_activation: bool, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.fc1 = ColumnParallelLinear( + input_size=input_size, + output_size=intermediate_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + self.fc2 = RowParallelLinear( + input_size=intermediate_size, + output_size=output_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) + self.activation_fn = nn.GELU() + self.output_activation = output_activation + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + if self.output_activation: + return self.activation_fn(hidden_states) + return hidden_states + + +class Llama4MultiModalProjector(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.linear_1 = ColumnParallelLinear( + input_size=config.vision_config.vision_output_dim, + output_size=config.text_config.hidden_size, + bias=False, + quant_config=quant_config, + gather_output=True, + prefix=f"{prefix}.linear_1", + ) + + def forward(self, image_features): + hidden_states, _ = self.linear_1(image_features) + return hidden_states + + +def pixel_shuffle(input_tensor, shuffle_ratio): + # input_tensor: [batch_size, num_patches, channels] + batch_size, num_patches, channels = input_tensor.shape + patch_size = int(math.sqrt(num_patches)) + + input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1) + batch_size, height, width, channels = input_tensor.size() + + reshaped_tensor = input_tensor.view(batch_size, height, + int(width * shuffle_ratio), + int(channels / shuffle_ratio)) + reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous() + + reshaped_tensor = reshaped_tensor.view(batch_size, + int(height * shuffle_ratio), + int(width * shuffle_ratio), + int(channels / (shuffle_ratio**2))) + reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous() + + output_tensor = reshaped_tensor.view(batch_size, -1, + reshaped_tensor.shape[-1]) + return output_tensor + + +class Llama4VisionPixelShuffleMLP(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.pixel_shuffle_ratio = config.pixel_shuffle_ratio + self.inner_dim = int(config.projector_input_dim // + (self.pixel_shuffle_ratio**2)) + self.output_dim = config.projector_output_dim + self.mlp = Llama4VisionMLP( + input_size=config.intermediate_size, + intermediate_size=config.projector_input_dim, + output_size=config.projector_output_dim, + bias=config.multi_modal_projector_bias, + output_activation=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor: + encoded_patches = pixel_shuffle(encoded_patches, + self.pixel_shuffle_ratio) + return self.mlp(encoded_patches) + + +class Llama4VisionAttention(nn.Module): + + def __init__( + self, + config: Llama4VisionConfig, + quant_config: Optional[QuantizationConfig], + prefix: str = "", + ): + super().__init__() + self.config = config + self.tp_size = get_tensor_model_parallel_world_size() + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = config.hidden_size // self.num_heads + assert self.num_heads % self.tp_size == 0 + self.num_local_heads = self.num_heads // self.tp_size + self.q_size = self.num_local_heads * self.head_dim + self.kv_size = self.num_local_heads * self.head_dim + self.attention_dropout = config.attention_dropout + self.scaling = self.head_dim**-0.5 + + self.attn = MultiHeadAttention(self.num_local_heads, self.head_dim, + self.scaling) + self.qkv_proj = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.num_heads * self.head_dim, + self.embed_dim, + bias=True, + input_is_parallel=True, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + head_size=self.head_dim, + rotary_dim=config.hidden_size // config.num_attention_heads // 2, + # number of image patches + max_position=(config.image_size // config.patch_size)**2, + base=config.rope_theta, + rope_scaling={"rope_type": "mllama4"}, + is_neox_style=False, + dtype=torch.complex64, # important + ) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + input_shape = hidden_states.shape[:-1] + + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + q = q.view(q.shape[0], q.shape[1], self.num_local_heads, self.head_dim) + k = k.view(k.shape[0], k.shape[1], self.num_local_heads, self.head_dim) + q, k = self.rotary_emb(q, k) + + q = q.view(q.shape[0], q.shape[1], -1) + k = k.view(k.shape[0], k.shape[1], -1) + + attn_output = self.attn(q, k, v) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output, _ = self.o_proj(attn_output) + + return attn_output + + +class Llama4VisionEncoderLayer(nn.Module): + + def __init__( + self, + config: Llama4VisionConfig, + quant_config: Optional[QuantizationConfig], + prefix: str = "", + ): + super().__init__() + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.intermediate_size = config.intermediate_size + + self.self_attn = Llama4VisionAttention(config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn") + self.mlp = Llama4VisionMLP(input_size=config.hidden_size, + intermediate_size=config.intermediate_size, + output_size=config.hidden_size, + bias=True, + output_activation=False, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + self.input_layernorm = nn.LayerNorm(config.hidden_size) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size) + + def forward( + self, + hidden_state: torch.Tensor, + ): + # Self Attention + residual = hidden_state + hidden_state = self.input_layernorm(hidden_state) + hidden_state = self.self_attn(hidden_state) + hidden_state = residual + hidden_state + + # Feed forward + residual = hidden_state + hidden_state = self.post_attention_layernorm(hidden_state) + hidden_state = self.mlp(hidden_state) + hidden_state = residual + hidden_state + + outputs = (hidden_state, ) + return outputs + + +class Llama4VisionEncoder(nn.Module): + + def __init__( + self, + config: Llama4VisionConfig, + quant_config: Optional[QuantizationConfig], + prefix: str = "", + ): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + Llama4VisionEncoderLayer( + config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}", + ) for layer_idx in range(config.num_hidden_layers) + ]) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> BaseModelOutput: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape + `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to + directly pass an embedded representation. This is useful if you + want more control over how to convert `input_ids` indices into + associated vectors than the model's internal embedding + lookup matrix. + """ + + for encoder_layer in self.layers: + layer_outputs = encoder_layer(hidden_states) + hidden_states = layer_outputs[0] + + return BaseModelOutput(last_hidden_state=hidden_states, ) + + +class Llama4UnfoldConvolution(nn.Module): + + def __init__(self, + config: Llama4VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + kernel_size = config.patch_size + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + self.unfold = torch.nn.Unfold(kernel_size=kernel_size, + stride=config.patch_size) + self.linear = ColumnParallelLinear(config.num_channels * + kernel_size[0] * kernel_size[1], + config.hidden_size, + bias=False, + quant_config=quant_config, + gather_output=True, + prefix=f"{prefix}.linear") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.unfold(hidden_states) + hidden_states = hidden_states.permute(0, 2, 1) + hidden_states, _ = self.linear(hidden_states) + return hidden_states + + +class Llama4VisionModel(nn.Module): + + def __init__( + self, + config: Llama4VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.image_size = config.image_size + self.patch_size = config.patch_size + self.hidden_size = config.hidden_size + self.num_channels = config.num_channels + + self.num_patches = (self.image_size // self.patch_size)**2 + 1 + self.scale = config.hidden_size**-0.5 + + self.patch_embedding = Llama4UnfoldConvolution( + config, + quant_config=quant_config, + prefix=f"{prefix}.patch_embedding") + + self.class_embedding = nn.Parameter(self.scale * + torch.randn(self.hidden_size)) + self.positional_embedding_vlm = nn.Parameter( + self.scale * torch.randn(self.num_patches, self.hidden_size)) + + # layer norms + self.layernorm_pre = nn.LayerNorm(self.hidden_size, eps=1e-5) + self.layernorm_post = nn.LayerNorm(self.hidden_size, eps=1e-5) + + # encoders + self.model = Llama4VisionEncoder(config, + quant_config=quant_config, + prefix=f"{prefix}.model") + self.vision_adapter = Llama4VisionPixelShuffleMLP( + config, quant_config, prefix=f"{prefix}.vision_adapter") + + def forward( + self, + images_flattened: torch.Tensor, + ) -> BaseModelOutput: + # Patch embedding + hidden_state = self.patch_embedding(images_flattened) + num_tiles, num_patches, hidden_dim = hidden_state.shape + + # Add cls token + class_embedding = self.class_embedding.expand(hidden_state.shape[0], 1, + hidden_state.shape[-1]) + hidden_state = torch.cat([hidden_state, class_embedding], dim=1) + num_patches += 1 + + # Position embeddings + hidden_state = hidden_state.reshape( + num_tiles, + 1, + num_patches, + hidden_dim, + ) + positional_embedding = self.positional_embedding_vlm.to( + dtype=hidden_state.dtype, device=hidden_state.device) + hidden_state = hidden_state + positional_embedding + hidden_state = self.layernorm_pre(hidden_state) + hidden_state = hidden_state.view(num_tiles, -1, hidden_dim) + + # Apply encoder + output = self.model(hidden_state) + hidden_state = output.last_hidden_state + hidden_state = self.layernorm_post(hidden_state) + + # Remove CLS token output + hidden_state = hidden_state[:, :-1, :] + + # now, we use Llama4VisionPixelShuffle + mlp to project embeddings + hidden_state = self.vision_adapter(hidden_state) + + return BaseModelOutput( + last_hidden_state=hidden_state, + attentions=None, + ) + + +class Mllama4ProcessingInfo(BaseProcessingInfo): + + def __init__(self, ctx: InputProcessingContext) -> None: + super().__init__(ctx) + + def get_hf_config(self) -> Llama4Config: + return self.ctx.get_hf_config(Llama4Config) + + def get_hf_processor(self, **kwargs: object) -> Llama4Processor: + return self.ctx.get_hf_processor(Llama4Processor, + use_fast=True, + **kwargs) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 10} + + @staticmethod + def get_patch_per_chunk(vision_config: Llama4VisionConfig) -> int: + image_size = vision_config.image_size + patch_size = vision_config.patch_size + + assert ( + image_size % + patch_size == 0), f"chunk size {image_size} should be multiple of " + f"patch_size {patch_size}" + + ds_ratio = int(round(1.0 / (vision_config.pixel_shuffle_ratio**2))) + return (image_size // patch_size)**2 // ds_ratio + + def get_max_num_tiles(self) -> int: + image_processor = self.get_hf_processor().image_processor + return image_processor.max_patches + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + vision_config = self.get_hf_config().vision_config + # image_start + local tiles * (patches + 1 x separator) + + # 1 global tile * (image x 1 + patches) + image_end + token_per_chunk = self.get_patch_per_chunk(vision_config) + 1 + mm_max_tokens = (self.get_max_num_tiles() + 1) * token_per_chunk + 2 + return {"image": mm_max_tokens} + + def get_image_size_with_most_features(self) -> ImageSize: + vision_config = self.get_hf_config().vision_config + image_size = vision_config.image_size + # Result in the max possible feature size (h:w = 16:1) + return ImageSize(height=self.get_max_num_tiles() * image_size, + width=image_size) + + +class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] + ): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + tokenizer = self.info.get_tokenizer() + + if mm_data is None: + return tokenizer(prompt, add_special_tokens=False) # exclude bos + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + processor = self.info.get_hf_processor(**mm_kwargs) + image_processor = processor.image_processor + vision_config = self.info.get_hf_config().vision_config + + if processed_outputs.get("pixel_values") is not None: + assert "images" in mm_data, \ + "images expected to be in mm_data when pixel_values is present" + + images = mm_data["images"] + parsed_images = (self._get_data_parser().parse_mm_data({ + "image": + images + }).get_items("image", ImageProcessorItems)) + + tile_size = vision_config.image_size + possible_resolutions = find_supported_resolutions( + max_num_chunks=self.info.get_max_num_tiles(), + patch_size=SizeDict(height=tile_size, width=tile_size), + ) + best_fit_sizes = [ + get_best_fit( + (image.size[1], image.size[0]), + torch.tensor(possible_resolutions), + resize_to_max_canvas=image_processor.resize_to_max_canvas) + for image in parsed_images + ] + # TODO tile height/width do not necessarily need to match + aspect_ratios = [(image_size[0] // tile_size, + image_size[1] // tile_size) + for image_size in best_fit_sizes] + patches_per_image = [ + 1 if r_h * r_w == 1 else 1 + r_h * r_w + for (r_h, r_w) in aspect_ratios + ] + + # embed_is_patch should have one feature per image-related token: + # <|image_start|>, <|tile_*_separator|>, <|image|>, <|image_end|> + # -> False + # <|patch|> -> True + # embed_is_patch has no entries corresponding to non-image-related + # tokens. + patch_id = tokenizer.get_vocab()[processor.img_patch_token] + num_patches_per_chunk = self.info.get_patch_per_chunk( + vision_config) + expanded_image_tokens_list = [ + processor._prompt_split_image(aspect_ratio, + num_patches_per_chunk) + for aspect_ratio in aspect_ratios + ] + expanded_image_token_ids = [ + tokenizer.encode(image_tokens, add_special_tokens=False) + for image_tokens in expanded_image_tokens_list + ] + embed_is_patch = [ + torch.tensor(tokens) == patch_id + for tokens in expanded_image_token_ids + ] + + processed_outputs["aspect_ratios"] = aspect_ratios + processed_outputs["patches_per_image"] = torch.tensor( + patches_per_image) + processed_outputs["embed_is_patch"] = embed_is_patch + + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + patches_per_image = hf_inputs.get("patches_per_image", torch.empty(0)) + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", patches_per_image), + patches_per_image=MultiModalFieldConfig.batched("image"), + aspect_ratios=MultiModalFieldConfig.batched("image"), + embed_is_patch=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> List[PromptUpdate]: + assert ( + mm_items.get_count("image", strict=False) == 0 + or "aspect_ratios" in out_mm_kwargs + ), "Transformers expect to include aspect_ratios in out_mm_kwargs" + + config = self.info.get_hf_config() + vision_config = config.vision_config + + num_patches_per_chunk = self.info.get_patch_per_chunk(vision_config) + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_token = hf_processor.image_token + + def get_replacement(item_idx: int): + aspect_ratio = out_mm_kwargs["aspect_ratios"][item_idx] + return hf_processor._prompt_split_image( + aspect_ratio=aspect_ratio, + num_patches_per_chunk=num_patches_per_chunk) + + return [ + PromptReplacement( + modality="image", + target=image_token, + replacement=get_replacement, + ) + ] + + +class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + + (target_width, + target_height) = self.info.get_image_size_with_most_features() + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + image_token = self.info.get_hf_processor().fake_image_token + return ProcessorInputs( + prompt_text=image_token * num_images, + mm_data=mm_data, + ) + + +@MULTIMODAL_REGISTRY.register_processor( + Mllama4MultiModalProcessor, + info=Mllama4ProcessingInfo, + dummy_inputs=Mllama4DummyInputsBuilder, +) +class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal): + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.quant_config = quant_config + self.multimodal_config = multimodal_config + self.vision_model = Llama4VisionModel(config.vision_config, + None, + prefix=maybe_prefix( + prefix, "vision_model")) + self.multi_modal_projector = Llama4MultiModalProjector( + self.config, + None, + prefix=maybe_prefix(prefix, "multi_modal_projector")) + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + architectures=["Llama4ForCausalLM"], + prefix=maybe_prefix(prefix, "language_model")) + + self.tokenizer = cached_tokenizer_from_config(vllm_config.model_config) + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Llama4ImagePatchInputs]: + # num_images, 1, num_chunks, channel, image_size, image_size + pixel_values = kwargs.pop("pixel_values", None) + if pixel_values is None: + return None + + # num_images x num_chunks, channel, image_size, image_size + # TODO: confirm handling for variable lengths + flat_pixel_values = flatten_bn(pixel_values, concat=True) + patches_per_image = flatten_bn(kwargs.pop("patches_per_image")) + + embed_is_patch = kwargs.pop("embed_is_patch", None) + if not isinstance(embed_is_patch, (torch.Tensor, list)): + raise ValueError("Incorrect type of embed_is_patch. " + f"Got type: {type(embed_is_patch)}") + + aspect_ratios = kwargs.pop("aspect_ratios", None) + if not isinstance(aspect_ratios, (torch.Tensor, list)): + raise ValueError("Incorrect type of aspect_ratios. " + f"Got type: {type(aspect_ratios)}") + + return Llama4ImagePatchInputs( + type="pixel_values", + flat_data=flat_pixel_values, + patches_per_image=patches_per_image, + embed_is_patch=embed_is_patch, + aspect_ratios=aspect_ratios, + ) + + def _process_image_input( + self, image_input: Llama4ImagePatchInputs) -> MultiModalEmbeddings: + flat_data = image_input["flat_data"] + patches_per_image = image_input["patches_per_image"].tolist() + vision_embeddings_flat = self.vision_model(flat_data).last_hidden_state + return vision_embeddings_flat.split(patches_per_image, dim=0) + + def get_multimodal_embeddings(self, + **kwargs) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + + # num_images x [num_chunks, num_patches, hidden_dim] + image_features = self._process_image_input(image_input) + # num_images x [num_chunks x num_patches, hidden_dim] + image_features_flat = [img.flatten(0, 1) for img in image_features] + # num_images x [1, input_len] -> num_images x [input_len] + embed_is_patch_flat = [ + is_patch.flatten(0, 1) + for is_patch in image_input["embed_is_patch"] + ] + + return scatter_patch_features( + image_features_flat, + embed_is_patch_flat, + ) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + + if multimodal_embeddings is not None: + multimodal_embeddings = torch.cat(multimodal_embeddings) + mm_embeddings = self.multi_modal_projector(multimodal_embeddings) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, select_patch_features(mm_embeddings), + self.config.image_token_index) + + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + if "pixel_values" in kwargs: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + return self.language_model(input_ids, positions, intermediate_tensors, + inputs_embeds) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample(self, logits: torch.Tensor, + sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def separate_weights( + self, + weights: Iterable[Tuple[str, torch.Tensor]], + prefix: str, + ) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[ + str, torch.Tensor]]]: + weights1, weights2 = tee(weights, 2) + + def get_prefix_weights() -> Iterable[Tuple[str, torch.Tensor]]: + for name, data in weights1: + if name.startswith(prefix): + yield (name, data) + + def get_other_weights() -> Iterable[Tuple[str, torch.Tensor]]: + for name, data in weights2: + if not name.startswith(prefix): + yield (name, data) + + return get_prefix_weights(), get_other_weights() + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".self_attn.qkv_proj", ".self_attn.q_proj", "q"), + (".self_attn.qkv_proj", ".self_attn.k_proj", "k"), + (".self_attn.qkv_proj", ".self_attn.v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + updated_params: Set[str] = set() + + # language_model is an Llama4ForCausalLM instance. We load it's + # using llama4's load_weights routine. + language_model_prefix = "language_model.model." + language_model_weights, other_weights = self.separate_weights( + weights, prefix=language_model_prefix) + loader = AutoWeightsLoader(self) + loaded_language_model_params = loader.load_weights( + language_model_weights) + assert loaded_language_model_params is not None + updated_params.update(loaded_language_model_params) + + for name, loaded_weight in other_weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + updated_params.add(name) + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + + weight_loader(param, loaded_weight) + updated_params.add(name) + return updated_params diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py new file mode 100644 index 0000000..a9dcae4 --- /dev/null +++ b/vllm/model_executor/models/mlp_speculator.py @@ -0,0 +1,205 @@ +# SPDX-License-Identifier: Apache-2.0 + +import math +from typing import Iterable, List, Set, Tuple + +import torch +import torch.nn as nn + +from vllm.config import VllmConfig +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import SupportsPP + +SQRT2 = 2**0.5 + + +class MLPSpeculatorLayerNorm(nn.Module): + """ + A L2 normalization implementation + ... + Args + ---- + normalized_shape : int + Dimensionality of input data (size of final tensor axis) + eps : float + Safety term to prevent division by zero. Make sure the chosen value + fits in the range of your encoding scheme + (i.e. fp16 requires eps >= 6e-8). + elementwise_scale_and_shift : bool + Include a learned scaling and shift term after normalization. + """ + + def __init__( + self, + normalized_shape, + eps=1e-06, + elementwise_scale_and_shift=True, + ): + super().__init__() + self.elementwise_scale_and_shift = elementwise_scale_and_shift + if self.elementwise_scale_and_shift: + self.weight = nn.Parameter(torch.empty(normalized_shape)) + self.bias = nn.Parameter(torch.empty(normalized_shape)) + self.eps = eps + + def forward(self, x): + xf = x + xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps) + x = xf.type_as(x) + if self.elementwise_scale_and_shift: + x = self.weight * x + x = x + self.bias + return x + +class MLPSpeculator(nn.Module, SupportsPP): + """ + An implementation of the speculative models introduced in + "Accelerating Production LLMs with Combined Token/Embedding + Speculators" + https://arxiv.org/pdf/2404.19124 + + Trained speculators of this type are available on HF hub at: + https://huggingface.co/ibm-ai-platform and https://huggingface.co/ibm-granite + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + config = vllm_config.model_config.hf_config + self.n_predict = config.n_predict + self.vocab_size = config.vocab_size + self.emb_dim = config.emb_dim + self.inner_dim = config.inner_dim if config.inner_dim != 0 \ + else config.emb_dim + + self.max_speculative_tokens = config.num_lookahead_tokens + + self.tie_weights = config.tie_weights + self.scale_input = config.scale_input + + if self.tie_weights: + assert ( + self.n_predict > 1 + ), "You cannot tie weights between stages when only 1 exists" + embedding = VocabParallelEmbedding( + config.vocab_size, + self.inner_dim, + org_num_embeddings=config.vocab_size) + self.emb = nn.ModuleList([embedding] * self.max_speculative_tokens) + + # the initial projection from the base model may + # have a different size, so that stays separate. + proj_first = nn.Linear(self.emb_dim, self.inner_dim, bias=False) + proj_tied = nn.Linear(self.inner_dim, self.inner_dim, bias=False) + self.proj = nn.ModuleList([proj_first] + [proj_tied] * + (self.max_speculative_tokens - 1)) + + head = ParallelLMHead(self.vocab_size, self.inner_dim, bias=False) + self.head = nn.ModuleList([head] * self.max_speculative_tokens) + + ln = MLPSpeculatorLayerNorm(self.inner_dim, + elementwise_scale_and_shift=True) + self.ln = nn.ModuleList([ln] * self.max_speculative_tokens) + + else: + self.emb = nn.ModuleList([ + VocabParallelEmbedding(config.vocab_size, + self.inner_dim, + org_num_embeddings=config.vocab_size) + for _ in range(self.max_speculative_tokens) + ]) + + self.proj = nn.ModuleList([ + nn.Linear((self.emb_dim if i == 0 else self.inner_dim), + self.inner_dim, + bias=False) + for i in range(self.max_speculative_tokens) + ]) + + self.head = nn.ModuleList([ + ParallelLMHead(self.vocab_size, self.inner_dim, bias=False) + for _ in range(self.max_speculative_tokens) + ]) + self.ln = nn.ModuleList([ + MLPSpeculatorLayerNorm(self.inner_dim, + elementwise_scale_and_shift=True) + for _ in range(self.max_speculative_tokens) + ]) + if self.scale_input: + self.ln0 = MLPSpeculatorLayerNorm( + self.emb_dim, elementwise_scale_and_shift=False) + + self.state_weight = 0.5**(0.5 / config.n_predict) + self.emb_weight = math.sqrt( + (1 - self.state_weight**2) * (self.inner_dim / 2)) + self.activation = nn.GELU() + self.config = config + self.logits_processor = LogitsProcessor(config.vocab_size, + config.vocab_size, 1.0) + self.sampler = get_sampler() + + def generate_proposals( + self, + input_ids: torch.Tensor, + previous_hidden_states: torch.Tensor, + num_predict_tokens: int, + sampling_metadata: SamplingMetadata, + ) -> List[SamplerOutput]: + if num_predict_tokens > self.max_speculative_tokens: + raise ValueError(f"Max speculative tokens for model is " + f"{self.max_speculative_tokens}, but " + f"{num_predict_tokens} were requested") + + # b x 1 x d + previous_hidden_states = previous_hidden_states.unsqueeze(1) + + if self.scale_input: + previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2 + + # b x 1 + last_tokens = input_ids.unsqueeze(1) + + next_tokens = [] + + for head_index in range(num_predict_tokens): + + # Project and predict + z = self.emb[head_index](last_tokens) # b k d + states = self.proj[head_index](previous_hidden_states) + + # Weighted add of state_weight*state and emb_weight*z + # Let subsequent LN take care of denominator + # state_weight is close to 1, so shouldn't be any precision issues + states.add_(z, alpha=self.emb_weight / self.state_weight) + + states = self.activation(self.ln[head_index](states)) # b k d + previous_hidden_states = states + # TODO: not yet supporting top_k_tokens_per_head + states = states.flatten(0, 1) + + logits = self.logits_processor(self.head[head_index], states, + sampling_metadata) + + output = self.sampler(logits, sampling_metadata) + last_tokens = output.sampled_token_ids + next_tokens.append(output) + + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + name = name.replace("speculator.", "") + param = params_dict.get(name) + if param is not None: + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/module_mapping.py b/vllm/model_executor/models/module_mapping.py new file mode 100644 index 0000000..23814e6 --- /dev/null +++ b/vllm/model_executor/models/module_mapping.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/modelscope/ms-swift/blob/v2.4.2/swift/utils/module_mapping.py + +from dataclasses import dataclass, field +from typing import List, Union + + +@dataclass +class ModelKeys: + model_type: str = None + + module_list: str = None + + embedding: str = None + + mlp: str = None + + down_proj: str = None + + attention: str = None + + o_proj: str = None + + q_proj: str = None + + k_proj: str = None + + v_proj: str = None + + qkv_proj: str = None + + qk_proj: str = None + + qa_proj: str = None + + qb_proj: str = None + + kva_proj: str = None + + kvb_proj: str = None + + output: str = None + + +@dataclass +class MultiModelKeys(ModelKeys): + language_model: List[str] = field(default_factory=list) + connector: List[str] = field(default_factory=list) + # vision tower and audio tower + tower_model: List[str] = field(default_factory=list) + generator: List[str] = field(default_factory=list) + + @staticmethod + def from_string_field(language_model: Union[str, List[str]] = None, + connector: Union[str, List[str]] = None, + tower_model: Union[str, List[str]] = None, + generator: Union[str, List[str]] = None, + **kwargs) -> 'MultiModelKeys': + + def to_list(value): + if value is None: + return [] + return [value] if isinstance(value, str) else list(value) + + return MultiModelKeys(language_model=to_list(language_model), + connector=to_list(connector), + tower_model=to_list(tower_model), + generator=to_list(generator), + **kwargs) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py new file mode 100644 index 0000000..b2f7951 --- /dev/null +++ b/vllm/model_executor/models/molmo.py @@ -0,0 +1,1639 @@ +# SPDX-License-Identifier: Apache-2.0 + +import math +from collections.abc import Iterable, Mapping, Sequence +from dataclasses import dataclass +from functools import cached_property, partial +from typing import List, Optional, Set, Tuple, TypedDict, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from transformers import (BatchFeature, PretrainedConfig, ProcessorMixin, + TensorType) +from transformers.image_utils import ImageInput +from transformers.tokenization_utils_base import TextInput + +from vllm.attention import Attention +from vllm.attention.layer import MultiHeadAttention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather) +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.activation import (MulAndSilu, QuickGELU, + SiluAndMul) +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, + MultiModalDataItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptIndexTargets, + PromptInsertion, PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors + +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsPP, SupportsQuant) +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix, merge_multimodal_embeddings) +from .vision import scatter_patch_features, select_patch_features + +# TODO: hard-coded for now. Consider making it configurable. +VIT_LAYERS = [-2, -9] +NUM_PREFIX_TOKENS = 1 +ADDITIONAL_VOCAB_SIZE = 128 +IMAGE_PATCH_TOKEN = "" +IM_COL_TOKEN = "" +IM_START_TOKEN = "" +IM_END_TOKEN = "" +POOLING_SIZE = 2 + + +class MolmoImageInputs(TypedDict): + images: Union[torch.Tensor, list[torch.Tensor]] + """Shape: `(batch_size * num_images, num_crops, num_patch, patch_dim)`""" + + image_masks: Optional[Union[torch.Tensor, list[torch.Tensor]]] + """Shape: `(batch_size * num_images, num_crops, num_patch)`""" + + feat_is_patch: Union[torch.Tensor, list[torch.Tensor]] + """ + A boolean mask indicating which image features correspond + to patch tokens. + + Shape: `(batch_size * num_images, num_crops, num_patch)` + """ + + embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] + """ + A boolean mask indicating which image embeddings correspond + to patch tokens. + + Shape: `(batch_size * num_images, num_embeds)` + """ + + num_crops: torch.Tensor + """Shape: `(batch_size * num_images)`""" + + +@dataclass +class VisionBackboneConfig: + image_default_input_size: Tuple[int, int] = (336, 336) + image_patch_size: int = 14 + image_pos_patch_size: int = 14 + image_emb_dim: int = 1024 + image_num_heads: int = 16 + image_num_key_value_heads: int = 16 + image_num_layers: int = 23 + image_mlp_dim: int = 4096 + image_mlp_activations: str = "quick_gelu" + image_num_pos: int = 577 + image_norm_eps: float = 1e-5 + + def __post_init__(self): + self.image_default_input_size = tuple( + self.image_default_input_size) # type: ignore[assignment] + + @property + def image_num_patch(self): + h, w = self.image_default_input_size + return h // self.image_patch_size, w // self.image_patch_size + + +class ViTMLP(nn.Module): + """MLP used in Vision Transformer.""" + + def __init__( + self, + config: VisionBackboneConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.w1 = ColumnParallelLinear( + config.image_emb_dim, + config.image_mlp_dim, + bias=True, + quant_config=quant_config, + ) + # Activation function. + assert config.image_mlp_activations == "quick_gelu" + self.act = QuickGELU() + self.w2 = RowParallelLinear( + config.image_mlp_dim, + config.image_emb_dim, + bias=True, + quant_config=quant_config, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.w1(x) + x = self.act(x) + x, _ = self.w2(x) + return x + + +class MultiHeadDotProductAttention(nn.Module): + """Multi-head attention used in Vision Transformer.""" + + def __init__( + self, + config: VisionBackboneConfig, + use_bias: bool = True, + nlayers: int = 1, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + + self.hidden_size = config.image_emb_dim + self.total_num_heads = config.image_num_heads + tp_size = get_tensor_model_parallel_world_size() + + assert self.hidden_size % self.total_num_heads == 0 + assert self.total_num_heads % tp_size == 0 + + self.num_heads = self.total_num_heads // tp_size + self.head_dim = self.hidden_size // self.total_num_heads + + self.total_num_kv_heads = config.image_num_key_value_heads + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + self.wq = ColumnParallelLinear( + nlayers * self.hidden_size, + self.total_num_heads * self.head_dim, + bias=use_bias, + quant_config=quant_config, + ) + self.wk = ColumnParallelLinear( + nlayers * self.hidden_size, + self.total_num_kv_heads * self.head_dim, + bias=use_bias, + quant_config=quant_config, + ) + self.wv = ColumnParallelLinear( + nlayers * self.hidden_size, + self.total_num_kv_heads * self.head_dim, + bias=use_bias, + quant_config=quant_config, + ) + self.wo = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=use_bias, + quant_config=quant_config, + ) + + self.scale = self.head_dim**-0.5 + self.attn = MultiHeadAttention(self.num_heads, + self.head_dim, + self.scale, + num_kv_heads=self.num_kv_heads) + + def forward(self, + inputs_q: torch.Tensor, + inputs_kv: Optional[torch.Tensor] = None) -> torch.Tensor: + + if inputs_kv is not None: + inputs_k = inputs_kv + inputs_v = inputs_kv + else: + inputs_k = inputs_q + inputs_v = inputs_q + + xq, _ = self.wq(inputs_q) + xk, _ = self.wk(inputs_k) + xv, _ = self.wv(inputs_v) + + output = self.attn(xq, xk, xv) + output, _ = self.wo(output) + + return output + + +class ResidualAttentionBlock(nn.Module): + """Residual attention block used in Vision Transformer.""" + + def __init__( + self, + config: VisionBackboneConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.attention = MultiHeadDotProductAttention( + config, quant_config=quant_config) + self.feed_forward = ViTMLP(config, quant_config) + self.attention_norm = nn.LayerNorm( + config.image_emb_dim, + eps=config.image_norm_eps, + ) + self.ffn_norm = nn.LayerNorm( + config.image_emb_dim, + eps=config.image_norm_eps, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.attention(self.attention_norm(x)) + x = x + self.feed_forward(self.ffn_norm(x)) + return x + + +class BlockCollection(nn.Module): + """Collection of residual attention blocks used in Vision Transformer.""" + + def __init__( + self, + config: VisionBackboneConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.resblocks = nn.ModuleList([ + ResidualAttentionBlock(config, quant_config) + for _ in range(config.image_num_layers) + ]) + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + hidden_states = [] + for r in self.resblocks: + x = r(x) + hidden_states.append(x) + return hidden_states + + +def _expand_token(token: torch.Tensor, batch_size: int) -> torch.Tensor: + return token.view(1, 1, -1).expand(batch_size, -1, -1) + + +class VisionTransformer(nn.Module): + """Vision Transformer used in Vision Backbone.""" + + def __init__( + self, + config: VisionBackboneConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + scale = config.image_emb_dim**-0.5 + self.patch_num = config.image_num_patch + self.class_embedding = nn.Parameter( + torch.randn(config.image_emb_dim) * scale) + self.num_prefix_tokens: int = NUM_PREFIX_TOKENS + self.positional_embedding = nn.Parameter( + torch.randn(config.image_num_pos, config.image_emb_dim) * scale) + image_patch_size = config.image_patch_size + self.patch_embedding = nn.Linear( + image_patch_size * image_patch_size * 3, + config.image_emb_dim, + bias=False, + ) + self.pre_ln = nn.LayerNorm(config.image_emb_dim, + eps=config.image_norm_eps) + self.transformer = BlockCollection(config, quant_config) + + def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor: + cls_emb = self.positional_embedding[0:1] + pos_emb = self.positional_embedding[1:] + + pos_emb = pos_emb.reshape( + (int(math.sqrt(pos_emb.shape[0])), + int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1])) + + (patch_num_0, patch_num_1) = patch_num + + if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1: + # from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py + pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2) + pos_emb = F.interpolate( + pos_emb, + size=(patch_num_0, patch_num_1), + mode="bicubic", + align_corners=False, + antialias=True, + ) + pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0) + + pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1]) + x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]], + dim=1).to(x.dtype) + return x + + def forward(self, + x: torch.Tensor, + patch_num: Optional[int] = None) -> List[torch.Tensor]: + """ + : param x: (batch_size, num_patch, n_pixels) + """ + if patch_num is None: + patch_num = self.patch_num + B, N, D = x.shape + + x = self.patch_embedding(x) + + # class embeddings and positional embeddings + x = torch.cat( + [_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], + dim=1) + x = self.add_pos_emb(x, patch_num) + + x = self.pre_ln(x) + + hidden_states = self.transformer(x) + return hidden_states + + +class MolmoAttention(nn.Module): + """Molmo's LLM attention.""" + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + + assert self.hidden_size % self.total_num_heads == 0 + assert self.total_num_heads % self.tp_size == 0 + + self.num_heads = self.total_num_heads // self.tp_size + self.total_num_kv_heads = config.num_key_value_heads \ + or self.total_num_heads + if self.total_num_kv_heads >= self.tp_size: + assert self.total_num_kv_heads % self.tp_size == 0 + else: + assert self.tp_size % self.total_num_kv_heads == 0 + + self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) + self.head_dim = self.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + # Attention input projection. Projects x -> (q, k, v) + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.qkv_bias, + quant_config=quant_config, + ) + + self.tp_rank: Optional[int] = None + self.k_norm: Optional[nn.Module] = None + self.q_norm: Optional[nn.Module] = None + if config.attention_layer_norm: + self.tp_rank = get_tensor_model_parallel_rank() + self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim, + eps=config.layer_norm_eps) + self.q_norm = RMSNorm(config.hidden_size, + eps=config.layer_norm_eps) + + # Rotary embeddings. + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + ) + self.scaling = self.head_dim**-0.5 + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + # Attention output projection. + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + ) + + def _apply_qk_norm(self, q: torch.Tensor, + k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if self.tp_size > 1: + q = tensor_model_parallel_all_gather(q.contiguous()) + k = tensor_model_parallel_all_gather(k.contiguous()) + q = self.q_norm.forward_native(q) + k = self.k_norm.forward_native(k) + if self.tp_size > 1: + splitter = partial(split_tensor_along_last_dim, + num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + return q, k + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if self.q_norm is not None and self.k_norm is not None: + q, k = self._apply_qk_norm(q, k) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class LanguageModelMLP(nn.Module): + """Molmo's LLM mlp.""" + + def __init__(self, + config: PretrainedConfig, + input_dim: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size // 2 + + self.gate_up_proj = MergedColumnParallelLinear( + input_dim or self.hidden_size, + [self.intermediate_size] * 2, + bias=False, + quant_config=quant_config, + ) + # Activation function. + self.act_fn = MulAndSilu() + # Feed-forward output projection. + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + quant_config=quant_config, + ) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class ImageProjectorMLP(nn.Module): + """Molmo's image_projector mlp.""" + + def __init__( + self, + config: PretrainedConfig, + input_dim: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size // 2 + + self.merged_linear = MergedColumnParallelLinear( + input_dim or self.hidden_size, + [self.intermediate_size] * 2, + bias=False, + quant_config=quant_config, + ) + # Activation function. + self.act_fn = SiluAndMul() + + # Feed-forward output projection. + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + quant_config=quant_config, + ) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + gate_up, _ = self.merged_linear(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class MolmoDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + # Attention block. + self.self_attn = MolmoAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.self_attn") + + # MLP block. + self.mlp = LanguageModelMLP(config, quant_config=quant_config) + + # LayerNorm + assert config.layer_norm_type == "rms" + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class MolmoDecoderNormAfterLayer(MolmoDecoderLayer): + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + # Self Attention + residual = hidden_states + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + hidden_states = self.input_layernorm(hidden_states) + hidden_states = hidden_states + residual + residual = hidden_states + + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = hidden_states + residual + residual = None + return hidden_states, residual + + +class MolmoVisionBackbone(nn.Module, SupportsQuant): + packed_modules_mapping = {"merged_linear": ["gate_proj", "up_proj"]} + + def __init__( + self, + config: PretrainedConfig, + vision_config: VisionBackboneConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.vit_layers = VIT_LAYERS + self.image_num_patch = vision_config.image_num_patch + self.llm_patches_per_crop = ( + (self.image_num_patch[0] + 1) // POOLING_SIZE, + (self.image_num_patch[1] + 1) // POOLING_SIZE, + ) + self.image_vit = VisionTransformer(vision_config, + quant_config=quant_config) + self.num_prefix_tokens = self.image_vit.num_prefix_tokens + assert self.num_prefix_tokens in { + 0, 1 + }, "Only 0 or 1 prefix tokens are supported" + self.image_pooling_2d = MultiHeadDotProductAttention( + vision_config, + nlayers=len(self.vit_layers), + quant_config=quant_config) + self.image_projector = ImageProjectorMLP( + config, + input_dim=vision_config.image_emb_dim, + quant_config=quant_config, + ) + + image_dim = vision_config.image_emb_dim * len(self.vit_layers) + self.pad_embed = nn.Parameter(torch.zeros((2, image_dim))) + + @property + def dtype(self) -> torch.dtype: + return self.image_vit.patch_embedding.weight.dtype + + @property + def device(self) -> torch.device: + return self.image_vit.patch_embedding.weight.device + + def encode_image(self, images: torch.Tensor) -> torch.Tensor: + """ + : param images: (batch_size, num_crops, num_patch, n_pixels) + """ + B, T, N, D = images.shape + + mask = ~torch.all( + images.view(B * T, N, D) == -1, dim=(1, 2), keepdim=True) + + images = images.view(B * T, N, D) + image_features = self.image_vit(images) + + if self.vit_layers is not None: + features = [] + for layer in self.vit_layers: + features.append(image_features[layer]) + image_features = torch.cat(features, dim=-1) + else: + image_features = image_features[-1] + + if self.num_prefix_tokens > 0: + image_features = image_features[:, 1:] + + image_features = image_features * mask + image_features = image_features.view(B, T, N, -1) + + return image_features + + def forward( + self, + images: torch.Tensor, + image_masks: torch.Tensor, + ) -> torch.Tensor: + # image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim) # noqa: E501 + batch_size, num_image = images.shape[:2] + images = images.to(device=self.device, dtype=self.dtype) + image_features = self.encode_image(images) + + og_dtype = image_features.dtype + assert image_masks is not None + pad_embed = self.pad_embed[:, None, None, None, :] + all_pad = image_masks == 0 + partial_pad = torch.logical_and( + image_masks < 1, + torch.logical_not(all_pad)).to(dtype=torch.float32) + all_pad = all_pad.to(dtype=torch.float32) + image_features = image_features + pad_embed[0] * torch.unsqueeze( + all_pad, -1) + image_features = image_features + pad_embed[1] * torch.unsqueeze( + partial_pad, -1) + + image_features = image_features.to(og_dtype) + + image_features = image_features.reshape( + (batch_size, num_image) + self.image_num_patch + (-1, ), ) + + if (missing_w := self.image_num_patch[0] % POOLING_SIZE): + # Padding for image pooling (see below) + image_features = F.pad( + image_features, + (0, 0, 0, missing_w, 0, missing_w, 0, 0, 0, 0), + ) + + # image pooling + image_features = rearrange( + image_features, + 'b n (h dh) (w dw) c -> (b n h w) (dh dw) c', + dh=POOLING_SIZE, + dw=POOLING_SIZE, + ) + + query = image_features.mean(-2, keepdim=True) + image_features = self.image_pooling_2d(query, image_features) + + h, w = self.llm_patches_per_crop + image_features = image_features.view(batch_size, num_image, h * w, -1) + + image_features = self.image_projector(image_features) + + # image_features: (batch_size, num_image, num_patch, d_model) + return image_features + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("merged_linear", "gate_proj", 0), + ("merged_linear", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +@support_torch_compile +class MolmoModel(nn.Module, SupportsQuant): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + + self.embedding_size = config.embedding_size or config.vocab_size + self.embedding_size += ADDITIONAL_VOCAB_SIZE + self.embed_tokens = VocabParallelEmbedding( + self.embedding_size, + config.hidden_size, + quant_config=quant_config, + ) + + decoder_layer = MolmoDecoderNormAfterLayer if config.norm_after \ + else MolmoDecoderLayer + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: decoder_layer( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers", + ) + + assert config.layer_norm_type == "rms" + self.norm = RMSNorm(config.hidden_size, config.layer_norm_eps) + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + ) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + # Apply blocks one-by-one. + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + if residual is not None: + hidden_states, _ = self.norm(hidden_states, residual) + else: + hidden_states = self.norm(hidden_states) + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + + for name, loaded_weight in weights: + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +def _lowest_multiple(x: int, k: int) -> int: + return (x // k) * k + + +def get_num_patches( + num_tiles: int, + *, + crop_patches: int, + left_margin: int, + right_margin: int, + pooling_size: int, +) -> int: + if num_tiles == 1: + return _lowest_multiple(crop_patches + pooling_size - 1, pooling_size) + + crop_window_patches = crop_patches - (left_margin + right_margin) + + left_num = _lowest_multiple( + crop_window_patches + left_margin + pooling_size - 1, + pooling_size, + ) + middle_num = _lowest_multiple( + crop_window_patches + pooling_size - 1, + pooling_size, + ) + right_num = _lowest_multiple( + crop_window_patches + right_margin + pooling_size - 1, + pooling_size, + ) + + return left_num + (num_tiles - 2) * middle_num + right_num + + +def get_patches_grid_size( + *, + tiling_h: int, + tiling_w: int, + crop_patches: int, + left_margin: int, + right_margin: int, + pooling_size: int, +) -> tuple[int, int]: + nrows = get_num_patches( + tiling_h, + crop_patches=crop_patches, + left_margin=left_margin, + right_margin=right_margin, + pooling_size=pooling_size, + ) + ncols = get_num_patches( + tiling_w, + crop_patches=crop_patches, + left_margin=left_margin, + right_margin=right_margin, + pooling_size=pooling_size, + ) + + return nrows, ncols + + +def get_candidate_tilings(max_num: int) -> list[tuple[int, int]]: + tilings = [(i, j) for i in range(1, max_num + 1) + for j in range(1, max_num + 1) if i * j <= max_num] + return sorted(tilings, key=lambda x: x[0] * x[1]) + + +def select_tiling( + *, + height: int, + width: int, + patch_size: int, + max_num_patches: int, +): + tilings = get_candidate_tilings(max_num_patches) + candidate_tilings = np.array(tilings, dtype=np.int32) + candidate_resolutions = candidate_tilings * patch_size + + original_size = np.array([height, width], dtype=np.float32) + required_scale_d = candidate_resolutions.astype(np.float32) / original_size + required_scale = required_scale_d.min(axis=-1, keepdims=True) + + if (required_scale < 1).all(): + ix = required_scale.argmax() + else: + ix = np.where(required_scale < 1.0, 10e9, required_scale).argmin() + + return candidate_tilings[ix] + + +class MolmoProcessorWrapper: + """ + Wraps :class:`MolmoProcessor` so that it can be called directly. + + The original definition can be found here: + https://huggingface.co/allenai/Molmo-7B-D-0924/blob/main/preprocessing_molmo.py + """ + + def __init__(self, processor: ProcessorMixin): + super().__init__() + + self.processor = processor + + @cached_property + def vocab(self) -> dict[str, int]: + return self.processor.tokenizer.vocab # type: ignore + + @cached_property + def max_crops(self) -> int: + image_processor = self.processor.image_processor # type: ignore + + max_crops = image_processor.max_crops + assert isinstance(max_crops, int) + + return max_crops + + @cached_property + def base_image_input_size(self) -> tuple[int, int]: + image_processor = self.processor.image_processor # type: ignore + + base_image_input_size = image_processor.base_image_input_size + if isinstance(base_image_input_size, int): + return base_image_input_size, base_image_input_size + + return tuple(base_image_input_size) + + @cached_property + def image_patch_size(self) -> int: + image_processor = self.processor.image_processor # type: ignore + + image_patch_size = image_processor.image_patch_size + assert isinstance(image_patch_size, int) + + return image_patch_size + + @cached_property + def overlap_margins(self) -> tuple[int, int]: + image_processor = self.processor.image_processor # type: ignore + + left_margin, right_margin = image_processor.overlap_margins + assert isinstance(left_margin, int) + assert isinstance(right_margin, int) + + return left_margin, right_margin + + @cached_property + def image_token_length_w(self) -> int: + image_processor = self.processor.image_processor # type: ignore + + image_token_length_w = image_processor.image_token_length_w + assert isinstance(image_token_length_w, int) + + return image_token_length_w + + @cached_property + def image_token_length_h(self) -> int: + image_processor = self.processor.image_processor # type: ignore + + image_token_length_h = image_processor.image_token_length_h + assert isinstance(image_token_length_h, int) + + return image_token_length_h + + @property + def message_format(self) -> Optional[str]: + return "role" + + @property + def always_start_with_space(self) -> bool: + return True + + @cached_property + def image_patch_id(self) -> int: + return self.vocab[IMAGE_PATCH_TOKEN] + + @cached_property + def im_col_id(self) -> int: + return self.vocab[IM_COL_TOKEN] + + @cached_property + def im_start_id(self) -> int: + return self.vocab[IM_START_TOKEN] + + @cached_property + def im_end_id(self) -> int: + return self.vocab[IM_END_TOKEN] + + @property + def pooling_size(self) -> int: + return POOLING_SIZE + + def select_tiling( + self, + *, + image_width: int, + image_height: int, + ) -> tuple[int, int]: + max_crops = self.max_crops + left_margin, right_margin = self.overlap_margins + base_image_input_size = self.base_image_input_size + base_image_input_d = self.image_patch_size + + total_margin_pixels = base_image_input_d * (right_margin + left_margin) + crop_patches = base_image_input_size[0] // base_image_input_d + crop_window_patches = crop_patches - (right_margin + left_margin) + crop_window_size = crop_window_patches * base_image_input_d + tiling_h, tiling_w = select_tiling( + height=image_height - total_margin_pixels, + width=image_width - total_margin_pixels, + patch_size=crop_window_size, + max_num_patches=max_crops, + ) + + return tiling_w, tiling_h + + def get_patches_grid_size( + self, + *, + image_width: int, + image_height: int, + ) -> tuple[int, int]: + left_margin, right_margin = self.overlap_margins + base_image_input_size = self.base_image_input_size + base_image_input_d = self.image_patch_size + pooling_size = self.pooling_size + + crop_patches = base_image_input_size[0] // base_image_input_d + tiling_w, tiling_h = self.select_tiling( + image_height=image_height, + image_width=image_width, + ) + + nrows, ncols = get_patches_grid_size( + tiling_h=tiling_h, + tiling_w=tiling_w, + crop_patches=crop_patches, + left_margin=left_margin, + right_margin=right_margin, + pooling_size=pooling_size, + ) + + return ncols, nrows + + def __call__( + self, + text: Optional[Union[TextInput, list[TextInput]]] = None, + images: Optional[Union[ImageInput, list[ImageInput]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ) -> BatchFeature: + outputs = self.processor.process( # type: ignore + text, images, **kwargs) + + if images is None: + images = [] + if not isinstance(images, list): + images = [images] + + input_ids: torch.Tensor = outputs.pop("input_ids") + outputs["input_ids"] = input_ids.unsqueeze(0) + + image_input_idx = outputs.pop("image_input_idx", None) + if image_input_idx is not None: + feat_is_patch = image_input_idx >= 0 + + input_is_embed = torch.isin( + input_ids, + torch.tensor([ + self.image_patch_id, + self.im_col_id, + self.im_start_id, + self.im_end_id, + ]), + ) + embed_ids = input_ids[input_is_embed] + embed_is_patch = embed_ids == self.image_patch_id + assert embed_is_patch.sum() == feat_is_patch.sum() + + # image_tokens = extra_joint + joint + # Both `extra_joint` and `joint` have `im_start_id` and `im_end_id` + embed_start = torch.nonzero(embed_ids == self.im_start_id)[::2, 0] + embed_end = torch.nonzero(embed_ids == self.im_end_id)[1::2, 0] + assert len(embed_start) == len(embed_end) == len(images) + + embed_is_patch = [ + embed_is_patch[start:end + 1] + for start, end in zip(embed_start, embed_end) + ] + + tilings = [ + self.select_tiling( + image_width=image.size[0], + image_height=image.size[1], + ) for image in images + ] + # For each image: tiling_h * tiling_w + extra + num_crops = torch.tensor(tilings).prod(-1) + 1 + assert num_crops.sum() == len(feat_is_patch) + + outputs["feat_is_patch"] = feat_is_patch + outputs["embed_is_patch"] = embed_is_patch + outputs["num_crops"] = num_crops + outputs["img_patch_id"] = self.image_patch_id + + return BatchFeature(outputs) + + +class MolmoProcessingInfo(BaseProcessingInfo): + + def get_hf_processor(self, **kwargs: object) -> MolmoProcessorWrapper: + processor = self.ctx.get_hf_processor(**kwargs) + return MolmoProcessorWrapper(processor) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"image": self.get_max_image_tokens()} + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + processor: Optional[MolmoProcessorWrapper], + ) -> int: + if processor is None: + processor = self.get_hf_processor() + + ncols, nrows = processor.get_patches_grid_size( + image_width=image_width, + image_height=image_height, + ) + pooling_size = processor.pooling_size + + base_image_input_size = processor.base_image_input_size + base_image_input_d = processor.image_patch_size + + crop_patches = base_image_input_size[0] // base_image_input_d + + per_row = ncols // pooling_size + 1 + joint = per_row * (nrows // pooling_size) + 2 + image_token_length = (crop_patches + pooling_size - 1) // pooling_size + resize = (image_token_length + 1) * image_token_length + 2 + + return resize + joint + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + return self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + processor=None, + ) + + def get_image_size_with_most_features(self) -> ImageSize: + processor = self.get_hf_processor() + + tilings = get_candidate_tilings(processor.max_crops) + base_h, base_w = processor.base_image_input_size + + largest_feature_size, largest_feature_pinpoint = 0, None + for wr, hr in tilings: + width, height = base_w * wr, base_h * hr + + feat_size = self.get_num_image_tokens( + image_width=width, + image_height=height, + processor=processor, + ) + if feat_size > largest_feature_size: + largest_feature_size = feat_size + largest_feature_pinpoint = ImageSize(width=width, + height=height) + + if largest_feature_size == 0 or largest_feature_pinpoint is None: + raise ValueError("Cannot have a largest feature size of 0!") + + return largest_feature_pinpoint + + +class MolmoDummyInputsBuilder(BaseDummyInputsBuilder[MolmoProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + target_width, target_height = \ + self.info.get_image_size_with_most_features() + num_images = mm_counts.get("image", 0) + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text="", + mm_data=mm_data, + ) + + +class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): + + def _apply_hf_processor_tokens_only( + self, + prompt_tokens: list[int], + ) -> list[int]: + processor = self.info.get_hf_processor() + + # Apply the chat template to the tokens + tokens = processor.processor.get_tokens_input( # type: ignore + self.info.get_tokenizer().decode(prompt_tokens), + message_format=processor.message_format, + always_start_with_space=processor.always_start_with_space, + ) + + processed_data = self.info.ctx.call_hf_processor( + processor, # type: ignore + dict(tokens=tokens), + ) + prompt_ids, = processed_data.pop("input_ids").tolist() + + return prompt_ids + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + num_crops = hf_inputs.get("num_crops", torch.empty(0)) + num_images = len(num_crops) + + return dict( + images=MultiModalFieldConfig.flat_from_sizes("image", num_crops), + image_masks=MultiModalFieldConfig.flat_from_sizes( + "image", num_crops), + feat_is_patch=MultiModalFieldConfig.flat_from_sizes( + "image", num_crops), + embed_is_patch=MultiModalFieldConfig.batched("image"), + num_crops=MultiModalFieldConfig.batched("image"), + img_patch_id=MultiModalFieldConfig.shared("image", num_images), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + image_token_length_w = processor.image_token_length_w + image_token_length_h = processor.image_token_length_h + pooling_size = processor.pooling_size + + img_patch_id = processor.image_patch_id + img_col_id = processor.im_col_id + img_start_id = processor.im_start_id + img_end_id = processor.im_end_id + + extra_row = [img_patch_id] * image_token_length_w + [img_col_id] + extra_joint = ([img_start_id] + extra_row * image_token_length_h + + [img_end_id]) + + def get_insertion_molmo(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) + + ncols, nrows = processor.get_patches_grid_size( + image_width=image_size.width, + image_height=image_size.height, + ) + + joint_row = ([img_patch_id] * ((ncols + 1) // pooling_size) + + [img_col_id]) + joint = ([img_start_id] + joint_row * + ((nrows + 1) // pooling_size) + [img_end_id]) + + image_tokens = extra_joint + joint + return image_tokens + + return [ + PromptInsertion( + modality="image", + target=PromptIndexTargets.prefix("<|endoftext|>"), + insertion=get_insertion_molmo, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor(MolmoMultiModalProcessor, + info=MolmoProcessingInfo, + dummy_inputs=MolmoDummyInputsBuilder) +class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, + SupportsQuant): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + # vision backbone mapping + "image_projector.w1.": "image_projector.gate_proj.", + "image_projector.w3.": "image_projector.up_proj.", + "image_projector.w2.": "image_projector.down_proj.", + # language backbone mapping + "att_proj": "self_attn.qkv_proj", + "attn_out": "self_attn.o_proj", + "q_norm": "self_attn.q_norm", + "k_norm": "self_attn.k_norm", + "ff_proj": "mlp.gate_up_proj", + "ff_out": "mlp.down_proj", + "attn_norm": "input_layernorm", + "ff_norm": "post_attention_layernorm", + }, + orig_to_new_prefix={ + # vision backbone mapping + "model.vision_backbone.": "vision_backbone.", + # language backbone mapping + "model.transformer.blocks.": "model.layers.", + "model.transformer.ln_f.": "model.norm.", + # lm_head is renamed to model.transformer.mlp.down_proj firstly, + # we need to run a second renaming for it + "model.transformer.mlp.down_proj.": "lm_head.", + }, + ) + + packed_modules_mapping = { + "qkv_proj": ["qkv_proj"], + "gate_up_proj": ["gate_up_proj"], # language model + "merged_linear": ["gate_proj", "up_proj"] # image_projector + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + lora_config = vllm_config.lora_config + self.config = config + self.multimodal_config = multimodal_config + self.lora_config = lora_config + + vision_config = VisionBackboneConfig() + self.vision_backbone = MolmoVisionBackbone(config, vision_config, + quant_config) + self.model = MolmoModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.img_patch_id = None + + if self.config.weight_tying: + self.lm_head = self.model.transformer.wte + else: + self.lm_head = ParallelLMHead( + config.embedding_size or config.vocab_size, + config.hidden_size, + quant_config=quant_config, + ) + + self.logits_processor = LogitsProcessor(config.embedding_size + or config.vocab_size) + self.sampler = get_sampler() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def _parse_and_validate_image_input( + self, + **kwargs: object, + ) -> Optional[MolmoImageInputs]: + images = kwargs.pop("images", None) + if images is None: + return None + + if not isinstance(images, (torch.Tensor, list)): + raise ValueError("Incorrect type of images. " + f"Got type: {type(images)}") + + image_masks = kwargs.pop("image_masks", None) + if not (image_masks is None or isinstance(image_masks, + (torch.Tensor, list))): + raise ValueError("Incorrect type of image_masks. " + f"Got type: {type(image_masks)}") + + feat_is_patch = kwargs.pop("feat_is_patch", None) + if not isinstance(feat_is_patch, (torch.Tensor, list)): + raise ValueError("Incorrect type of feat_is_patch. " + f"Got type: {type(feat_is_patch)}") + + embed_is_patch = kwargs.pop("embed_is_patch", None) + if not isinstance(embed_is_patch, (torch.Tensor, list)): + raise ValueError("Incorrect type of embed_is_patch. " + f"Got type: {type(embed_is_patch)}") + + num_crops = kwargs.pop("num_crops", None) + if not isinstance(num_crops, (torch.Tensor, list)): + raise ValueError("Incorrect type of num_crops. " + f"Got type: {type(num_crops)}") + + img_patch_id = kwargs.pop("img_patch_id", None) + if not isinstance(img_patch_id, torch.Tensor): + raise ValueError("Incorrect type of img_patch_id. " + f"Got type: {type(img_patch_id)}") + self.img_patch_id = img_patch_id.flatten().unique().item() + + embed_is_patch = flatten_bn(embed_is_patch) + num_crops = flatten_bn(num_crops, concat=True) + + return MolmoImageInputs( + images=images, + image_masks=image_masks, + feat_is_patch=feat_is_patch, + embed_is_patch=embed_is_patch, + num_crops=num_crops, + ) + + def _process_image_input( + self, + image_input: MolmoImageInputs, + ) -> list[torch.Tensor]: + images = image_input["images"] + image_masks = image_input["image_masks"] + feat_is_patch = image_input["feat_is_patch"] + num_crops = image_input["num_crops"] + + # Call the vision backbone on the whole batch at once + images_flat = flatten_bn(images, concat=True) + image_masks_flat = (None if image_masks is None else flatten_bn( + image_masks, concat=True)) + feat_is_patch_flat = flatten_bn(feat_is_patch, concat=True) + + image_features_flat = self.vision_backbone( + images=images_flat.unsqueeze(0), + image_masks=(None if image_masks_flat is None else + image_masks_flat.unsqueeze(0)), + ).squeeze(0) + + # Only the features corresponding to patch tokens are relevant + return [ + feats[f_is_patch] for feats, f_is_patch in zip( + image_features_flat.split(num_crops.tolist()), + feat_is_patch_flat.split(num_crops.tolist()), + ) + ] + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + + image_features = self._process_image_input(image_input) + + return scatter_patch_features( + image_features, + image_input["embed_is_patch"], + ) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + assert self.img_patch_id is not None + + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + select_patch_features(multimodal_embeddings), + self.img_patch_id, + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.LongTensor, + positions: torch.LongTensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> SamplerOutput: + + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.model(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds) + + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + loader = AutoWeightsLoader(self) + weights = _get_weights_with_merged_embedding(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="model", + connector="vision_backbone.image_projector", + tower_model="vision_backbone", + ) + + +def _get_weights_with_merged_embedding( + weights: Iterable[Tuple[str, torch.Tensor]] +) -> Iterable[Tuple[str, torch.Tensor]]: + embedding_weights = {} + for name, weight in weights: + if "wte.embedding" in name: + embedding_weights["embedding"] = weight + elif "wte.new_embedding" in name: + embedding_weights["new_embedding"] = weight + else: + yield (name, weight) + # this is compatible with most of quantization, + # because they won't quantize embed_tokens + embedding_weights = torch.cat( + [embedding_weights["embedding"], embedding_weights["new_embedding"]], + dim=0, + ) + yield ("model.embed_tokens.weight", embedding_weights) diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py new file mode 100644 index 0000000..b30f3ee --- /dev/null +++ b/vllm/model_executor/models/mpt.py @@ -0,0 +1,339 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main +import math +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +import torch.nn as nn + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.mpt import MPTConfig + +from .interfaces import SupportsPP +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +def _get_alibi_slopes( + total_num_heads: int, + alibi_bias_max: int, +) -> torch.Tensor: + next_power_of_2 = 2**math.ceil(math.log2(total_num_heads)) + m = torch.arange(1, next_power_of_2 + 1, dtype=torch.float32) + m = m.mul(alibi_bias_max / next_power_of_2) + slopes = 1.0 / torch.pow(2, m) + if next_power_of_2 != total_num_heads: + slopes = torch.concat([slopes[1::2], slopes[::2]])[:total_num_heads] + return slopes + + +class MPTAttention(nn.Module): + + def __init__( + self, + config: MPTConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.d_model = config.d_model + self.total_num_heads = config.n_heads + self.head_dim = self.d_model // self.total_num_heads + self.clip_qkv = config.attn_config["clip_qkv"] + self.qk_ln = config.attn_config["qk_ln"] + self.alibi_bias_max = config.attn_config["alibi_bias_max"] + if "kv_n_heads" in config.attn_config: + self.total_num_kv_heads = config.attn_config['kv_n_heads'] + else: + self.total_num_kv_heads = self.total_num_heads + assert not config.attn_config["prefix_lm"] + assert config.attn_config["alibi"] + + # pylint: disable=invalid-name + self.Wqkv = QKVParallelLinear( + self.d_model, + self.d_model // self.total_num_heads, + self.total_num_heads, + self.total_num_kv_heads, + bias=not config.no_bias, + quant_config=quant_config, + ) + if self.qk_ln: + self.q_ln = nn.LayerNorm(self.d_model) + self.k_ln = nn.LayerNorm(self.d_model) + self.out_proj = RowParallelLinear( + self.d_model, + self.d_model, + bias=not config.no_bias, + quant_config=quant_config, + ) + + tp_world_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_world_size == 0 + self.num_heads = self.total_num_heads // tp_world_size + + if self.total_num_kv_heads >= tp_world_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_world_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_world_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + # Create the alibi slopes and slice them. + tp_rank = get_tensor_model_parallel_rank() + head_start = tp_rank * self.num_heads + head_end = (tp_rank + 1) * self.num_heads + alibi_slopes = _get_alibi_slopes(self.total_num_heads, + self.alibi_bias_max) + alibi_slopes = alibi_slopes[head_start:head_end].tolist() + + self.head_dim = self.d_model // self.total_num_heads + scaling = self.head_dim**-0.5 + self.attn = Attention(self.num_heads, + self.head_dim, + scaling, + alibi_slopes=alibi_slopes, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + del position_ids # unused. + qkv, _ = self.Wqkv(hidden_states) + if self.clip_qkv is not None: + qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if self.qk_ln: + q = self.q_ln(q) + k = self.k_ln(k) + attn_output = self.attn(q, k, v) + output, _ = self.out_proj(attn_output) + return output + + +class MPTMLP(nn.Module): + + def __init__( + self, + config: MPTConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + hidden_size = config.d_model + expansion_ratio = config.expansion_ratio + intermediate_size = expansion_ratio * hidden_size + self.up_proj = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=not config.no_bias, + quant_config=quant_config, + ) + self.act = get_act_fn("gelu") + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=not config.no_bias, + quant_config=quant_config, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.up_proj(x) + x = self.act(x) + x, _ = self.down_proj(x) + return x + + +class MPTBlock(nn.Module): + + def __init__( + self, + config: MPTConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + hidden_size = config.d_model + self.norm_1 = nn.LayerNorm(hidden_size) + self.attn = MPTAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.attn") + self.norm_2 = nn.LayerNorm(hidden_size) + self.ffn = MPTMLP(config, quant_config) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + x = self.norm_1(hidden_states) + x = self.attn( + position_ids=position_ids, + hidden_states=x, + ) + hidden_states = hidden_states + x + x = self.norm_2(hidden_states) + x = self.ffn(x) + hidden_states = hidden_states + x + return hidden_states + + +@support_torch_compile +class MPTModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + assert config.embedding_fraction == 1.0 + assert config.norm_type == "low_precision_layernorm" + + self.wte = VocabParallelEmbedding( + config.vocab_size, + config.d_model, + ) + self.start_layer, self.end_layer, self.blocks = make_layers( + config.n_layers, + lambda prefix: MPTBlock( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.blocks") + self.norm_f = nn.LayerNorm(config.d_model) + if config.no_bias: + for module in self.modules(): + if hasattr(module, "bias") and isinstance( + module.bias, nn.Parameter): + # Remove the bias term in Linear and LayerNorm. + module.register_parameter("bias", None) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.d_model)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + + for block in self.blocks[self.start_layer:self.end_layer]: + hidden_states = block(position_ids, hidden_states) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + hidden_states = self.norm_f(hidden_states) + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class MPTForCausalLM(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + assert config.tie_word_embeddings + self.quant_config = quant_config + + self.transformer = MPTModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "transformer")) + self.lm_head = self.transformer.wte + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py new file mode 100644 index 0000000..0ea296b --- /dev/null +++ b/vllm/model_executor/models/nemotron.py @@ -0,0 +1,517 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Nemotron model compatible with HuggingFace weights.""" +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union + +import torch +from torch import nn + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import NemotronConfig + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +# The architecture is pretty similar to Llama, with these changes: +# - There is no gate_proj, just up_proj +# - Normal LayerNorm (with a +1 to the weights) instead of RMSNorm +# - Squared ReLU instead of SwiGLU +# - Adds a partial_rotary_factor to RoPE + + +def _cast_if_autocast_enabled(*args): + if not torch.is_autocast_enabled(): + return args + else: + return torch.amp.autocast_mode._cast( + args, device_type="cuda", dtype=torch.get_autocast_gpu_dtype()) + + +class NemotronLayerNorm1P(nn.LayerNorm): + + def __init__(self, + normalized_shape: Union[int, List[int], torch.Size], + eps: float = 1e-5, + elementwise_affine: bool = True, + bias: bool = True, + device=None, + dtype=None): + super().__init__(normalized_shape, eps, elementwise_affine, bias, + device, dtype) + + def forward( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if residual is not None: + x = x + residual + residual = x + args = _cast_if_autocast_enabled(x, self.normalized_shape, + self.weight + 1, self.bias, self.eps) + with torch.amp.autocast("cuda", enabled=False): + x = torch.nn.functional.layer_norm(*args) + return x if residual is None else (x, residual) + + +class NemotronMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.up_proj = ColumnParallelLinear(input_size=hidden_size, + output_size=intermediate_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.up_proj") + self.down_proj = RowParallelLinear(input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj") + self.act_fn = get_act_fn(hidden_act) + + def forward(self, x): + up, _ = self.up_proj(x) + x = self.act_fn(up) + x, _ = self.down_proj(x) + return x + + +class NemotronAttention(nn.Module): + + def __init__( + self, + config: NemotronConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr(config, "head_dim", + self.hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.partial_rotary_factor = config.partial_rotary_factor + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + partial_rotary_factor=self.partial_rotary_factor, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class NemotronDecoderLayer(nn.Module): + + def __init__( + self, + config: NemotronConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) + self.self_attn = NemotronAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = NemotronMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = NemotronLayerNorm1P(config.hidden_size, + eps=config.norm_eps) + self.post_attention_layernorm = NemotronLayerNorm1P( + config.hidden_size, eps=config.norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class NemotronModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + else: + self.embed_tokens = PPMissingLayer() + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: NemotronDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers") + if get_pp_group().is_last_rank: + self.norm = NemotronLayerNorm1P(config.hidden_size, + eps=config.norm_eps) + else: + self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer(positions, hidden_states, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + assert isinstance(config, NemotronConfig) + + self.config = config + self.lora_config = lora_config + self.quant_config = quant_config + + self.model = NemotronModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + else: + self.lm_head = PPMissingLayer() + + self.sampler = get_sampler() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/nemotron_nas.py b/vllm/model_executor/models/nemotron_nas.py new file mode 100644 index 0000000..5c9b04c --- /dev/null +++ b/vllm/model_executor/models/nemotron_nas.py @@ -0,0 +1,454 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only deci model compatible with HuggingFace weights.""" +from typing import Iterable, Optional, Set, Tuple, Type, Union + +import torch +from torch import nn +from transformers import LlamaConfig + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.llama import LlamaAttention, LlamaMLP +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import HasNoOps, SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +def _ffn_mult_to_intermediate_size(ffn_mult: float, n_embd: int) -> int: + # DeciLM-specific code + intermediate_size = int(2 * ffn_mult * n_embd / 3) + return _find_multiple(intermediate_size, 256) + + +def _find_multiple(n: int, k: int) -> int: + # DeciLM-specific code + if n % k == 0: + return n + return n + k - (n % k) + + +class DeciLMDecoderLayer(nn.Module): + + def __init__( + self, + config: LlamaConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + block_config = config.block_configs[layer_idx] + self._is_no_op_attention = block_config.attention.no_op + self._is_no_op_ffn = block_config.ffn.no_op + + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) + bias_o_proj = attention_bias + # support internlm/internlm3-8b with qkv_bias + if hasattr(config, "qkv_bias"): + attention_bias = config.qkv_bias + + if not self._is_no_op_attention: + num_kv_heads = (config.num_attention_heads // + block_config.attention.n_heads_in_group) + self.self_attn = LlamaAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=num_kv_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + bias_o_proj=bias_o_proj, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + if not self._is_no_op_ffn: + ffn_mult = block_config.ffn.ffn_mult + intermediate_size = _ffn_mult_to_intermediate_size( + ffn_mult, config.hidden_size) + + self.mlp = LlamaMLP( + hidden_size=self.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + + if self._is_no_op_attention: + pass + else: + if (residual is None): + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + if not self._is_no_op_ffn: + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class DeciModel(nn.Module): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: Type[DeciLMDecoderLayer] = DeciLMDecoderLayer, + ): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.quant_config = quant_config + self.padding_idx = config.pad_token_id + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + else: + self.embed_tokens = PPMissingLayer() + + def get_layer(prefix: str): + layer_idx = int(prefix.rsplit(".", 1)[1]) + return layer_type( + config, + layer_idx, + cache_config, + quant_config=quant_config, + prefix=prefix, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + get_layer, + prefix=f"{prefix}.layers", + ) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + kv_cache_index = 0 + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + if not layer._is_no_op_attention: + hidden_states, residual = layer(positions, hidden_states, + residual) + kv_cache_index += 1 + else: + hidden_states, residual = layer(positions, hidden_states, + residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name)): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + if "scale" in name: + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class DeciLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, HasNoOps): + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + # Mistral/Llama models can also be loaded with --load-format mistral + # from consolidated.safetensors checkpoints + mistral_mapping = { + "layers": "model.layers", + "attention": "self_attn", + "wq": "q_proj", + "wk": "k_proj", + "wv": "v_proj", + "wo": "o_proj", + "attention_norm": "input_layernorm", + "feed_forward": "mlp", + "w1": "gate_proj", + "w2": "down_proj", + "w3": "up_proj", + "ffn_norm": "post_attention_layernorm", + "tok_embeddings": "model.embed_tokens", + "output": "lm_head", + "norm": "model.norm", + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config + self.lora_config = lora_config + + self.model = self._init_model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else + lora_config.lora_vocab_padding_size), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights( + self.model.embed_tokens) + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + else: + self.lm_head = PPMissingLayer() + + self.sampler = get_sampler() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def _init_model(self, vllm_config: VllmConfig, prefix: str = ""): + return DeciModel(vllm_config=vllm_config, prefix=prefix) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample(self, logits: torch.Tensor, + sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/nvlm_d.py b/vllm/model_executor/models/nvlm_d.py new file mode 100644 index 0000000..9d04f30 --- /dev/null +++ b/vllm/model_executor/models/nvlm_d.py @@ -0,0 +1,241 @@ +# SPDX-License-Identifier: Apache-2.0 + +# adapted from https://huggingface.co/nvidia/NVLM-D-72B/blob/main/modeling_nvlm_d.py +# -------------------------------------------------------- +# NVLM-D +# Copyright (c) 2024 NVIDIA +# Licensed under Apache 2.0 License [see LICENSE for details] +# -------------------------------------------------------- +from collections.abc import Mapping, Sequence +from typing import Optional + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalKwargs +from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, + MultiModalDataItems) +from vllm.multimodal.processing import (PromptReplacement, PromptUpdate, + PromptUpdateDetails) +from vllm.multimodal.profiling import ProcessorInputs + +from .intern_vit import InternVisionModel +from .internvl import (BaseInternVLProcessingInfo, BaseInternVLProcessor, + InternVLChatModel, InternVLDummyInputsBuilder, + InternVLMultiModalProcessor) + +IMG_PAD = "<|vision_pad|>" + + +class NVLMProcessor(BaseInternVLProcessor): + + @property + def image_token_id(self) -> int: + return self.tokenizer.get_vocab()[IMG_PAD] + + def get_image_repl( + self, + feature_size: int, + num_patches: Optional[int], + ) -> PromptUpdateDetails[str]: + if num_patches is None: + raise NotImplementedError("Embedding inputs are not supported") + + tile_pos_identifiers = [f"" for i in range(1, num_patches)] + if self.use_thumbnail: + tile_pos_identifiers += [""] + + context_size = feature_size // num_patches + features = "".join(identifier + IMG_PAD * context_size + for identifier in tile_pos_identifiers) + + # We include the start and end as well because "<", "tile"], resulting in assertion error + # when trying to find "" + + return PromptUpdateDetails(full=repl, features=repl) + + +class NVLMProcessingInfo(BaseInternVLProcessingInfo): + + def get_hf_processor( + self, + *, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + **kwargs: object, + ) -> NVLMProcessor: + if min_dynamic_patch is not None: + kwargs["min_dynamic_patch"] = min_dynamic_patch + if max_dynamic_patch is not None: + kwargs["max_dynamic_patch"] = max_dynamic_patch + if dynamic_image_size is not None: + kwargs["dynamic_image_size"] = dynamic_image_size + + return self.ctx.init_processor( + NVLMProcessor, + config=self.get_hf_config(), + tokenizer=self.get_tokenizer(), + **kwargs, + ) + + def get_max_image_tokens(self) -> int: + hf_processor = self.get_hf_processor() + tokenizer = hf_processor.tokenizer + + max_num_patches = hf_processor.max_dynamic_patch + # we need +1 here because max_dynamic_patch in config doesn't + # include the thumbnail patch + tile_pos_identifiers = [ + f"" for i in range(max_num_patches) + ] + if hf_processor.use_thumbnail and max_num_patches != 1: + tile_pos_identifiers += [""] + + # "<", "tile"] + # so we include in the start_str + start_str = "" + tile_pos_identifiers.pop(0) + end_str = "" + start_token_len = len(tokenizer.encode(start_str)) + end_token_len = len(tokenizer.encode(end_str)) + tile_token_len = sum( + len(tokenizer.encode(identifier)) + for identifier in tile_pos_identifiers) + non_image_tokens_num = start_token_len + end_token_len + tile_token_len + return super().get_max_image_tokens() + non_image_tokens_num + + +class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + target_width, target_height = \ + self.info.get_image_size_with_most_features() + num_images = mm_counts.get("image", 0) + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + return ProcessorInputs( + # The newline is necessary to separate ">" of the current item + # and "<" of the next item + prompt_text="\n" * num_images, + mm_data=mm_data, + ) + + +class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]): + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + if "image_num_patches" in out_mm_kwargs: + image_num_patches = out_mm_kwargs["image_num_patches"] + assert isinstance(image_num_patches, torch.Tensor) + image_num_patches = image_num_patches.tolist() + elif "image_embeds" in out_mm_kwargs: + # TODO: Use image size information in dictionary embedding inputs + # to compute num_patches (similar to Qwen2-VL) + image_num_patches = [None] * len(out_mm_kwargs["image_embeds"]) + else: + image_num_patches = [] + + def get_replacement_nvlm(item_idx: int): + images = mm_items.get_items( + "image", (ImageEmbeddingItems, ImageProcessorItems)) + + if isinstance(images, ImageEmbeddingItems): + feature_size = images.get_feature_size(item_idx) + else: + image_size = images.get_image_size(item_idx) + feature_size = self.info.get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + processor=hf_processor, + ) + + num_patches = image_num_patches[item_idx] + if num_patches is not None: + assert isinstance(num_patches, int) + + repl = hf_processor.get_image_repl(feature_size, num_patches) + + return PromptUpdateDetails( + full=repl.full + "\n", + features=repl.features + "\n", + ) + + # See note in dummy data regarding why we have the extra newline + return [ + PromptReplacement( + modality="image", + target="\n", + replacement=get_replacement_nvlm, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor(NVLMMultiModalProcessor, + info=NVLMProcessingInfo, + dummy_inputs=NVLMDummyInputsBuilder) +class NVLM_D_Model(InternVLChatModel): + + def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential: + vit_hidden_size = config.vision_config.hidden_size + llm_intermediate_size = config.text_config.intermediate_size + llm_hidden_size = config.text_config.hidden_size + + return nn.Sequential( + nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2), + nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2, + llm_intermediate_size, + bias=False), + nn.GELU(), + nn.Linear(llm_intermediate_size, llm_hidden_size, bias=False), + ) + + def _init_vision_model( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig], + *, + is_mono: bool, + prefix: str, + ): + if not is_mono: + vision_feature_layer = config.select_layer + if vision_feature_layer < 0: + num_hidden_layers = config.vision_config.num_hidden_layers \ + + vision_feature_layer + 1 + else: + num_hidden_layers = vision_feature_layer + 1 + + # We added additional dummy heads to the original num of heads to + # make the number of heads divisible by 8. + return InternVisionModel( + config.vision_config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers, + num_dummy_heads=7, + prefix=prefix, + ) + else: + msg = "Monolith mode is not applicable to NVLM_D" + raise NotImplementedError(msg) diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py new file mode 100644 index 0000000..4a341c9 --- /dev/null +++ b/vllm/model_executor/models/olmo.py @@ -0,0 +1,400 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/olmo/modeling_olmo.py +# Copyright 2024 The vLLM team. +# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only OLMo model compatible with HuggingFace weights.""" +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import OlmoConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class OlmoAttention(nn.Module): + """ + This is the attention block where the output is computed as + ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__( + self, + config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + self.total_num_heads = config.num_attention_heads + + assert self.hidden_size % self.total_num_heads == 0 + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + + self.num_heads = (self.total_num_heads // + tensor_model_parallel_world_size) + self.head_dim = self.hidden_size // self.total_num_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.clip_qkv = config.clip_qkv + + # Attention input projection. Projects x -> (q, k, v) + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + bias=config.attention_bias, + quant_config=quant_config, + ) + + # Rotary embeddings. + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + ) + self.scaling = self.head_dim**-0.5 + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.scaling, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + # Attention output projection. + self.o_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=config.attention_bias, + quant_config=quant_config, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + if self.clip_qkv is not None: + qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) + q, k, v = qkv.chunk(chunks=3, dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class OlmoMLP(nn.Module): + """ + This is the MLP block where the output is computed as + ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__( + self, + config: OlmoConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + # Feed-forward input projection. + self.gate_up_proj = MergedColumnParallelLinear( + self.hidden_size, + [self.intermediate_size] * 2, + bias=False, + quant_config=quant_config, + ) + + # Activation function. + self.act_fn = SiluAndMul() + + # Feed-forward output projection. + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + quant_config=quant_config, + ) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class OlmoDecoderLayer(nn.Module): + """ + This is a typical transformer block where the output is + computed as ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__(self, + config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + # Attention block. + self.self_attn = OlmoAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.self_attn") + + # MLP block. + self.mlp = OlmoMLP(config, quant_config) + + # LayerNorm + self.input_layernorm = nn.LayerNorm(config.hidden_size, + elementwise_affine=False, + bias=False) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + elementwise_affine=False, + bias=False) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + # Attention block. + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn(positions, hidden_states) + hidden_states = hidden_states + residual + + # MLP block. + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +@support_torch_compile +class OlmoModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: OlmoDecoderLayer( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers") + self.norm = nn.LayerNorm(config.hidden_size, + elementwise_affine=False, + bias=False) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + """ + :param input_ids: A tensor of shape `(batch_size, seq_len)`. + """ + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + + # Apply blocks one-by-one. + for layer in self.layers[self.start_layer:self.end_layer]: + # shape: (batch_size, seq_len, d_model) + hidden_states = layer(positions, hidden_states) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + # Apply final layer norm. + # shape: (batch_size, seq_len or 1, d_model) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class OlmoForCausalLM(nn.Module, SupportsPP): + """ + Extremely barebones HF model wrapper. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.model = OlmoModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py new file mode 100644 index 0000000..f9427cd --- /dev/null +++ b/vllm/model_executor/models/olmo2.py @@ -0,0 +1,422 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmo2/modeling_olmo2.py +# Copyright 2024 The vLLM team. +# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only OLMo2 model compatible with HuggingFace weights.""" + +from functools import partial +from typing import Iterable, Optional, Tuple, Union + +import torch +from torch import nn + +from vllm.attention import Attention +from vllm.config import VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.distributed.communication_op import tensor_model_parallel_all_gather +from vllm.distributed.parallel_state import get_tensor_model_parallel_rank +from vllm.distributed.utils import split_tensor_along_last_dim +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import SupportsPP +from vllm.model_executor.models.utils import ( + is_pp_missing_parameter, make_empty_intermediate_tensors_factory, + make_layers, maybe_prefix) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.olmo2 import Olmo2Config + + +class Olmo2Attention(nn.Module): + """ + This is the attention block where the output is computed as + ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.config = vllm_config.model_config.hf_config + assert isinstance(self.config, Olmo2Config) + + hidden_size = self.config.hidden_size + self.tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = self.config.num_attention_heads + + assert hidden_size % self.total_num_heads == 0 + assert self.total_num_heads % self.tp_size == 0 + + self.num_heads = self.total_num_heads // self.tp_size + self.total_num_kv_heads = (self.config.num_key_value_heads + or self.total_num_heads) + if self.total_num_kv_heads >= self.tp_size: + assert self.total_num_kv_heads % self.tp_size == 0 + else: + assert self.tp_size % self.total_num_kv_heads == 0 + + self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.max_position_embeddings = self.config.max_position_embeddings + self.rope_theta = self.config.rope_theta + + # Attention input projection. Projects x -> (q, k, v) + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=vllm_config.quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.tp_rank = get_tensor_model_parallel_rank() + self.k_norm = RMSNorm( + self.total_num_kv_heads * self.head_dim, + eps=self.config.rms_norm_eps, + ) + self.q_norm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) + + # Rotary embeddings. + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, # type: ignore + ) + self.scaling = self.head_dim**-0.5 + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + prefix=prefix, + ) + + # Attention output projection. + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=vllm_config.quant_config, + prefix=f"{prefix}.o_proj", + ) + + def _apply_qk_norm(self, q: torch.Tensor, + k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if self.tp_size > 1: + q = tensor_model_parallel_all_gather(q.contiguous()) + k = tensor_model_parallel_all_gather(k.contiguous()) + q = self.q_norm.forward_native(q) + k = self.k_norm.forward_native(k) + if self.tp_size > 1: + splitter = partial(split_tensor_along_last_dim, + num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + return q, k + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self._apply_qk_norm(q, k) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class Olmo2MLP(nn.Module): + """ + This is the MLP block where the output is computed as + ``MLP(x)`` in ``LN(MLP(x + LN(Attention(x))))`` + (plus another skip connection). + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + assert isinstance(config, Olmo2Config) + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + + # Feed-forward input projection. + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=vllm_config.quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + + # Activation function. + self.act_fn = SiluAndMul() + + # Feed-forward output projection. + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=vllm_config.quant_config, + prefix=f"{prefix}.down_proj", + ) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Olmo2DecoderLayer(nn.Module): + """ + This is a typical transformer block where the output is + computed as ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + assert isinstance(config, Olmo2Config) + # Attention block. + self.self_attn = Olmo2Attention(vllm_config=vllm_config, + prefix=f"{prefix}.self_attn") + + # MLP block. + self.mlp = Olmo2MLP(vllm_config=vllm_config, prefix=f"{prefix}.mlp") + + # LayerNorm + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + self.post_feedforward_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + # Attention block. + residual = hidden_states + hidden_states = self.self_attn(positions, hidden_states) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = hidden_states + residual + + # MLP block. + residual = hidden_states + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class Olmo2Model(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.config = vllm_config.model_config.hf_config + assert isinstance(self.config, Olmo2Config) + + self.embed_tokens = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + prefix=f"{prefix}.embed_tokens", + ) + self.start_layer, self.end_layer, self.layers = make_layers( + self.config.num_hidden_layers, + lambda prefix: Olmo2DecoderLayer(vllm_config=vllm_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + self.norm = RMSNorm( + self.config.hidden_size, + eps=self.config.rms_norm_eps, + ) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + self.config.hidden_size)) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + """ + :param input_ids: A tensor of shape `(batch_size, seq_len)`. + """ + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + # Get embeddings of input. + # shape: (batch_size, seq_len, d_model) + else: + hidden_states = self.embed_tokens(input_ids) + + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + assert isinstance(hidden_states, torch.Tensor) + + # Apply blocks one-by-one. + for layer in self.layers[self.start_layer:self.end_layer]: + # shape: (batch_size, seq_len, d_model) + hidden_states = layer(positions, hidden_states) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + + # Apply final layer norm. + # shape: (batch_size, seq_len or 1, d_model) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class Olmo2ForCausalLM(nn.Module, SupportsPP): + """ + Extremely barebones HF model wrapper. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + assert isinstance(config, Olmo2Config) + self.config = config + self.model = Olmo2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=vllm_config.quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if is_pp_missing_parameter(name, self): + continue + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader # type: ignore + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py new file mode 100644 index 0000000..6cf3f1f --- /dev/null +++ b/vllm/model_executor/models/olmoe.py @@ -0,0 +1,455 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only OLMoE model compatible with HuggingFace weights.""" +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + + +class OlmoeMoE(nn.Module): + """A tensor-parallel MoE implementation for Olmoe that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__(self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = ""): + super().__init__() + self.hidden_size = hidden_size + + # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear(hidden_size, + num_experts, + bias=False, + quant_config=None) + + self.experts = FusedMoE(num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + reduce_results=True, + renormalize=False, + quant_config=quant_config, + tp_size=tp_size, + prefix=f"{prefix}.experts") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + return final_hidden_states.view(orig_shape) + + +class OlmoeAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 4096, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + self.q_norm = RMSNorm(hidden_size, eps=1e-5) + self.k_norm = RMSNorm(hidden_size, eps=1e-5) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=True, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous()) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class OlmoeDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 4096) + + self.self_attn = OlmoeAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + self.mlp = OlmoeMoE( + num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class OlmoeModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: OlmoeDecoderLayer( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers") + self.norm = RMSNorm(config.hidden_size, eps=1e-5) + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class OlmoeForCausalLM(nn.Module, SupportsPP): + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = OlmoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + logger.warning_once( + "Found kv scale in the checkpoint " + f"(e.g. {name}), but not found the expected " + f"name in the model " + f"(e.g. {remapped_kv_scale_name}). " + "kv-scale is not loaded.") + continue + else: + name = remapped_kv_scale_name + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py new file mode 100644 index 0000000..d4c2b4c --- /dev/null +++ b/vllm/model_executor/models/opt.py @@ -0,0 +1,413 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py +# Copyright 2023 The vLLM team. +# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights +# reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only OPT model compatible with HuggingFace weights.""" +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import OPTConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the + # embedding ids by 2 and adjust num_embeddings appropriately. Other + # models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward(self, positions: torch.Tensor): + return super().forward(positions + self.offset) + + +class OPTAttention(nn.Module): + + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.embed_dim = embed_dim + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + total_num_heads = num_heads + assert num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = total_num_heads // tensor_model_parallel_world_size + self.head_dim = embed_dim // total_num_heads + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + embed_dim, + self.head_dim, + total_num_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.out_proj = RowParallelLinear( + embed_dim, + embed_dim, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.scaling, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + attn_output = self.attn(q, k, v) + output, _ = self.out_proj(attn_output) + return output + + +class OPTDecoderLayer(nn.Module): + + def __init__( + self, + config: OPTConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.self_attn = OPTAttention( + embed_dim=self.embed_dim, + num_heads=config.num_attention_heads, + bias=config.enable_bias, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.do_layer_norm_before = config.do_layer_norm_before + + self.self_attn_layer_norm = nn.LayerNorm( + self.embed_dim, + elementwise_affine=config.layer_norm_elementwise_affine) + self.fc1 = ColumnParallelLinear( + self.embed_dim, + config.ffn_dim, + bias=config.enable_bias, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + self.activation_fn = get_act_fn(config.activation_function) + self.fc2 = RowParallelLinear( + config.ffn_dim, + self.embed_dim, + bias=config.enable_bias, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) + self.final_layer_norm = nn.LayerNorm( + self.embed_dim, + elementwise_affine=config.layer_norm_elementwise_affine) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn(hidden_states=hidden_states) + hidden_states = residual + hidden_states + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Fully Connected + residual = hidden_states + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + hidden_states = residual + hidden_states + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + return hidden_states + + +class OPTDecoder(nn.Module): + + def __init__( + self, + config: OPTConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.max_target_positions = config.max_position_embeddings + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.word_embed_proj_dim, + ) + # Positional embeddings are replicated (not sharded). + self.embed_positions = OPTLearnedPositionalEmbedding( + config.max_position_embeddings, config.hidden_size) + + # Project out & in will be replicated if they exist. + if config.word_embed_proj_dim != config.hidden_size: + self.project_out = ReplicatedLinear(config.hidden_size, + config.word_embed_proj_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.project_out") + else: + self.project_out = None + + if config.word_embed_proj_dim != config.hidden_size: + self.project_in = ReplicatedLinear(config.word_embed_proj_dim, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.project_in") + else: + self.project_in = None + + # Note that the only purpose of `config._remove_final_layer_norm` is to + # keep backward compatibility with checkpoints that have been fine-tuned + # before transformers v4.20.1 + # see https://github.com/facebookresearch/metaseq/pull/164 + if config.do_layer_norm_before and not config._remove_final_layer_norm: + self.final_layer_norm = nn.LayerNorm( + config.hidden_size, + elementwise_affine=config.layer_norm_elementwise_affine) + else: + self.final_layer_norm = None + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: OPTDecoderLayer( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers") + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) + pos_embeds = self.embed_positions(positions) + if self.project_in is not None: + inputs_embeds, _ = self.project_in(inputs_embeds) + hidden_states = inputs_embeds + pos_embeds + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(hidden_states) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + if self.final_layer_norm is not None: + hidden_states = self.final_layer_norm(hidden_states) + if self.project_out is not None: + hidden_states, _ = self.project_out(hidden_states) + return hidden_states + + +@support_torch_compile +class OPTModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.decoder = OPTDecoder(config, + cache_config, + quant_config, + prefix=f"{prefix}.decoder") + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.decoder.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + return self.decoder(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds) + + +class OPTForCausalLM(nn.Module, SupportsPP): + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"] + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = OPTModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + if self.config.tie_word_embeddings: + self.lm_head = self.model.decoder.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.word_embed_proj_dim) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "lm_head.weight" in name and self.config.tie_word_embeddings: + continue + if name.startswith("decoder."): + name = "model." + name + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py new file mode 100644 index 0000000..0b42666 --- /dev/null +++ b/vllm/model_executor/models/orion.py @@ -0,0 +1,359 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/modeling_orion.py +# Copyright (c) OrionStar Inc. +# LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE +"""Inference-only Orion-14B model compatible with HuggingFace weights.""" +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class OrionMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class OrionAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class OrionDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.self_attn = OrionAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = OrionMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +@support_torch_compile +class OrionModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: OrionDecoderLayer( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers") + self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory([ + "hidden_states", + ], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(positions, hidden_states) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + }) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class OrionForCausalLM(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = OrionModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py new file mode 100644 index 0000000..6fedb8c --- /dev/null +++ b/vllm/model_executor/models/paligemma.py @@ -0,0 +1,390 @@ +# SPDX-License-Identifier: Apache-2.0 +from collections.abc import Iterable, Mapping, Sequence +from typing import Literal, Optional, Set, Tuple, TypedDict, Union + +import torch +from torch import nn +from transformers import BatchFeature, PaliGemmaConfig + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalInputs, MultiModalKwargs) +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptIndexTargets, + PromptInsertion, PromptUpdate, + PromptUpdateDetails) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .siglip import SiglipVisionModel +from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, + maybe_prefix, merge_multimodal_embeddings) +from .vision import get_vision_encoder_info + +logger = init_logger(__name__) + + +class PaliGemmaImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: `(batch_size * num_images, num_channels, height, width)`""" + + +class PaliGemmaImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: torch.Tensor + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + """ + + +PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs, + PaliGemmaImageEmbeddingInputs] + + +class PaliGemmaMultiModalProjector(nn.Module): + + def __init__(self, vision_hidden_size: int, projection_dim: int): + super().__init__() + + self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True) + + def forward(self, image_features: torch.Tensor) -> torch.Tensor: + hidden_states = self.linear(image_features) + return hidden_states + + +class PaliGemmaProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(PaliGemmaConfig) + + def get_vision_encoder_info(self): + return get_vision_encoder_info(self.get_hf_config()) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"image": self.get_num_image_tokens()} + + def get_num_image_tokens(self) -> int: + vision_encoder_info = self.get_vision_encoder_info() + return vision_encoder_info.get_max_image_tokens() + + +class PaliGemmaDummyInputsBuilder( + BaseDummyInputsBuilder[PaliGemmaProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + hf_config = self.info.get_hf_config() + vision_config = hf_config.vision_config + max_image_size = vision_config.image_size + + num_images = mm_counts.get("image", 0) + + mm_data = { + "image": + self._get_dummy_images(width=max_image_size, + height=max_image_size, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text="", + mm_data=mm_data, + ) + + +class PaliGemmaMultiModalProcessor( + BaseMultiModalProcessor[PaliGemmaProcessingInfo]): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + tokenizer = self.info.get_tokenizer() + if not mm_data: + prompt_ids = tokenizer.encode(prompt) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + return super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values=MultiModalFieldConfig.batched("image")) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_config = self.info.get_hf_config() + image_token_id = hf_config.image_token_index + + tokenizer = self.info.get_tokenizer() + num_image_tokens = self.info.get_num_image_tokens() + image_tokens = [image_token_id] * num_image_tokens + + bos_token_id = tokenizer.bos_token_id + assert isinstance(bos_token_id, int) + + # Paligemma 1 and 2 have different tokenizer.add_bos_token + # Insert *n + after for Paligemma 1 + # Insert *n + for Paligemma 2 + return [ + PromptInsertion( + modality="image", + target=PromptIndexTargets.prefix( + [bos_token_id] if tokenizer.add_bos_token else []), + insertion=PromptUpdateDetails( + full=image_tokens + [bos_token_id], + features=image_tokens, + ), + ) + ] + + def apply( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + return_mm_hashes: bool = False, + ) -> MultiModalInputs: + mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs, + return_mm_hashes) + prompt_token_ids = mm_inputs["prompt_token_ids"] + + tokenizer = self.info.get_tokenizer() + newline_prompt = "\n" + newline_token_id = tokenizer.encode(newline_prompt)[-1] # 108 + # Force to add newline at the end of prompt for paligemma's format + # This step can NOT be replacemented by current PromptUpdate methods + if len(prompt_token_ids) and prompt_token_ids[-1] != newline_token_id: + prompt_token_ids.append(newline_token_id) + mm_inputs["prompt_token_ids"] = prompt_token_ids + mm_inputs["prompt"] += newline_prompt + + return mm_inputs + + +@MULTIMODAL_REGISTRY.register_processor( + PaliGemmaMultiModalProcessor, + info=PaliGemmaProcessingInfo, + dummy_inputs=PaliGemmaDummyInputsBuilder) +class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.multimodal_config = multimodal_config + + self.vision_tower = SiglipVisionModel(config.vision_config, + quant_config, + prefix=maybe_prefix( + prefix, "vision_tower")) + self.multi_modal_projector = PaliGemmaMultiModalProjector( + vision_hidden_size=config.vision_config.hidden_size, + projection_dim=config.vision_config.projection_dim) + + self.quant_config = quant_config + + if config.text_config.model_type == "gemma": + config.text_config.architectures = ["GemmaForCausalLM"] + else: + config.text_config.architectures = ["Gemma2ForCausalLM"] + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + logit_scale = getattr(config, "logit_scale", 1.0) + self.language_model.logits_processor.scale *= logit_scale + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @property + def sampler(self): + return self.language_model.sampler + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + actual_dims = tuple(data.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("batch_size", *map(str, expected_dims)) + raise ValueError( + f"The expected shape of pixel values is {expected_expr}. " + f"You supplied {tuple(data.shape)}.") + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[PaliGemmaImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + pixel_values = flatten_bn(pixel_values, concat=True) + + return PaliGemmaImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values(pixel_values), + ) + + if image_embeds is not None: + if not isinstance(image_embeds, torch.Tensor): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + + image_embeds = flatten_bn(image_embeds, concat=True) + + return PaliGemmaImageEmbeddingInputs( + type="image_embeds", + data=image_embeds, + ) + + raise AssertionError("This line should be unreachable.") + + def _image_pixels_to_features( + self, + vision_tower: SiglipVisionModel, + pixel_values: torch.Tensor, + ) -> torch.Tensor: + + target_dtype = vision_tower.get_input_embeddings().weight.dtype + image_features = vision_tower(pixel_values.to(dtype=target_dtype)) + + return image_features + + def _process_image_input( + self, + image_input: PaliGemmaImageInputs, + ) -> torch.Tensor: + + if image_input["type"] == "image_embeds": + return image_input["data"] + + assert self.vision_tower is not None + pixel_values = image_input["data"] + image_features = self._image_pixels_to_features( + self.vision_tower, + pixel_values, + ) + + return self.multi_modal_projector(image_features) + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa + vision_embeddings = vision_embeddings * (self.config.hidden_size**-0.5) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.config.image_token_index) + return inputs_embeds + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object) -> Union[SamplerOutput, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.language_model.model(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py new file mode 100644 index 0000000..db8d170 --- /dev/null +++ b/vllm/model_executor/models/persimmon.py @@ -0,0 +1,353 @@ +# SPDX-License-Identifier: Apache-2.0 + +# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/persimmon/modeling_persimmon.py +# Copyright 2023 The vLLM team. +# Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only persimmon model compatible with HuggingFace weights.""" +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import PersimmonConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class PersimmonMLP(nn.Module): + + def __init__(self, + config: PersimmonConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.dense_h_to_4h = ColumnParallelLinear(config.hidden_size, + config.intermediate_size, + quant_config=quant_config) + self.dense_4h_to_h = RowParallelLinear(config.intermediate_size, + config.hidden_size, + quant_config=quant_config) + self.act = get_act_fn(config.hidden_act) + + def forward(self, hidden_states) -> torch.Tensor: + hidden_states, _ = self.dense_h_to_4h(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.dense_4h_to_h(hidden_states) + return hidden_states + + +class PersimmonAttention(nn.Module): + + def __init__(self, + config: PersimmonConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.config = config + tensor_parallel_world_size = get_tensor_model_parallel_world_size() + + self.hidden_size = config.hidden_size + self.total_num_heads = config.num_attention_heads + self.num_heads = self.total_num_heads // tensor_parallel_world_size + self.head_dim = self.hidden_size // self.total_num_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.partial_rotary_factor = config.partial_rotary_factor + self.is_causal = True + + assert (self.head_dim * self.total_num_heads) == self.hidden_size + assert self.total_num_heads % tensor_parallel_world_size == 0 + + self.query_key_value = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + bias=True, + quant_config=quant_config, + ) + self.dense = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=True, + quant_config=quant_config, + ) + self.is_qk_layernorm = config.qk_layernorm + + if self.is_qk_layernorm: + self.q_layernorm = nn.LayerNorm(self.head_dim) + self.k_layernorm = nn.LayerNorm(self.head_dim) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=int(self.partial_rotary_factor * self.head_dim), + max_position=self.max_position_embeddings, + base=self.rope_theta, + ) + self.scaling = self.head_dim**-0.5 + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.scaling, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def _split_heads(self, x: torch.Tensor) -> torch.Tensor: + # [seq_length, hidden_size] -> [seq_length, num_heads, head_dim] + seq_length = x.shape[0] + return x.view(seq_length, self.num_heads, self.head_dim) + + def _merge_heads(self, x: torch.Tensor) -> torch.Tensor: + # [seq_length, num_heads, head_dim] -> [seq_length, hidden_size] + seq_length = x.shape[0] + return x.view(seq_length, self.num_heads * self.head_dim) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + # [seq_length, 3 x hidden_size] + qkv, _ = self.query_key_value(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + + if self.is_qk_layernorm: + # [seq_length, num_heads, head_dim] + q = self._split_heads(q) + k = self._split_heads(k) + + q = self.q_layernorm(q) + k = self.k_layernorm(k) + + q = self._merge_heads(q) + k = self._merge_heads(k) + + q, k = self.rotary_emb(position_ids, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.dense(attn_output) + return output + + +class PersimmonDecoderLayer(nn.Module): + + def __init__(self, + config: PersimmonConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = PersimmonAttention(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn") + self.mlp = PersimmonMLP(config, quant_config=quant_config) + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states = self.self_attn( + position_ids=position_ids, + hidden_states=hidden_states, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + + hidden_states = hidden_states + residual + + outputs = hidden_states + return outputs + + +@support_torch_compile +class PersimmonModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: PersimmonDecoderLayer( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers") + self.final_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(positions, hidden_states) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + hidden_states = self.final_layernorm(hidden_states) + return hidden_states + + +class PersimmonForCausalLM(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.config = config + self.vocab_size = config.vocab_size + self.model = PersimmonModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + bias=False) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ): + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + + if "query_key_value" in name: + # copy from vllm/model_executor/models/bloom.py + # NOTE: Persimmon's fused QKV's output_dim has the shape of + # (num_heads * 3 * head_size), while the + # required shape is (3 * num_heads * head_size). + # Thus, we need weight conversion. + output_dim = getattr(param, "output_dim", None) + num_heads = self.config.num_attention_heads + if output_dim is not None: + loaded_weight_shape = loaded_weight.shape + loaded_weight = loaded_weight.view( + loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + + loaded_weight_shape[output_dim + 1:]) + loaded_weight = loaded_weight.transpose( + output_dim, output_dim + 1) + loaded_weight = loaded_weight.reshape(loaded_weight_shape) + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py new file mode 100644 index 0000000..6ee8021 --- /dev/null +++ b/vllm/model_executor/models/phi.py @@ -0,0 +1,359 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_phi.py +# Copyright 2023 The vLLM team. +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +# +# BSD 3-Clause License +# +# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Inference-only Phi-1.5 model compatible with HuggingFace weights.""" +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import PhiConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class PhiAttention(nn.Module): + + def __init__(self, + config: PhiConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.total_num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.total_num_heads + + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = (self.total_num_heads // + tensor_model_parallel_world_size) + + # pylint: disable=C0103 + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_size, + self.total_num_heads, + bias=True, + quant_config=quant_config, + ) + self.dense = RowParallelLinear( + self.hidden_size, + self.hidden_size, + quant_config=quant_config, + ) + + scaling = self.head_size**-0.5 + rotary_dim = int(config.partial_rotary_factor * + (config.hidden_size // config.num_attention_heads)) + assert rotary_dim % 2 == 0 + + # pylint: disable=C0301 + # Refer to: + # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518 + rope_theta = getattr(config, "rope_theta", 10000.0) + max_position_embeddings = getattr(config, "max_position_embeddings", + 2048) + self.rotary_emb = get_rope( + self.head_size, + rotary_dim=rotary_dim, + max_position=max_position_embeddings, + base=rope_theta, + ) + self.attn = Attention(self.num_heads, + self.head_size, + scaling, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + q, k = self.rotary_emb(position_ids, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.dense(attn_output) + return output + + +class PhiMLP(nn.Module): + + def __init__(self, + config: PhiConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + + n_inner = getattr(config, "n_inner", None) + n_inner = n_inner if n_inner is not None else 4 * config.hidden_size + + self.fc1 = ColumnParallelLinear( + config.hidden_size, + n_inner, + quant_config=quant_config, + ) + self.fc2 = RowParallelLinear( + n_inner, + config.hidden_size, + quant_config=quant_config, + ) + self.act = get_act_fn(config.hidden_act) + + def forward(self, hidden_states): + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + return hidden_states + + +class PhiLayer(nn.Module): + + def __init__(self, + config: PhiConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.self_attn = PhiAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.self_attn") + self.mlp = PhiMLP(config, quant_config) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + attn_outputs = self.self_attn( + position_ids=position_ids, + hidden_states=hidden_states, + ) + feed_forward_hidden_states = self.mlp(hidden_states) + hidden_states = attn_outputs + feed_forward_hidden_states + residual + return hidden_states + + +@support_torch_compile +class PhiModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + self.quant_config = quant_config + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: PhiLayer( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers") + self.final_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(positions, hidden_states) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states + + +class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ] + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config + # lm_head use bias, cannot share word embeddings + assert not config.tie_word_embeddings + self.lora_config = lora_config + + self.quant_config = quant_config + + self.model = PhiModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + bias=True, + quant_config=quant_config) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata, self.lm_head.bias) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v") + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # pylint: disable=E1136 + + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/phi3.py b/vllm/model_executor/models/phi3.py new file mode 100644 index 0000000..8f84e07 --- /dev/null +++ b/vllm/model_executor/models/phi3.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from llama.py +"""Inference-only Phi3 model code inherit from Llama.py""" + +from vllm.model_executor.models.llama import LlamaForCausalLM + + +class Phi3ForCausalLM(LlamaForCausalLM): + + packed_modules_mapping = { + "qkv_proj": [ + "qkv_proj", + ], + "gate_up_proj": [ + "gate_up_proj", + ], + } diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py new file mode 100644 index 0000000..e2706f8 --- /dev/null +++ b/vllm/model_executor/models/phi3_small.py @@ -0,0 +1,468 @@ +# SPDX-License-Identifier: Apache-2.0 + +import math +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers.configuration_utils import PretrainedConfig + +from vllm.attention import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +def load_column_parallel_weight(param: torch.nn.Parameter, + loaded_weight: torch.Tensor): + tp = get_tensor_model_parallel_world_size() + rk = get_tensor_model_parallel_rank() + assert param.size(0) * tp == loaded_weight.size(0) + s = rk * param.size(0) + e = (rk + 1) * param.size(0) + loaded_weight = loaded_weight[s:e] + assert param.shape == loaded_weight.shape + param.data.copy_(loaded_weight) + + +class HeadMajorQKVParallelLinear(QKVParallelLinear): + + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor): + return load_column_parallel_weight(param, loaded_weight) + + +class HeadMajorColumnParallelLinear(MergedColumnParallelLinear): + + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor): + return load_column_parallel_weight(param, loaded_weight) + + +# @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) +def quick_gelu(x): + return x * torch.sigmoid(1.702 * x) + + +# @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) +def gegelu(input, limit: Optional[float] = None): + a_gelu, a_linear = input[..., ::2], input[..., 1::2] + if limit is not None: + a_gelu = torch.where(torch.isinf(a_gelu), a_gelu, + a_gelu.clamp(min=None, max=limit)) + a_linear = torch.where( + torch.isinf(a_linear), + a_linear, + a_linear.clamp(min=-limit, max=limit), + ) + out_gelu = quick_gelu(a_gelu) + return out_gelu * (a_linear + 1) + + +class Phi3SmallMLP(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + assert (self.config.hidden_act == "gegelu" + ), "Only `gegelu` is supported for the 4.7 series of models .." + self.hidden_size = config.hidden_size + self.gegelu_limit = config.gegelu_limit + self.intermediate_size = config.intermediate_size + + self.up_proj = HeadMajorColumnParallelLinear( + self.hidden_size, + 2 * [self.intermediate_size], + bias=True, + quant_config=quant_config, + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + ) + + def forward(self, x): + gate_up, _ = self.up_proj(x) + x = gegelu(gate_up) + x, _ = self.down_proj(x) + return x + + +class Phi3SmallSelfAttention(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.config = config + self.sparse_block_size = config.blocksparse_block_size + self.homo_heads = config.blocksparse_homo_head_pattern + self.local_blocks = config.blocksparse_num_local_blocks + self.vert_stride = config.blocksparse_vert_stride + + assert (config.blocksparse_block_size == + config.blocksparse_triton_kernel_block_size) + + self.hidden_size = config.hidden_size + # Number of Query Heads + self.num_heads = config.num_attention_heads + + self.head_dim = self.hidden_size // self.num_heads + self.tp_size = get_tensor_model_parallel_world_size() + # Number of total Key Value Heads before tensor parallel + self.num_key_value_heads = config.num_key_value_heads + self.num_q_per_kv = self.num_heads // self.num_key_value_heads + if self.tp_size > 1: + assert self.num_key_value_heads % self.tp_size == 0 + self.num_kv_heads_per_partion = max( + 1, self.num_key_value_heads // self.tp_size) + self.num_heads_per_partition = self.num_heads // self.tp_size + + self.max_position_embeddings = config.max_position_embeddings + self.rope_embedding_base = config.rope_embedding_base + self.rope_position_scale = config.rope_position_scale + self.is_causal = True + + norm_factor = None + if config.mup_use_scaling: + norm_factor = self.head_dim / config.mup_attn_multiplier + else: + norm_factor = math.sqrt(self.head_dim) + self.scale = 1 / norm_factor + + self.query_key_value = HeadMajorQKVParallelLinear( + self.hidden_size, + self.head_dim, + self.num_heads, + self.num_key_value_heads, + bias=True, + quant_config=quant_config, + ) + + self.dense = RowParallelLinear(self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config) + + if getattr(self.config, "rope_scaling", None) is not None: + rope_scaling = self.config.rope_scaling + for key in rope_scaling: + if isinstance(rope_scaling[key], list): + rope_scaling[key] = tuple(rope_scaling[key]) + + if "factor" not in rope_scaling: + rope_scaling["factor"] = self.rope_position_scale + else: + rope_scaling = { + "rope_type": "linear", + "factor": self.rope_position_scale, + } + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_embedding_base, + rope_scaling=rope_scaling, + ) + + # blocksparse params + self.blocksparse_block_size = config.blocksparse_block_size + self.blocksparse_num_local_blocks = config.blocksparse_num_local_blocks + self.blocksparse_vert_stride = config.blocksparse_vert_stride + + use_dense_attn = (getattr(self.config, + "dense_attention_every_n_layers", None) + and (self.layer_idx + 1) % + self.config.dense_attention_every_n_layers == 0) + + bs_params = None + if not use_dense_attn: + bs_params = { + 'max_seqlen': self.max_position_embeddings, + 'num_heads': self.num_heads_per_partition, + "num_kv_heads": self.num_kv_heads_per_partion, + "block_size": self.sparse_block_size, + "local_blocks": self.local_blocks, + "vert_stride": self.vert_stride, + "homo_head": self.homo_heads + } + + self.attn = Attention(self.num_heads_per_partition, + self.head_dim, + self.scale, + num_kv_heads=self.num_kv_heads_per_partion, + cache_config=cache_config, + quant_config=quant_config, + blocksparse_params=bs_params, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + qkv, _ = self.query_key_value(hidden_states) + + qkv = qkv.view(qkv.shape[:-1] + + (-1, (self.num_q_per_kv + 2), self.head_dim)) + q, k, v = qkv.split([self.num_q_per_kv, 1, 1], dim=-2) + + # NOTE: this is required by RotaryEmbed, which indeed does not have to + # TODO: allow 3D QK for rotary forward + q = q.reshape(-1, self.head_dim * self.num_heads_per_partition) + k = k.reshape(-1, self.head_dim * self.num_kv_heads_per_partion) + v = v.reshape(-1, self.head_dim * self.num_kv_heads_per_partion) + + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.dense(attn_output) + + return output + + +class Phi3SmallDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Phi3SmallSelfAttention(config, + layer_idx, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn") + self.mlp = Phi3SmallMLP(config, quant_config) + + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + self.post_attention_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_epsilon) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class Phi3SmallModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.mup_embedding_multiplier = config.mup_embedding_multiplier + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Phi3SmallDecoderLayer(config, + int(prefix.split('.')[-1]), + cache_config, + quant_config, + prefix=prefix), + prefix=f"{prefix}.layers") + + self.final_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.LongTensor, + positions: Optional[torch.LongTensor], + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + if (self.mup_embedding_multiplier is not None + and self.mup_embedding_multiplier > 0.0): + hidden_states = hidden_states * self.mup_embedding_multiplier + else: + assert intermediate_tensors + hidden_states = intermediate_tensors["hidden_states"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(positions, hidden_states) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + hidden_states = self.final_layernorm(hidden_states) + return hidden_states + + +class Phi3SmallForCausalLM(nn.Module, SupportsPP): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = Phi3SmallModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.vocab_size = config.vocab_size + self.mup_width_multiplier = config.mup_width_multiplier + self.lm_head = ParallelLMHead( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + quant_config=quant_config, + ) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + # tokens in tiktoken but not used + if hasattr(config, 'dummy_token_indices'): + device = self.lm_head.weight.device + self.register_buffer('dummy_token_indices', + torch.LongTensor( + config.dummy_token_indices).to(device), + persistent=False) + else: + self.dummy_token_indices = None + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, value): + self.lm_head = value + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + if self.dummy_token_indices is not None and logits is not None: + logits.index_fill_(-1, self.dummy_token_indices, -torch.inf) + return logits + + def forward( + self, + input_ids: torch.LongTensor, + positions: Optional[torch.LongTensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + output_hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + output_hidden_states = output_hidden_states + return output_hidden_states + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + + next_tokens = self.sampler(logits / self.mup_width_multiplier, + sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + if "lm_head.weight" in name and self.config.tie_word_embeddings: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py new file mode 100644 index 0000000..d5c6498 --- /dev/null +++ b/vllm/model_executor/models/phi3v.py @@ -0,0 +1,754 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2024 The vLLM team. +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re +from collections.abc import Iterable, Mapping, Sequence +from functools import cached_property +from typing import Any, List, Literal, Optional, Set, Tuple, TypedDict, Union + +import torch +import torch.nn as nn +from transformers import (BatchFeature, CLIPVisionConfig, PretrainedConfig, + ProcessorMixin) + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, + ImageSize, MultiModalDataItems) +# yapf conflicts with isort for this block +# yapf: disable +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, BoundPromptUpdate, + PlaceholderFeaturesInfo, + PromptReplacement, PromptUpdate, + PromptUpdateDetails) +# yapf: enable +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors +from vllm.utils import is_list_of + +from .clip import CLIPVisionModel +from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP, + SupportsQuant) +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, + init_vllm_registered_model, maybe_prefix, + merge_multimodal_embeddings) + +logger = init_logger(__name__) + +# Cannot find the following 2 numbers from hf config. +_IMAGE_TOKEN_ID = 32044 + +CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0, + hidden_act="quick_gelu", + hidden_size=1024, + image_size=336, + intermediate_size=4096, + num_attention_heads=16, + num_channels=3, + num_hidden_layers=24, + patch_size=14, + projection_dim=768) + + +def _init_img_processor(hf_config: PretrainedConfig, + quant_config: Optional[QuantizationConfig], + prefix: str = "") -> CLIPVisionModel: + clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG + layer_idx = hf_config.img_processor.get('layer_idx', -2) + + # Initialize the CLIP only up to the required feature layer + if layer_idx < 0: + num_hidden_layers = clip_config.num_hidden_layers + \ + layer_idx + 1 + else: + num_hidden_layers = layer_idx + 1 + + img_processor = CLIPVisionModel( + clip_config, + quant_config, + num_hidden_layers_override=num_hidden_layers, + prefix=prefix, + ) + + return img_processor + + +class Phi3VImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: Union[torch.Tensor, List[torch.Tensor]] + """ + Shape: + `(batch_size * num_images, 1 + num_patches, num_channels, height, width)` + + Note that `num_patches` may be different per batch and image, + in which case the data is passed as a list instead of a batched tensor. + """ + + image_sizes: torch.Tensor + """ + Shape: `(batch_size * num_images, 2)` + + This should be in `(height, width)` format. + """ + + +class Phi3VImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: Union[torch.Tensor, List[torch.Tensor]] + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + """ + + +Phi3VImageInputs = Union[Phi3VImagePixelInputs, Phi3VImageEmbeddingInputs] + + +class Phi3ImageEmbeddingBase(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.layer_idx: int + self.type_feature: str + self.img_processor: CLIPVisionModel + + def get_img_features(self, + img_embeds: torch.FloatTensor) -> torch.FloatTensor: + TYPE_FEATURE = self.type_feature + + # NOTE: we skip the step to select the vision feature layer since + # this is already done inside the img_processor + img_feature = self.img_processor(img_embeds) + + if TYPE_FEATURE == "patch": + patch_feature = img_feature[:, 1:] + return patch_feature + + if TYPE_FEATURE == "cls_patch": + return img_feature + + raise NotImplementedError + + +# adapted from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py +class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase): + """Phi3 Image embedding with HD transform.""" + + def __init__(self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig], + prefix: str = "") -> None: + super().__init__() + + # n_embed or hidden_size + hidden_size = config.n_embd if hasattr( + config, 'n_embd') else config.hidden_size + + self.img_processor = _init_img_processor( + config, quant_config, prefix=f"{prefix}.img_processor") + + image_dim_out = config.img_processor['image_dim_out'] + self.num_img_tokens = config.img_processor['num_img_tokens'] + + self.image_dim_out = image_dim_out + + # global_gn and sub_gn for hd transform, serves as line separator + self.use_hd_transform = config.embd_layer.get('use_hd_transform', + False) + self.with_learnable_separator = config.embd_layer.get( + 'with_learnable_separator', False) + self.hd_transform_order = config.embd_layer.get( + 'hd_transform_order', 'glb_sub') + # with_hd_transform and with_learnable_separator should have same value + assert self.use_hd_transform and self.with_learnable_separator + + # 1024 * 4, merge spatial to channel dimension + self.glb_GN = nn.Parameter(torch.empty([1, 1, self.image_dim_out * 4])) + self.sub_GN = nn.Parameter( + torch.empty([1, 1, 1, self.image_dim_out * 4])) + + dim_projection = hidden_size + depth = 2 + layers = [nn.Linear(image_dim_out * 4, dim_projection)] + for _ in range(1, depth): + layers.extend( + [nn.GELU(), + nn.Linear(dim_projection, dim_projection)]) + self.img_projection = nn.Sequential(*layers) + + self.type_feature = config.img_processor.get('type_feature', 'patch') + + def forward(self, pixel_values: torch.FloatTensor, + image_sizes: torch.Tensor) -> torch.FloatTensor: + """ + process image and return vision embeddings. + + pixel_values: (num_images, num_crops, c, h, w) + output: (num_images, num_img_tokens, hidden_size) + """ + num_images, num_crops, c, h, w = pixel_values.shape + pixel_values = pixel_values.flatten(0, 1) + img_features = self.get_img_features(pixel_values) + img_features = img_features.reshape(num_images, num_crops, -1, + self.image_dim_out) + image_features_proj = self.hd_feature_transform( + img_features, image_sizes) + return image_features_proj + + def hd_feature_transform(self, image_features, image_sizes): + """ + image_features: (num_images, num_crops+1, 24*24, 1024) + """ + assert ( + self.hd_transform_order == 'sub_glb' + ), f'hd_transform_order `{self.hd_transform_order}` not implemented' + if isinstance(self.img_projection, nn.Sequential): + target_device = self.img_projection[0].bias.device + target_dtype = self.img_projection[0].bias.dtype + else: # It's a single nn.Linear layer + target_device = self.img_projection.bias.device + target_dtype = self.img_projection.bias.dtype + + global_image_features = image_features[:, + 0] # (num_images, 24*24, 1024) + # global feature can be viewed as a special HD case with num_crops 1x1 + global_image_features_hd = self.reshape_hd_patches_2x2merge( + global_image_features, 1, 1) + global_image_features_hd_newline = self.add_image_newline( + global_image_features_hd) + + batch_image_features_proj = [] + # need a for loop to process each image because of different image sizes + # (patch arrangement is different for each image) + for i, img_size in enumerate(image_sizes): + h, w = img_size + h_crop = h // 336 + w_crop = w // 336 + num_crops = h_crop * w_crop + + # NOTE: real num_crops is padded + # (num_crops, 24*24, 1024) + sub_image_features = image_features[i, 1:1 + num_crops] + sub_image_features_hd = self.reshape_hd_patches_2x2merge( + sub_image_features, h_crop, w_crop) + sub_image_features_hd_newline = self.add_image_newline( + sub_image_features_hd) + + # [sub features, separator, global features] + image_embeddings = torch.cat([ + sub_image_features_hd_newline.squeeze( + 0), # (h_crop*12*(w_crop*12+1), 4096) + self.glb_GN.squeeze(0), + global_image_features_hd_newline[i], + ]) + img_proj = self.img_projection( + image_embeddings.to(target_device, target_dtype)) + batch_image_features_proj.append(img_proj) + + return batch_image_features_proj + + def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop): + """ + image_features: (num_images*num_crops, 24*24, 1024) + output: (num_images, h_crop*12, w_crop*12, 4096) + where h_crop*w_crop == num_crops + """ + N, L, C = image_features.shape + assert L == 576 and C == 1024 and N % (h_crop * w_crop) == 0 + num_images = N // (h_crop * w_crop) + H = int(L**0.5) + image_features_hd = ( + image_features.reshape(N, H, H, C) # N, 24, 24, 1024 + .reshape(N, H // 2, 2, H // 2, 2, C) # N, 12, 2, 12, 2, 1024 + .permute(0, 1, 3, 2, 4, 5) # N, 12, 12, 2, 2, 1024 + .reshape(N, -1, 4 * C) # N, 144, 4096 + .reshape(num_images, h_crop, w_crop, H // 2, H // 2, + -1) # n_img, h_crop, w_crop, 12, 12, 4096 + .permute(0, 1, 3, 2, 4, 5) # n_img, h_crop, 12, w_crop, 12, 4096 + .reshape(num_images, h_crop * H // 2, w_crop * H // 2, + 4 * C) # n_img, h_crop*12, w_crop*12, 4096 + ) + return image_features_hd + + def add_image_newline(self, image_features_hd): + """ + image_features_hd: (num_images, h_crop*12, w_crop*12, 4096) + output: (num_images, (h_crop*12) * (w_crop*12+1), 4096) + """ + num_images, h, w, hid_dim = image_features_hd.shape + # add the newline token to the HD image feature patches + newline_embeddings = self.sub_GN.expand(num_images, h, -1, + -1) # (n_img, h, 1, hid_dim) + image_features_hd_newline = torch.cat( + [image_features_hd, newline_embeddings], + dim=2).reshape(num_images, -1, hid_dim) + return image_features_hd_newline + + +class Phi3VProcessingInfo(BaseProcessingInfo): + + def get_hf_processor( + self, + *, + num_crops: Optional[int] = None, + **kwargs: object, + ) -> ProcessorMixin: + if num_crops is not None: + kwargs["num_crops"] = num_crops + + return self.ctx.get_hf_processor(**kwargs) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + target_width, target_height = self.get_image_size_with_most_features() + + max_image_tokens = self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + processor=None, + ) + + return {"image": max_image_tokens} + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + processor: Optional[ProcessorMixin], + ) -> int: + if processor is None: + processor = self.get_hf_processor() + + return processor.calc_num_image_tokens_from_image_size( # type: ignore + width=image_width, + height=image_height, + ) + + def get_image_size_with_most_features(self) -> ImageSize: + # Result in the max possible feature size (h:w = 16:1) + return ImageSize(height=8000, width=50) + + +class Phi3VDummyInputsBuilder(BaseDummyInputsBuilder[Phi3VProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + + target_width, target_height = \ + self.info.get_image_size_with_most_features() + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + hf_processor = self.info.get_hf_processor() + image_tokens: list[str] = hf_processor.img_tokens # type: ignore + + return ProcessorInputs( + prompt_text="".join(image_tokens[:num_images]), + mm_data=mm_data, + ) + + +class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + input_ids = processed_outputs["input_ids"] + assert isinstance(input_ids, torch.Tensor) + + # Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids, + # which will cause OverflowError when decoding the prompt_ids. + # Therefore, we need to do an early replacement here + input_ids.masked_fill_(input_ids < 0, _IMAGE_TOKEN_ID) + + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + image_sizes=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_tokens: list[str] = hf_processor.img_tokens # type: ignore + + def get_replacement_phi3v(item_idx: int): + images = mm_items.get_items( + "image", (ImageEmbeddingItems, ImageProcessorItems)) + + if isinstance(images, ImageEmbeddingItems): + num_image_tokens = images.get_feature_size(item_idx) + else: + image_size = images.get_image_size(item_idx) + num_image_tokens = self.info.get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + processor=hf_processor, + ) + + image_tokens = [_IMAGE_TOKEN_ID] * num_image_tokens + + return PromptUpdateDetails( + full=image_tokens, + features=image_tokens, + ) + + num_images = mm_items.get_count("image", strict=False) + + return [ + PromptReplacement( + modality="image", + target=image_token, + replacement=get_replacement_phi3v, + ) for image_token in image_tokens[:num_images] + ] + + def _apply_prompt_updates( + self, + token_ids: list[int], + mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], + mm_item_counts: Mapping[str, int], + ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: + # align to hf behavior when there are images + if len(mm_item_counts): + tokenizer = self.info.get_tokenizer() + # to decode token_ids to the original text, we need to + # 1. remove the first bos token + # 2. remove space after each special token + # introduced by the tokenizer + if len(token_ids) and token_ids[0] == tokenizer.bos_token_id: + token_ids = token_ids[1:] + text = tokenizer.decode(token_ids) + for special_tokens in tokenizer.special_tokens_map.values(): + if isinstance(special_tokens, str): + text = text.replace(f"{special_tokens} ", special_tokens) + elif isinstance(special_tokens, list): + for special_token in special_tokens: + text = text.replace(f"{special_token} ", special_token) + # perform hf behavior + # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/64f88b6/processing_phi3_v.py#L407 + pattern = r"<\|image_\d+\|>" + prompt_chunks = [ + tokenizer(chunk).input_ids + for chunk in re.split(pattern, text) + ] + image_tags = [ + tokenizer(chunk, add_special_tokens=False).input_ids + for chunk in re.findall(pattern, text) + ] + if len(prompt_chunks) > len(image_tags): + image_tags.append([]) + token_ids = [ + e for sublist in zip(prompt_chunks, image_tags) + for ele in sublist for e in ele + ] + + token_ids, text, placeholders = super()._apply_prompt_updates( + token_ids=token_ids, + mm_prompt_updates=mm_prompt_updates, + mm_item_counts=mm_item_counts, + ) + + # Keep the behavior in line with HF processor + if text.startswith(" <|image|>"): + text = text.replace(" <|image|>", "<|image|>", 1) + token_ids = [token_ids[0], *token_ids[2:]] + placeholders = { + modality: [ + PlaceholderFeaturesInfo( + modality=p.modality, + item_idx=p.item_idx, + start_idx=p.start_idx - 1, + tokens=p.tokens, + ) for p in ps + ] + for modality, ps in placeholders.items() + } + + return token_ids, text, placeholders + + +@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor, + info=Phi3VProcessingInfo, + dummy_inputs=Phi3VDummyInputsBuilder) +class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, + SupportsQuant): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.vision_embed_tokens.wte": "embed_tokens", + "model.vision_embed_tokens.": "vision_embed_tokens.", + "lm_head.": "language_model.lm_head.", + "model.": "language_model.model.", + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.multimodal_config = multimodal_config + self.image_token_id = _IMAGE_TOKEN_ID + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, "model.embed_tokens"), + ) + + # TODO: Optionally initializes this for supporting input embeddings. + self.vision_embed_tokens = Phi3HDImageEmbedding( + config, + self.quant_config, + prefix=maybe_prefix(prefix, "model.vision_embed_tokens")) + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + # The prefix is empty intentionally because default prefix of + # LlamaForCausalLM is "model" + prefix="", + # We don't directly initialize vLLM's LlamaForCausalLM so we + # can automatically apply embedding wrapper if this model is + # initialized as an embedding model + architectures=["LlamaForCausalLM"], + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return get_sampler() + + def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: + expected_dims = (2, ) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape) + + if actual_dims != expected_dims: + expected_expr = str(expected_dims) + raise ValueError( + f"The expected shape of image sizes per image per batch " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _validate_pixel_values( + self, data: Union[torch.Tensor, List[torch.Tensor]] + ) -> Union[torch.Tensor, List[torch.Tensor]]: + + h = w = CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("num_patches", *map(str, expected_dims)) + raise ValueError( + "The expected shape of pixel values per image per batch " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Phi3VImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_sizes = kwargs.pop("image_sizes", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + if not isinstance(image_sizes, (torch.Tensor, list)): + raise ValueError("Incorrect type of image sizes. " + f"Got type: {type(image_sizes)}") + + return Phi3VImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values(flatten_bn(pixel_values)), + image_sizes=self._validate_image_sizes( + flatten_bn(image_sizes, concat=True))) + + if image_embeds is not None: + if not isinstance(image_embeds, torch.Tensor): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + + return Phi3VImageEmbeddingInputs( + type="image_embeds", + data=flatten_bn(image_embeds), + ) + + raise AssertionError("This line should be unreachable.") + + def _process_image_input( + self, + image_input: Phi3VImageInputs, + ) -> torch.Tensor: + + if image_input["type"] == "image_embeds": + image_data = image_input["data"] + if is_list_of(image_data, torch.Tensor): + # it's already a list of tensors + return image_data + if len(image_data.shape) == 3: + # 3D tensor + return list(torch.unbind(image_data, dim=0)) + raise ValueError( + "We expect batched 2D tensors; " + "this can be either a list of 2D tensors or a single 3D tensor." + ) + + assert self.vision_embed_tokens is not None + image_embeds = self.vision_embed_tokens(image_input["data"], + image_input["image_sizes"]) + + return image_embeds + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.embed_tokens(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.image_token_id) + return inputs_embeds + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object): + + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.language_model.model(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + + loader = AutoWeightsLoader(self) + autoloaded_weights = loader.load_weights(weights, + mapper=self.hf_to_vllm_mapper) + + # The HF config doesn't specify whether these are tied, + # so we detect it this way + if "embed_tokens.weight" not in autoloaded_weights: + self.embed_tokens = self.language_model.model.embed_tokens + autoloaded_weights.add("embed_tokens.weight") + return autoloaded_weights diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py new file mode 100644 index 0000000..cb75ee1 --- /dev/null +++ b/vllm/model_executor/models/phi4mm.py @@ -0,0 +1,1804 @@ +# SPDX-License-Identifier: Apache-2.0 +import math +import re +from functools import lru_cache +from typing import (Dict, Iterable, List, Literal, Mapping, Optional, Tuple, + TypedDict, Union) + +import numpy as np +import scipy.signal +import torch +import torch.nn as nn +import torchvision.transforms as T +from PIL import Image +from transformers import PretrainedConfig, SiglipVisionConfig +from transformers.utils import logging + +from vllm.config import VllmConfig +from vllm.distributed import get_pp_group +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext) +from vllm.inputs.data import TokenInputs, token_inputs +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) +from vllm.model_executor.models.llama import LlamaModel +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors +from vllm.sequence import IntermediateTensors, SequenceData +from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config + +from .idefics2_vision_model import Idefics2VisionTransformer +from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsV0Only +from .phi4mm_audio import AudioEmbedding +from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix + +# <|endoftext10|> (see vocab.json in hf model) +_IMAGE_PLACEHOLDER_TOKEN_ID = 200010 +# <|endoftext11|> +_AUDIO_PLACEHOLDER_TOKEN_ID = 200011 + +_AUDIO_MAX_SOUNDFILE_SIZE = 241_000 +DUMMY_SAMPLING_FREQUENCY = 16_000 # kHz + +DYNAMIC_HD = 16 +AUDIO_TOKEN_PATTERN = r"<\|audio_(\d+)\|>" +IMAGE_TOKEN_PATTERN = r"<\|image_(\d+)\|>" + +SIGLIP_NAME = "siglip-so400m-patch14-448" +VISION_ENCODER_TO_PROCESSING_CONFIG = { + 'siglip-so400m-patch14-448': { + 'dynamic_hd': 16, + 'vit_image_size': 448, + 'vit_patch_size': 14, + 'token_compression_factor': 2, + }, +} +logger = logging.get_logger(__name__) +# This is a workaround to prevent text (user input) + audio + image +# from being used in the same prompt. +# It includes token ids for "/n" and tokens in added_tokens_decoder +# from the tokenizer_confg.json file. +NON_USER_INPUT_TOKENS = { + 198, 200010, 200011, 199999, 200018, 200019, 200020, 200021, 200022, + 200023, 200024, 200025, 200026, 200027, 200028 +} + + +def get_max_dummy_image(ctx: InputContext): + hf_config = ctx.get_hf_config() + vision_encoder_name = hf_config.img_processor + if vision_encoder_name is None: + vision_encoder_name = SIGLIP_NAME + prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name] + dynamic_hd_size = prepro_config['dynamic_hd'] + vit_image_size = prepro_config['vit_image_size'] + + max_side = vit_image_size * dynamic_hd_size + dummy_image = dummy_image_for_phi4mm(vit_image_size, max_side) + return dummy_image + + +# image token length +def get_max_phi4mm_image_tokens(ctx: InputContext): + dummy_image = get_max_dummy_image(ctx) + + hf_config = ctx.get_hf_config() + vision_encoder_name = hf_config.img_processor + if vision_encoder_name is None: + vision_encoder_name = SIGLIP_NAME + prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name] + dynamic_hd_size = prepro_config['dynamic_hd'] + vit_image_size = prepro_config['vit_image_size'] + vit_patch_size = prepro_config['vit_patch_size'] + token_compression_factor = prepro_config['token_compression_factor'] + + image_num_tokens = _compute_num_image_tokens(dummy_image, dynamic_hd_size, + vit_image_size, + vit_patch_size, + token_compression_factor) + return image_num_tokens + + +def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, + image_size): + best_ratio_diff = float('inf') + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + +def _find_target_aspect_ratio(image, image_size, max_num, min_num): + orig_width, orig_height = image.size + + w_crop_num = math.ceil(orig_width / float(image_size)) + h_crop_num = math.ceil(orig_height / float(image_size)) + if w_crop_num * h_crop_num > max_num: + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set((i, j) for i in range(1, max_num + 1) + for j in range(1, max_num + 1) + if i * j <= max_num and i * j >= min_num) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + logger.debug("target_aspect_ratio: %s", target_aspect_ratio) + else: + target_width = image_size * w_crop_num + target_height = image_size * h_crop_num + target_aspect_ratio = (w_crop_num, h_crop_num) + return target_aspect_ratio, target_height, target_width + + +def _get_padding_size(image, target_height, target_width): + orig_width, orig_height = image.size + ratio_width = target_width / orig_width + ratio_height = target_height / orig_height + + if ratio_width < ratio_height: + padding_width = 0 + padding_height = target_height - int(orig_height * ratio_width) + else: + padding_width = target_width - int(orig_width * ratio_height) + padding_height = 0 + return padding_height, padding_width + + +def dynamic_preprocess(image, + min_num=1, + max_num=12, + image_size=384, + mask_size=27): + target_aspect_ratio, target_height, target_width =\ + _find_target_aspect_ratio( + image, image_size, max_num, min_num) + padding_height, padding_width = _get_padding_size(image, target_height, + target_width) + + # Calculate the ratio + orig_width, orig_height = image.size + ratio_width = target_width / orig_width + ratio_height = target_height / orig_height + if ratio_width < ratio_height: + new_size = (target_width, int(orig_height * ratio_width)) + else: + new_size = (int(orig_width * ratio_height), target_height) + + attention_mask = torch.ones((int(mask_size * target_aspect_ratio[1]), + int(mask_size * target_aspect_ratio[0]))) + if padding_width >= 14: + attention_mask[:, -math.floor(padding_width / 14):] = 0 + if padding_height >= 14: + attention_mask[-math.floor(padding_height / 14):, :] = 0 + assert attention_mask.sum( + ) > 0, f'attention mask is empty {attention_mask}' + + if min(new_size[1], target_height) < 10 or min(new_size[0], + target_width) < 10: + raise ValueError(f'the aspect ratio is very extreme {new_size}') + + image = T.functional.resize( + image, + [new_size[1], new_size[0]], + ) + + resized_img = T.functional.pad(image, + [0, 0, padding_width, padding_height], + fill=[255, 255, 255]) + + return resized_img, attention_mask + + +def pad_to_max_num_crops(images, max_crops=5): + """ + images: B x 3 x H x W, B<=max_crops + """ + B, _, H, W = images.shape + if max_crops > B: + pad = torch.zeros(max_crops - B, + 3, + H, + W, + dtype=images.dtype, + device=images.device) + images = torch.cat([images, pad], dim=0) + return images + + +def pad_mask_to_max_num_crops(masks, max_crops=5): + B, H, W = masks.shape + if max_crops > B: + pad = torch.ones(max_crops - B, + H, + W, + dtype=masks.dtype, + device=masks.device) + masks = torch.cat([masks, pad], dim=0) + return masks + + +def preprocess(images, dynamic_hd_size, vit_resolution, vit_patch_size): + + # Basic settings. + img_processor = T.Compose([ + T.ToTensor(), + T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ]) + # Dynamic HD + base_resolution = vit_resolution + images = [image.convert('RGB') for image in images] + # cover 384 and 448 resolution + mask_resolution = base_resolution // vit_patch_size + elems, image_attention_masks = [], [] + for im in images: + elem, attention_mask = dynamic_preprocess(im, + max_num=dynamic_hd_size, + image_size=base_resolution, + mask_size=mask_resolution) + elems.append(elem) + image_attention_masks.append(attention_mask) + hd_images = [img_processor(im) for im in elems] + global_image = [ + torch.nn.functional.interpolate( + im.unsqueeze(0).float(), + size=(base_resolution, base_resolution), + mode='bicubic', + ).to(im.dtype) for im in hd_images + ] + shapes = [[im.size(1), im.size(2)] for im in hd_images] + mask_shapes = [[mask.size(0), mask.size(1)] + for mask in image_attention_masks] + global_attention_mask = [ + torch.ones((1, mask_resolution, mask_resolution)) for _ in hd_images + ] + hd_images_reshape = [ + im.reshape(1, 3, h // base_resolution, base_resolution, + w // base_resolution, base_resolution).permute( + 0, 2, 4, 1, 3, 5).reshape(-1, 3, base_resolution, + base_resolution).contiguous() + for im, (h, w) in zip(hd_images, shapes) + ] + attention_masks_reshape = [ + mask.reshape(1, h // mask_resolution, mask_resolution, + w // mask_resolution, mask_resolution).permute( + 0, 1, 3, 2, 4).reshape(-1, mask_resolution, + mask_resolution).contiguous() + for mask, (h, w) in zip(image_attention_masks, mask_shapes) + ] + # NOTE token compression is hard coded here, and odd numbers seems to fail + downsample_attention_masks = [ + mask[:, 0::2, + 0::2].reshape(1, h // mask_resolution, w // mask_resolution, + mask_resolution // 2 + mask_resolution % 2, + mask_resolution // 2 + mask_resolution % 2).permute( + 0, 1, 3, 2, 4) + for mask, (h, w) in zip(attention_masks_reshape, mask_shapes) + ] + downsample_attention_masks = [ + mask.reshape(mask.size(1) * mask.size(2), + mask.size(3) * mask.size(4)) + for mask in downsample_attention_masks + ] + # NOTE hard coded number of tokens + num_img_tokens = [ + 256 + 1 + int(mask.sum().item()) + int(mask[:, 0].sum().item()) + 16 + for mask in downsample_attention_masks + ] + + hd_images_reshape = [ + torch.cat([_global_image] + [_im], dim=0) + for _global_image, _im in zip(global_image, hd_images_reshape) + ] + hd_masks_reshape = [ + torch.cat([_global_mask] + [_mask], + dim=0) for _global_mask, _mask in zip( + global_attention_mask, attention_masks_reshape) + ] + max_crops = max([img.size(0) for img in hd_images_reshape]) + image_transformed = [ + pad_to_max_num_crops(im, max_crops) for im in hd_images_reshape + ] + image_transformed = torch.stack(image_transformed, dim=0) + mask_transformed = [ + pad_mask_to_max_num_crops(mask, max_crops) \ + for mask in hd_masks_reshape + ] + mask_transformed = torch.stack(mask_transformed, dim=0) + + returned_input_image_embeds = image_transformed + returned_image_sizes = torch.tensor(shapes, dtype=torch.long) + returned_image_attention_mask = mask_transformed + returned_num_img_tokens = num_img_tokens + + data = { + "pixel_values": returned_input_image_embeds, + "image_sizes": returned_image_sizes, + "image_attention_mask": returned_image_attention_mask, + "num_img_tokens": returned_num_img_tokens, + } + return data + + +def get_navit_vision_model(layer_idx: int = -1, **kwargs): + vision_config = { + "hidden_size": 1152, + "image_size": 448, + "intermediate_size": 4304, + "model_type": "siglip_vision_model", + "num_attention_heads": 16, + "num_hidden_layers": 27, + "patch_size": 14, + } + + model_config = SiglipVisionConfig(**vision_config, **kwargs) + if layer_idx < 0: + num_hidden_layers = model_config.num_hidden_layers \ + + layer_idx + 1 + else: + num_hidden_layers = layer_idx + 1 + + vision_model = Idefics2VisionTransformer( + config=model_config, + require_post_norm=False, + num_hidden_layers_override=num_hidden_layers, + ) + + return vision_model + + +class Phi4MMImageEncoder(nn.Module): + """Image embedding.""" + + def __init__(self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig], + prefix: str = "", + model_dir: str = "") -> None: + super().__init__() + + # n_embed or hidden_size + hidden_size = config.n_embd if hasattr( + config, 'n_embd') else config.hidden_size + + # layer_idx to output the img features + if isinstance(config.img_processor, dict): + self.layer_idx = config.img_processor.get('layer_idx', -2) + self.type_feature = config.img_processor.get( + 'type_feature', 'patch') + else: + self.layer_idx = -2 + self.type_feature = 'patch' + + self.img_processor = get_navit_vision_model(layer_idx=self.layer_idx) + + pe_weight = self.img_processor.embeddings.position_embedding.weight + L, D = pe_weight.size() + H = int(math.sqrt(L)) + assert H**2 == L, f'position embedding size {L} is not square' + if H % 2 != 0: + self.img_processor_padding = nn.ReflectionPad2d((0, 1, 0, 1)) + H += 1 + image_dim_out = D + # ((448/14)//2)**2 + self.num_img_tokens = (H // 2)**2 + self.base_feat_height_target = H + + self.image_dim_out = image_dim_out + self.img_sizes = None + self.image_attention_mask = None + + # global_gn and sub_gn for hd transform, serves as line separator + self.use_hd_transform = True + self.with_learnable_separator = True + self.hd_transform_order = "sub_glb" + self.freeze_img_processor = False + self.crop_size = 448 + + # image token compression + self.image_token_compression_cls = 'avg_pool_2d' + self.image_token_compression = nn.AvgPool2d(kernel_size=2, stride=2) + self.base_feat_height_reduction = 1 + self.base_feat_height_target = self.base_feat_height_target // 2 + + # with_hd_transform and with_learnable_separator should have same value + assert self.use_hd_transform == self.with_learnable_separator, \ + 'use_hd_transform and with_learnable_separator should have same value' + assert self.use_hd_transform, \ + 'learnable separator is only for hd transform' + # 1024 * 4, merge spatial to channel dimension + self.glb_GN = nn.Parameter( + torch.zeros([ + 1, 1, self.image_dim_out * self.base_feat_height_reduction**2 + ])) + self.sub_GN = nn.Parameter( + torch.zeros([ + 1, 1, 1, + self.image_dim_out * self.base_feat_height_reduction**2 + ])) + + dim_projection = hidden_size + depth = 2 + layers = [ + nn.Linear(image_dim_out * self.base_feat_height_reduction**2, + dim_projection) + ] + for _ in range(1, depth): + layers.extend( + [nn.GELU(), + nn.Linear(dim_projection, dim_projection)]) + self.img_projection = nn.Sequential(*layers) + + self.vocab_size = config.vocab_size + self.img_features = None + + self.use_out_place_operations = False + + def get_img_features(self, + img_embeds: torch.FloatTensor, + attention_mask=None) -> torch.FloatTensor: + + img_feature = self.img_processor(img_embeds, + patch_attention_mask=attention_mask) + + if self.type_feature == "patch": + patch_feature = img_feature + + use_token_compression = self.image_token_compression is not None + use_padding = getattr(self, 'img_processor_padding', + None) is not None + if use_token_compression or use_padding: + # reshape to 2D tensor + width = int(math.sqrt(patch_feature.size(1))) + patch_feature = patch_feature.view(-1, width, width, + patch_feature.size(-1)) + # convert to NCHW + patch_feature = patch_feature.permute(0, 3, 1, 2) + + if use_padding: + patch_feature = self.img_processor_padding(patch_feature) + if use_token_compression: + patch_feature = self.image_token_compression(patch_feature) + + # convert to NHWC + patch_feature = patch_feature.permute(0, 2, 3, 1) + patch_feature = patch_feature.view( + -1, + patch_feature.size(1) * patch_feature.size(2), + patch_feature.size(-1)) + + return patch_feature + + raise NotImplementedError + + def forward(self, pixel_values: torch.FloatTensor, + image_sizes: torch.Tensor, + image_attention_mask: torch.Tensor) -> torch.FloatTensor: + """ + process image and return vision embeddings. + + pixel_values: (num_images, num_crops, c, h, w) + image_sizes: [[h1, w1], [h2, w2]] + image_attention_mask: num_images x num_crops x 32 x 32 + output: (num_images, num_img_tokens, hidden_size) + """ + + # eg + # pixel_values: torch.Size([1, 7, 3, 448, 448]) + # image_sizes: tensor([[ 896, 1344]], device='cuda:0') + # output: torch.Size([1, 1841, 3072]) + + if isinstance(self.img_projection, nn.Sequential): + target_device = self.img_projection[0].bias.device + target_dtype = self.img_projection[0].bias.dtype + else: # It's a single nn.Linear layer + target_device = self.img_projection.bias.device + target_dtype = self.img_projection.bias.dtype + + img_sizes = image_sizes + num_images, num_crops, c, h, w = pixel_values.shape + bs = num_images + pixel_values = pixel_values.flatten(0, 1) + + img_features = self.get_img_features( + pixel_values, + image_attention_mask.type(torch.BoolTensor).flatten( + 0, 1).to(target_device)) + + base_feat_height_target = self.base_feat_height_target + base_resolution = self.crop_size + base_feat_height_reduction = self.base_feat_height_reduction + + base_feat_height = base_feat_width = int(np.sqrt( + img_features.shape[1])) + assert base_feat_height == base_feat_height_target \ + and base_feat_width == base_feat_height_target, \ + f'base_feat_height: {base_feat_height},"\ + f" base_feat_width: {base_feat_width}, "\ + f"expect {base_feat_height_target} features for hd transform' + + # bs x max_num_crops x (24x24) x C + img_features = img_features.view(bs, -1, + base_feat_height * base_feat_width, + self.image_dim_out) + C = self.image_dim_out + H = base_feat_height + + output_imgs = [] + output_len = [] + # training is tensor, inference is list + if isinstance(img_sizes, torch.Tensor): + img_sizes = img_sizes.view(-1, 2) + for _bs in range(bs): + h, w = img_sizes[_bs] + h = h // base_resolution + w = w // base_resolution + B_ = h * w + + # 1 x (24x24) x 1024 + global_img_feature = img_features[_bs, :1] + + # 1 x 12 x 12 x 4096 + glb_img = global_img_feature.reshape(1, H, H, C).reshape( + 1, H // base_feat_height_reduction, base_feat_height_reduction, + H // base_feat_height_reduction, base_feat_height_reduction, + C).contiguous().permute(0, 1, 3, 2, 4, 5).reshape( + 1, H // base_feat_height_reduction, + H // base_feat_height_reduction, + base_feat_height_reduction * base_feat_height_reduction * + C).contiguous() + temp_glb_GN = self.sub_GN.repeat(1, + H // base_feat_height_reduction, + 1, 1) + + # 1 x 156 x 4096 + glb_img = torch.cat([glb_img, temp_glb_GN], dim=2).reshape( + 1, -1, + base_feat_height_reduction * base_feat_height_reduction * C) + + # (max_num_crops-1) x (12x12) x C + sub_img = img_features[_bs, 1:] + # 16x574x1024 + # get rid of padding sub_img + sub_img = sub_img[:B_] + + # (num_crops, 12, 2, 12, 2, 1024) -> + # (num_crops, 12, 12, 2, 2, 1024) -> (num_crops, 12*12, 4*1024) + sub_img = sub_img.reshape(B_, H, H, C).reshape( + B_, H // base_feat_height_reduction, + base_feat_height_reduction, H // base_feat_height_reduction, + base_feat_height_reduction, + C).contiguous().permute(0, 1, 3, 2, 4, 5).reshape( + B_, -1, base_feat_height_reduction * + base_feat_height_reduction * C).contiguous() + sub_img = sub_img.reshape( + 1, h, w, base_feat_height // base_feat_height_reduction, + base_feat_width // base_feat_height_reduction, + -1).permute(0, 1, 3, 2, 4, 5).reshape( + 1, h * base_feat_height // base_feat_height_reduction, + w * base_feat_width // base_feat_height_reduction, + base_feat_height_reduction * base_feat_height_reduction * + C) + + if image_attention_mask is not None and len( + image_attention_mask) > 0: + reshaped_image_attention_mask = image_attention_mask[ + _bs, 1:B_ + 1, 0::2, 0::2].reshape( + 1, h, w, + base_feat_height // base_feat_height_reduction, + base_feat_width // base_feat_height_reduction).permute( + 0, 1, 3, 2, 4).reshape( + 1, h * base_feat_height // + base_feat_height_reduction, w * + base_feat_width // base_feat_height_reduction) + useful_height = int( + reshaped_image_attention_mask[0, :, 0].sum().item()) + useful_width = int( + reshaped_image_attention_mask[0, 0, :].sum().item()) + sub_img = sub_img[:, :useful_height, :useful_width] + temp_sub_GN = self.sub_GN.repeat(1, useful_height, 1, 1) + temp_len = int( + image_attention_mask[_bs, :B_ + 1, 0::2, 0::2].sum().item( + )) + (useful_height + + 1) + base_feat_height // base_feat_height_reduction + else: + temp_sub_GN = self.sub_GN.repeat( + 1, h * base_feat_height // base_feat_height_reduction, 1, + 1) + temp_len = int((h * w + 1) * self.num_img_tokens + 1 + + (h + 1) * base_feat_height // + base_feat_height_reduction) + + sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape( + 1, -1, + base_feat_height_reduction * base_feat_height_reduction * C) + # (1, num_img_tokens, 1024*4) + + # glb + sub + if self.hd_transform_order == 'glb_sub': + output_imgs.append( + torch.cat([glb_img, self.glb_GN, sub_img], dim=1)) + elif self.hd_transform_order == 'sub_glb': + output_imgs.append( + torch.cat([sub_img, self.glb_GN, glb_img], dim=1)) + else: + raise NotImplementedError( + f'hd_transform_order = {self.hd_transform_order}, "\ + "not implemented') + + #temp_len = int((h*w+1)*144 + 1 + (h+1)*12) + assert temp_len == output_imgs[-1].shape[ + 1], f'temp_len: {temp_len}, output_imgs[-1].shape[1]: "\ + "{output_imgs[-1].shape[1]}' + + output_len.append(temp_len) + + img_set_tensor = [] + for _output_img in output_imgs: + img_feature_proj = self.img_projection( + _output_img.to(target_device).to(target_dtype)) + img_set_tensor.append(img_feature_proj) + + return img_set_tensor + + +class Phi4MMAudioFeatureInputs(TypedDict): + type: Literal["audio_features"] + data: Tuple[NestedTensors] + """Shape: `((batch_size, num_audios, 80, M), )""" + + +class Phi4MMAudioEmbeddingInputs(TypedDict): + type: Literal["audio_embeds"] + data: NestedTensors + """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)""" + + +Phi4MMAudioInputs = Union[Phi4MMAudioFeatureInputs, Phi4MMAudioEmbeddingInputs] + + +def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None): + """Create a Mel filter-bank the same as SpeechLib FbankFC. + + Args: + sample_rate (int): Sample rate in Hz. number > 0 [scalar] + n_fft (int): FFT size. int > 0 [scalar] + n_mel (int): Mel filter size. int > 0 [scalar] + fmin (float): lowest frequency (in Hz). If None use 0.0. + float >= 0 [scalar] + fmax: highest frequency (in Hz). If None use sample_rate / 2. + float >= 0 [scalar] + + Returns + out (numpy.ndarray): Mel transform matrix + [shape=(n_mels, 1 + n_fft/2)] + """ + + bank_width = int(n_fft // 2 + 1) + if fmax is None: + fmax = sample_rate / 2 + if fmin is None: + fmin = 0 + assert fmin >= 0, "fmin cannot be negative" + assert (fmin < fmax <= + sample_rate / 2), "fmax must be between (fmin, samplerate / 2]" + + def mel(f): + return 1127.0 * np.log(1.0 + f / 700.0) + + def bin2mel(fft_bin): + return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0)) + + def f2bin(f): + return int((f * n_fft / sample_rate) + 0.5) + + # Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1] + klo = f2bin(fmin) + 1 + khi = f2bin(fmax) + + khi = max(khi, klo) + + # Spec 2: SpeechLib uses triangles in Mel space + mlo = mel(fmin) + mhi = mel(fmax) + m_centers = np.linspace(mlo, mhi, n_mels + 2) + ms = (mhi - mlo) / (n_mels + 1) + + matrix = np.zeros((n_mels, bank_width), dtype=np.float32) + for m in range(0, n_mels): + left = m_centers[m] + center = m_centers[m + 1] + right = m_centers[m + 2] + for fft_bin in range(klo, khi): + mbin = bin2mel(fft_bin) + if left < mbin < right: + matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms + + return matrix + + +class LogFbankProcessor: + + def __init__(self): + + self._eightk_method = "fillzero" + self._mel = speechlib_mel(16000, 512, 80, fmin=None, fmax=7690).T + + self._hamming400 = np.hamming(400) # for 16k audio + self._hamming200 = np.hamming(200) # for 8k audio + + def extract_spectrogram(self, wav, fs): + """Extract spectrogram features from waveform. + Args: + wav (1D array): waveform of the input + fs (int): sampling rate of the waveform, 16000 or 8000. + If fs=8000, the waveform will be resampled to 16000Hz. + Output: + log_fbank (2D array): a TxD matrix of log Mel filterbank features. + D=80, and T is the number of frames. + """ + if wav.ndim > 1: + wav = np.squeeze(wav) + + # by default, we extract the mean if stereo + if len(wav.shape) == 2: + wav = wav.mean(1) + + # Resample to 16000 or 8000 if needed + if fs > 16000: + wav = scipy.signal.resample_poly(wav, 1, fs // 16000) + fs = 16000 + elif 8000 < fs < 16000: + wav = scipy.signal.resample_poly(wav, 1, fs // 8000) + fs = 8000 + elif fs < 8000: + raise RuntimeError(f"Unsupported sample rate {fs}") + + if fs == 8000: + if self._eightk_method == "resample": + # Input audio is 8 kHz. Convert to 16 kHz before feature + # extraction + wav = scipy.signal.resample_poly(wav, 2, 1) + fs = 16000 + # Do nothing here for fillzero method + elif fs != 16000: + # Input audio is not a supported sample rate. + raise RuntimeError( + f"Input data using an unsupported sample rate: {fs}") + + preemphasis = 0.97 + + if fs == 8000: + n_fft = 256 + win_length = 200 + hop_length = 80 + fft_window = self._hamming200 + elif fs == 16000: + n_fft = 512 + win_length = 400 + hop_length = 160 + fft_window = self._hamming400 + + # Spec 1: SpeechLib cut remaining sample insufficient for a hop + n_batch = (wav.shape[0] - win_length) // hop_length + 1 + # Here we don't use stride_tricks since the input array may not satisfy + # memory layout requirement and we need writeable output + # Here we only use list of views before copy to destination + # so it is more efficient than broadcasting + y_frames = np.array( + [ + wav[_stride:_stride + win_length] + for _stride in range(0, hop_length * n_batch, hop_length) + ], + dtype=np.float32, + ) + + # Spec 2: SpeechLib applies preemphasis within each batch + y_frames_prev = np.roll(y_frames, 1, axis=1) + y_frames_prev[:, 0] = y_frames_prev[:, 1] + y_frames = (y_frames - preemphasis * y_frames_prev) * 32768 + + S = np.fft.rfft(fft_window * y_frames, n=n_fft, + axis=1).astype(np.complex64) + + if fs == 8000: + # Need to pad the output to look like 16 kHz data but with zeros in + # the 4 to 8 kHz bins. + frames, bins = S.shape + padarray = np.zeros((frames, bins)) + S = np.concatenate((S[:, 0:-1], padarray), + axis=1) # Nyquist bin gets set to zero + + spec = np.abs(S).astype(np.float32) + return spec + + def extract_features(self, wav, fs): + """Extract log filterbank features from waveform. + Args: + wav (1D array): waveform of the input + fs (int): sampling rate of the waveform, 16000 or 8000. + If fs=8000, the waveform will be resampled to 16000Hz. + Output: + log_fbank (2D array): a TxD matrix of log Mel filterbank features. + D=80, and T is the number of frames. + """ + spec = self.extract_spectrogram(wav, fs) + spec_power = spec**2 + + fbank_power = np.clip(spec_power.dot(self._mel), 1.0, None) + log_fbank = np.log(fbank_power).astype(np.float32) + + return log_fbank + + +@lru_cache +def audio_feature_extractor() -> LogFbankProcessor: + # Creates an instance of the audio processor, needed to extract the + # the audio features from the sound file + # LRU cache ensures that we only make one copy + return LogFbankProcessor() + + +def _compute_num_image_tokens(image, dynamic_hd_size, vit_image_size, + vit_patch_size, token_compression_factor): + """ + compute the number of tokens an image is expected to take up considering + the image encoder architecture and exclude output features containing + only padding pixels + + for siglip, vit_image_size=448, vit_patch_size=14, so output will be + 32x32 feature map + NOTE right now, Phi4MM uses hard-coded token_compression_factor=2 + """ + assert vit_image_size % vit_patch_size == 0, \ + "vit_image_size must be divisible by vit_patch_size" + assert vit_image_size // vit_patch_size % token_compression_factor == 0, \ + "vit_image_size // vit_patch_size must be divisible by "\ + "token_compression_factor" + + target_aspect_ratio, target_height, target_width = ( + _find_target_aspect_ratio(image, + vit_image_size, + dynamic_hd_size, + min_num=1)) + assert target_aspect_ratio[ + 0] * vit_image_size == target_width, \ + f"{target_aspect_ratio[0]} * {vit_image_size} != {target_width}" + assert target_aspect_ratio[ + 1] * vit_image_size == target_height, \ + f"{target_aspect_ratio[1]} * {vit_image_size} != {target_height}" + assert (target_height % vit_image_size == 0 + and target_width % vit_image_size == 0) + + padding_height, padding_width = _get_padding_size(image, target_height, + target_width) + assert padding_width == 0 or padding_height == 0, \ + "padding_width or padding_height must be 0" + + target_feat_width = target_width // vit_patch_size + target_feat_height = target_height // vit_patch_size + if padding_width >= vit_patch_size: + assert padding_height == 0, "padding_height not 0" + non_pad_feat_width = target_feat_width - math.floor( + padding_width / vit_patch_size) + non_pad_feat_height = target_feat_height + elif padding_height >= vit_patch_size: + assert padding_width == 0, "padding_width not 0" + non_pad_feat_height = target_feat_height - math.floor( + padding_height / vit_patch_size) + non_pad_feat_width = target_feat_width + else: + # small padding shorter than a vit patch + non_pad_feat_width = target_feat_width + non_pad_feat_height = target_feat_height + + feat_width = non_pad_feat_width // token_compression_factor + feat_height = non_pad_feat_height // token_compression_factor + # NOTE it's possible that the non-padding feature is not divisible + if non_pad_feat_width % token_compression_factor != 0: + feat_width += 1 + if non_pad_feat_height % token_compression_factor != 0: + feat_height += 1 + num_hd_patch_tokens = feat_width * feat_height + num_hd_newline_tokens = feat_height + vit_feature_size = vit_image_size // vit_patch_size + num_global_image_tokens = (vit_feature_size // token_compression_factor)**2 + num_sep_tokens = 1 + num_global_image_newline_tokens = \ + vit_feature_size // token_compression_factor + + return (num_global_image_tokens + num_sep_tokens + num_hd_patch_tokens + + num_hd_newline_tokens + num_global_image_newline_tokens) + + +def compute_logfbank_output_size(wav_length: int, fs: int) -> Tuple[int, int]: + """ + Compute the output size of the `extract_features` method. + + Args: + wav_length (int): Length of the input waveform in samples. + fs (int): Sampling rate of the waveform, either 16000 or 8000. + + Returns: + tuple (int, int): Output size as (T, D), where: + T: Number of time frames. + D: Number of Mel filterbank bins (80). + """ + + # Resample to 16000 or 8000 if needed + if fs > 16000: + wav_length //= fs // 16000 + fs = 16000 + elif 8000 <= fs < 16000: + # We'll resample to 16K from 8K + wav_length *= 2 + fs = 16000 + elif fs < 8000: + raise RuntimeError(f"Unsupported sample rate {fs}") + + # Spectrogram parameters for 16 kHz + win_length = 400 # Frame length in samples + hop_length = 160 # Frame shift in samples + mel_bins = 80 # Number of mel filterbank bins + + # Calculate number of frames (T) + T = (wav_length - win_length) // hop_length + 1 + if T < 1: + raise ValueError("Waveform too short for given parameters.") + + # Return time frames (T) and mel bins (D) + return T, mel_bins + + +def _get_audio_embed_sizes(audios, ctx: InputContext): + """ + Get the audio embedding sizes for each audio file. + + Args: + audios (List[Tuple[np.ndarray, int]]): List of audio files as tuples of + waveform and sample rate. + ctx (InputContext): Input context. + + Returns: + List[int]: List of audio embedding sizes. + """ + audio_embed_sizes = [] + for audio in audios: + audio_data, sf = audio + audio_frames, _ = compute_logfbank_output_size(len(audio_data), sf) + audio_embed_size = _compute_audio_embed_size(ctx.get_hf_config(), + audio_frames) + audio_embed_sizes.append(audio_embed_size) + return audio_embed_sizes + + +def _get_audio_id_to_input_ids(audios, ctx: InputContext, prompt_str=""): + """ + The following will search for `<|audio_{idx}|>` tokens and + return a mapping of audio placeholder tokens to audio placeholder token ids + based on the size of the audio embeddings. + + Args: + audios (List[Tuple[np.ndarray, int]]): List of audio files as tuples of + waveform and sample rate. + ctx (InputContext): Input context. + prompt_str (str): The prompt string. + + Returns: + Dict[str, List[int]]: Mapping of audio placeholder tokens to audio + placeholder token ids. + + """ + if len(audios) == 0: + return {} + + audio_embed_sizes = _get_audio_embed_sizes(audios, ctx) + audio_ids = re.findall(AUDIO_TOKEN_PATTERN, prompt_str) + audio_ids = [int(audio_id) for audio_id in audio_ids] + assert len(audio_ids) == len( + audio_embed_sizes + ), "Number of audio tokens and audio features do not match" + assert tuple(audio_ids) == tuple(range(1, + len(audio_ids) + + 1)), "Audio ids are not in order!" + audio_id_to_input_ids = { + f"<|audio_{audio_id}|>": + [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size + for audio_id, audio_embed_size in zip(audio_ids, audio_embed_sizes) + } + + return audio_id_to_input_ids + + +def _count_image_tokens(images, ctx: InputContext): + hf_config = ctx.get_hf_config() + vision_encoder_name = hf_config.img_processor + if vision_encoder_name is None: + vision_encoder_name = SIGLIP_NAME + prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name] + dynamic_hd_size = prepro_config['dynamic_hd'] + vit_image_size = prepro_config['vit_image_size'] + vit_patch_size = prepro_config['vit_patch_size'] + token_compression_factor = prepro_config['token_compression_factor'] + + image_token_counts = [ + _compute_num_image_tokens(image, dynamic_hd_size, vit_image_size, + vit_patch_size, token_compression_factor) + for image in images + ] + return image_token_counts + + +def _get_image_id_to_input_ids(images, prompt, ctx: InputContext): + if len(images) == 0: + return {} + + image_ids = re.findall(IMAGE_TOKEN_PATTERN, prompt) + image_ids = [int(image_id) for image_id in image_ids] + assert len(image_ids) == len( + set(image_ids)), "Duplicate image tokens in prompt" + assert len(images) == len( + image_ids), "Number of images and image tokens in prompt do not match" + + # NOTE the following assertion is not strictly necessary + assert tuple(image_ids) == tuple(range(1, + len(image_ids) + + 1)), "Image ids are not in order" + + image_token_counts = _count_image_tokens(images, ctx) + image_id_to_input_ids = { + f"<|image_{image_id}|>": [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_tokens + for image_id, num_tokens in zip(image_ids, image_token_counts) + } + return image_id_to_input_ids + + +def input_processor_for_phi4mm(ctx: InputContext, + inputs: DecoderOnlyInputs) -> TokenInputs: + """ + Implements the input processor, which transforms the input prompt ids + to include the audio placeholder token. This will become the `input_ids` + in `forward` for the model. + + Args: + ctx (InputContext): Input context. + inputs (DecoderOnlyInputs): The inputs (e.g. prompt, prompt_token_ids) + to process. + + Returns: + TokenInputs: Processed inputs + """ + multi_modal_data = inputs.get("multi_modal_data") + if (multi_modal_data is None or + ("audio" not in multi_modal_data and "image" not in multi_modal_data)): + # pure text input, so no need to do pre-processing + return inputs + + prompt_str = inputs.get("prompt") + prompt_token_ids = inputs.get("prompt_token_ids") + # for offline_inference, we will get str input and we parse MM special + # tokens from it + # (ignore prompt_token_ids) + # for OAI server, we will get prompt_token_ids, where MM special tokens + # are already parsed + + if 'audio' in multi_modal_data: + audios = multi_modal_data["audio"] + + if not isinstance(audios, list): + audios = [audios] + if prompt_str is not None: + audio_id_to_input_ids = _get_audio_id_to_input_ids( + audios, ctx, prompt_str=prompt_str) + audio_embed_sizes = [] + elif prompt_token_ids is not None: + audio_id_to_input_ids = {} + audio_embed_sizes = _get_audio_embed_sizes(audios, ctx) + else: + audio_id_to_input_ids = {} + audio_embed_sizes = [] + + if 'image' in multi_modal_data: + # PIL Image or list of PIL Images + images = multi_modal_data["image"] + if not isinstance(images, list): + images = [images] + if prompt_str is not None: + image_id_to_input_ids = _get_image_id_to_input_ids( + images, prompt_str, ctx) + image_token_counts = [] + elif prompt_token_ids is not None: + image_id_to_input_ids = {} + image_token_counts = _count_image_tokens(images, ctx) + else: + image_id_to_input_ids = {} + image_token_counts = [] + + # Handle the case where the prompt is a string and we need to manually + # tokenize it. + # In this case, the `audio_id_to_input_ids` dict will be mapping from + # an audio placeholder + # string (e.g. `<|audio_1|>`) to the audio placeholder tokens for the + # given audio length. + if prompt_str: + pattern = r"(<\|image_\d+\|>|<\|audio_\d+\|>)" + prompt_chunk_strings = re.split(pattern, prompt_str) + prompt_chunk_strings = [s for s in prompt_chunk_strings if s != ""] + + # Create the new input_ids with the placeholder image and audio + # tokens inserted + tokenizer = cached_tokenizer_from_config(ctx.model_config) + input_ids = [] + has_imag, has_audio, has_user_text_input = False, False, False + for prompt_chunk_string in prompt_chunk_strings: + if re.match(IMAGE_TOKEN_PATTERN, prompt_chunk_string): + input_ids.extend(image_id_to_input_ids[prompt_chunk_string]) + has_imag = True + elif re.match(AUDIO_TOKEN_PATTERN, prompt_chunk_string): + input_ids.extend(audio_id_to_input_ids[prompt_chunk_string]) + has_audio = True + else: + curr_token_ids = tokenizer(prompt_chunk_string).input_ids + if not has_user_text_input: + for token_id in curr_token_ids: + if token_id not in NON_USER_INPUT_TOKENS: + has_user_text_input = True + break + input_ids.extend(curr_token_ids) + if has_audio and has_imag and has_user_text_input: + raise ValueError( + "Phi4MMForCausalLM does not support text + audio + image" + + " inputs in the same prompt") + # Handle the case where the prompt is already tokenized + else: + assert prompt_token_ids is not None, \ + "If string prompt isn't provided, prompt_token_ids must be" + + i = 0 + input_ids = prompt_token_ids + # only needed for later assertion + img_cnt, audio_cnt, user_text_input_cnt = 0, 0, 0 + image_token_count_iter = iter(image_token_counts) + audio_embed_size_iter = iter(audio_embed_sizes) + while i < len(input_ids): + token_id = input_ids[i] + if token_id == _AUDIO_PLACEHOLDER_TOKEN_ID: + token_count = next(audio_embed_size_iter) + audio_cnt += 1 + elif token_id == _IMAGE_PLACEHOLDER_TOKEN_ID: + token_count = next(image_token_count_iter) + img_cnt += 1 + else: + user_text_input_cnt += 1 if token_id not in \ + NON_USER_INPUT_TOKENS else 0 + i += 1 + continue + tokens = [token_id] * token_count + input_ids = input_ids[:i] + tokens + input_ids[i + 1:] + i += token_count + + if audio_cnt > 0 and img_cnt > 0 and user_text_input_cnt > 0: + raise ValueError( + "Phi4MMForCausalLM does not support text + audio + image" + + " inputs in the same prompt") + # If the below assertion fails, it might be that input pure-text + # messages contain image/audio special tokens literally + # (<|endoftext10|>, <|endoftext11|>). + assert (img_cnt == len(image_token_counts)), ( + f"Number of image tokens in prompt_token_ids ({img_cnt}) " + f"does not match number of images ({len(image_token_counts)})") + assert (audio_cnt == len(audio_embed_sizes)), ( + f"Number of audio tokens in prompt_token_ids ({audio_cnt}) " + f"does not match number of audios ({len(audio_embed_sizes)})") + + # NOTE: Create a defensive copy of the original inputs + return token_inputs( + prompt_token_ids=input_ids, + prompt=prompt_str, + multi_modal_data=multi_modal_data, + ) + + +def _compute_audio_embed_size(hf_config, audio_frames): + """ + Compute the audio embedding size based on the audio frames and + compression rate. + """ + compression_rate = hf_config.embd_layer['audio_embd_layer'][ + 'compression_rate'] + # NOTE: this is a hard-coded value but might be configurable in the future + qformer_compression_rate = 1 + integer = audio_frames // compression_rate + remainder = audio_frames % compression_rate + + result = integer if remainder == 0 else integer + 1 + + integer = result // qformer_compression_rate + remainder = result % qformer_compression_rate + result = integer if remainder == 0 else integer + 1 # qformer compression + + return result + + +def get_max_phi4mm_audio_tokens(ctx: InputContext) -> int: + return 10000 + + +def dummy_audio_for_phi4mm(audio_count: int) -> dict: + """ + Create dummy audio data for the Phi4MM model, which is used for profiling. + + Args: + audio_count (int): Number of audio samples. + + Returns: + dict: Dummy audio data. + """ + dummy_audio = np.full((_AUDIO_MAX_SOUNDFILE_SIZE, ), 0.0) + return [(dummy_audio, DUMMY_SAMPLING_FREQUENCY)] * audio_count + + +def dummy_image_for_phi4mm(width: int, height: int): + image = Image.new('RGB', (width, height), color='black') + return image + + +def dummy_data_for_phi4mm(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]) -> DummyData: + """ + Create dummy sequence (input_ids) and audio data for the Phi4MM model, + which is used for profiling. + + In this case, the sequence data is a bunch of 0s with a number of audio + tokens that correspond to the audio embed size of the + _AUDIO_MAX_SOUNDFILE_SIZE. + + Args: + ctx (InputContext): Input context. + seq_len (int): Length of the sequence. + mm_counts (Mapping[str, int]): Multi-modal counts. + + Returns: + Tuple: Dummy sequence data and dummy audio data. + """ + audio_count = mm_counts["audio"] + audio_frames, _ = compute_logfbank_output_size(_AUDIO_MAX_SOUNDFILE_SIZE, + DUMMY_SAMPLING_FREQUENCY) + audio_feature_size = _compute_audio_embed_size(ctx.get_hf_config(), + audio_frames) + + image_count = mm_counts["image"] + dummy_image = get_max_dummy_image(ctx) + max_image_tokens = get_max_phi4mm_image_tokens(ctx) + total_image_tokens = image_count * max_image_tokens + + if seq_len - audio_feature_size * audio_count - total_image_tokens < 0: + raise RuntimeError( + f"Phi4MM cannot process {audio_count} audios and {image_count}" + f"images in a prompt, please increase max_model_len to be at" + f" larger than " + f"{audio_feature_size * audio_count + total_image_tokens}" + " or reduce audio/image limit by --limit-mm-per-prompt.") + + if audio_feature_size * audio_count > total_image_tokens: + seq_data = SequenceData.from_prompt_token_counts( + (_AUDIO_PLACEHOLDER_TOKEN_ID, audio_feature_size * audio_count), + (0, seq_len - audio_feature_size * audio_count), + ) + mm_data = { + "audio": dummy_audio_for_phi4mm(audio_count), + } + else: + seq_data = SequenceData.from_prompt_token_counts( + (_IMAGE_PLACEHOLDER_TOKEN_ID, total_image_tokens), + (0, seq_len - total_image_tokens), + ) + mm_data = { + "image": [dummy_image] * image_count, + } + return DummyData(seq_data, mm_data) + + +def input_mapper_for_phi4mm_audio(ctx: InputContext, + data: object) -> MultiModalKwargs: + """ + This function is used to create the MultiModalKwargs for the Phi4MM + (audio) model. + Specifically, for audio, we extract the audio features from the sound + file and create pairs of audio features and audio embed lengths (the + latter of which is used to repeat the audio placeholder token in the + input prompt IDs). + These pairs are used, downstream, in `_audio_features_to_embeddings` + (via `_process_audio_input`). + + Note that the incoming audio data (each entry in `data`) is a tuple of + the audio data and the sampling frequency (e.g. from soundfile.read). + + Args: + ctx (InputContext): Input context. + data (object): Audio data. + + Returns: + MultiModalKwargs: Multi-modal inputs. + """ + if not isinstance(data, list): + data = [data] + + if len(data) == 0: + return MultiModalKwargs() + + audio_features = [] + for audio_input in data: + if not isinstance(audio_input, tuple): + raise NotImplementedError( + f"Unsupported data type: {type(audio_input)}") + + audio, sf = audio_input + feature_extractor = audio_feature_extractor() + single_audio_features = feature_extractor.extract_features(audio, sf) + feat_stride = (1 if not hasattr(feature_extractor, "stride") else + feature_extractor.stride) + audio_frames = len(single_audio_features) * feat_stride + single_audio_embed_size = _compute_audio_embed_size( + ctx.get_hf_config(), audio_frames) + single_audio_feature_audio_len_pair = ( + single_audio_features, + [single_audio_embed_size], + ) + audio_features.append(single_audio_feature_audio_len_pair) + return MultiModalKwargs({"audio_features": audio_features}) + + +def input_mapper_for_phi4mm_image(ctx: InputContext, data: object): + if not isinstance(data, list): + data = [data] + # data: list of PIL images + if len(data) == 0: + return MultiModalKwargs() + hf_config = ctx.get_hf_config() + vision_encoder_name = hf_config.img_processor + if vision_encoder_name is None: + vision_encoder_name = SIGLIP_NAME + prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name] + dynamic_hd_size = prepro_config['dynamic_hd'] + vit_image_size = prepro_config['vit_image_size'] + vit_patch_size = prepro_config['vit_patch_size'] + + image_input_dict = preprocess(data, dynamic_hd_size, vit_image_size, + vit_patch_size) + return MultiModalKwargs({ + "pixel_values": + image_input_dict["pixel_values"], + "image_sizes": + image_input_dict["image_sizes"], + "image_attention_mask": + image_input_dict["image_attention_mask"], + "num_img_tokens": + image_input_dict["num_img_tokens"], + }) + + +def cat_with_pad(tensors, dim, padding_value=0): + """ + cat along dim, while pad to max for all other dims + """ + ndim = tensors[0].dim() + assert all( + t.dim() == ndim for t in + tensors[1:]), "All tensors must have the same number of dimensions" + + out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)] + out_size[dim] = sum(t.shape[dim] for t in tensors) + output = tensors[0].new_full(out_size, padding_value) + + index = 0 + for t in tensors: + # Create a slice list where every dimension except dim is full slice + slices = [slice(0, t.shape[d]) for d in range(ndim)] + # Update only the concat dimension slice + slices[dim] = slice(index, index + t.shape[dim]) + + output[slices] = t + index += t.shape[dim] + + return output + + +@MULTIMODAL_REGISTRY.register_input_mapper("audio", + input_mapper_for_phi4mm_audio) +@MULTIMODAL_REGISTRY.register_input_mapper("image", + input_mapper_for_phi4mm_image) +@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( + "audio", get_max_phi4mm_audio_tokens) +@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( + "image", get_max_phi4mm_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi4mm) +@INPUT_REGISTRY.register_input_processor(input_processor_for_phi4mm) +class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal, + SupportsV0Only): + """ + Implements the Phi-4-multimodal-instruct model in vLLM. + """ + packed_modules_mapping = { + "qkv_proj": [ + "qkv_proj", + ], + "gate_up_proj": [ + "gate_up_proj", + ], + } + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + "base_layer.": "", + }, + orig_to_new_prefix={ + "model.embed_tokens_extend.audio_embed.audio_projection.vision.": + "embed_tokens_extend.audio_projection_for_vision.", + "model.embed_tokens_extend.audio_embed.audio_projection.speech.": + "embed_tokens_extend.audio_projection.", + "model.embed_tokens_extend.audio_embed.": "embed_tokens_extend.", + "model.embed_tokens_extend.image_embed.": "vision_encoder.", + }, + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + multimodal_config = vllm_config.model_config.multimodal_config + assert multimodal_config, "multimodal_config is required" + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.multimodal_config = multimodal_config + self.quant_config = quant_config + self.lora_config = lora_config + + # Tensor/Pipeline parallel not supported for now. + assert get_pp_group( + ).world_size == 1, "pipeline parallel is not supported" + + self.vision_encoder = Phi4MMImageEncoder( + config, + quant_config, + prefix="model.vision_embed_tokens", + model_dir=config._name_or_path) + + if isinstance(config.embd_layer["audio_embd_layer"], dict): + embedding_config = { + "embedding_cls": + config.embd_layer["audio_embd_layer"]["embedding_cls"], + **config.embd_layer["audio_embd_layer"], + } + else: + embedding_config = { + "embedding_cls": self.config.embd_layer["embedding_cls"] + } + + self.embed_tokens_extend = AudioEmbedding(config, **embedding_config) + self.model = LlamaModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size), + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, logit_scale) + self.sampler = Sampler() + + def _audio_features_to_embeddings( + self, + input_ids: torch.Tensor, + input_features: List[torch.Tensor], + audio_input_sizes: torch.Tensor, + audio_projection_mode: str, + ) -> torch.Tensor: + """ + Convert audio features to embeddings, which are used as input to the + model (via `inputs_embeds`). + + Args: + input_ids (torch.Tensor): Input IDs (the prompt in this case). + input_features (list[torch.Tensor]): Input features (the audio + embeddings). + audio_input_sizes (list[torch.Tensor]): Audio input sizes (the + audio embed lengths to use for padding the audio placeholder token + in the input prompt IDs). + """ + # The audio projection can either be a single linear or Sequential, + # so handle both cases + if isinstance(self.embed_tokens_extend.audio_projection, + nn.Sequential): + target_dtype = self.embed_tokens_extend.audio_projection[ + 0].bias.dtype + else: + target_dtype = self.embed_tokens_extend.audio_projection.bias.dtype + + audio_input = [ + input.unsqueeze(0).to(target_dtype) for input in input_features + ] + kwargs = { + "wte": self.model.embed_tokens, + 'audio_projection_mode': audio_projection_mode + } + audio_embeddings = self.embed_tokens_extend(input_ids, audio_input, + audio_input_sizes, + **kwargs) + audio_embeddings = audio_embeddings.to(target_dtype) + return audio_embeddings + + def _parse_and_validate_audio_input( + self, **kwargs: object) -> Optional[Phi4MMAudioInputs]: + """ + Parse and validate the audio input to the model. This handles both + audio features and audio embeddings, but only the former is used for + now. + + Args: + kwargs (object): Keyword arguments. + + Returns: + Optional[Phi4MMAudioInputs]: Parsed and validated audio inputs. + """ + audio_features = kwargs.pop("audio_features", None) + audio_embeds = kwargs.pop("audio_embeds", None) + + if audio_features is None and audio_embeds is None: + return None + + if audio_features is not None: + if not isinstance(audio_features, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio features. " + f"Got type: {type(audio_features)}") + + return Phi4MMAudioFeatureInputs(type="audio_features", + data=audio_features) + + if audio_embeds is not None: + if not isinstance(audio_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio embeds. " + f"Got type: {type(audio_embeds)}") + + return Phi4MMAudioEmbeddingInputs(type="audio_embeds", + data=audio_embeds) + + raise AssertionError("This line should be unreachable.") + + def _process_audio_input(self, input_ids: torch.Tensor, + audio_input: Phi4MMAudioInputs, + audio_projection_mode: str) -> NestedTensors: + """ + Create the audio embeddings from the audio input, where the audio input + is pairs of audio features and audio embed lengths. The audio input is + created by `input_mapper_for_phi4mm_audio`. + + Args: + input_ids (torch.Tensor): Input IDs (the prompt in this case, + before the audio token replication). + audio_input (Phi4MMAudioInputs): Audio input. + + Returns: + NestedTensors: Audio embeddings + """ + if audio_input["type"] == "audio_embeds": + return audio_input["data"] + + audio_features = audio_input["data"] + # (e.g. multiple examples) and the second dim is the multi-audio dim + # (e.g. multiple audios in the same example) + audio_feature = [i[0] for j in audio_features for i in j] + audio_feature_len = [i[1].item() for j in audio_features for i in j] + # Add the batch dim via `squeeze` + + return self._audio_features_to_embeddings( + input_ids.unsqueeze(0), + audio_feature, + audio_feature_len, + audio_projection_mode, + ).squeeze(0) + + def _parse_and_validate_image_input(self, + **kwargs: object) -> Optional[Dict]: + pixel_values: Optional[Dict] = kwargs.get("pixel_values") + if pixel_values is None: + return None + + image_sizes = kwargs.get("image_sizes") + image_attention_mask = kwargs.get("image_attention_mask") + num_img_tokens = kwargs.get("num_img_tokens") + assert image_sizes is not None and image_attention_mask is not None\ + and num_img_tokens is not None, "Missing image inputs" + + if isinstance(pixel_values, list): + assert pixel_values[0].dim() == 5, "Incorrect image inputs" + # list len is batch_size. + # each tensor has dimension: num_img_per_example, num_hd_patches, + # channels, height, width. + # need to pad along num_hd_patches. + # mask size num_img_per_prompt, num_hd_patches, feat_h, heat_w. + pixel_values = cat_with_pad(pixel_values, dim=0) + elif isinstance(pixel_values, torch.Tensor): + # dimension: batch_size, num_img_per_example, num_hd_patches, + # channels, height, width. + # we flatten first 2 dims to make it a single large batch for + # SigLIP Encoder. + assert pixel_values.dim() == 6, "Incorrect image inputs" + pixel_values = pixel_values.flatten(0, 1) + else: + raise ValueError("Incorrect pixel_values inputs") + + if isinstance(image_attention_mask, list): + image_attention_mask = cat_with_pad(image_attention_mask, dim=0) + elif isinstance(image_attention_mask, torch.Tensor): + image_attention_mask = image_attention_mask.flatten(0, 1) + else: + raise ValueError("Incorrect image_attention_mask inputs") + + if isinstance(image_sizes, list): + image_sizes = torch.cat(image_sizes, dim=0) + elif isinstance(image_sizes, torch.Tensor): + image_sizes = image_sizes.flatten(0, 1) + else: + raise ValueError("Incorrect image_attention_mask inputs") + + if isinstance(num_img_tokens, list): + num_img_tokens = [ + n for num_tensor in num_img_tokens + for n in num_tensor.tolist() + ] + elif isinstance(num_img_tokens, torch.Tensor): + num_img_tokens = num_img_tokens.flatten(0, 1).tolist() + else: + raise ValueError("Incorrect image_attention_mask inputs") + + return { + 'pixel_values': pixel_values, + 'image_sizes': image_sizes, + 'image_attention_mask': image_attention_mask, + 'num_img_tokens': num_img_tokens, + } + + def merge_image_features_to_inputs_embeds( + self, + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + image_set_tensors: List[torch.Tensor], + ): + position_tuple = (input_ids == _IMAGE_PLACEHOLDER_TOKEN_ID).nonzero( + as_tuple=True) + + assert all([t.shape[0] == 1 for t in image_set_tensors + ]), 'img_set_tensor should have shape (1, N_tokens, C)' + # Shape: (merged_N_tokens, C) + image_set_tensor = torch.cat(image_set_tensors, dim=1).squeeze(0) + image_set_tensor = image_set_tensor.to(inputs_embeds.dtype).to( + inputs_embeds.device) + merged_embeds = inputs_embeds.index_put( + indices=position_tuple, + values=image_set_tensor, + accumulate=False, + ) + return merged_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs: object, + ) -> torch.Tensor: + if intermediate_tensors is not None: + input_ids = None + inputs_embeds = None + else: + # Each entry in this is a pair of audio_features and audio_embed + # lengths + audio_input = self._parse_and_validate_audio_input(**kwargs) + image_inputs = self._parse_and_validate_image_input(**kwargs) + + has_audio = audio_input is not None + has_image = image_inputs is not None + + if has_audio: + audio_projection_mode = 'vision' if has_image else 'speech' + inputs_embeds = self._process_audio_input( + input_ids, audio_input, audio_projection_mode) + + if has_image: + dtype = self.vision_encoder.img_processor.embeddings.\ + patch_embedding.weight.dtype + pixel_values = image_inputs['pixel_values'].to(dtype) + image_sizes = image_inputs['image_sizes'] + image_attention_mask = image_inputs['image_attention_mask'] + image_set_tensors = self.vision_encoder( + pixel_values, image_sizes, image_attention_mask) + if not has_audio: + inputs_embeds = self.model.embed_tokens(input_ids) + + inputs_embeds = self.merge_image_features_to_inputs_embeds( + input_ids, inputs_embeds, image_set_tensors) + + if has_image or has_audio: + # multi-modal input, we have set inputs_embeds properly in + # previous steps + input_ids = None + else: + # text-only, we keep using original input_ids + inputs_embeds = None + + hidden_states = self.model( + input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> None: + weights = ((name, data) for name, data in weights + if "lora" not in name) + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="model.", + connector=["audio_projection_for_vision", "audio_projection"], + tower_model=["vision_encoder", "embed_tokens_extend"], + ) diff --git a/vllm/model_executor/models/phi4mm_audio.py b/vllm/model_executor/models/phi4mm_audio.py new file mode 100644 index 0000000..db90848 --- /dev/null +++ b/vllm/model_executor/models/phi4mm_audio.py @@ -0,0 +1,1271 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +# Code copied from Microsoft/MoE by Jacob Platin (jacobplatin@microsoft.com) +# but implemented by the Phi-Speech team +#!/usr/bin/env python3 +import abc +import math +from typing import List, Literal, Optional + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointWrapper) +from torch.distributed.fsdp.fully_sharded_data_parallel import ( + FullyShardedDataParallel) +from transformers import PretrainedConfig + +from vllm.model_executor.models.phi4mm_utils import ( + AbsolutePositionalEncoding, ConvModule, FeedForward, MeanVarianceNormLayer, + MultiHeadedAttention, MultiSequential, NemoConvSubsampling, + T5RelativeAttentionLogitBias, adaptive_enc_mask, get_offset, unfold_tensor) + +_AUDIO_PLACEHOLDER_TOKEN_ID = 200011 # <|endoftext11|> + + +class ConformerEncoderLayer(nn.Module): + """ConformerEncoder Layer module. + for more details see conformer paper: + https://arxiv.org/abs/2005.08100 + This module implement the Conformer block layer. + + Args: + d_model: int + attention dim. + ext_pw_out_channel: int + if > 0, ext_pw_out_channel is a dim channel size + for the last pointwise conv after swish activation. + depthwise_seperable_out_channel: int + if set different to 0, the number of + depthwise_seperable_out_channel will be used as a + channel_out of the second conv1d layer. + otherwise, it equal to 0, the second conv1d layer is skipped. + depthwise_multiplier: int + number of input_dim channels duplication. this value + will be used to compute the hidden channels of the Conv1D. + n_head: int + the number of heads for multihead attention module. + d_ffn: int + output size of the feed_forward blocks. + ext_pw_kernel_size: int + kernel size of the conv pointwise of the conformer. + kernel_size: int + kernel size. + dropout_rate: float + dropout rate. + causal: bool, optional + if set to True, convolution have no access + to future frames. default False. + batch_norm: bool, optional + if set to True, apply batchnorm before activation + in ConvModule layer of the conformer. + default False + activation: str, optional + activation function name, + one of ["relu", "swish", "sigmoid"], + sigmoid activation is only used with "glu_in_fnn=True", + default "relu". + chunk_se: int, optional + 0 for offline SE. + 1 for streaming SE, where mean is computed + by accumulated history until current chunk_se. + 2 for streaming SE, where mean is computed + by only the current chunk. + default 0. + chunk_size: int, optional + chunk_size for cnn. default 18 + conv_activation: str, optional + activation function used in ConvModule part + of the conformer, default "relu". + conv_glu_type: str, optional + activation function used for the glu inside + the ConvModule part of the conformer. + default: "sigmoid". + bias_in_glu: bool, optional + if set to True, use additive bias in the weight module + before GLU. + linear_glu_in_convm: bool, optional + if set to True, use GLULinear module, + otherwise, used GLUPointWiseConv module. + default to False. + attention_innner_dim: int, optional + if equal to -1, attention dim for linears k/q/v is + equal to d_model. otherwise attention_innner_dim is used. + default -1. + attention_glu_type: str, optional + activation function for glu used in the multihead attention, + default "swish". + activation_checkpointing: str, optional + a dictionarry of {"module","interval","offload"}, where + "module": str + accept ["transformer", "attention"] to select + which module should do activation checkpointing. + "interval": int, default 1, + interval of applying activation checkpointing, + interval = 1 means that we apply checkpointing + on every layer (if activation), otherwise, + we apply it every x interval. + "offload": bool, default False, + if set to True, we offload activation to cpu and + reload it during backward, otherwise, + we recalculate activation in backward. + default "". + export: bool, optional + if set to True, it remove the padding from convolutional layers + and allow the onnx conversion for inference. + default False. + use_pt_scaled_dot_product_attention: bool, optional + if set to True, use pytorch's scaled dot product attention + implementation in training. + attn_group_sizes: int, optional + the number of groups to use for attention, default 1 + (Multi-Head Attention), + 1 = typical Multi-Head Attention, + 1 < attn_group_sizes < attention_heads = Grouped-Query Attention + attn_group_sizes = attenion_heads = Multi-Query Attention + """ + + def __init__( + self, + d_model=512, + ext_pw_out_channel=0, + depthwise_seperable_out_channel=256, + depthwise_multiplier=1, + n_head=4, + d_ffn=2048, + ext_pw_kernel_size=1, + kernel_size=3, + dropout_rate=0.1, + causal=False, + batch_norm=False, + activation="relu", + chunk_se=0, + chunk_size=18, + conv_activation="relu", + conv_glu_type="sigmoid", + bias_in_glu=True, + linear_glu_in_convm=False, + attention_innner_dim=-1, + attention_glu_type="swish", + activation_checkpointing="", + export=False, + use_pt_scaled_dot_product_attention=False, + attn_group_sizes: int = 1, + ): + super().__init__() + + self.feed_forward_in = FeedForward( + d_model=d_model, + d_inner=d_ffn, + dropout_rate=dropout_rate, + activation=activation, + bias_in_glu=bias_in_glu, + ) + + self.self_attn = MultiHeadedAttention( + n_head, + d_model, + dropout_rate, + attention_innner_dim, + attention_glu_type, + bias_in_glu, + use_pt_scaled_dot_product_attention= + use_pt_scaled_dot_product_attention, + group_size=attn_group_sizes, + ) + self.conv = ConvModule( + d_model, + ext_pw_out_channel, + depthwise_seperable_out_channel, + ext_pw_kernel_size, + kernel_size, + depthwise_multiplier, + dropout_rate, + causal, + batch_norm, + chunk_se, + chunk_size, + conv_activation, + conv_glu_type, + bias_in_glu, + linear_glu_in_convm, + export=export, + ) + + self.feed_forward_out = FeedForward( + d_model=d_model, + d_inner=d_ffn, + dropout_rate=dropout_rate, + activation=activation, + bias_in_glu=bias_in_glu, + ) + + self.layer_norm_att = nn.LayerNorm(d_model) + self.layer_norm = nn.LayerNorm(d_model) + + def forward( + self, + x, + pos_k, + pos_v, + mask, + relative_attention_bias: Optional[Tensor] = None, + ): + """ConformerEncoder forward. + + Args: + x: torch.Tensor + input feature of shape (batch, max_time_in, size) + pos_k: torch.Tensor + positional key embedding. + mask: torch.Tensor + mask for x (batch, max_time_in) + relative_attention_bias: Optional[torch.Tensor] + bias added to attention logits w.r.t. relative positions + (1, n_head, time1, time2) + """ + x = x + 0.5 * self.feed_forward_in(x) + norm_x = self.layer_norm_att(x) + + x = x + self.self_attn( + norm_x, + norm_x, + norm_x, + pos_k, + pos_v, + mask, + relative_attention_bias=relative_attention_bias, + ) + x = x + self.conv(x) + x = x + 0.5 * self.feed_forward_out(x) + + out = self.layer_norm(x) + + return out, pos_k, pos_v, mask + + +class TransformerEncoderBase(abc.ABC, nn.Module): + """The Base class for Transformer based encoders + + Please set causal = True in streaming model + Args: + input_size: int + input feature dimension. + chunk_size: int, list(int) + Number of frames for each chunk + This variable can take 2 forms: + int: Used for inference, or single chunk size training + list(int) : Used only for variable chunk size training + Some examples for the 2 cases: + chunk_size = 12 + chunk_size = [6, 8, 12, 24] + left_chunk: int, list(int) + Number of chunks used for masking in streaming mode. + This variable can take 2 forms: + int: Used for inference, or single chunk size training + list(int) : Used only for variable chunk size training. When + chunk_size is a list, left_chunk must be a list with same length. + Some examples for the 2 cases: + left_chunk = 6 + left_chunk = [12, 9, 6, 3] + attention_dim: int, optional + attention dimension. default 256. + attention_heads: int, optional + the number of heads. default 4 + input_layer: str, optional + input layer type before Conformer, + one of ["linear", "conv2d", "custom", "vgg2l", "embed"], + default "conv2d" + cnn_out: int, optional + the number of CNN channels before Conformer. + default -1. + cnn_layer_norm: bool, optional + layer norm between Conformer and the first CNN. + default False. + time_reduction: int, optional + time reduction factor + default 4 + dropout_rate: float, optional + dropout rate. default 0.1 + padding_idx: int, optional + padding index for input_layer=embed + default -1 + relative_attention_bias_args: dict, optional + use more efficient scalar bias-based relative multihead attention + (Q*K^T + B) implemented in cmb.basics.embedding. + [T5/ALiBi]RelativeAttentionLogitBias + usage: relative_attention_bias_args={"type": t5/alibi} + additional method-specific arguments can be provided (see + transformer_base.py) + positional_dropout_rate: float, optional + dropout rate after positional encoding. default 0.0 + nemo_conv_settings: dict, optional + A dictionary of settings for NeMo Subsampling. + default None + conv2d_extra_padding: str, optional + Add extra padding in conv2d subsampling layers. Choices are + (feat, feat_time, none, True). + if True or feat_time, the extra padding is added into non full + supraframe utts in batch. + Default: none + attention_group_size: int, optional + the number of groups to use for attention, default 1 + (Multi-Head Attention), + 1 = typical Multi-Head Attention, + 1 < attention_group_size < attention_heads = Grouped-Query + Attention + attention_group_size = attenion_heads = Multi-Query Attention + """ + + def __init__( + self, + input_size, + chunk_size, + left_chunk, + attention_dim=256, + attention_heads=4, + input_layer="nemo_conv", + cnn_out=-1, + cnn_layer_norm=False, + time_reduction=4, + dropout_rate=0.0, + padding_idx=-1, + relative_attention_bias_args=None, + positional_dropout_rate=0.0, + nemo_conv_settings=None, + conv2d_extra_padding: Literal["feat", "feat_time", "none", + True] = "none", + attention_group_size=1, + encoder_embedding_config=None, + ): + super().__init__() + self.input_size = input_size + self.input_layer = input_layer + self.chunk_size = chunk_size + self.left_chunk = left_chunk + self.attention_dim = attention_dim + self.num_heads = attention_heads + self.attention_group_size = attention_group_size + self.time_reduction = time_reduction + self.nemo_conv_settings = nemo_conv_settings + self.encoder_embedding_config = encoder_embedding_config + + if self.input_layer == "nemo_conv": + default_nemo_conv_settings = { + "subsampling": "dw_striding", + "subsampling_factor": self.time_reduction, + "feat_in": input_size, + "feat_out": attention_dim, + "conv_channels": 256, + "subsampling_conv_chunking_factor": 1, + "activation": nn.ReLU(), + "is_causal": False, + } + # Override any of the defaults with the incoming, user settings + if nemo_conv_settings: + default_nemo_conv_settings.update(nemo_conv_settings) + for i in ["subsampling_factor", "feat_in", "feat_out"]: + assert ( + i not in nemo_conv_settings + ), "{i} should be specified outside of the NeMo dictionary" + + self.embed = NemoConvSubsampling(**default_nemo_conv_settings, ) + else: + raise ValueError("unknown input_layer: " + input_layer) + + self.pos_emb = AbsolutePositionalEncoding(attention_dim, + positional_dropout_rate) + + self.relative_attention_bias_type = ( + relative_attention_bias_args.get("type") + if relative_attention_bias_args else None) + if self.relative_attention_bias_type == "t5": + assert (self.num_heads % self.attention_group_size == 0 + ), "attention_group_size must divide n_head" + self.relative_attention_bias_layer = T5RelativeAttentionLogitBias( + self.num_heads // self.attention_group_size, + max_distance=relative_attention_bias_args.get( + "t5_bias_max_distance", 1000), + symmetric=relative_attention_bias_args.get( + "t5_bias_symmetric", False), + ) + else: + raise NotImplementedError + + self.encoder_embedding = MeanVarianceNormLayer( + self.encoder_embedding_config["input_size"]) + + def compute_lens_change(self, feature_lens): + """feature_lens: int + return updated feature lens. + + This used to return a different lambda function for each case that + computed the right thing. That does not work within Torchscript. + If you really need this to be faster, create nn.Module()-s for all + the cases and return one of them. Torchscript does support that. + """ + if self.input_layer == "nemo_conv": + # Handle the special causal case + subsampling_causal_cond = self.nemo_conv_settings.get( + "subsampling", "dw_striding") in [ + "dw_striding", + "striding", + "striding_conv1d", + ] + is_causal = self.nemo_conv_settings.get("is_causal", False) + if is_causal and subsampling_causal_cond: + lens_change = (torch.ceil(feature_lens / + self.time_reduction).long() + if isinstance(feature_lens, Tensor) else + math.ceil(feature_lens / self.time_reduction)) + feature_lens_remainder = feature_lens % self.time_reduction + if isinstance(feature_lens, Tensor): + lens_change[feature_lens_remainder != 1] += 1 + elif feature_lens_remainder != 1: + lens_change += 1 + return lens_change + ceil_func = (math.ceil + if isinstance(feature_lens, int) else torch.ceil) + return ceil_func(feature_lens / self.time_reduction) + + @abc.abstractmethod + def forward(self): + """Abstract forward method implementation.""" + + def _chunk_size_selection(self, chunk_size=None, left_chunk=None): + """If chunk size is a list, we will randomly select a chunk size.""" + + if chunk_size is None: + chunk_size = self.chunk_size + if left_chunk is None: + left_chunk = self.left_chunk + if isinstance(chunk_size, list): + # Variable chunk size during training + chunk_size_index = int( + torch.randint(low=0, high=len(chunk_size), size=(1, ))) + chunk_size_train_eff = chunk_size[chunk_size_index] + if not isinstance(left_chunk, list): + raise ValueError( + "Since chunk_size is a list, left_chunk must be a list") + if len(left_chunk) != len(chunk_size): + raise ValueError( + "The length of left_chunk must be the same as length of "\ + "chunk_size." + ) + left_chunk_train_eff = left_chunk[chunk_size_index] + else: + chunk_size_train_eff = chunk_size + left_chunk_train_eff = left_chunk + + return chunk_size_train_eff, left_chunk_train_eff + + def _get_embed_class(self, embed): + # pylint: disable=protected-access + is_embed_using_act_chkpt = isinstance(embed, CheckpointWrapper) + is_embed_fsdp_wrapped = isinstance(embed, FullyShardedDataParallel) + embed_class = embed + if is_embed_using_act_chkpt: + embed_class = embed._checkpoint_wrapped_module + if is_embed_fsdp_wrapped: + embed_class = embed.module + return embed_class + + def _forward_embeddings_core(self, input_tensor, masks): + embed_class = self._get_embed_class(self.embed) + assert isinstance(embed_class, NemoConvSubsampling) + input_tensor, masks = self.embed(input_tensor, masks) + return input_tensor, masks + + def _position_embedding(self, input_tensor): + pos_k = None + pos_v = None + if self.relative_attention_bias_layer is None: + input_tensor = self.pos_emb( + input_tensor) # default to add abs sinusoid embedding + return pos_k, pos_v + + def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk): + chunk_size_train_eff, left_chunk_train_eff = \ + self._chunk_size_selection(chunk_size, left_chunk) + + # Create mask matrix for streaming + # S stores start index. if chunksize is 18, s is [0,18,36,....] + chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff) + + enc_streaming_mask = (adaptive_enc_mask( + seq_len, chunk_start_idx, + left_window=left_chunk_train_eff).unsqueeze(0).expand( + [batch_size, -1, -1])) + return enc_streaming_mask + + def forward_embeddings(self, + xs_pad, + masks, + chunk_size_nc=None, + left_chunk_nc=None): + """Forwarding the inputs through the top embedding layers + + Args: + xs_pad: torch.Tensor + input tensor + masks: torch.Tensor + input mask + chunk_size_nc: (optional, default is None) chunk size for + non-causal layers + left_chunk_nc: (optional, default is None) # of left chunks for + non-causal layers + """ + # pylint: disable=R0915 + # get new lens. + seq_len = int(self.compute_lens_change(xs_pad.shape[1])) + if seq_len <= 0: + raise ValueError( + f"""The sequence length after time reduction is invalid: + {seq_len}. Your input feature is too short. Consider + filtering out the very short sentence from data + loader""", ) + + batch_size = xs_pad.shape[0] + + enc_streaming_mask = self._streaming_mask(seq_len, batch_size, + self.chunk_size, + self.left_chunk) + + if xs_pad.is_cuda: + enc_streaming_mask = enc_streaming_mask.cuda() + xs_pad = xs_pad.cuda() + + input_tensor = xs_pad + input_tensor, masks = self._forward_embeddings_core( + input_tensor, masks) + + streaming_mask = enc_streaming_mask + if streaming_mask is not None and masks is not None: + hs_mask = masks & streaming_mask + elif masks is not None: + hs_mask = masks + else: + hs_mask = streaming_mask + + if chunk_size_nc is not None: + enc_streaming_mask_nc = self._streaming_mask( + seq_len, batch_size, chunk_size_nc, left_chunk_nc) + if xs_pad.is_cuda: + enc_streaming_mask_nc = enc_streaming_mask_nc.cuda() + if masks is not None: + hs_mask_nc = masks & enc_streaming_mask_nc + else: + hs_mask_nc = enc_streaming_mask_nc + else: + hs_mask_nc = None + + pos_k, pos_v = self._position_embedding(input_tensor) + + if chunk_size_nc is None: + return input_tensor, pos_k, pos_v, hs_mask, masks + return input_tensor, pos_k, pos_v, hs_mask, masks, hs_mask_nc + + def get_offset(self): + """Returns offset used when retaining inputs for decoding. + + This is essentially, how many additional frames have to be added to + the front-end CNN input to ensure it can produce a single output. + So if the "padding" parameter is 0, typically offset will be > 0. + """ + return get_offset(self.input_layer, self.time_reduction) + + +class ConformerEncoder(TransformerEncoderBase): + """ConformerEncoder module. + see original paper for more details: + https://arxiv.org/abs/2005.08100 + + Please set causal = True in streaming model + Args: + input_size: int + input feature dimension. + chunk_size: int, list(int) + Number of frames for each chunk + This variable can take 2 forms: + int: Used for inference, or single chunk size training + list(int) : Used only for variable chunk size training + Some examples for the 2 cases: + chunk_size = 12 + chunk_size = [6, 8, 12, 24] + left_chunk: int, list(int) + Number of chunks used for masking in streaming mode. + This variable can take 2 forms: + int: Used for inference, or single chunk size training + list(int) : Used only for variable chunk size training. When + chunk_size is a list, left_chunk must be a list with same length. + Some examples for the 2 cases: + left_chunk = 6 + left_chunk = [12, 9, 6, 3] + left_chunk: int + number of chunks used for masking in streaming mode. + num_lang: int + This parameter is used to store the number of languages in the + lang_dict, only used for multiseed/multilingual models. + default None. + attention_dim: int, optional + attention dimension. default 256. + attention_heads: int, optional + the number of heads. default 4 + linear_units: + the number of units of position-wise feed forward. + default 2048 + num_block: + number of Transformer layer. default 6 + dropout_rate: float, optional + dropout rate. default 0.1 + input_layer: str, optional + input layer type before Conformer, + one of ["linear", "conv2d", "custom", "vgg2l", "embed"], + default "conv2d" + causal: bool, optional + if set to True, convolution have no access + to future frames. default False. + batch_norm: bool, optional + if set to True, apply batchnorm before activation + in ConvModule layer of the conformer. + default False + cnn_out: int, optional + the number of CNN channels before Conformer. + default -1. + cnn_layer_norm: bool, optional + layer norm between Conformer and the first CNN. + default False. + ext_pw_out_channel: int, optional + the number of channel for CNN + before depthwise_seperable_CNN. + If 0 then use linear. default 0. + ext_pw_kernel_size: int, optional + kernel size of N before depthwise_seperable_CNN. + only work for ext_pw_out_channel > 0. + default 1 + depthwise_seperable_out_channel: int, optional + the number of channel for + depthwise_seperable_CNN. + default 256. + depthwise_multiplier: int, optional + the number of multiplier for + depthwise_seperable_CNN. + default 1. + chunk_se: int, optional + 0 for offline SE. + 1 for streaming SE, where mean is computed + by accumulated history until current chunk_se. + 2 for streaming SE, where mean is computed + by only the current chunk. + default 0. + kernel_size: int, optional + the number of kernels for depthwise_seperable_CNN. + default 3. + activation: str, optional + FeedForward block activation. + one of ["relu", "swish", "sigmoid"] + default "relu". + conv_activation: str, optional + activation function used in ConvModule part + of the conformer, default "relu". + conv_glu_type: str, optional + activation used use glu in depthwise_seperable_CNN, + default "sigmoid" + bias_in_glu: bool, optional + if set to True, use additive bias in the weight module + before GLU. default True + linear_glu_in_convm: bool, optional + if set to True, use GLULinear module, + otherwise, used GLUPointWiseConv module. + default to False. + attention_glu_type: str + only work for glu_in_attention !=0 + default "swish". + export: bool, optional + if set to True, it remove the padding from convolutional layers + and allow the onnx conversion for inference. + default False. + activation_checkpointing: str, optional + a dictionarry of {"module","interval","offload"}, where + "module": str + accept ["transformer", "attention"] to select + which module should do activation checkpointing. + "interval": int, default 1, + interval of applying activation checkpointing, + interval = 1 means that we apply checkpointing + on every layer (if activation), otherwise, + we apply it every x interval. + "offload": bool, default False, + if set to True, we offload activation to cpu and + reload it during backward, otherwise, + we recalculate activation in backward. + default "". + extra_layer_output_idx: int + the layer index to be exposed. + relative_attention_bias_args: dict, optional + use more efficient scalar bias-based relative multihead attention + (Q*K^T + B) implemented in cmb.basics.embedding. + [T5/ALiBi]RelativeAttentionLogitBias + usage: relative_attention_bias_args={"type": t5/alibi} + additional method-specific arguments can be provided (see + transformer_base.py) + time_reduction: int optional + time reduction factor + default 4 + use_pt_scaled_dot_product_attention: whether to use pytorch scaled + dot product attention in training. + Default: False + nemo_conv_settings: dict, optional + A dictionary of settings for NeMo Subsampling. + default: None + usage: nemo_conv_settings= + { + "subsampling": + dw_striding/striding/dw_striding_conv1d/striding_conv1d, + "conv_channels": int, + "subsampling_conv_chunking_factor": int, + "is_causal": True/False + } + conv2d_extra_padding: str, optional + Add extra padding in conv2d subsampling layers. Choices are + (feat, feat_time, none, True) + Default: none + replication_pad_for_subsample_embedding: For batched-streaming + decoding, use "replication" padding for the cache at start of + utterance. + Default: False + attention_group_size: int, optional + the number of groups to use for attention, default 1 + (Multi-Head Attention), + 1 = typical Multi-Head Attention, + 1 < attention_group_size < attention_heads = Grouped-Query + Attention + attention_group_size = attenion_heads = Multi-Query Attention + """ + + extra_multi_layer_output_idxs: List[int] + + def __init__( # pylint: disable-all + self, + input_size, + chunk_size, + left_chunk, + num_lang=None, + attention_dim=256, + attention_heads=4, + linear_units=2048, + num_blocks=6, + dropout_rate=0.1, + input_layer="nemo_conv", + causal=True, + batch_norm=False, + cnn_out=-1, + cnn_layer_norm=False, + ext_pw_out_channel=0, + ext_pw_kernel_size=1, + depthwise_seperable_out_channel=256, + depthwise_multiplier=1, + chunk_se=0, + kernel_size=3, + activation="relu", + conv_activation="relu", + conv_glu_type="sigmoid", + bias_in_glu=True, + linear_glu_in_convm=False, + attention_glu_type="swish", + export=False, + extra_layer_output_idx=-1, + extra_multi_layer_output_idxs=[], # noqa + activation_checkpointing="", + relative_attention_bias_args=None, + time_reduction=4, + use_pt_scaled_dot_product_attention=False, + nemo_conv_settings=None, + conv2d_extra_padding: Literal["feat", "feat_time", "none", + True] = "none", + replication_pad_for_subsample_embedding=False, + attention_group_size=1, + encoder_embedding_config=None, + ): + super().__init__( + input_size, + chunk_size, + left_chunk, + attention_dim, + attention_heads, + input_layer, + cnn_out, + cnn_layer_norm, + time_reduction, + dropout_rate=dropout_rate, + relative_attention_bias_args=relative_attention_bias_args, + positional_dropout_rate=0.0, + nemo_conv_settings=nemo_conv_settings, + conv2d_extra_padding=conv2d_extra_padding, + attention_group_size=attention_group_size, + encoder_embedding_config=encoder_embedding_config, + ) + self.num_blocks = num_blocks + self.num_lang = num_lang + self.kernel_size = kernel_size + self.replication_pad_for_subsample_embedding: bool = ( + replication_pad_for_subsample_embedding) + assert (self.num_heads % attention_group_size == 0 + ), "attention_group_size must divide n_head" + self.num_heads_k = self.num_heads // attention_group_size + + self.encoders = MultiSequential(*[ + ConformerEncoderLayer( + d_model=attention_dim, + ext_pw_out_channel=ext_pw_out_channel, + depthwise_seperable_out_channel=depthwise_seperable_out_channel, + depthwise_multiplier=depthwise_multiplier, + n_head=attention_heads, + d_ffn=linear_units, + ext_pw_kernel_size=ext_pw_kernel_size, + kernel_size=kernel_size, + dropout_rate=dropout_rate, + causal=causal, + batch_norm=batch_norm, + activation=activation, + chunk_se=chunk_se, + chunk_size=chunk_size, + conv_activation=conv_activation, + conv_glu_type=conv_glu_type, + bias_in_glu=bias_in_glu, + linear_glu_in_convm=linear_glu_in_convm, + attention_glu_type=attention_glu_type, + activation_checkpointing=activation_checkpointing, + export=export, + use_pt_scaled_dot_product_attention= + use_pt_scaled_dot_product_attention, + attn_group_sizes=attention_group_size, + ) for _ in range(num_blocks) + ]) + self.extra_layer_output_idx = extra_layer_output_idx + self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs + # Make a zeros scalar we can use in get_initial_state to determine + # the device and the needed dtype: + self.register_buffer("dev_type", torch.zeros(()), persistent=False) + + def init_relative_attention_bias(self, input_tensor): + if self.relative_attention_bias_layer: + return self.relative_attention_bias_layer(input_tensor) + + def calculate_hs_mask(self, xs_pad, device, mask): + max_audio_length = xs_pad.shape[1] + batch_size = xs_pad.shape[0] + enc_streaming_mask = self._streaming_mask(max_audio_length, batch_size, + self.chunk_size, + self.left_chunk) + enc_streaming_mask = enc_streaming_mask.to(device) + if mask is None: + return enc_streaming_mask + + feature_lens = mask.sum(1) + padding_length = feature_lens + pad_mask = (torch.arange(0, max_audio_length, + device=device).expand(padding_length.size(0), + -1) + < padding_length.unsqueeze(1)) + pad_mask = pad_mask.unsqueeze(1) + pad_mask = pad_mask & enc_streaming_mask + return pad_mask + + @torch.jit.ignore + def forward(self, xs_pad, masks): + """Conformer Forward function + + Args: + xs_pad: torch.Tensor + input tensor + masks: torch.Tensor + post-embedding input lengths + """ + xs_pad = self.encoder_embedding(xs_pad) + input_tensor, pos_k, pos_v, hs_mask, masks = self.forward_embeddings( + xs_pad, masks) + + unfolded = False + ori_bz, seq_len, D = input_tensor.shape + max_seq_len = 500 #maximum position for absolute positional encoding + if seq_len > max_seq_len: + # audio sequence is longer than max_seq_len, unfold it into chunks + # of max_seq_len + unfolded = True + # the unfold op will drop residual frames, pad it to the multiple + # of max_seq_len + if seq_len % max_seq_len > 0: + chunk_pad_size = max_seq_len - (seq_len % max_seq_len) + else: + chunk_pad_size = 0 + if chunk_pad_size > 0: + input_tensor_pad = F.pad(input_tensor, + (0, 0, 0, chunk_pad_size), "constant", + 0) + input_tensor = input_tensor_pad.to(input_tensor.device) + input_tensor = unfold_tensor(input_tensor, max_seq_len) + if masks is not None: + # revise hs_mask here because the previous calculated hs_mask + # did not consider extra pad + subsampled_pad_mask = masks.squeeze( + 1) # [bz, subsampled_unmask_seq_len] + extra_padded_subsamlped_pad_mask = F.pad( + subsampled_pad_mask, (0, chunk_pad_size), "constant", + False) # extra padding to the pad mask + extra_padded_subsamlped_pad_mask = \ + extra_padded_subsamlped_pad_mask.unsqueeze(-1).float() + masks_unfold = unfold_tensor( + extra_padded_subsamlped_pad_mask, max_seq_len + ) # unfold the pad mask like we did to the input tensor + masks_unfold = masks_unfold.squeeze( + -1).bool() # unfold op does not support bool tensor + else: + masks_unfold = None + hs_mask = self.calculate_hs_mask( + input_tensor, input_tensor.device, masks_unfold + ) # calculate hs_mask based on the unfolded pad mask + + # layer_emb = None + + relative_attention_bias = self.init_relative_attention_bias( + input_tensor) + + _simplified_path = (self.extra_layer_output_idx == -1 + and relative_attention_bias is None) + + if _simplified_path: + input_tensor, *_ = self.encoders(input_tensor, pos_k, pos_v, + hs_mask) + else: + for i, layer in enumerate(self.encoders): + input_tensor, _, _, _ = layer( + input_tensor, + pos_k, + pos_v, + hs_mask, + relative_attention_bias=relative_attention_bias, + ) + + # if i == self.extra_layer_output_idx: + # layer_emb = input_tensor + + if unfolded: + embed_dim = input_tensor.shape[-1] + input_tensor = input_tensor.reshape(ori_bz, -1, embed_dim) + # if we ever padded before unfolding, we need to remove the padding + if chunk_pad_size > 0: + input_tensor = input_tensor[:, :-chunk_pad_size, :] + + return input_tensor, masks # , layer_emb + + +class WindowQformer(nn.Module): + """Window-level Qformer""" + + def __init__( + self, + window_size: int = 8, + num_queries: int = 1, + num_blocks: int = 2, + attention_dim: int = 512, + attention_heads: int = 8, + linear_units: int = 2048, + dropout_rate: float = 0.0, + normalize_before: bool = True, + ): + super().__init__() + + self.decoders = nn.ModuleList([ + nn.TransformerDecoderLayer( + d_model=attention_dim, + nhead=attention_heads, + dim_feedforward=linear_units, + dropout=dropout_rate, + activation="relu", + batch_first=True, + norm_first=normalize_before, # TODO need to verify + ) for _ in range(num_blocks) + ]) + + self.queries = nn.Parameter(torch.zeros(1, num_queries, attention_dim)) + self.after_norm = (nn.LayerNorm(attention_dim, eps=1e-12) + if normalize_before else None) + self.window_size = window_size + + def forward(self, audio_embed, mask, embed_len=None): + """forward decoder""" + # audio_embed: N x T x D => N x D x T + + audio_embed = audio_embed.transpose(1, 2) + # audio_embed: N x D x 1 x T => N x DK x T' + padding = audio_embed.shape[-1] % self.window_size + if padding > 0: + audio_embed = F.pad(audio_embed, (0, self.window_size - padding), + "constant", 0) + + embed_chunk = F.unfold( + audio_embed[..., None, :], + kernel_size=(1, self.window_size), + stride=(1, self.window_size), + ) + bsz, _, slen = embed_chunk.shape + # N x D x K x T' + embed_chunk = embed_chunk.view(bsz, -1, self.window_size, slen) + # N x T' x K x D + embed_chunk = embed_chunk.transpose(1, 3).contiguous() + # NT' x K x D + embed_chunk = embed_chunk.view(bsz * slen, self.window_size, -1) + # NT' x 1 x D + q = self.queries.expand(bsz * slen, -1, -1) + for layer in self.decoders: + q = layer(tgt=q, + memory=embed_chunk, + tgt_mask=None, + memory_mask=mask) + + if self.after_norm is not None: + q = self.after_norm(q) + + if embed_len is not None: + embed_len = embed_len // self.window_size + # N x T' x D + out = q.view(bsz, slen, -1) + + return out, embed_len + + +class AudioEmbedding(nn.Module): + """Image embedding.""" + + def __init__(self, config: PretrainedConfig, **kwargs) -> None: + super().__init__() + self.config = config + # n_embed or hidden_size for text LM + hidden_size = (config.n_embd + if hasattr(config, "n_embd") else config.hidden_size) + + # self.wte = nn.Embedding(config.vocab_size, hidden_size) + + audio_dim_out = ( + None # Set this variable according to the actual audio processor + ) + self.layer_idx = -2 + + if (isinstance(config.audio_processor, dict) + and config.audio_processor.get("name", None) == "cascades"): + encoder_config = config.audio_processor.get("config", None) + assert encoder_config is not None + self.encoder = ConformerEncoder(**encoder_config) + + audio_dim_out = encoder_config["attention_dim"] + n_mels = encoder_config["input_size"] + else: + raise NotImplementedError("") + + assert (audio_dim_out + is not None), "Remember to set values for audio_dim_out" + self.audio_dim_out = audio_dim_out + self.audio_dim_in = n_mels + + self.freeze_audio_processor = kwargs.get("freeze_audio_processor", + False) + + self.downsample_rate = kwargs.get("downsample_rate", 1) + + if kwargs.get("use_qformer", False): + qformer_config = kwargs.get("qformer_config", {}) + qformer_config["attention_dim"] = audio_dim_out + self.qformer = WindowQformer(**qformer_config) + else: + self.qformer = None + + if kwargs.get("use_conv_downsample", False): + assert (self.qformer is None + ), "don't support use qformer and conv downsample together" + nemo_conv_settings = kwargs.get("nemo_conv_settings", {}) + default_nemo_conv_settings = { + "subsampling": "dw_striding", + "subsampling_factor": self.downsample_rate, + "feat_in": audio_dim_out, + "feat_out": audio_dim_out, + "conv_channels": 256, + "subsampling_conv_chunking_factor": 1, + "activation": nn.ReLU(), + "is_causal": False, + } + # Override any of the defaults with the incoming, user settings + if nemo_conv_settings: + default_nemo_conv_settings.update(nemo_conv_settings) + for i in ["subsampling_factor", "feat_in", "feat_out"]: + assert ( + i not in nemo_conv_settings + ), "{i} should be specified outside of the NeMo dictionary" + + self.conv_ds = NemoConvSubsampling(**default_nemo_conv_settings, ) + else: + self.conv_ds = None + + projection_cls = kwargs.get("projection_cls", "linear") + if projection_cls == "linear": + self.audio_projection = nn.Linear(audio_dim_out, hidden_size) + elif projection_cls == "mlp": + # follow llava-v1.5's implementation + # (do not use image_projection and image_proj_norm) + dim_projection = hidden_size + depth = 2 + self.linear_downsample_rate = (1 if (self.qformer or self.conv_ds) + else self.downsample_rate) + layers = [ + nn.Linear(audio_dim_out * self.linear_downsample_rate, + dim_projection) + ] + for _ in range(1, depth): + layers.extend( + [nn.GELU(), + nn.Linear(dim_projection, dim_projection)]) + self.audio_projection = nn.Sequential(*layers) + # NOTE vision-speech tasks use a separate projection layer + layers = [ + nn.Linear(audio_dim_out * self.linear_downsample_rate, + dim_projection) + ] + for _ in range(1, depth): + layers.extend( + [nn.GELU(), + nn.Linear(dim_projection, dim_projection)]) + self.audio_projection_for_vision = nn.Sequential(*layers) + else: + raise NotImplementedError( + f"projection_cls = {projection_cls}, not implemented") + + # TODO: audio sequence compression - Qformer + self.vocab_size = config.vocab_size + self.input_embeds = None + self.audio_embed_sizes = None + + def set_audio_embeds(self, input_embeds: torch.FloatTensor) -> None: + self.input_embeds = input_embeds + + def set_audio_embed_sizes(self, + audio_embed_sizes: torch.LongTensor) -> None: + self.audio_embed_sizes = audio_embed_sizes + + def get_audio_features( + self, + input_embeds: torch.FloatTensor, + audio_attention_mask: torch.Tensor = None, + audio_projection_mode: str = "speech", + ): + + if self.freeze_audio_processor: + with torch.no_grad(): + audio_features, masks = self.encoder(input_embeds, + audio_attention_mask) + else: + audio_features, masks = self.encoder(input_embeds, + audio_attention_mask) + + if self.qformer is not None: + audio_features, _ = self.qformer(audio_features, mask=None) + + if self.conv_ds is not None: + if masks is not None: + masks = masks.squeeze(1) + + audio_features, masks = self.conv_ds(audio_features, mask=masks) + + if self.linear_downsample_rate != 1: + bs, seq_len, feat_dim = audio_features.size() + padding = seq_len % self.linear_downsample_rate + if padding > 0: + audio_features = F.pad( + audio_features, + (0, 0, 0, self.linear_downsample_rate - padding), + "constant", + 0, + ) + + seq_len = audio_features.size(1) + audio_features = audio_features.view( + bs, + seq_len // self.linear_downsample_rate, + feat_dim * self.linear_downsample_rate, + ) + + if audio_projection_mode == 'speech': + audio_set_tensor = self.audio_projection(audio_features) + elif audio_projection_mode == 'vision': + audio_set_tensor = self.audio_projection_for_vision(audio_features) + else: + raise ValueError( + f"audio_projection_mode = {audio_projection_mode} not "\ + "implemented" + ) + + return audio_set_tensor + + def forward( + self, + input_ids: torch.LongTensor, + input_embeds: torch.FloatTensor, + audio_embed_sizes, + **kwargs, + ) -> torch.FloatTensor: + """ + arguments: + input_ids: input text ids (B, U) + input_embeds: audio features (B, T, D) B: num audios in a sequence + """ + assert input_embeds is not None and len(input_embeds) == len( + audio_embed_sizes) + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + with torch.no_grad(): + positions = (input_ids == _AUDIO_PLACEHOLDER_TOKEN_ID).nonzero( + as_tuple=False) + + if not isinstance(input_embeds, list): + input_embeds = [input_embeds] + + audio_projection_mode = kwargs.get("audio_projection_mode", "speech") + audio_set_tensor = [ + self.get_audio_features( + input_embed, audio_projection_mode=audio_projection_mode) + for input_embed in input_embeds + ] + + with torch.no_grad(): + input_ids.clamp_min_(0).clamp_max_(self.vocab_size) + + if "wte" in kwargs: + # we use the token embedding layer from the huggingface model, this + # is REQUIRED to make sure we are using the loaded weights. + hidden_states = kwargs["wte"](input_ids) + else: + # otherwise, we use token embedding in pretrained mixformer from + # phi team + hidden_states = self.wte(input_ids) + + if len(positions.tolist()) > 0: + assert sum(audio_embed_sizes) == len( + positions + ), "please ensure the encoder outputs have the same length as"\ + " defined in input_ids!" + idx = 0 + for i in range(len(audio_embed_sizes)): + cnt = audio_embed_sizes[i] + assert audio_set_tensor[i].shape[0] == 1 + hidden_states[ + positions[idx, 0], + positions[idx, 1]:positions[idx, 1] + cnt, + ] = (audio_set_tensor[i][0, :audio_embed_sizes[i], :].to( + hidden_states.dtype).to(hidden_states.device)) + idx += cnt + + return hidden_states diff --git a/vllm/model_executor/models/phi4mm_utils.py b/vllm/model_executor/models/phi4mm_utils.py new file mode 100644 index 0000000..9f08a1c --- /dev/null +++ b/vllm/model_executor/models/phi4mm_utils.py @@ -0,0 +1,1883 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +# Code copied from Microsoft/MoE by Jacob Platin (jacobplatin@microsoft.com) +# but implemented by the Phi-Speech team +#!/usr/bin/env python3 +import math +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + + +class Block(nn.Module): + """Block abstract module""" + + def __init__(self, input_size, output_size): + super().__init__() + self.input_size = input_size + self.output_size = output_size + + +def get_activation(name="relu"): + """Select an activation function by name + + Args: + name: str + activation function name, + one of ["relu", "gelu", "swish", "sigmoid"], + default "relu". + """ + name = name.lower() + if name == "relu": + return nn.ReLU(inplace=True) + if name == "gelu": + return nn.GELU() + if name == "swish": + return Swish() + if name == "sigmoid": + return torch.nn.Sigmoid() + return nn.Identity() + + +def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0): + """ + The function is very important for Transformer Transducer Streaming mode + Args: + xs_len (int): sequence length + chunk_start_idx (list): first idx of each chunk, such as [0,18,36,48]. + It also supports adaptive chunk size [0,10,15,45] + left_window (int): how many left chunks can be seen + right_window (int): how many right chunks can be seen. It is used for + chunk overlap model. + Returns: + mask (torch.Tensor): a mask tensor for streaming model + Torch 1.0.1 + tensor([[1., 1., 0., 0.], + [0., 1., 1., 0.], + [0., 0., 1., 1.]]) + Torch 1.4.1 + tensor([[True., True., False., False.], + [False., True., True., False.], + [False., False., True., True.]]) + """ + chunk_start_idx = torch.Tensor(chunk_start_idx).long( + ) # first idx of each chunk, such as [0,18,36,48]. + start_pad = torch.nn.functional.pad( + chunk_start_idx, + (1, 0)) # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48] + end_pad = torch.nn.functional.pad( + chunk_start_idx, (0, 1), value=x_len + ) # append x_len to the end, so it becomes [0,18,36,48, x_len] + seq_range = torch.arange(0, + x_len).unsqueeze(-1) # seq_range size: [x_len, 1] + idx = ((seq_range < end_pad) & + (seq_range >= start_pad)).nonzero()[:, 1] # idx size: [x_len] + # boundary = end_pad[idx] # boundary size: [x_len] + seq_range_expand = (torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1) + ) # seq_range_expand size [x_len, x_len] + idx_left = idx - left_window + idx_left[idx_left < 0] = 0 + boundary_left = start_pad[idx_left] + mask_left = seq_range_expand >= boundary_left.unsqueeze(-1) + idx_right = idx + right_window + idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx) + boundary_right = end_pad[idx_right] + mask_right = seq_range_expand < boundary_right.unsqueeze(-1) + return mask_left & mask_right + + +class Swish(nn.Module): + """Implement Swish activation module. + From https://arxiv.org/pdf/2005.03191.pdf + + """ + + def __init__(self) -> None: + super().__init__() + self.act_fn = nn.Sigmoid() + + def forward(self, x: Tensor) -> Tensor: + """Apply Swish function + + Args: + x: torch.Tensor + Input. + """ + return x * self.act_fn(x) + + +class GLU(nn.Module): + """Implement Gated Linear Unit (GLU) module""" + + def __init__(self, dim: int = -1, act_name: str = "sigmoid") -> None: + super().__init__() + self.dim = dim + self.act_name = act_name.lower() + + if self.act_name == "relu": + self.act_fn = nn.ReLU(inplace=True) + elif self.act_name == "gelu": + self.act_fn = nn.GELU() + elif self.act_name == "swish": + self.act_fn = Swish() + elif self.act_name == "sigmoid": + self.act_fn = nn.Sigmoid() + else: + self.act_fn = nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + """GLU forward + Apply Swish function on the first half of input matrices + with sigmoid of the second half. + + Args: + x: torch.Tensor + Input. + + """ + half_x, gate = x.chunk(2, dim=self.dim) + return half_x * self.act_fn(gate) + + +# TODO: Abdel, this can be improved using GLU module +class GLUPointWiseConv(nn.Module): + """GLUPointWiseConv module + used for conformer architecture, + for more details see: + https://arxiv.org/pdf/2005.08100v1.pdf + + Args: + input_dim: int + input channel size. + output_dim: int + output channel size. + kernel_size: int + kernel size + glu_type: str, optional + activation function one of + ["sigmoid", "relu", "gelu"] + default "sigmoid". + bias_in_glu: bool, optional + use addtive bias in glu + causal: bool, optional + if set to True, padding is set to the half of + kernel size, ie, convolution can't see future frames. + default False. + + """ + + def __init__( + self, + input_dim, + output_dim, + kernel_size, + glu_type="sigmoid", + bias_in_glu=True, + causal=False, + ): + super().__init__() + + self.glu_type = glu_type + self.output_dim = output_dim + self.bias_in_glu = bias_in_glu + if causal: + self.ext_pw_conv_1d = nn.Conv1d( + input_dim, + output_dim * 2, + kernel_size, + 1, + padding=(kernel_size - 1), + ) + else: + self.ext_pw_conv_1d = nn.Conv1d( + input_dim, + output_dim * 2, + kernel_size, + 1, + padding=(kernel_size - 1) // 2, + ) + + if glu_type == "sigmoid": + self.glu_act = nn.Sigmoid() + elif glu_type == "relu": + self.glu_act = nn.ReLU() + elif glu_type == "gelu": + self.glu_act = nn.GELU() + elif glu_type == "swish": + self.glu_act = Swish() + else: + raise ValueError(f"Unsupported activation type {self.glu_act}") + + if bias_in_glu: + self.b1 = nn.Parameter(torch.zeros(1, output_dim, 1)) + self.b2 = nn.Parameter(torch.zeros(1, output_dim, 1)) + + def forward(self, x): + """ + Args: + x: torch.Tensor + input tensor + """ + # to be consistent with GLULinear, we assume the input always has the + # #channel (#dim) in the last dimension of the tensor, so need to + # switch the dimension first for 1D-Conv case + x = x.permute([0, 2, 1]) + x = self.ext_pw_conv_1d(x) + if self.glu_type == "bilinear": + if self.bias_in_glu: + x = (x[:, 0:self.output_dim, :] + self.b1) * ( + x[:, self.output_dim:self.output_dim * 2, :] + self.b2) + else: + x = (x[:, 0:self.output_dim, :]) * ( + x[:, self.output_dim:self.output_dim * 2, :]) + else: + if self.bias_in_glu: + x = (x[:, 0:self.output_dim, :] + self.b1) * self.glu_act( + x[:, self.output_dim:self.output_dim * 2, :] + self.b2) + else: + x = (x[:, 0:self.output_dim, :]) * self.glu_act( + x[:, self.output_dim:self.output_dim * 2, :]) + + x = x.permute([0, 2, 1]) + return x + + +class DepthWiseSeperableConv1d(nn.Module): + """DepthWiseSeperableConv1d module used in Convnet module + for the conformer, for more details see: + https://arxiv.org/pdf/2005.08100v1.pdf + + Args: + input_dim: int + input channel size. + depthwise_seperable_out_channel: int + if set different to 0, the number of + depthwise_seperable_out_channel will be used as a channel_out + of the second conv1d layer. + otherwise, it equal to 0, the second conv1d layer is skipped. + kernel_size: int + kernel_size + depthwise_multiplier: int + number of input_dim channels duplication. this value + will be used to compute the hidden channels of the Conv1D. + padding: int, optional + padding for the conv1d, + default: 0. + + """ + + def __init__( + self, + input_dim, + depthwise_seperable_out_channel, + kernel_size, + depthwise_multiplier, + padding=0, + ): + super().__init__() + + self.dw_conv = nn.Conv1d( + input_dim, + input_dim * depthwise_multiplier, + kernel_size, + 1, + padding=padding, + groups=input_dim, + ) + + if depthwise_seperable_out_channel != 0: + self.pw_conv = nn.Conv1d( + input_dim * depthwise_multiplier, + depthwise_seperable_out_channel, + 1, + 1, + 0, + ) + else: + self.pw_conv = nn.Identity() + self.depthwise_seperable_out_channel = depthwise_seperable_out_channel + + def forward(self, x): + """ + + Args: + x: torch.Tensor + input tensor + """ + x = self.dw_conv(x) + if self.depthwise_seperable_out_channel != 0: + x = self.pw_conv(x) + return x + + +class ConvModule(nn.Module): + """ConvModule Module for the conformer block. + for more details see: + https://arxiv.org/pdf/2005.08100v1.pdf + + Args: + input_dim: int + input channel size. + ext_pw_out_channel: int + if > 0, ext_pw_out_channel is a dim channel size + for the last pointwise conv after swish activation. + depthwise_seperable_out_channel: int + if set different to 0, the number of + depthwise_seperable_out_channel + will be used as a channel_out of the second conv1d layer. + otherwise, it equal to 0, the second conv1d layer is skipped. + ext_pw_kernel_size: int + kernel size of the conv pointwise of the conformer. + kernel_size: int + kernel size. + depthwise_multiplier: int + number of input_dim channels duplication. this value + will be used to compute the hidden channels of the Conv1D. + dropout_rate: float + dropout rate. + causal: bool, optional + if set to True, convolution have no access + to future frames. default False. + batch_norm: bool, optional + if set to True, apply batchnorm before activation. + default False + chunk_se: int, optional + 0 for offline SE. + 1 for streaming SE, where mean is computed + by accumulated history until current chunk_se. + 2 for streaming SE, where mean is computed + by only the current chunk. + chunk_size: int, optional + chunk size for cnn. default 18 + activation: str, optional + activation function used in ConvModule, + default: "relu". + glu_type: str, optional + activation function used for the glu, + default: "sigmoid". + bias_in_glu: bool, optional + if set to True, use additive bias in the weight module + before GLU. + linear_glu_in_convm: bool, optional + if set to True, use GLULinear module, + otherwise, used GLUPointWiseConv module. + default to False. + export: bool, optional, + if set to True, padding is equal to 0. This is for inference, + or onnx export. Typically this is set by the export program or + the decoder program, and it isn't present in your config file. + default False + """ + + def __init__( + self, + input_dim, + ext_pw_out_channel, + depthwise_seperable_out_channel, + ext_pw_kernel_size, + kernel_size, + depthwise_multiplier, + dropout_rate, + causal=False, + batch_norm=False, + chunk_se=0, + chunk_size=18, + activation="relu", + glu_type="sigmoid", + bias_in_glu=True, + linear_glu_in_convm=False, + export=False, + ): + super().__init__() + self.layer_norm = nn.LayerNorm(input_dim) + self.input_dim = input_dim + self.ext_pw_out_channel = ext_pw_out_channel + self.ext_pw_kernel_size = ext_pw_kernel_size + self.depthwise_seperable_out_channel = depthwise_seperable_out_channel + self.glu_type = glu_type + self.bias_in_glu = bias_in_glu + self.linear_glu_in_convm = linear_glu_in_convm + self.causal = causal + + self._add_ext_pw_layer() + + self.batch_norm = batch_norm + self.kernel_size = kernel_size + + if batch_norm: + self.bn_layer = nn.BatchNorm1d(input_dim) + + self.act = get_activation(activation) + self.dropout = nn.Dropout(dropout_rate) + self.export = export + + if causal: + padding = 0 if export else kernel_size - 1 + else: + padding = (kernel_size - 1) // 2 + + self.dw_sep_conv_1d = DepthWiseSeperableConv1d( + input_dim, + depthwise_seperable_out_channel, + kernel_size, + depthwise_multiplier, + padding=padding, + ) + + if depthwise_seperable_out_channel != 0: + if input_dim != depthwise_seperable_out_channel: + self.ln2 = nn.Linear(depthwise_seperable_out_channel, + input_dim) + else: + if depthwise_multiplier != 1: + self.ln2 = nn.Linear(input_dim * depthwise_multiplier, + input_dim) + + def _add_ext_pw_layer(self): + """ + This function is an extension of __init__ function + and dedicated to the convolution module creation + of the conformer. + """ + self.ln1 = self.glu = self.bn_layer = self.ext_pw_conv_1d = ( + nn.Identity()) # jit hacks. + self.squeeze_excitation = nn.Identity() # jit. + self.apply_ln1 = self.fix_len1 = False # jit. + + if self.ext_pw_out_channel != 0: + if self.causal: + self.ext_pw_conv_1d = nn.Conv1d( + self.input_dim, + self.ext_pw_out_channel, + self.ext_pw_kernel_size, + 1, + padding=(self.ext_pw_kernel_size - 1), + ) + if self.ext_pw_kernel_size > 1: + self.fix_len1 = True + else: + self.fix_len1 = False + else: + self.ext_pw_conv_1d = nn.Conv1d( + self.input_dim, + self.ext_pw_out_channel, + self.ext_pw_kernel_size, + 1, + padding=(self.ext_pw_kernel_size - 1) // 2, + ) + self.fix_len1 = False + + if self.linear_glu_in_convm: + self.glu = GLULinear( + self.input_dim, + self.ext_pw_out_channel, + self.glu_type, + self.bias_in_glu, + ) + else: + self.glu = GLUPointWiseConv( + self.input_dim, + self.ext_pw_out_channel, + self.ext_pw_kernel_size, + self.glu_type, + self.bias_in_glu, + self.causal, + ) + + if self.input_dim != self.ext_pw_out_channel: + self.apply_ln1 = True + self.ln1 = nn.Linear(self.ext_pw_out_channel, self.input_dim) + else: + self.apply_ln1 = False + else: + self.pw_conv_simplify_w = torch.nn.Parameter(torch.ones(3)) + self.pw_conv_simplify_b = torch.nn.Parameter(torch.zeros(3)) + + def forward(self, x): + """ConvModule Forward. + + Args: + x: torch.Tensor + input tensor. + """ + x = self.layer_norm(x) + + if self.ext_pw_out_channel != 0: + x = self.glu(x) + if self.causal and self.ext_pw_kernel_size > 1: + x = x[:, :-(self.ext_pw_kernel_size - 1), :] + if self.apply_ln1: + x = self.ln1(x) + else: + x_0 = x * self.pw_conv_simplify_w[0] + self.pw_conv_simplify_b[0] + x_1 = x * self.pw_conv_simplify_w[1] + self.pw_conv_simplify_b[1] + x = x_0 + x_1 + + x = x.permute([0, 2, 1]) + + x = self.dw_sep_conv_1d(x) + if self.causal and self.kernel_size > 1: + x = x[:, :, :-(self.kernel_size - 1)] + if hasattr(self, "ln2"): + x = x.permute([0, 2, 1]) + x = self.ln2(x) + x = x.permute([0, 2, 1]) + if self.batch_norm: + x = self.bn_layer(x) + x = self.act(x) + + if self.ext_pw_out_channel != 0: + x = self.ext_pw_conv_1d(x) + if self.fix_len1: + x = x[:, :, :-(self.ext_pw_kernel_size - 1)] + + if self.apply_ln1: + x = x.permute([0, 2, 1]) + x = self.ln1(x) + x = x.permute([0, 2, 1]) + + x = x.permute([0, 2, 1]) + else: + x = x.unsqueeze(1).permute([0, 1, 3, 2]) + x = x * self.pw_conv_simplify_w[2] + self.pw_conv_simplify_b[2] + x = x.squeeze(1) + + x = self.dropout(x) + return x + + +class GLULinear(nn.Module): + """Linear + GLU module + + Args: + input_dim: int + input size + output_dim: int + output size. + glu_type: + activation function name used in glu module. + default "sigmoid" (swish function). + bias_in_glu: bool, optional + If True, the addtive bias is added. Default False. + """ + + def __init__( + self, + input_dim, + output_dim, + glu_type="sigmoid", + bias_in_glu=True, + ): + super().__init__() + self.linear = nn.Linear(input_dim, output_dim * 2, bias_in_glu) + self.glu_act = GLU(-1, glu_type) + + def forward(self, x): + """GLULinear forward + + Args: + x: torch.Tensor + inpute tensor. + """ + x = self.linear(x) + return self.glu_act(x) + + +class FeedForward(nn.Module): + """FeedForward Module. + For more details see Conformer paper: + https://arxiv.org/pdf/2005.08100.pdf + + Args: + d_model: int + input size. + d_inner: int + output size. + dropout_rate: float, + dropout rate. + activation: str, + activation function name, + one of ["relu", "swish", "sigmoid"], + sigmoid activation is only used with "glu_in_fnn=True", + default "sigmoid". + bias_in_glu: bool, optional + """ + + def __init__( + self, + d_model, + d_inner, + dropout_rate, + activation="sigmoid", + bias_in_glu=True, + ): + super().__init__() + self.d_model = d_model + self.d_inner = d_inner + + self.layer_norm = nn.LayerNorm(d_model) + module = GLULinear(d_model, d_inner, activation, bias_in_glu) + self.net = nn.Sequential( + module, + nn.Dropout(dropout_rate), + nn.Linear(d_inner, d_model), + nn.Dropout(dropout_rate), + ) + + def forward(self, x): + """FeedForward forward function. + + Args: + x: torch.Tensor + input tensor. + """ + out = self.net(self.layer_norm(x)) + + return out + + +#### positional encoding starts here +def _pre_hook( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, +): + """Perform pre-hook in load_state_dict for backward compatibility. + + Note: + We saved self.pe until v.0.5.2 but we have omitted it later. + Therefore, we remove the item "pe" from `state_dict` for backward + compatibility. + + """ + k = prefix + "pe" + if k in state_dict: + state_dict.pop(k) + + +class T5RelativeAttentionLogitBias(nn.Module): + """ + This module implements the relative position bias described in Section + 2.1 of the T5 paper: https://arxiv.org/pdf/1910.10683.pdf + + The Huggingface implementation is used as a reference + https://github.com/huggingface/transformers/blob/v4.30.0/src/ + transformers/models/t5/modeling_t5.py#L435 + + Modifies attention as Q*K^T + B, where B is a learned scalar bias based + on relative position of the query and key. It is HxNxN, where H is the + number of heads, N is the sequence length. + + I've made these modifications to the original T5 bias: + - Skipping of the bucketing step. Original T5 bias converted rel + position distances into logarithmically increasing buckets. This is + supposed to help with length generalization. + - I just directly use rel position index as bias values, as we don't + need length generalization (40s max is good enough for ASR encoder), + and it keeps ONNX export simple. + - I've also extended it so that biases can be asymmetric, the default + implementation treats L->R and R->L the same. Asymmetric was found to + yield better results in my experiments. + + Args: + num_heads: int + Number of attention heads + num_buckets: int + Number of buckets to use for relative attention bias. This is the + size of the learnable bias parameter. Bucketing is not yet + supported, so this defaults to -1 which means no bucketing is + used (max_distance determines size of bias param). + max_distance: int + Maximum distance to use for relative attention bias. With + num_buckets=-1, this directly controls the max size of the bias + parameter. When num_buckets > 0 is supported, this will control + the maximum distance for logarithmic bucketing after which all + positions are in the same bucket. + symmetric: bool + Whether to use symmetric or asymmetric biases. symmetric=False uses + 2x number of bias params to distinguish L->R from R->L. This was + found to be better for the encoder. + """ + + def __init__(self, + num_heads, + num_buckets=-1, + max_distance=1000, + symmetric=False): + super().__init__() + self.num_heads = num_heads + self.num_buckets = num_buckets + self.max_distance = max_distance + self.symmetric = symmetric + self._skip_bucketing = self.num_buckets < 0 + if self._skip_bucketing: + self.num_buckets = max_distance + else: + raise NotImplementedError( + "T5 attention bias with bucketed positions is not yet tested") + if not self.symmetric: + self.num_buckets *= 2 + self.bias_values = nn.Embedding(self.num_buckets, self.num_heads) + + def forward(self, x): + # instantiate bias compatible with shape of x + maxpos = x.size(1) + context_position = torch.arange(maxpos, + device=x.device, + dtype=torch.long)[:, None] + memory_position = torch.arange(maxpos, + device=x.device, + dtype=torch.long)[None, :] + relative_position = memory_position - context_position + # clipping to a maximum distance using ops that play well with ONNX + # export + relative_position = relative_position.masked_fill( + relative_position < -self.max_distance, -self.max_distance) + relative_position = relative_position.masked_fill( + relative_position > self.max_distance - 1, self.max_distance - 1) + + # mapping from relative position to index in the bias parameter + if self._skip_bucketing: + bias_idx = relative_position + else: + bias_idx = self._bucket_relative_position(relative_position) + if self.symmetric: + bias_idx = bias_idx.abs() + else: + bias_idx += self.num_buckets // 2 + + t5_rel_att_bias = self.bias_values(bias_idx) # [L, L, H] + t5_rel_att_bias = t5_rel_att_bias.permute(2, 0, 1).unsqueeze( + 0) # [1, H, L, L] + + return t5_rel_att_bias + + def _bucket_relative_position(self, relative_position): + # this is a placeholder (isn't tested, likely buggy) using HuggingFace + # implem as a reference this also needs to be extended to support + # asymmetric +/- ve positions + relative_buckets = 0 + if not self.causal: + self.num_buckets //= 2 + relative_buckets += (relative_position > 0).to( + torch.long) * self.num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, + torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = self.num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in + # positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) / + math.log(self.max_distance / max_exact) * + (self.num_buckets - max_exact)).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, + torch.full_like(relative_position_if_large, self.num_buckets - 1), + ) + + relative_buckets += torch.where(is_small, relative_position, + relative_position_if_large) + return relative_buckets + + +class AbsolutePositionalEncoding(nn.Module): + """Absolute Positional encoding module. + This module implement Absolute sinusoidal positional encoding + from: https://arxiv.org/pdf/1706.03762.pdf + + Args: + d_model: int + Input embedding size. + dropout_rate: float + dropout rate + max_len: int, optional + Maximum input length sequence, Default 5000 + + """ + + def __init__(self, d_model, dropout_rate, max_len=5000): + """Construct an PositionalEncoding object.""" + super().__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + self._register_load_state_dict_pre_hook(_pre_hook) + + def extend_pe(self, x): + """Reset the positional encodings. + + Args: + x: torch.Tensor + """ + if self.pe is not None and self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) * + -(math.log(10000.0) / self.d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor): + """Add positional encoding. + + Args: + x: torch.Tensor + Input tensor. shape is (batch, time, ...) + + Returns: + torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) + + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, :x.size(1)] + return self.dropout(x) + + +#### forward embedding layers starts here +class MeanVarianceNormLayer(nn.Module): + """Mean/variance normalization layer. + + Will subtract mean and multiply input by inverted standard deviation. + Typically used as a very first layer in a model. + + Args: + input_size: int + layer input size. + """ + + def __init__(self, input_size): + super().__init__() + self.input_size = input_size + self.global_mean = nn.Parameter(torch.zeros(input_size)) + self.global_invstd = nn.Parameter(torch.ones(input_size)) + + def forward(self, input_: Tensor) -> Tensor: + """MeanVarianceNormLayer Forward + + Args: + input_: torch.Tensor + input tensor. + """ + return (input_ - self.global_mean) * self.global_invstd + + +class CausalConv1D(nn.Conv1d): + """ + A causal version of nn.Conv1d where each step would have limited access to + locations on its right or left + All arguments are the same as nn.Conv1d except padding. + + If padding is set None, then paddings are set automatically to make it a + causal convolution where each location would not see any steps on its right. + + If padding is set as a list (size of 2), then padding[0] would be used as + left padding and padding[1] as right padding. + It would make it possible to control the number of steps to be accessible + on the right and left. + This mode is not supported when stride > 1. padding[0]+padding[1] should + be equal to (kernel_size - 1). + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: Union[str, int] = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + self.cache_drop_size = None + if padding is None: + self._left_padding = kernel_size - 1 + self._right_padding = stride - 1 + else: + if stride != 1 and padding != kernel_size - 1: + raise ValueError( + "No striding allowed for non-symmetric convolutions!") + if isinstance(padding, int): + self._left_padding = padding + self._right_padding = padding + elif (isinstance(padding, list) and len(padding) == 2 + and padding[0] + padding[1] == kernel_size - 1): + self._left_padding = padding[0] + self._right_padding = padding[1] + else: + raise ValueError(f"Invalid padding param: {padding}!") + + self._max_cache_len = self._left_padding + + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=0, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + def update_cache(self, x, cache=None): + if cache is None: + new_x = F.pad(x, pad=(self._left_padding, self._right_padding)) + next_cache = cache + else: + new_x = F.pad(x, pad=(0, self._right_padding)) + new_x = torch.cat([cache, new_x], dim=-1) + if self.cache_drop_size > 0: + next_cache = new_x[:, :, :-self.cache_drop_size] + else: + next_cache = new_x + next_cache = next_cache[:, :, -cache.size(-1):] + return new_x, next_cache + + def forward(self, x, cache=None): + x, cache = self.update_cache(x, cache=cache) + x = super().forward(x) + if cache is None: + return x + else: + return x, cache + + +class CausalConv2D(nn.Conv2d): + """ + A causal version of nn.Conv2d where each location in the 2D matrix would + have no access to locations on its right or down + All arguments are the same as nn.Conv2d except padding which should be + set as None + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: Union[str, int] = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + if padding is not None: + raise ValueError( + "Argument padding should be set to None for CausalConv2D.") + self._left_padding = kernel_size - 1 + self._right_padding = stride - 1 + + padding = 0 + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + device, + dtype, + ) + + def forward( + self, + x, + ): + x = F.pad( + x, + pad=(self._left_padding, self._right_padding, 0, 0), + ) + x = super().forward(x) + return x + + +class NemoConvSubsampling(torch.nn.Module): + """Convlutional subsampling module, taken from NeMo ASR + (https://github.com/NVIDIA/NeMo/blob/b367413645d5c72db3c2c96e46e95a + 34501479cf/nemo/collections/asr/parts/submodules/subsampling.py) + + Striding Subsampling: "Speech-Transformer: A No-Recurrence + Sequence-to-Sequence Model for Speech Recognition" by Linhao Dong + et al. (https://ieeexplore.ieee.org/document/8462506) + + + Compared with the EncoderConv2D (`input_layer: custom`), this is a + much simplified approach, and uses no LayerNorm and far fewer Conv2Ds. + Moreover, depthwise convolutions are used to reduce FLOPs, but the first + layer is kept as a regular convolution so as not to degrade accuracy. + + `Striding` and `dw_striding` are the same except that the latter uses + depthwise convolutions after the first layer, whereas the former does not. + + Args: + subsampling_factor (int): Time reduction factor + feat_in (int): size of the input features + feat_out (int): size of the output features + subsampling (str): The subsampling technique, choose from + {"striding", "dw-striding", "striding_conv1d", + "dw_striding_conv1d"} + conv_channels (int): Number of channels for the convolution layers, + default is 256. + subsampling_conv_chunking_factor (int): Input chunking factor which + can be -1 (no chunking) 1 (auto) or a power of 2. Default is 1 + activation (Module): activation function, default is nn.ReLU() + is_causal (bool): whether to use causal Conv1/2D, where each step will + have limited access to locations on its right or left + """ + + def __init__( + self, + feat_in, + feat_out, + subsampling_factor=4, + subsampling="dw_striding", + conv_channels=256, + subsampling_conv_chunking_factor=1, + activation=nn.ReLU(), # noqa: B008 + is_causal=False, + ): + super().__init__() + self._subsampling = subsampling + self._conv_channels = conv_channels + self._feat_in = feat_in + self._feat_out = feat_out + + if subsampling_factor % 2 != 0: + raise ValueError("Sampling factor should be a multiply of 2!") + self._sampling_num = int(math.log(subsampling_factor, 2)) + self.subsampling_factor = subsampling_factor + self.is_causal = is_causal + self.subsampling_causal_cond = subsampling in ( + "dw_striding", + "striding", + "striding_conv1d", + ) + + if (subsampling_conv_chunking_factor != -1 + and subsampling_conv_chunking_factor != 1 + and subsampling_conv_chunking_factor % 2 != 0): + raise ValueError( + "subsampling_conv_chunking_factor should be -1, 1, or a "\ + "power of 2" + ) + self.subsampling_conv_chunking_factor = \ + subsampling_conv_chunking_factor + + in_channels = 1 + layers = [] + + if subsampling == "dw_striding": + self._stride = 2 + self._kernel_size = 3 + self._ceil_mode = False + + if self.is_causal: + self._left_padding = self._kernel_size - 1 + self._right_padding = self._stride - 1 + self._max_cache_len = subsampling_factor + 1 + else: + self._left_padding = (self._kernel_size - 1) // 2 + self._right_padding = (self._kernel_size - 1) // 2 + self._max_cache_len = 0 + + # Layer 1 + if self.is_causal: + layers.append( + CausalConv2D( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=None, + )) + else: + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + )) + in_channels = conv_channels + layers.append(activation) + + for i in range(self._sampling_num - 1): + if self.is_causal: + layers.append( + CausalConv2D( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=None, + groups=in_channels, + )) + else: + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + groups=in_channels, + )) + + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=1, + stride=1, + padding=0, + groups=1, + )) + layers.append(activation) + in_channels = conv_channels + + elif subsampling == "striding": + self._stride = 2 + self._kernel_size = 3 + self._ceil_mode = False + + if self.is_causal: + self._left_padding = self._kernel_size - 1 + self._right_padding = self._stride - 1 + self._max_cache_len = subsampling_factor + 1 + else: + self._left_padding = (self._kernel_size - 1) // 2 + self._right_padding = (self._kernel_size - 1) // 2 + self._max_cache_len = 0 + + for i in range(self._sampling_num): + if self.is_causal: + layers.append( + CausalConv2D( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=None, + )) + else: + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + )) + layers.append(activation) + in_channels = conv_channels + + elif subsampling == "striding_conv1d": + in_channels = feat_in + + self._stride = 2 + self._kernel_size = 5 + self._ceil_mode = False + + if self.is_causal: + self._left_padding = self._kernel_size - 1 + self._right_padding = self._stride - 1 + self._max_cache_len = subsampling_factor + 1 + else: + self._left_padding = (self._kernel_size - 1) // 2 + self._right_padding = (self._kernel_size - 1) // 2 + self._max_cache_len = 0 + + for i in range(self._sampling_num): + if self.is_causal: + layers.append( + CausalConv1D( + in_channels=in_channels, + out_channels=(feat_out if self._sampling_num == i + + 1 else conv_channels), + kernel_size=self._kernel_size, + stride=self._stride, + padding=None, + )) + else: + layers.append( + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=(feat_out if self._sampling_num == i + + 1 else conv_channels), + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + )) + layers.append(activation) + in_channels = conv_channels + + elif subsampling == "dw_striding_conv1d": + in_channels = feat_in + + self._stride = 2 + self._kernel_size = 5 + self._ceil_mode = False + + self._left_padding = (self._kernel_size - 1) // 2 + self._right_padding = (self._kernel_size - 1) // 2 + + # Layer 1 + layers.extend([ + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + groups=in_channels, + ), + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=(feat_out if self._sampling_num == 1 else + conv_channels), + kernel_size=1, + stride=1, + padding=0, + groups=1, + ), + ]) + in_channels = conv_channels + layers.append(activation) + + for i in range(self._sampling_num - 1): + layers.extend([ + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + groups=in_channels, + ), + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=(feat_out if self._sampling_num == i + + 2 else conv_channels), + kernel_size=1, + stride=1, + padding=0, + groups=1, + ), + ]) + layers.append(activation) + in_channels = conv_channels + + else: + raise ValueError(f"Not valid sub-sampling: {subsampling}!") + + if subsampling in ["dw_striding", "striding"]: + in_length = torch.tensor(feat_in, dtype=torch.float) + out_length = calc_length( + lengths=in_length, + all_paddings=self._left_padding + self._right_padding, + kernel_size=self._kernel_size, + stride=self._stride, + ceil_mode=self._ceil_mode, + repeat_num=self._sampling_num, + ) + self.out = torch.nn.Linear(conv_channels * int(out_length), + feat_out) + self.conv2d_subsampling = True + elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]: + self.out = None + self.conv2d_subsampling = False + else: + raise ValueError(f"Not valid sub-sampling: {subsampling}!") + + self.conv = torch.nn.Sequential(*layers) + + def get_sampling_frames(self): + return [1, self.subsampling_factor] + + def get_streaming_cache_size(self): + return [0, self.subsampling_factor + 1] + + def forward(self, x, mask): + """ + Forward method for NeMo subsampling. + + Args: + x[Batch, Time, Filters]: torch.Tensor + input tensor + x_mask: torch.Tensor + input mask + + Returns: + x: torch.Tensor + Resulting tensor from subsampling (B, T // + time_reduction_factor, feat_out) + pad_mask: torch.Tensor + tensor of padded hidden state sequences (B, 1, T // + time_reduction_factor) + """ + x = x.unsqueeze(1) if self.conv2d_subsampling else x.transpose(1, 2) + + # split inputs if chunking_factor is set + if (self.subsampling_conv_chunking_factor != -1 + and self.conv2d_subsampling): + if self.subsampling_conv_chunking_factor == 1: + # if subsampling_conv_chunking_factor is 1, we split only + # if needed. + # avoiding a bug / feature limiting indexing of tensors + # to 2**31. + # see https://github.com/pytorch/pytorch/issues/80020 + x_ceil = (2**31 / self._conv_channels * self._stride * + self._stride) + need_to_split = torch.numel(x) > x_ceil + else: + # if subsampling_conv_chunking_factor > 1 we always split + need_to_split = True + + if need_to_split: + x, success = self.conv_split_by_batch(x) + if not success: # if unable to split by batch, try by channel + if self._subsampling == "dw_striding": + x = self.conv_split_by_channel(x) + else: + x = self.conv(x) # try anyway + else: + x = self.conv(x) + else: + x = self.conv(x) + + # Flatten Channel and Frequency Axes + if self.conv2d_subsampling: + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).reshape(b, t, -1)) + # Transpose to Channel Last mode + else: + x = x.transpose(1, 2) + + if mask is None: + return x, None + + max_audio_length = x.shape[1] + feature_lens = mask.sum(1) + padding_length = torch.ceil(feature_lens / self.subsampling_factor) + if self.is_causal and self.subsampling_causal_cond: + feature_lens_remainder = feature_lens % self.subsampling_factor + padding_length[feature_lens_remainder != 1] += 1 + pad_mask = torch.arange(0, max_audio_length, device=x.device).expand( + padding_length.size(0), -1) < padding_length.unsqueeze(1) + return x, pad_mask.unsqueeze(1) + + def reset_parameters(self): + # initialize weights + if self._subsampling == "dw_striding": + with torch.no_grad(): + # init conv + scale = 1.0 / self._kernel_size + dw_max = (self._kernel_size**2)**-0.5 + pw_max = self._conv_channels**-0.5 + + torch.nn.init.uniform_(self.conv[0].weight, -scale, scale) + torch.nn.init.uniform_(self.conv[0].bias, -scale, scale) + + for idx in range(2, len(self.conv), 3): + torch.nn.init.uniform_(self.conv[idx].weight, -dw_max, + dw_max) + torch.nn.init.uniform_(self.conv[idx].bias, -dw_max, + dw_max) + torch.nn.init.uniform_(self.conv[idx + 1].weight, -pw_max, + pw_max) + torch.nn.init.uniform_(self.conv[idx + 1].bias, -pw_max, + pw_max) + + # init fc (80 * 64 = 5120 from https://github.com/kssteven418/ + # Squeezeformer/blob/13c97d6cf92f2844d2cb3142b4c5bfa9ad1a8951/ + # src/models/conformer_encoder.py#L487 + fc_scale = (self._feat_out * self._feat_in / + self._sampling_num)**-0.5 + torch.nn.init.uniform_(self.out.weight, -fc_scale, fc_scale) + torch.nn.init.uniform_(self.out.bias, -fc_scale, fc_scale) + + def conv_split_by_batch(self, x): + """Tries to split input by batch, run conv and concat results""" + b, _, _, _ = x.size() + if b == 1: # can't split if batch size is 1 + return x, False + + if self.subsampling_conv_chunking_factor > 1: + cf = self.subsampling_conv_chunking_factor + else: + # avoiding a bug / feature limiting indexing of tensors to 2**31 + # see https://github.com/pytorch/pytorch/issues/80020 + x_ceil = 2**31 / self._conv_channels * self._stride * self._stride + p = math.ceil(math.log(torch.numel(x) / x_ceil, 2)) + cf = 2**p + + new_batch_size = b // cf + if new_batch_size == 0: # input is too big + return x, False + + return ( + torch.cat([ + self.conv(chunk) + for chunk in torch.split(x, new_batch_size, 0) + ]), + True, + ) + + def conv_split_by_channel(self, x): + """For dw convs, tries to split input by time, run conv and concat + results""" + x = self.conv[0](x) # full conv2D + x = self.conv[1](x) # activation + + for i in range(self._sampling_num - 1): + _, c, t, _ = x.size() + + if self.subsampling_conv_chunking_factor > 1: + cf = self.subsampling_conv_chunking_factor + else: + # avoiding a bug / feature limiting indexing of tensors + # to 2**31 + # see https://github.com/pytorch/pytorch/issues/80020 + p = math.ceil(math.log(torch.numel(x) / 2**31, 2)) + cf = 2**p + + new_c = int(c // cf) + if new_c == 0: + new_c = 1 + + new_t = int(t // cf) + if new_t == 0: + new_t = 1 + + x = self.channel_chunked_conv(self.conv[i * 3 + 2], new_c, + x) # conv2D, depthwise + + # splitting pointwise convs by time + x = torch.cat( + [ + self.conv[i * 3 + 3](chunk) + for chunk in torch.split(x, new_t, 2) + ], + 2, + ) # conv2D, pointwise + x = self.conv[i * 3 + 4](x) # activation + return x + + def channel_chunked_conv(self, conv, chunk_size, x): + """Performs channel chunked convolution""" + + ind = 0 + out_chunks = [] + for chunk in torch.split(x, chunk_size, 1): + step = chunk.size()[1] + + if self.is_causal: + chunk = nn.functional.pad( + chunk, + pad=( + self._kernel_size - 1, + self._stride - 1, + self._kernel_size - 1, + self._stride - 1, + ), + ) + ch_out = nn.functional.conv2d( + chunk, + conv.weight[ind:ind + step, :, :, :], + bias=conv.bias[ind:ind + step], + stride=self._stride, + padding=0, + groups=step, + ) + else: + ch_out = nn.functional.conv2d( + chunk, + conv.weight[ind:ind + step, :, :, :], + bias=conv.bias[ind:ind + step], + stride=self._stride, + padding=self._left_padding, + groups=step, + ) + out_chunks.append(ch_out) + ind += step + + return torch.cat(out_chunks, 1) + + def change_subsampling_conv_chunking_factor( + self, subsampling_conv_chunking_factor: int): + if (subsampling_conv_chunking_factor != -1 + and subsampling_conv_chunking_factor != 1 + and subsampling_conv_chunking_factor % 2 != 0): + raise ValueError( + "subsampling_conv_chunking_factor should be -1, 1, or a "\ + "power of 2" + ) + self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor + + +def calc_length(lengths, + all_paddings, + kernel_size, + stride, + ceil_mode, + repeat_num=1): + """Calculates the output length of a Tensor passed through a convolution or + max pooling layer""" + add_pad: float = all_paddings - kernel_size + one: float = 1.0 + for i in range(repeat_num): + lengths = (torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + + one) + lengths = torch.ceil(lengths) if ceil_mode else torch.floor(lengths) + return lengths.to(dtype=torch.int) + + +#### multihead attention starts here +class AttModule(nn.Module): + """Attention abstraction module""" + + def __init__(self): + super().__init__() + self.export_mode = False + + def set_export(self, mode=True): + """set the export mode""" + self.export_mode = mode + + def forward( + self, + x: Tensor, + memory: Optional[Tensor] = None, + pos_emb: Optional[Tensor] = None, + att_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + """AttModule forward + + Args: + x: torch.Tensor + input tensor. + memory: torch.Tensor, optional + memory tensor. + pos_emb: torch.Tensor, optional + positional encoder embedding. + att_mask: torch.Tensor, optional + attention mask tensor. + """ + return x, memory, pos_emb, att_mask + + +class AttBlock(Block, AttModule): + """Attention Block module to support both Attention and Block module.""" + + def memory_dims(self, max_len=False): + """memory dimensions""" + return (1, self.input_size) + + +def masked_softmax( + scores, + mask: Optional[Tensor], +): + if mask is not None: + mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2) + scores = scores.masked_fill(mask, -torch.inf) + attn = torch.softmax(scores, dim=-1).masked_fill( + mask, 0.0) # (batch, head, time1, time2) + else: + attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + return attn + + +class MultiHeadedAttention(nn.Module): + """Multi-Head Attention layer with optional relative position embedding + and GLU. + + Args: + n_head: int + the number of heads. + n_feat: int + input size features. + dropout_rate: float + dropout rate. + use_LN: bool + apply layer norm or not + dropout_at_output: bool + whether to apply dropout at output + attention_inner_dim: int, optional + the attention dimension used in the class, + it can be different from the input dimension n_feat. + default: -1 (equal to n_feat). + use_pt_scaled_dot_product_attention: bool, optional + if set True, use pytorch scaled dot product attention in training. + NOTE: this will NOT be used in ONNX decoding due to a lack of + support. In that case, we use the original attention + implementation, which shows no regression. + default: False. + n_value: int, optional + if set to values other than -1, use a different dimension for + value. With the default value (i.e. -1), it is backward compatible. + group_size: int, optional. must divide `n_head` + if group_size > 1: GQA + if group_size = 1: MHA + if group_size = n_head: MQA + """ + + inv_sqrt_d_k: torch.jit.Final[float] + h: torch.jit.Final[int] + h_k: torch.jit.Final[int] + g: torch.jit.Final[int] + + def __init__( + self, + n_head, + n_feat, + dropout_rate, + attention_inner_dim=-1, + glu_type="swish", + bias_in_glu=True, + use_pt_scaled_dot_product_attention=False, + n_value=-1, + group_size: int = 1, + ): + super().__init__() + if n_value == -1: + n_value = n_feat + if attention_inner_dim == -1: + attention_inner_dim = n_feat + assert attention_inner_dim % n_head == 0 + + # We assume d_v always equals d_k + self.d_k = attention_inner_dim // n_head + self.inv_sqrt_d_k = 1.0 / math.sqrt(self.d_k) + self.h = n_head + assert n_head % group_size == 0, "group_size must divide n_head" + self.g = group_size + self.h_k = n_head // group_size + + self.linear_q = nn.Linear(n_feat, attention_inner_dim) + self.linear_k = nn.Linear(n_feat, attention_inner_dim // group_size) + self.linear_v = nn.Linear(n_value, attention_inner_dim // group_size) + self.linear_out = nn.Linear(attention_inner_dim // group_size, n_value) + + self.attn = torch.jit.Attribute(None, Optional[Tensor]) + self.dropout = nn.Dropout(p=dropout_rate) + self.dropout_rate = dropout_rate + self.use_pt_scaled_dot_product_attention = ( + use_pt_scaled_dot_product_attention) + + if use_pt_scaled_dot_product_attention and group_size > 1: + raise ValueError("Cannot use PT Scaled Attention with GQA") + + # Torchscript eager quantization. Note that these functions below are + # NOOPs and have very little impact on performance unless quantization + # is enabled. + self.quant_q = torch.ao.quantization.QuantStub() + self.quant_x = torch.ao.quantization.QuantStub() + self.dequant = torch.ao.quantization.DeQuantStub() + self.ffunc = torch.ao.nn.quantized.FloatFunctional() + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_k: Tensor, + pos_v: Tensor, + mask: Optional[Tensor], + relative_attention_bias: Optional[Tensor] = None, + ): + """Compute 'Scaled Dot Product Attention'. + + Args: + query: torch.Tensor + query tensor (batch, time1, size) + key: torch.Tensor + key tensor (batch, time2, size) + value: torch.Tensor + value tensor (batch, time1, size) + pos_k: torch.Tensor + key tensor used for relative positional embedding. + pos_v: torch.Tensor + value tensor used for relative positional embedding. + mask: torch.Tensor + mask tensor (batch, time1, time2) + relative_attention_bias: torch.Tensor + bias added to attention logits w.r.t. relative positions + (1, n_head, time1, time2) + """ + n_batch = query.size(0) + + q = self.linear_q(query).view(n_batch, -1, self.h, + self.d_k) # (b, t, d) + k = self.linear_k(key).view(n_batch, -1, self.h_k, + self.d_k) # (b, t, d) + v = self.linear_v(value).view(n_batch, -1, self.h_k, self.d_k) + q = (q.transpose(1, 2) if self.use_pt_scaled_dot_product_attention + and not torch.jit.is_scripting() else q.transpose(1, 2) * + self.inv_sqrt_d_k) + k = k.transpose(1, 2) # (batch, head_k, time2, d_k) + v = v.transpose(1, 2) # (batch, head_k, time2, d_k) + + if (self.use_pt_scaled_dot_product_attention + and not torch.jit.is_scripting()): + attn_mask = None + if mask is not None: + mask = mask.unsqueeze(1) + if relative_attention_bias is not None: + attn_mask = mask + relative_attention_bias + else: + attn_mask = mask + if mask.dtype != q.dtype: + attn_mask = attn_mask.to(q.dtype) + + with torch.nn.attention.sdpa_kernel([ + torch.nn.attention.SDPBackend.FLASH_ATTENTION, + torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, + torch.nn.attention.SDPBackend.MATH, + torch.nn.attention.SDPBackend.CUDNN_ATTENTION, + ]): + x = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attn_mask, + dropout_p=self.dropout_rate, + ) + else: + if self.h != self.h_k: + q = q.reshape(n_batch, self.g, self.h_k, -1, self.d_k) + A = torch.einsum("b g h t d, b h s d -> b h t s", q, k) + else: + A = torch.matmul(q, k.transpose(-2, -1)) + if pos_k is not None: + if self.h != self.h_k: + B = torch.einsum("b g h t d, t s d -> b h t s", q, pos_k) + else: + reshape_q = (q.contiguous().view(n_batch * self.h, -1, + self.d_k).transpose(0, 1) + ) # (t1,nh,dk) + B = torch.matmul(reshape_q, + pos_k.transpose(-2, + -1)) # pos_k: (t1,dk,t2) + B = B.transpose(0, 1).view(n_batch, self.h, pos_k.size(0), + pos_k.size(1)) + scores = A + B + else: + scores = A + + if relative_attention_bias is not None: + scores = scores + relative_attention_bias + + attn = masked_softmax(scores, mask) # (batch, head, time1, time2) + + self.attn = attn + + p_attn = self.dropout(attn) + x = torch.matmul(p_attn.to(v.dtype), + v) # (batch, head, time1, d_k) + if pos_v is not None: + reshape_attn = (p_attn.contiguous().view( + n_batch * self.h, pos_v.size(0), + pos_v.size(1)).transpose(0, 1)) # (t1, bh, t2) + + attn_v = (torch.matmul(reshape_attn, pos_v).transpose( + 0, 1).contiguous().view(n_batch, self.h, pos_v.size(0), + self.d_k)) + x = x + attn_v + x = (x.transpose(1, 2).contiguous().view(n_batch, -1, + self.h_k * self.d_k) + ) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + +class MultiSequential(torch.nn.Sequential): + """Multi-input multi-output torch.nn.Sequential""" + + @torch.jit.ignore + def forward(self, *args): + """Forward method implementation.""" + for m in self: + args = m(*args) + return args + + +def get_offset(input_layer: str, time_reduction: int): + """Get an offset. We will use the offset for determining #frames of a + subsampled feature. + + Args: + input_layer (str): Type of an input layer + time_reduction (int): time reduction factor for downsampling a feature + Returns: + int: offset + """ + if input_layer in ("conv2d", "nemo_conv") and time_reduction == 4: + return 3 + if input_layer in ("conv2d", ) and time_reduction == 6: + return 1 + if input_layer in ("conv2d", "nemo_conv") and time_reduction == 8: + return 7 + return 0 + + +def unfold_tensor(xs_pad, max_seq_len): + """ + For a given tensor with shape of (N, T, D), if sequence length T is + longer than max_seq_len, this function unfold it to a + (NT', max_seq_len, D) where T' is T // max_seq_len. + Args: + xs_pad: N, T, D + """ + _, _, D = xs_pad.shape + xs_pad = xs_pad.transpose(-1, -2) # convert to N, D, T + # N x D x 1 x T => N x (D x max_seq_len) x T' + xs_pad = F.unfold( + xs_pad[..., None, :], + kernel_size=(1, max_seq_len), + stride=(1, max_seq_len), + ) + new_bsz, _, slen = xs_pad.shape + # N x D x max_seq_len x T' + xs_pad = xs_pad.view(new_bsz, -1, max_seq_len, slen) + # N x T' x max_seq_len x D + xs_pad = xs_pad.permute(0, 3, 2, 1).contiguous() + # NT' x max_seq_len x D + xs_pad = xs_pad.view(-1, max_seq_len, D) + return xs_pad diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py new file mode 100644 index 0000000..4a4420c --- /dev/null +++ b/vllm/model_executor/models/phimoe.py @@ -0,0 +1,671 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only PhiMoE model.""" +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers.configuration_utils import PretrainedConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class PhiMoEConfig(PretrainedConfig): + + model_type = "phimoe" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=1e6, + sliding_window=None, + attention_dropout=0.0, + num_experts_per_tok=2, + num_local_experts=16, + output_router_logits=False, + router_aux_loss_coef=0.001, + router_jitter_noise=0.0, + attention_bias=False, + lm_head_bias=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + self.attention_bias = attention_bias + self.lm_head_bias = lm_head_bias + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.router_jitter_noise = router_jitter_noise + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class mp(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + scores: torch.Tensor, + multiplier: torch.Tensor, + selected_experts: torch.Tensor, + masked_gates: torch.Tensor, + mask_for_one: torch.Tensor, + ): + ctx.save_for_backward(multiplier, selected_experts, masked_gates) + return multiplier * mask_for_one + + @staticmethod + def backward( + ctx, + grad_at_output: torch.Tensor, + ): + multiplier, selected_experts, masked_gates = ctx.saved_tensors + + grad_at_output = grad_at_output * multiplier + + grad_at_scores_expaned = masked_gates * grad_at_output.mul(-1) + grad_at_scores_expaned.scatter_add_( + dim=-1, + index=selected_experts, + src=grad_at_output, + ) + + return ( + grad_at_scores_expaned, + None, + None, + None, + None, + ) + + +def sparsemixer(scores, jitter_eps=0.01): + ################ first expert ################ + + with torch.no_grad(): + # compute mask for sparsity + mask_logits_threshold, max_ind = scores.max(dim=-1, keepdim=True) + factor = scores.abs().clamp(min=mask_logits_threshold) + mask_logits_threshold = ((mask_logits_threshold - scores) / + factor) > (2 * jitter_eps) + + # apply mask + masked_gates = scores.masked_fill(mask_logits_threshold, float("-inf")) + selected_experts = max_ind + + # compute scores for gradients + masked_gates = torch.softmax(masked_gates, dim=-1) + multiplier_o = masked_gates.gather(dim=-1, index=selected_experts) + + multiplier = multiplier_o + + # masked out first expert + masked_scores = torch.scatter( + scores, + -1, + selected_experts, + float("-inf"), + ) + with torch.no_grad(): + # compute mask for sparsity + mask_logits_threshold, max_ind = masked_scores.max(dim=-1, + keepdim=True) + factor = scores.abs().clamp(min=mask_logits_threshold) + mask_logits_threshold = ((mask_logits_threshold - scores) / + factor) > (2 * jitter_eps) + + # apply mask + masked_gates_top2 = masked_scores.masked_fill(mask_logits_threshold, + float("-inf")) + selected_experts_top2 = max_ind + # compute scores for gradients + masked_gates_top2 = torch.softmax(masked_gates_top2, dim=-1) + multiplier_top2 = masked_gates_top2.gather(dim=-1, + index=selected_experts_top2) + + multiplier = torch.concat((multiplier, multiplier_top2), dim=-1) + selected_experts = torch.concat((selected_experts, selected_experts_top2), + dim=-1) + multiplier = multiplier.to(torch.float32) + selected_experts = selected_experts.to(torch.int32) + + return ( + multiplier, + selected_experts, + ) + + +def phimoe_routing_function( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +): + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + assert topk == 2, "Only top-2 routing is supported" + assert renormalize is False, "Renormalization is not supported" + + topk_weights, topk_ids = sparsemixer(gating_output) + return topk_weights, topk_ids + + +class PhiMoE(nn.Module): + """A tensor-parallel MoE implementation for PhiMoE that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = hidden_size + + # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear( + hidden_size, + num_experts, + bias=False, + params_dtype=params_dtype, + quant_config=None, + ) + + self.experts = FusedMoE( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=False, + quant_config=quant_config, + tp_size=tp_size, + custom_routing_function=phimoe_routing_function, + prefix=f"{prefix}.experts") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states, router_logits) + return final_hidden_states.view(orig_shape) + + +class PhiMoEAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[dict] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=True, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=True, + quant_config=quant_config, + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=int(self.rope_theta), + is_neox_style=True, + rope_scaling=self.rope_scaling, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class PhiMoEDecoderLayer(nn.Module): + + def __init__( + self, + config: PhiMoEConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 10000) + self.self_attn = PhiMoEAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=config.rope_scaling, + prefix=f"{prefix}.self_attn", + ) + self.block_sparse_moe = PhiMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.block_sparse_moe", + ) + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.rms_norm_eps, + elementwise_affine=True) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.rms_norm_eps, + elementwise_affine=True) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + residual = hidden_states + + # Self Attention + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + hidden_states = hidden_states + residual + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.block_sparse_moe(hidden_states) + + hidden_states = hidden_states + residual + return hidden_states, residual + + +@support_torch_compile +class PhiMoEModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: PhiMoEDecoderLayer( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers") + self.norm = nn.LayerNorm(config.hidden_size, + eps=config.rms_norm_eps, + elementwise_affine=True) + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states = self.norm(hidden_states) + return hidden_states + + +class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + fall_back_to_pt_during_load = False + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + lora_config = vllm_config.lora_config + self.config = config + self.lora_config = lora_config + self.quant_config = vllm_config.quant_config + + self.model = PhiMoEModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size), + quant_config=None, + bias=True, + ) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = get_sampler() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_local_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py new file mode 100644 index 0000000..a6e4ea6 --- /dev/null +++ b/vllm/model_executor/models/pixtral.py @@ -0,0 +1,1353 @@ +# SPDX-License-Identifier: Apache-2.0 + +import math +from collections.abc import Iterable, Mapping, Sequence +from dataclasses import dataclass, fields +from functools import cached_property +from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mistral_common.protocol.instruct.messages import ImageChunk +from mistral_common.tokens.tokenizers.multimodal import ImageEncoder +from PIL import Image +from transformers import PixtralVisionConfig, TensorType +from transformers.image_utils import ImageInput +from transformers.models.pixtral.image_processing_pixtral import ( + _num_image_tokens as _get_pixtral_hf_num_image_tokens) +from transformers.models.pixtral.modeling_pixtral import ( + PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid) +from transformers.tokenization_utils_base import TextInput + +from vllm.config import VllmConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_and_mul_fn +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors +from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, + MultiModalDataItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.tokenizer import (MistralTokenizer, + cached_tokenizer_from_config) + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix, + merge_multimodal_embeddings) +from .vision import (VisionEncoderInfo, resolve_visual_encoder_outputs, + scatter_patch_features, select_patch_features) + +import ixformer.inference.functions as ixf +try: + from xformers import ops as xops + USE_XFORMERS_OPS = True +except ImportError: + USE_XFORMERS_OPS = False + +PATCH_MERGE = "patch_merge" + + +class PixtralImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + + images: Union[torch.Tensor, list[torch.Tensor]] + """ + Shape: `(batch_size * num_images, num_channels, image_width, image_height)` + + The result of stacking :attr:`ImageEncoding.tokens` from each prompt. + """ + + embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] + """ + A boolean mask indicating which image embeddings correspond + to patch tokens. + + Shape: `(batch_size * num_images, num_embeds)` + """ + + +class PixtralProcessorAdapter: + """ + Provide a HF-compatible interface for + :class:`mistral_common.tokens.tokenizers.multimodal.ImageEncoder`. + """ + + def __init__(self, tokenizer: MistralTokenizer) -> None: + super().__init__() + + self.tokenizer = tokenizer + + @property + def image_processor(self) -> ImageEncoder: + image_encoder = self.tokenizer.instruct.mm_encoder + assert isinstance(image_encoder, ImageEncoder) + return image_encoder + + @cached_property + def image_break_id(self) -> int: + return self.image_processor.special_ids.img_break + + @cached_property + def image_token_id(self) -> int: + return self.image_processor.special_ids.img + + @cached_property + def image_end_id(self) -> int: + return self.image_processor.special_ids.img_end + + @cached_property + def image_size(self) -> int: + return self.image_processor.mm_config.max_image_size + + @cached_property + def patch_size(self) -> int: + return self.image_processor.mm_config.image_patch_size + + def __call__( + self, + text: Optional[Union[TextInput, list[TextInput]]] = None, + images: Optional[Union[ImageInput, list[ImageInput]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ) -> Mapping[str, NestedTensors]: + if text is None: + text = [] + if not isinstance(text, list): + text = [text] + if images is None: + images = [] + if not isinstance(images, list): + images = [images] + + if not images: + input_ids = self.tokenizer(text).input_ids + + return {"input_ids": torch.tensor(input_ids)} + + # Allow dummy text, which is used for profiling as well as token inputs + if any(len(t) > 0 for t in text): + raise ValueError( + "You've passed text inputs instead of token inputs. " + "Make sure to process your input via `mistral_common`'s " + "tokenizer or pass a chat completion request. " + "For more info, see: " + "https://github.com/vllm-project/vllm/issues/8411.") + + image_token_id = self.image_token_id + + images_processed = list[torch.Tensor]() + images_tokens = list[torch.Tensor]() + images_embed_is_patch = list[torch.Tensor]() + + for image in images: + image_inputs = self.image_processor(ImageChunk(image=image)) + image_processed = torch.tensor(image_inputs.image) + image_tokens = torch.tensor(image_inputs.tokens) + + images_processed.append(image_processed) + images_tokens.append(image_tokens) + images_embed_is_patch.append(image_tokens == image_token_id) + + return { + "input_ids": torch.cat(images_tokens)[None].expand(len(text), -1), + "images": images_processed, + "embed_is_patch": images_embed_is_patch, + } + + +class PixtralProcessingInfo(BaseProcessingInfo): + + def get_tokenizer(self) -> MistralTokenizer: + tokenizer = cached_tokenizer_from_config(self.ctx.model_config) + if not isinstance(tokenizer, MistralTokenizer): + raise ValueError("This model requires `--tokenizer-mode mistral`") + + return tokenizer + + def get_hf_processor(self) -> PixtralProcessorAdapter: + return PixtralProcessorAdapter(self.get_tokenizer()) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"image": self.get_max_image_tokens()} + + def get_vision_config( + self, + processor: Optional[PixtralProcessorAdapter] = None, + ): + if processor is None: + processor = self.get_hf_processor() + + return PixtralVisionConfig( + image_size=processor.image_size, + patch_size=processor.patch_size, + ) + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + processor: Optional[PixtralProcessorAdapter] = None, + ) -> int: + if processor is None: + processor = self.get_hf_processor() + + ncols, nrows = processor.image_processor._image_to_num_tokens( + Image.new("RGB", (image_width, image_height))) + + return (ncols + 1) * nrows + + def get_image_size_with_most_features(self) -> ImageSize: + image_processor = self.get_hf_processor().image_processor + max_image_size = image_processor.mm_config.max_image_size + + return ImageSize(width=max_image_size, height=max_image_size) + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + return self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + ) + + +class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + + target_width, target_height = \ + self.info.get_image_size_with_most_features() + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text="", + mm_data=mm_data, + ) + + +class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] + ): + + def _get_mm_fields_config( + self, + hf_inputs: Mapping[str, NestedTensors], + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + images=MultiModalFieldConfig.batched("image"), + embed_is_patch=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + image_break_id = processor.image_break_id + image_token_id = processor.image_token_id + image_end_id = processor.image_end_id + + def get_replacement(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) + + ncols, nrows = processor.image_processor._image_to_num_tokens( + Image.new("RGB", (image_size.width, image_size.height))) + + tokens = ([image_token_id] * ncols + [image_break_id]) * nrows + tokens[-1] = image_end_id + + return tokens + + return [ + PromptReplacement( + modality="image", + target="", # Never match the prompt (see below note) + replacement=get_replacement, + ), + ] + + def _cached_apply_hf_processor( + self, + prompt: Union[str, list[int]], + mm_data_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> tuple[list[int], MultiModalKwargs, bool]: + prompt_ids, mm_kwargs, _ = super()._cached_apply_hf_processor( + prompt=prompt, + mm_data_items=mm_data_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + + # NOTE: The tokens are already inserted by the chat template + return prompt_ids, mm_kwargs, True + + +@MULTIMODAL_REGISTRY.register_processor(PixtralMultiModalProcessor, + info=PixtralProcessingInfo, + dummy_inputs=PixtralDummyInputsBuilder) +class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.multimodal_config = multimodal_config + + dataclass_fields = {field.name for field in fields(VisionEncoderArgs)} + vision_args = { + key: value + for key, value in self.config.vision_config.to_dict().items() + if key in dataclass_fields + } + + self.vision_args = VisionEncoderArgs(**vision_args) + + # init MistralForCausalLM + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + + self.vision_encoder = VisionTransformer(self.vision_args) + + if self.vision_args.add_pre_mm_projector_layer_norm: + self.pre_mm_projector_norm = RMSNorm(self.vision_args.hidden_size, + eps=1e-5) + + if self.vision_args.mm_projector_id == PATCH_MERGE: + self.patch_merger = PatchMerger( + vision_encoder_dim=self.vision_args.hidden_size, + spatial_merge_size=self.vision_args.spatial_merge_size, + use_mlp_bias=False, + ) + + self.vision_language_adapter = VisionLanguageAdapter( + self.vision_args, dim=config.text_config.hidden_size) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return get_sampler() + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[PixtralImagePixelInputs]: + images = kwargs.pop("images", None) + if images is None: + return None + + if not isinstance(images, (torch.Tensor, list)): + raise ValueError("Incorrect type of images. " + f"Got type: {type(images)}") + + embed_is_patch = kwargs.pop("embed_is_patch") + if not isinstance(embed_is_patch, (torch.Tensor, list)): + raise ValueError("Incorrect type of embed_is_patch. " + f"Got type: {type(embed_is_patch)}") + + embed_is_patch = flatten_bn(embed_is_patch) + + return PixtralImagePixelInputs( + type="pixel_values", + images=flatten_bn(images), + embed_is_patch=embed_is_patch, + ) + + def _process_image_input( + self, + image_input: PixtralImagePixelInputs, + ) -> tuple[torch.Tensor, ...]: + images = image_input["images"] + image_features = self.vision_encoder(images) + feature_sizes = [ + image_feature.shape[0] for image_feature in image_features + ] + image_features = torch.cat(image_features) + if self.vision_args.add_pre_mm_projector_layer_norm: + image_features = self.pre_mm_projector_norm(image_features) + if self.vision_args.mm_projector_id == PATCH_MERGE: + patch_size = self.vision_args.patch_size + spatial_merge_size_square = self.vision_args.spatial_merge_size**2 + img_patch_dims = [(img.shape[1] // patch_size, + img.shape[2] // patch_size) for img in images] + feature_sizes = [ + feature_size // spatial_merge_size_square + for feature_size in feature_sizes + ] + image_features = self.patch_merger(image_features, + image_sizes=img_patch_dims) + image_embeds = self.vision_language_adapter(image_features) + image_embeds = torch.split(image_embeds, feature_sizes) + return image_embeds + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + + image_features = self._process_image_input(image_input) + + return scatter_patch_features( + image_features, + image_input["embed_is_patch"], + ) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + select_patch_features(multimodal_embeddings), + self.vision_args.image_token_id, + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + """Run forward pass for pixtral.""" + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.language_model.model(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + def is_vision_encoder_weights(weight: Tuple[str, torch.Tensor]): + return weight[0].startswith("vision_encoder") + + def is_vision_lang_adapter_weights(weight: Tuple[str, torch.Tensor]): + return weight[0].startswith("vision_language_adapter") + + def is_patch_merger(weight: Tuple[str, torch.Tensor]): + return weight[0].startswith("patch_merger") + + def is_pre_mm_projector_norm(weight: Tuple[str, torch.Tensor]): + return weight[0].startswith("pre_mm_projector_norm") + + # Get references to parameters for direct loading + vision_encoder_dict = dict(self.vision_encoder.named_parameters()) + patch_merger_dict = dict(self.patch_merger.named_parameters( + )) if self.vision_args.mm_projector_id == PATCH_MERGE else dict() + pre_mm_projector_norm_dict = dict( + self.pre_mm_projector_norm.named_parameters( + )) if self.vision_args.add_pre_mm_projector_layer_norm else dict() + vision_lang_adapter_dict = dict( + self.vision_language_adapter.named_parameters()) + + def llm_weights_generator(): + # Single pass over weights + for name, w in weights: + if is_vision_encoder_weights((name, w)): + # Load vision encoder weights directly + trimmed_name = '.'.join(name.split(".")[1:]) + param = vision_encoder_dict[trimmed_name] + with torch.no_grad(): + default_weight_loader(param, w) + elif is_patch_merger((name, w)): + # Load vision patch merger weights directly + trimmed_name = '.'.join(name.split(".")[1:]) + param = patch_merger_dict[trimmed_name] + with torch.no_grad(): + default_weight_loader(param, w) + elif is_pre_mm_projector_norm((name, w)): + # Load vision pre_mm_projector_norm weights directly + trimmed_name = '.'.join(name.split(".")[1:]) + param = pre_mm_projector_norm_dict[trimmed_name] + with torch.no_grad(): + default_weight_loader(param, w) + elif is_vision_lang_adapter_weights((name, w)): + # Load vision-language adapter weights directly + trimmed_name = '.'.join(name.split(".")[1:]) + param = vision_lang_adapter_dict[trimmed_name] + with torch.no_grad(): + default_weight_loader(param, w) + else: + # LLM weights: yield them to be loaded + # by language_model.load_weights + yield (name, w) + + # Now we call the language model load with the generator + self.language_model.load_weights(llm_weights_generator()) + + +# Vision encoder +@dataclass +class VisionEncoderArgs: + hidden_size: int + num_channels: int + image_size: int + patch_size: int + intermediate_size: int + num_hidden_layers: int + num_attention_heads: int + rope_theta: float # for rope-2D + image_token_id: int + adapter_bias: bool = True + spatial_merge_size: int = 1 + add_pre_mm_projector_layer_norm: bool = False + mm_projector_id: str = "" + + +def _reshape_for_broadcast(freqs_cis: torch.Tensor, + x: torch.Tensor) -> torch.Tensor: + """ + freqs_cis: complex - (seq_len, head_dim / 2) + x: complex - (bsz, seq_len, head_dim / 2) + """ + ndim = x.ndim + assert ndim > 1 + assert freqs_cis.shape == (x.shape[1], x.shape[-1]), ( + freqs_cis.shape, + (x.shape[1], x.shape[-1]), + ) + shape = [ + d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape) + ] + return freqs_cis.view(*shape) + + +def precompute_freqs_cis_2d( + dim: int, + height: int, + width: int, + theta: float, +) -> torch.Tensor: + """ + freqs_cis: 2D complex tensor of shape (height, width, dim // 2) + to be indexed by (height, width) position tuples + """ + # (dim / 2) frequency bases + freqs = 1.0 / (theta**(torch.arange(0, dim, 2).float() / dim)) + + h = torch.arange(height, device=freqs.device) + w = torch.arange(width, device=freqs.device) + + freqs_h = torch.outer(h, freqs[::2]).float() + freqs_w = torch.outer(w, freqs[1::2]).float() + freqs_2d = torch.cat( + [ + freqs_h[:, None, :].repeat(1, width, 1), + freqs_w[None, :, :].repeat(height, 1, 1), + ], + dim=-1, + ) + return torch.polar(torch.ones_like(freqs_2d), freqs_2d) + + +def apply_rotary_emb_vit( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + assert freqs_cis.dtype == torch.complex64 + freqs_cis = _reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class FeedForward(nn.Module): + + def __init__(self, args: VisionEncoderArgs): + super().__init__() + assert args.intermediate_size is not None + self.w1 = nn.Linear(args.hidden_size, + args.intermediate_size, + bias=False) + self.w2 = nn.Linear(args.intermediate_size, + args.hidden_size, + bias=False) + self.w3 = nn.Linear(args.hidden_size, + args.intermediate_size, + bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class Attention(nn.Module): + + def __init__(self, args: VisionEncoderArgs): + super().__init__() + self.args = args + assert not args.hidden_size % args.num_attention_heads + self.n_heads = args.num_attention_heads + self.head_dim = args.hidden_size // args.num_attention_heads + + self.wq = nn.Linear(args.hidden_size, args.hidden_size, bias=False) + self.wk = nn.Linear(args.hidden_size, args.hidden_size, bias=False) + self.wv = nn.Linear(args.hidden_size, args.hidden_size, bias=False) + self.wo = nn.Linear(args.hidden_size, args.hidden_size, bias=False) + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + freqs_cis: torch.Tensor, + ) -> torch.Tensor: + batch, patches, _ = x.shape + + q, k, v = self.wq(x), self.wk(x), self.wv(x) + q = q.reshape(batch, patches, self.n_heads, self.head_dim) + k = k.reshape(batch, patches, self.n_heads, self.head_dim) + v = v.reshape(batch * patches, self.n_heads, self.head_dim) + + q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis) + q = q.view(batch * patches, self.n_heads, self.head_dim) + k = k.view(batch * patches, self.n_heads, self.head_dim) + out = ixf.ixinfer_flash_attn_unpad(q,k,v, mask.q_seqinfo.seqstart.to(q.device), mask.k_seqinfo.seqstart.to(q.device), mask.q_seqinfo.max_seqlen, mask.k_seqinfo.max_seqlen) + # out = memory_efficient_attention(q, k, v, attn_bias=mask) + out = out.reshape(batch, patches, self.n_heads * self.head_dim) + return self.wo(out) + + +class TransformerBlock(nn.Module): + + def __init__(self, args: VisionEncoderArgs): + super().__init__() + self.attention = Attention(args) + self.feed_forward = FeedForward(args) + self.attention_norm = RMSNorm(args.hidden_size, eps=1e-5) + self.ffn_norm = RMSNorm(args.hidden_size, eps=1e-5) + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + freqs_cis: torch.Tensor, + ) -> torch.Tensor: + r = self.attention.forward(self.attention_norm(x), + mask=mask, + freqs_cis=freqs_cis) + h = x + r + r = self.feed_forward.forward(self.ffn_norm(h)) + out = h + r + return out + + +class Transformer(nn.Module): + + def __init__(self, args: VisionEncoderArgs): + super().__init__() + self.layers = torch.nn.ModuleList() + for _ in range(args.num_hidden_layers): + self.layers.append(TransformerBlock(args)) + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + freqs_cis: Optional[torch.Tensor], + ) -> torch.Tensor: + for layer in self.layers: + x = layer(x, mask=mask, freqs_cis=freqs_cis) + return x + + +def position_meshgrid(patch_embeds_list: List[torch.Tensor], ) -> torch.Tensor: + positions = torch.cat([ + torch.stack( + torch.meshgrid( + torch.arange(p.shape[-2]), + torch.arange(p.shape[-1]), + indexing="ij", + ), + dim=-1, + ).reshape(-1, 2) for p in patch_embeds_list + ]) + return positions + + +class VisionTransformer(nn.Module): + + def __init__(self, args: VisionEncoderArgs): + super().__init__() + self.args = args + self.patch_conv = nn.Conv2d( + in_channels=args.num_channels, + out_channels=args.hidden_size, + kernel_size=args.patch_size, + stride=args.patch_size, + bias=False, + ) + self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5) + self.transformer = Transformer(args) + + head_dim = self.args.hidden_size // self.args.num_attention_heads + assert head_dim % 2 == 0, "ROPE requires even head_dim" + self._freqs_cis: Optional[torch.Tensor] = None + + @property + def max_patches_per_side(self) -> int: + return self.args.image_size // self.args.patch_size + + @property + def device(self) -> torch.types.Device: + return next(self.parameters()).device + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + @property + def freqs_cis(self) -> torch.Tensor: + if self._freqs_cis is None: + self._freqs_cis = precompute_freqs_cis_2d( + dim=self.args.hidden_size // self.args.num_attention_heads, + height=self.max_patches_per_side, + width=self.max_patches_per_side, + theta=self.args.rope_theta, + ) + + if self._freqs_cis.device != self.device: + self._freqs_cis = self._freqs_cis.to(device=self.device) + + return self._freqs_cis + + def forward( + self, + images: List[torch.Tensor], + ) -> torch.Tensor: + """ + Args: + images: list of N_img images of variable sizes, + each of shape (C, H, W) + Returns: + image_features: tensor of token features for + all tokens of all images of shape (N_toks, D) + """ + # pass images through initial convolution independently + patch_embeds_list = [ + self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images + ] + + patch_embeds = [ + p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list + ] + embed_sizes = [p.shape[1] for p in patch_embeds] + + # flatten to a single sequence + patch_embeds = torch.cat(patch_embeds, dim=1) + patch_embeds = self.ln_pre(patch_embeds) + + # positional embeddings + positions = position_meshgrid(patch_embeds_list).to(self.device) + freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]] + + # pass through Transformer with a block diagonal mask delimiting images + if USE_XFORMERS_OPS: + mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens( + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], ) + else: + raise ImportError("Xformers is required for Pixtral inference " + "with the Mistral format") + out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis) + + # squeeze dim 0 and split into separate tensors for each image + return torch.split(out.squeeze(0), embed_sizes) + + +class VisionLanguageAdapter(nn.Module): + + def __init__(self, args: VisionEncoderArgs, dim: int): + super().__init__() + assert isinstance(args, VisionEncoderArgs) + self.w_in = nn.Linear( + args.hidden_size, + dim, + bias=args.adapter_bias, + ) + self.gelu = nn.GELU() + self.w_out = nn.Linear(dim, dim, bias=args.adapter_bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w_out(self.gelu(self.w_in(x))) + + +class PatchMerger(nn.Module): + """ + Learned merging of spatial_merge_size ** 2 patches + """ + + def __init__( + self, + vision_encoder_dim: int, + spatial_merge_size: int, + use_mlp_bias: bool = False, + ) -> None: + super().__init__() + + mlp_input_dim = vision_encoder_dim * (spatial_merge_size**2) + + self.spatial_merge_size = spatial_merge_size + self.mlp_input_dim = mlp_input_dim + + self.merging_layer = nn.Linear( + mlp_input_dim, + vision_encoder_dim, + bias=use_mlp_bias, + ) + + def forward(self, x: torch.Tensor, + image_sizes: list[tuple[int, int]]) -> torch.Tensor: + # image_sizes specified in tokens + assert sum([h * w for h, w in image_sizes]) == len(x) + + # x is (N, vision_encoder_dim) + x = self.permute(x, image_sizes) + + # x is (N / spatial_merge_size ** 2, + # vision_encoder_dim * spatial_merge_size ** 2) + x = self.merging_layer(x) + + # x is (N / spatial_merge_size ** 2, vision_encoder_dim) + return x + + def permute( + self, + x: torch.Tensor, + image_sizes: list[tuple[int, int]], + ) -> torch.Tensor: + """ + Args: + x: (N, D) where N is flattened and concatenated patch tokens + for all images + image_sizes: list of tuple of (height, width) in tokens for + each image + Returns: + image_features: reorders patch tokens so each grid of + (spatial_merge_size, spatial_merge_size) is contiguous. + now (N / spatial_merge_size ** 2, D * spatial_merge_size ** 2) + """ + + sub_grids = get_sub_grids( + x=x, + image_sizes=image_sizes, + spatial_merge_size=self.spatial_merge_size + ) # list of [d x sub_grid_size x sub_grid_size x n_patches] + permuted_tensor: list[torch.Tensor] = [] + for grid in sub_grids: + n_patches = grid.shape[-1] + permuted_tensor.append(grid.view(-1, n_patches).t( + )) # n_patches x d * sub_grid_size * sub_grid_size + return torch.cat( + permuted_tensor, dim=0 + ) # (N / spatial_merge_size ** 2, d * spatial_merge_size ** 2) + + +def get_sub_grids( + x: torch.Tensor, + image_sizes: list[tuple[int, int]], + spatial_merge_size: int, +) -> list[torch.Tensor]: + # image_sizes specified in tokens + tokens_per_image = [h * w for h, w in image_sizes] + d = x.shape[-1] + all_img_sub_grids: list[torch.Tensor] = [] + sub_grid_size = spatial_merge_size + + for image_index, image_tokens in enumerate(x.split(tokens_per_image)): + # Reshape image_tokens into a 2D grid + h, w = image_sizes[image_index] + image_grid = image_tokens.view(h, w, d).permute( + 2, 0, 1)[None, :, :, :] # 1 x d x h x w + sub_grids = torch.nn.functional.unfold(image_grid, + kernel_size=sub_grid_size, + stride=sub_grid_size) + sub_grids = sub_grids.view( + 1, d, sub_grid_size, sub_grid_size, + -1) # 1 x d x sub_grid_size x sub_grid_size x n_patches + + all_img_sub_grids.append(sub_grids[0]) + + return all_img_sub_grids + + +#### HF Transformers version of Pixtral #### +# Based off https://github.com/huggingface/transformers/blob/d7950bff82b18c823193d17d72188c5e46d06c83/src/transformers/models/pixtral/modeling_pixtral.py +# This model follows the Llava family, meaning image embeddings are placed +# instead of the `[IMG]` token placeholders. +# The model uses [`PixtralVisionModel`] for its vision encoder, +# and [`MistralForCausalLM`] for its language decoder. + + +class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]): + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + ncols, nrows = self.get_patch_grid_size( + image_width=image_width, + image_height=image_height, + ) + + # Consider the image_break_token + return (ncols + 1) * nrows + + def get_max_image_tokens(self) -> int: + image_size = self.get_image_size() + + return self.get_num_image_tokens( + image_width=image_size, + image_height=image_size, + ) + + def get_image_size(self) -> int: + return self.vision_config.image_size + + def get_patch_size(self) -> int: + return (self.vision_config.patch_size * + self.vision_config.spatial_merge_size) + + def get_patch_grid_length(self) -> int: + image_size, patch_size = self.get_image_size(), self.get_patch_size() + + # Since interpolation is applied, the image size need not be divisible + # assert image_size % patch_size == 0 + return image_size // patch_size + + # Adapted from: https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/pixtral/image_processing_pixtral.py#L99 + def get_patch_grid_size( + self, + *, + image_width: int, + image_height: int, + ) -> tuple[int, int]: + max_width = max_height = self.get_image_size() + patch_width = patch_height = self.get_patch_size() + + ratio = max(image_width / max_width, image_height / max_height) + + if ratio > 1: + image_width = int(math.floor(image_width / ratio)) + image_height = int(math.floor(image_height / ratio)) + + nrows, ncols = _get_pixtral_hf_num_image_tokens( + (image_height, image_width), + (patch_height, patch_width), + ) # type: ignore + + return ncols, nrows + + +class PixtralHFMLP(nn.Module): + + def __init__( + self, + config: PixtralVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + prefix: str = "", + ) -> None: + super().__init__() + + assert config.intermediate_size is not None + self.gate_up_proj = MergedColumnParallelLinear( + input_size=config.hidden_size, + output_sizes=[config.intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(input_size=config.intermediate_size, + output_size=config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj") + self.act_and_mul = get_act_and_mul_fn(config.hidden_act) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_and_mul(gate_up) + x, _ = self.down_proj(x) + return x + + +class PixtralHFAttention(nn.Module): + + def __init__( + self, + config: PixtralVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + assert not config.hidden_size % config.num_attention_heads + self.total_num_heads = config.num_attention_heads + tp_size = get_tensor_model_parallel_world_size() + self.n_heads = divide(config.num_attention_heads, tp_size) + self.head_dim = config.hidden_size // config.num_attention_heads + + self.qkv_proj = QKVParallelLinear( + hidden_size=config.hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + assert self.total_num_heads * self.head_dim == config.hidden_size + self.o_proj = RowParallelLinear( + input_size=config.hidden_size, + output_size=config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + batch, patches, _ = hidden_states.size() + + qkv_states, _ = self.qkv_proj(hidden_states) + q, k, v = qkv_states.chunk(3, dim=-1) + + # Transpose q and k to apply HF's Rotary Position Embedding + q = q.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2) + k = k.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2) + v = v.view(batch, patches, self.n_heads, self.head_dim) + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0) + + if USE_XFORMERS_OPS: + # Transpose q and k back for attention + q = q.transpose(1, 2).contiguous() + k = k.transpose(1, 2).contiguous() + + out = xops.memory_efficient_attention(q, + k, + v, + attn_bias=attention_mask) + else: + v = v.transpose(1, 2) + out = nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=attention_mask) + out = out.transpose(1, 2) + + out = out.view(batch, patches, self.n_heads * self.head_dim) + attn_output, _ = self.o_proj(out) + + return attn_output, None + + +class PixtralHFTransformerBlock(nn.Module): + + def __init__( + self, + config: PixtralVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + prefix: str = "", + ) -> None: + super().__init__() + + self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5) + self.attention = PixtralHFAttention(config, + quant_config=quant_config, + prefix=f"{prefix}.attention") + self.feed_forward = PixtralHFMLP(config, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward") + self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor, + ) -> torch.Tensor: + r, _ = self.attention.forward(self.attention_norm(hidden_states), + attention_mask=attention_mask, + position_embeddings=position_embeddings) + h = hidden_states + r + r = self.feed_forward.forward(self.ffn_norm(h)) + out = h + r + return out + + +class PixtralHFTransformer(nn.Module): + + def __init__( + self, + config: PixtralVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + num_hidden_layers_override: Optional[int] = None, + prefix: str = "", + ) -> None: + super().__init__() + + if num_hidden_layers_override is None: + num_hidden_layers = config.num_hidden_layers + else: + num_hidden_layers = num_hidden_layers_override + + self.layers = nn.ModuleList([ + PixtralHFTransformerBlock(config=config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}") + for layer_idx in range(num_hidden_layers) + ]) + + def forward( + self, + x: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor, + return_all_hidden_states: bool, + ) -> torch.Tensor: + hidden_states_pool = [x] + + for layer in self.layers: + x = layer(x, attention_mask, position_embeddings) + if return_all_hidden_states: + hidden_states_pool.append(x) + # If we have multiple feature sample layers, we return all hidden + # states in order and grab the ones we need by index. + if return_all_hidden_states: + return hidden_states_pool + return x + + +class PixtralHFVisionModel(nn.Module): + + def __init__( + self, + config: PixtralVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + num_hidden_layers_override: Optional[int] = None, + require_post_norm: Optional[bool] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + + self.patch_conv = nn.Conv2d( + in_channels=config.num_channels, + out_channels=config.hidden_size, + kernel_size=config.patch_size, + stride=config.patch_size, + bias=False, + ) + self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5) + self.transformer = PixtralHFTransformer( + config, + quant_config, + num_hidden_layers_override=num_hidden_layers_override, + prefix=f"{prefix}.transformer", + ) + + num_hidden_layers = config.num_hidden_layers + if len(self.transformer.layers) > config.num_hidden_layers: + raise ValueError( + f"The original encoder only has {num_hidden_layers} " + f"layers, but you requested {len(self.transformer.layers)} " + "layers.") + + if require_post_norm is True: + msg = "PixtralHFVisionModel does not have post-layernorm" + raise ValueError(msg) + + self.dtype = next(self.parameters()).dtype + self.device = next(self.parameters()).device + self.patch_positional_embedding = PixtralRotaryEmbedding( + config, self.device) + + def forward( + self, + pixel_values: list[torch.Tensor], + feature_sample_layers: Optional[list[int]] = None, + ) -> tuple[torch.Tensor, ...]: + """ + Args: + pixel_values: Each image to be processed will be a separate tensor + in pixel_values. This means it will be a list of tensors + because multiple requests batched can have multiple images, + each with their own shape potentially + feature_sample_layers: Layer indices whose features should be + concatenated and used as the visual encoder output. If none + are provided, the last layer is used. + + Returns: + image_features: tensor of token features for + all tokens of all images of shape (N_toks, D) + """ + # pass images through initial convolution independently + patch_embeds_list = [ + self.patch_conv(img.unsqueeze(0).to(self.dtype)) + for img in pixel_values + ] + + patch_embeds = [ + p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list + ] + embed_sizes = [p.shape[1] for p in patch_embeds] + + # flatten to a single sequence + patch_embeds = torch.cat(patch_embeds, dim=1) + patch_embeds = self.ln_pre(patch_embeds) + + # positional embeddings + position_ids = position_ids_in_meshgrid( + patch_embeds_list, + max_width=self.config.image_size // self.config.patch_size).to( + self.device) + position_embedding = self.patch_positional_embedding( + patch_embeds, position_ids) + + if USE_XFORMERS_OPS: + attention_mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens( + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], ) + else: + from transformers.models.pixtral.modeling_pixtral import ( + generate_block_attention_mask) + attention_mask = generate_block_attention_mask( + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], + patch_embeds) + + return_all_hidden_states = feature_sample_layers is not None + out = self.transformer( + patch_embeds, + attention_mask, + position_embedding, + return_all_hidden_states=return_all_hidden_states) + + out = resolve_visual_encoder_outputs(out, feature_sample_layers, None, + self.config.num_hidden_layers) + + # squeeze dim 0 and split into separate tensors for each image + return torch.split(out.squeeze(0), embed_sizes) + + # (TODO) Add prefix argument for filtering out weights to be loaded + # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + layer_count = len(self.transformer.layers) + + for name, loaded_weight in weights: + # omit layers when num_hidden_layers_override is set + if name.startswith("transformer.layers"): + layer_idx = int(name.split(".")[2]) + if layer_idx >= layer_count: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py new file mode 100644 index 0000000..a69c0fc --- /dev/null +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -0,0 +1,233 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2025 The vLLM team. +# Copyright 2025 IBM. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only IBM/NASA Prithvi Geospatial model.""" +from collections.abc import Iterable, Mapping, Sequence +from typing import Optional, Set, Tuple, Union + +import torch +import torch.nn as nn +from transformers import BatchFeature + +from vllm.config import VllmConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import (IsAttentionFree, + SupportsMultiModal, + SupportsV0Only) +from vllm.model_executor.models.utils import AutoWeightsLoader +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalInputs, MultiModalKwargs) +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import (IntermediateTensors, PoolerOutput, + PoolingSequenceGroupOutput) + + +class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo): + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: + return {"image": 0} + + +class PrithviGeoSpatialMAEInputBuilder( + BaseDummyInputsBuilder[PrithviGeoSpatialMAEProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + return ProcessorInputs( + prompt_text="", + # This model input is fixed and is in the form of a torch Tensor. + # The size of pixel_values might change in the cases where we resize + # the input but never exceeds the dimensions below. + mm_data={ + "pixel_values": torch.full((1, 6, 512, 512), 1.0), + "location_coords": torch.full((1, 2), 1.0) + }) + + +class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + location_coords=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + return [] + + def apply( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + return_mm_hashes: bool = False, + ) -> MultiModalInputs: + mm_kwargs = {} + + for k, v in mm_data.items(): + mm_kwargs[k] = v + + return MultiModalInputs( + type="multimodal", + prompt=prompt, + prompt_token_ids=[1], + mm_kwargs=MultiModalKwargs(mm_kwargs), + mm_hashes=None, + mm_placeholders={}, + ) + + +@MULTIMODAL_REGISTRY.register_processor( + PrithviGeoSpatialMAEMultiModalProcessor, + info=PrithviGeoSpatialMAEProcessingInfo, + dummy_inputs=PrithviGeoSpatialMAEInputBuilder) +class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal, + SupportsV0Only): + """ Prithvi Masked Autoencoder""" + + def _instantiate_model(self, config: dict) -> Optional[nn.Module]: + + # We might be able/need to support different tasks with this same model + if config["task_args"]["task"] == "SemanticSegmentationTask": + from terratorch.cli_tools import SemanticSegmentationTask + task = SemanticSegmentationTask( + config["model_args"], + config["task_args"]["model_factory"], + loss=config["task_args"]["loss"], + lr=config["task_args"]["lr"], + ignore_index=config["task_args"]["ignore_index"], + optimizer=config["task_args"]["optimizer"], + optimizer_hparams=config["optimizer_params"], + scheduler=config["task_args"]["scheduler"], + scheduler_hparams=config["scheduler_params"], + plot_on_val=config["task_args"]["plot_on_val"], + freeze_decoder=config["task_args"]["freeze_decoder"], + freeze_backbone=config["task_args"]["freeze_backbone"]) + + return task.model + else: + return None + + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + # the actual model is dynamically instantiated using terratorch + # allowing us to perform changes to the model architecture + # at startup time (e.g., change the model decoder class.) + self.model = self._instantiate_model( + vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"]) + if self.model is None: + raise ValueError( + "Unsupported task. " + "Only SemanticSegmentationTask is supported for now " + "by PrithviGeospatialMAE.") + + def _parse_and_validate_multimodal_data( + self, **kwargs) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + + pixel_values = kwargs.pop("pixel_values", None) + if not isinstance(pixel_values, torch.Tensor): + raise ValueError(f"Incorrect type of pixel_values. " + f"Got type: {type(pixel_values)}") + pixel_values = torch.unbind(pixel_values, dim=0)[0] + + location_coords = kwargs.pop("location_coords", None) + if not isinstance(location_coords, torch.Tensor): + raise ValueError(f"Incorrect type of location_coords. " + f"Got type: {type(location_coords)}") + location_coords = torch.unbind(location_coords, dim=0)[0] + if location_coords.shape == torch.Size([0]): + location_coords = None + + return pixel_values, location_coords + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ): + + pixel_values, location_coords = ( + self._parse_and_validate_multimodal_data(**kwargs)) + model_output = self.model(pixel_values, + location_coords=location_coords) + + return model_output.output + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return PoolerOutput([PoolingSequenceGroupOutput(hidden_states)]) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + params_list = [] + model_buffers = dict(self.named_buffers()) + loaded_buffers = [] + for key, value in weights: + if key == "state_dict": + weights_to_parse = value + for name, weight in weights_to_parse.items(): + if "pos_embed" in name: + continue + + if "_timm_module." in name: + name = name.replace("_timm_module.", "") + + # this model requires a couple of buffers to be loaded + # that are not loadable with the AutoWeightsLoader + if name in model_buffers: + if "_timm_module." in name: + name = name.replace("_timm_module.", "") + buffer = model_buffers[name] + weight_loader = getattr(buffer, "weight_loader", + default_weight_loader) + weight_loader(buffer, weight) + loaded_buffers.append(name) + else: + params_list.append((name, weight)) + break + + # Load the remaining model parameters + loader = AutoWeightsLoader(self) + autoloaded_weights = loader.load_weights(params_list) + + return autoloaded_weights.union(set(loaded_buffers)) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py new file mode 100644 index 0000000..a33739a --- /dev/null +++ b/vllm/model_executor/models/qwen.py @@ -0,0 +1,370 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py +# Copyright (c) Alibaba Cloud. +# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE +"""Inference-only QWen model compatible with HuggingFace weights.""" +import json +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class QWenMLP(nn.Module): + """MLP for the language component of the Qwen model, which contains a + MergedColumnParallelLinear merging 2 outputs via silu activation.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str = "silu", + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.c_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.c_proj(x) + return x + + +class QWenAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + max_position_embeddings: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = hidden_size + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( + ) + self.total_num_heads = num_heads + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = (self.total_num_heads // + tensor_model_parallel_world_size) + self.head_dim = hidden_size // self.total_num_heads + self.c_attn = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + bias=True, + quant_config=quant_config, + ) + self.c_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + self.scaling = self.head_dim**-0.5 + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.c_attn(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.c_proj(attn_output) + return output + + +class QWenBlock(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + self.attn = QWenAttention(config.hidden_size, + config.num_attention_heads, + config.max_position_embeddings, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + self.mlp = QWenMLP(config.hidden_size, + config.intermediate_size // 2, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + else: + hidden_states, residual = self.ln_1(hidden_states, residual) + hidden_states = self.attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.ln_2(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class QWenModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + self.vocab_size = config.vocab_size + + self.wte = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.h = make_layers( + config.num_hidden_layers, + lambda prefix: QWenBlock( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.h") + self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in self.h[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.ln_f(hidden_states, residual) + return hidden_states + + +class QWenBaseModel(nn.Module): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + transformer_type: type[QWenModel] = QWenModel, + ) -> None: + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.multimodal_config = multimodal_config + self.quant_config = quant_config + self.transformer = transformer_type(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "transformer")) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.transformer.wte.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "w2", 0), + ("gate_up_proj", "w1", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class QWenLMHeadModel(QWenBaseModel, SupportsPP, SupportsLoRA): + packed_modules_mapping = { + "c_attn": ["c_attn"], + "gate_up_proj": [ + "w2", + "w1", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + if hasattr(config, "visual"): + hf_overrides = { + "architectures": ["QwenVLForConditionalGeneration"] + } + raise RuntimeError( + "The configuration of this model indicates that it supports " + "vision inputs, but you instantiated the text-only version " + "of this model. Please use the vision model by setting " + f"`--hf-overrides '{json.dumps(hf_overrides)}'`") + + super().__init__(vllm_config=vllm_config, prefix=prefix) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) + return hidden_states diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py new file mode 100644 index 0000000..2831a5a --- /dev/null +++ b/vllm/model_executor/models/qwen2.py @@ -0,0 +1,562 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen2 model compatible with HuggingFace weights.""" +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import Qwen2Config + +from vllm.attention import Attention, AttentionType +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors, PoolerOutput + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + + +class Qwen2MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Qwen2Attention(nn.Module): + + def __init__(self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[Tuple] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=self.rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=attn_type) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class Qwen2DecoderLayer(nn.Module): + + def __init__( + self, + config: Qwen2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 1000000) + rope_scaling = getattr(config, "rope_scaling", None) + + # By default, Qwen2 uses causal attention as it is a decoder-only model. + # You can override the HF config with `is_causal=False` to enable + # bidirectional attention, which is used in some embedding models + # (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct) + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + attn_type = AttentionType.ENCODER_ONLY + + self.self_attn = Qwen2Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=rope_scaling, + prefix=f"{prefix}.self_attn", + attn_type=attn_type, + ) + self.mlp = Qwen2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, + # otherwise (seq_len, ). + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }) +class Qwen2Model(nn.Module): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + # TODO (@robertgshaw2): see if this can be moved out + if (cache_config.sliding_window is not None + and hasattr(config, "max_window_layers")): + raise ValueError("Sliding window for some but all layers is not " + "supported. This model uses sliding window " + "but `max_window_layers` = {} is less than " + "`num_hidden_layers` = {}. Please open an issue " + "to discuss this feature.".format( + config.max_window_layers, + config.num_hidden_layers, + )) + + self.config = config + self.quant_config = quant_config + self.vocab_size = config.vocab_size + + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", + ) + else: + self.embed_tokens = PPMissingLayer() + + # Use the provided decoder layer type or default to Qwen2DecoderLayer + decoder_layer_type = decoder_layer_type or Qwen2DecoderLayer + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: decoder_layer_type(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = Qwen2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) + + +class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + pooler_config = vllm_config.model_config.pooler_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = Qwen2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + # TODO: Replace this model class with as_embedding_model( + # Qwen2ForCausalLM) after changing the default pooling method + if pooler_config.pooling_type is None: + logger.warning( + "This embedding model will default to last-token pooling in " + "an upcoming version. To avoid breaking changes, you should " + "pass `--override-pooler-config '{\"pooling_type\": \"MEAN\"}'`" + " explicitly.") + + self._pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.MEAN, + normalize=True, + softmax=False) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + return self.model(input_ids, positions, intermediate_tensors) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + weights = self.hf_to_vllm_mapper.apply(weights) + weights = ((name, data) for name, data in weights + if not name.startswith("lm_head.")) + self.model.load_weights(weights) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py new file mode 100644 index 0000000..1e6ff1f --- /dev/null +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -0,0 +1,1122 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +# Copyright 2025 The vLLM team. +# Copyright 2025 The Qwen Team. +# Copyright 2025 The HuggingFace Inc. team. +# All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights.""" +from functools import cached_property, partial +from typing import (Callable, Iterable, List, Literal, Mapping, Optional, Set, + Tuple, TypedDict, Union) + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from transformers import BatchFeature +from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor +from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( + Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) + +from vllm.config import VllmConfig +from vllm.distributed import parallel_state, tensor_model_parallel_all_gather +from vllm.distributed import utils as dist_utils +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.gptq import GPTQConfig +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig) +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalFieldConfig +from vllm.platforms import _Backend +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.config import uses_mrope + +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsPP) +from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder +from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo, + apply_rotary_pos_emb_vision) +from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors, + init_vllm_registered_model, maybe_prefix, + merge_multimodal_embeddings) +from .vision import get_vit_attn_backend + +logger = init_logger(__name__) + +# === Vision Inputs === # + + +class Qwen2_5_VLImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + pixel_values: torch.Tensor + """Shape: + `(num_patches, num_channels * patch_size * patch_size)` + """ + + image_grid_thw: torch.Tensor + """Shape: `(num_images, 3)` + This should be in `(grid_t, grid_h, grid_w)` format. + """ + + +class Qwen2_5_VLImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + image_embeds: torch.Tensor + """Supported types: + - List[`torch.Tensor`]: A list of tensors holding all images' features. + Each tensor holds an image's features. + - `torch.Tensor`: A tensor holding all images' features + (concatenation of all images' feature tensors). + + Tensor shape: `(num_image_features, hidden_size)` + - `num_image_features` varies based on + the number and resolution of the images. + - `hidden_size` must match the hidden size of language model backbone. + """ + + image_grid_thw: torch.Tensor + """Shape: `(num_images, 3)` + This should be in `(grid_t, grid_h, grid_w)` format. + """ + + +Qwen2_5_VLImageInputs = Union[Qwen2_5_VLImagePixelInputs, + Qwen2_5_VLImageEmbeddingInputs] + + +class Qwen2_5_VLVideoPixelInputs(TypedDict): + type: Literal["pixel_values_videos"] + pixel_values_videos: torch.Tensor + """Shape: + `(num_patches, + num_channels * temporal_patch_size * patch_size * patch_size)` + """ + + video_grid_thw: torch.Tensor + """Shape: `(num_videos, 3)` + + This should be in `(grid_t, grid_h, grid_w)` format. + """ + + second_per_grid_ts: torch.Tensor + """ + The video time interval (in seconds) for each grid along the temporal + dimension in the 3D position IDs. Returned when `videos` is not `None`. + """ + + +class Qwen2_5_VLVideoEmbeddingInputs(TypedDict): + type: Literal["video_embeds"] + video_embeds: torch.Tensor + """Supported types: + - List[`torch.Tensor`]: A list of tensors holding all videos' features. + Each tensor holds an video's features. + - `torch.Tensor`: A tensor holding all videos' features + (concatenation of all videos' feature tensors). + + Tensor shape: `(num_image_features, hidden_size)` + - `num_image_features` varies based on + the number and resolution of the videos. + - `hidden_size` must match the hidden size of language model backbone. + """ + + video_grid_thw: torch.Tensor + """Shape: `(num_videos, 3)` + This should be in `(grid_t, grid_h, grid_w)` format. + """ + + +Qwen2_5_VLVideoInputs = Union[Qwen2_5_VLVideoPixelInputs, + Qwen2_5_VLVideoEmbeddingInputs] + +# === Vision Encoder === # + + +class Qwen2_5_VisionMLP(nn.Module): + + def __init__(self, + in_features: int, + hidden_features: int, + bias: bool = False, + act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.gate_proj = ColumnParallelLinear(in_features, + hidden_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_proj") + self.up_proj = ColumnParallelLinear(in_features, + hidden_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.up_proj") + self.down_proj = RowParallelLinear(hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj") + self.act_fn = act_fn + + def forward(self, x: torch.Tensor): + x_gate, _ = self.gate_proj(x) + x_gate = self.act_fn(x_gate) + x_up, _ = self.up_proj(x) + x_down, _ = self.down_proj(x_gate * x_up) + return x_down + + +class Qwen2_5_VisionAttention(nn.Module): + + def __init__( + self, + embed_dim: int, + num_heads: int, + projection_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + # Per attention head and per partition values. + self.tp_size = parallel_state.get_tensor_model_parallel_world_size() + self.tp_rank = parallel_state.get_tensor_model_parallel_rank() + self.hidden_size_per_attention_head = dist_utils.divide( + projection_size, num_heads) + self.num_attention_heads_per_partition = dist_utils.divide( + num_heads, self.tp_size) + + self.qkv = ColumnParallelLinear(input_size=embed_dim, + output_size=3 * projection_size, + quant_config=quant_config, + prefix=f"{prefix}.qkv") + self.proj = RowParallelLinear(input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj") + + # Detect attention implementation. + self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + if self.attn_backend not in { + _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS + }: + raise RuntimeError( + f"Qwen2.5-VL does not support {self.attn_backend} backend now." + ) + + def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: + # [s, b, 3 * head * head_dim] + seq_len, bs, _ = qkv.shape + if self.tp_size > 1: + qkv = tensor_model_parallel_all_gather(qkv) + + # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] + q, k, v = qkv.chunk(3, dim=2) + + # 3 * [s, b, head * head_dim] + if self.tp_size > 1: + splitter = partial(dist_utils.split_tensor_along_last_dim, + num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + v = splitter(v)[self.tp_rank] + + # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] + new_shape = (seq_len, bs, self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + q, k, v = (x.view(*new_shape) for x in (q, k, v)) + return q, k, v + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers + ) -> torch.Tensor: + # [s, b, c] --> [s, b, head * 3 * head_dim] + x, _ = self.qkv(x) + + # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] + q, k, v = self.split_qkv(x) + batch_size = q.shape[1] + + q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() + for x in (q, k, v)) + if rotary_pos_emb is not None: + use_flash_attn = self.attn_backend == _Backend.FLASH_ATTN + q = apply_rotary_pos_emb_vision(q, + rotary_pos_emb, + use_flash_attn=use_flash_attn) + k = apply_rotary_pos_emb_vision(k, + rotary_pos_emb, + use_flash_attn=use_flash_attn) + + if self.attn_backend == _Backend.FLASH_ATTN: + # from vllm_flash_attn.flash_attn_interface import ( + # flash_attn_varlen_func) + from flash_attn import flash_attn_varlen_func + + q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + + output = flash_attn_varlen_func(q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0, + causal=False) + + context_layer = rearrange(output, + "(b s) ... -> b s ...", + b=batch_size) + elif self.attn_backend == _Backend.TORCH_SDPA: + # Execute attention entry by entry for speed & less VRAM. + outputs = [] + for i in range(1, len(cu_seqlens)): + start_idx = cu_seqlens[i - 1] + end_idx = cu_seqlens[i] + q_i = q[:, start_idx:end_idx] + k_i = k[:, start_idx:end_idx] + v_i = v[:, start_idx:end_idx] + q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d") + for x in [q_i, k_i, v_i]) + output_i = F.scaled_dot_product_attention(q_i, + k_i, + v_i, + dropout_p=0.0) + output_i = rearrange(output_i, "b h s d -> b s h d ") + outputs.append(output_i) + context_layer = torch.cat(outputs, dim=1) + elif self.attn_backend == _Backend.XFORMERS: + from xformers import ops as xops + from xformers.ops.fmha.attn_bias import BlockDiagonalMask + + attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, + kv_seqlen=None, + device=q.device) + + context_layer = xops.memory_efficient_attention_forward( + q, k, v, attn_bias=attn_bias, p=0, scale=None) + context_layer = rearrange(context_layer, + "b s h d -> s b (h d)").contiguous() + + output, _ = self.proj(context_layer) + return output + + +class Qwen2_5_VisionBlock(nn.Module): + + def __init__( + self, + dim: int, + num_heads: int, + mlp_hidden_dim: int, + act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, + norm_layer: Optional[Callable[[int], nn.Module]] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim) + self.attn = Qwen2_5_VisionAttention(embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn") + self.mlp = Qwen2_5_VisionMLP(dim, + mlp_hidden_dim, + act_fn=act_fn, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers + ) -> torch.Tensor: + x = x + self.attn(self.norm1(x), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens) + + x = x + self.mlp(self.norm2(x)) + return x + + +class Qwen2_5_VisionPatchEmbed(nn.Module): + + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + hidden_size: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.hidden_size = hidden_size + + kernel_size = (temporal_patch_size, patch_size, patch_size) + self.proj = nn.Conv3d(in_channels, + hidden_size, + kernel_size=kernel_size, + stride=kernel_size, + bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + L, C = x.shape + x = x.view(L, -1, self.temporal_patch_size, self.patch_size, + self.patch_size) + x = self.proj(x).view(L, self.hidden_size) + return x + + +class Qwen2_5_VisionPatchMerger(nn.Module): + + def __init__( + self, + d_model: int, + context_dim: int, + norm_layer: Optional[Callable[[int], nn.Module]] = None, + spatial_merge_size: int = 2, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.ln_q = norm_layer(context_dim) + self.mlp = nn.ModuleList([ + ColumnParallelLinear(self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.0"), + nn.GELU(), + RowParallelLinear(self.hidden_size, + d_model, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.2"), + ]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.ln_q(x) + x = x.view(-1, self.hidden_size) + + mlp_fc1, mlp_act, mlp_fc2 = self.mlp + x_parallel, _ = mlp_fc1(x) + x_parallel = mlp_act(x_parallel) + out, _ = mlp_fc2(x_parallel) + return out + + +class Qwen2_5_VisionRotaryEmbedding(nn.Module): + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + self.dim = dim + self.theta = theta + inv_freq = 1.0 / (theta + **(torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._freqs_cached = None + + def update_freqs_cache(self, seqlen: int) -> None: + if seqlen > self._seq_len_cached: + seqlen *= 2 + self._seq_len_cached = seqlen + self.inv_freq = 1.0 / (self.theta**(torch.arange( + 0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device) + / self.dim)) + seq = torch.arange(seqlen, + device=self.inv_freq.device, + dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + self._freqs_cached = freqs + + def forward(self, seqlen: int) -> torch.Tensor: + self.update_freqs_cache(seqlen) + return self._freqs_cached[:seqlen] + + +class Qwen2_5_VisionTransformer(nn.Module): + + def __init__( + self, + vision_config: Qwen2_5_VLVisionConfig, + norm_eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + patch_size = vision_config.patch_size + temporal_patch_size = vision_config.temporal_patch_size + in_channels = vision_config.in_channels + depth = vision_config.depth + self.hidden_size = vision_config.hidden_size + self.num_heads = vision_config.num_heads + + # args for get_window_index + self.window_size = vision_config.window_size + self.patch_size = vision_config.patch_size + self.spatial_merge_size = vision_config.spatial_merge_size + self.fullatt_block_indexes = vision_config.fullatt_block_indexes + self.spatial_merge_unit = self.spatial_merge_size**2 + + self.patch_embed = Qwen2_5_VisionPatchEmbed( + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + in_channels=in_channels, + hidden_size=self.hidden_size, + ) + + norm_layer = partial(RMSNorm, eps=norm_eps) + head_dim = self.hidden_size // self.num_heads + self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList([ + Qwen2_5_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}") + for layer_idx in range(depth) + ]) + self.merger = Qwen2_5_VisionPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=self.hidden_size, + norm_layer=norm_layer, + spatial_merge_size=self.spatial_merge_size, + quant_config=quant_config, + prefix=f"{prefix}.merger", + ) + self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + + @property + def dtype(self) -> torch.dtype: + return self.patch_embed.proj.weight.dtype + + @property + def device(self) -> torch.device: + return self.patch_embed.proj.weight.device + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + pos_ids.append( + torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = (self.window_size // + self.spatial_merge_size // self.patch_size) + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h = grid_h // self.spatial_merge_size + llm_grid_w = grid_w // self.spatial_merge_size + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( + grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100) + index_padded = index_padded.reshape(grid_t, num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, num_windows_h * num_windows_w, vit_merger_window_size, + vit_merger_window_size) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = seqlens.cumsum( + 0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + return window_index, cu_window_seqlens + + def compute_attn_mask_seqlen( + self, + cu_seqlens: torch.Tensor, + ) -> tuple[Optional[int], Optional[list[int]]]: + max_seqlen, seqlens = None, None + if self.attn_backend == _Backend.FLASH_ATTN: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + elif self.attn_backend == _Backend.XFORMERS: + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + return max_seqlen, seqlens + + def forward( + self, + x: torch.Tensor, + grid_thw: torch.Tensor, + ) -> torch.Tensor: + # patchify + hidden_states = x.to(device=self.device, dtype=self.dtype) + hidden_states = self.patch_embed(hidden_states) + + # compute position embedding + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + # windows attention + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=hidden_states.device, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + # compute cu_seqlens + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], + grid_thw[:, 0]).cumsum( + dim=0, dtype=torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) + + # transformers + hidden_states = hidden_states.unsqueeze(1) + + # pre-compute seqlens for window/full attn to reduce cuMemcpy operations + max_seqlen_full, seqlens_full = self.compute_attn_mask_seqlen( + cu_seqlens) + max_seqlen_window, seqlens_window = self.compute_attn_mask_seqlen( + cu_window_seqlens) + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + max_seqlen_now = max_seqlen_full + seqlens_now = seqlens_full + else: + cu_seqlens_now = cu_window_seqlens + max_seqlen_now = max_seqlen_window + seqlens_now = seqlens_window + + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens_now, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen_now, + seqlens=seqlens_now, + ) + + # For Qwen2.5-VL-3B, float16 will overflow at last block + # for long visual tokens sequences. + if hidden_states.dtype == torch.float16: + hidden_states = cast_overflow_tensors(hidden_states) + + # adapter + hidden_states = self.merger(hidden_states) + reverse_indices = torch.argsort(window_index) + hidden_states = hidden_states[reverse_indices, :] + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() + + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Qwen2_5_VLProcessingInfo(Qwen2VLProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(Qwen2_5_VLConfig) + + def get_hf_processor( + self, + *, + min_pixels: Optional[int] = None, + max_pixels: Optional[int] = None, + size: Optional[dict[str, int]] = None, + fps: Optional[Union[float, List[float]]] = None, + **kwargs: object, + ) -> Qwen2_5_VLProcessor: + if fps is not None: + kwargs["fps"] = fps + + return self.ctx.get_hf_processor( + Qwen2_5_VLProcessor, + image_processor=self.get_image_processor(min_pixels=min_pixels, + max_pixels=max_pixels, + size=size), + **kwargs, + ) + + +class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor): + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + **super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs), + second_per_grid_ts=MultiModalFieldConfig.batched("video"), + ) + + +@MULTIMODAL_REGISTRY.register_processor( + Qwen2_5_VLMultiModalProcessor, + info=Qwen2_5_VLProcessingInfo, + dummy_inputs=Qwen2_5_VLDummyInputsBuilder) +class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # To ensure correct weight loading and mapping. + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ + "lm_head.": "language_model.lm_head.", + "model.": "language_model.model.", + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + + self.visual = Qwen2_5_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=self._maybe_ignore_quant_config(quant_config), + prefix=maybe_prefix(prefix, "visual"), + ) + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=["Qwen2ForCausalLM"], + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return get_sampler() + + def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + # GPTQ configs do not have a list of ignored modules, however AutoGPTQ + # seems to avoid vision encoder sections for some models. + if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): + return None + return quant_config + + def _validate_and_reshape_mm_tensor(self, mm_input: object, + name: str) -> torch.Tensor: + if not isinstance(mm_input, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of {name}. " + f"Got type: {type(mm_input)}") + if isinstance(mm_input, torch.Tensor): + if mm_input.ndim == 2: + return mm_input + if mm_input.ndim != 3: + raise ValueError(f"{name} should be 2D or batched 3D tensor. " + f"Got ndim: {mm_input.ndim} " + f"(shape={mm_input.shape})") + return torch.concat(list(mm_input)) + else: + return torch.concat(mm_input) + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Qwen2_5_VLImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + pixel_values = self._validate_and_reshape_mm_tensor( + pixel_values, "image pixel values") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of image pixel values. " + f"Got type: {type(pixel_values)}") + + return Qwen2_5_VLImagePixelInputs(type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw) + + if image_embeds is not None: + image_embeds = self._validate_and_reshape_mm_tensor( + image_embeds, "image embeds") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(image_embeds, torch.Tensor): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + return Qwen2_5_VLImageEmbeddingInputs( + type="image_embeds", + image_embeds=image_embeds, + image_grid_thw=image_grid_thw) + + def _parse_and_validate_video_input( + self, **kwargs: object) -> Optional[Qwen2_5_VLVideoInputs]: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + video_embeds = kwargs.pop("video_embeds", None) + video_grid_thw = kwargs.pop("video_grid_thw", None) + second_per_grid_ts = kwargs.pop("second_per_grid_ts", None) + + if pixel_values_videos is None and video_embeds is None: + return None + + if pixel_values_videos is not None: + pixel_values_videos = self._validate_and_reshape_mm_tensor( + pixel_values_videos, "video pixel values") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + return Qwen2_5_VLVideoPixelInputs( + type="pixel_values_videos", + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + ) + + if video_embeds is not None: + video_embeds = self._validate_and_reshape_mm_tensor( + video_embeds, "video embeds") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + if not isinstance(video_embeds, torch.Tensor): + raise ValueError("Incorrect type of video embeddings. " + f"Got type: {type(video_embeds)}") + return Qwen2_5_VLVideoEmbeddingInputs( + type="video_embeds", + video_embeds=video_embeds, + video_grid_thw=video_grid_thw) + + def _process_image_input( + self, + image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]: + + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + + if image_input["type"] == "image_embeds": + image_embeds = image_input["image_embeds"].type(self.visual.dtype) + else: + pixel_values = image_input["pixel_values"].type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + + # Split concatenated embeddings for each image item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return image_embeds.split(sizes.tolist()) + + def _process_video_input( + self, + video_input: Qwen2_5_VLVideoInputs) -> tuple[torch.Tensor, ...]: + + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + + if video_input["type"] == "video_embeds": + video_embeds = video_input["video_embeds"].type(self.visual.dtype) + else: + pixel_values_videos = video_input["pixel_values_videos"].type( + self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) + + # Split concatenated embeddings for each video item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return video_embeds.split(sizes.tolist()) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("pixel_values", + "image_embeds") and "images" not in modalities: + modalities["images"] = self._parse_and_validate_image_input( + **kwargs) + if input_key in ("pixel_values_videos", + "video_embeds") and "videos" not in modalities: + modalities["videos"] = self._parse_and_validate_video_input( + **kwargs) + return modalities + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: + return None + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in modalities: + if modality == "images": + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + multimodal_embeddings += vision_embeddings + if modality == "videos": + video_input = modalities["videos"] + video_embeddings = self._process_video_input(video_input) + multimodal_embeddings += video_embeddings + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + [self.config.image_token_id, self.config.video_token_id]) + return inputs_embeds + + def get_input_embeddings_v0( + self, + input_ids: torch.Tensor, + image_input: Optional[Qwen2_5_VLImageInputs] = None, + video_input: Optional[Qwen2_5_VLVideoInputs] = None, + ) -> torch.Tensor: + inputs_embeds = self.get_input_embeddings(input_ids) + if image_input is not None: + image_embeds = self._process_image_input(image_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + image_embeds, + placeholder_token_id=self.config.image_token_id, + ) + + if video_input is not None: + video_embeds = self._process_video_input(video_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + video_embeds, + placeholder_token_id=self.config.video_token_id, + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + """Run forward pass for Qwen2.5-VL. + + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + positions: Flattened (concatenated) position ids corresponding to a + batch. + **NOTE**: If mrope is enabled (default setting for Qwen2.5-VL + opensource models), the shape will be `(3, seq_len)`, + otherwise it will be `(seq_len,). + pixel_values: Pixel values to be fed to a model. + `None` if no images are passed. + image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. + `None` if no images are passed. + pixel_values_videos: Pixel values of videos to be fed to a model. + `None` if no videos are passed. + video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. + `None` if no videos are passed. + second_per_grid_ts: Tensor `(num_videos)` of video time interval ( + in seconds) for each grid along the temporal dimension in the + 3D position IDs. `None` if no videos are passed. + """ + + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner from + # `get_multimodal_embeddings` and `get_input_embeddings`, this + # condition is only for v0 compatibility. + elif inputs_embeds is None: + image_input = self._parse_and_validate_image_input(**kwargs) + video_input = self._parse_and_validate_video_input(**kwargs) + + if image_input is None and video_input is None: + inputs_embeds = None + else: + if uses_mrope(self.config): + assert positions.ndim == 2 and positions.size(0) == 3, ( + "multimodal section rotary embedding requires " + f"(3, seq_len) positions, but got {positions.size()}") + inputs_embeds = self.get_input_embeddings_v0( + input_ids, + image_input=image_input, + video_input=video_input) + input_ids = None + + hidden_states = self.language_model.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="visual.", + tower_model="visual.merger.") diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py new file mode 100644 index 0000000..ccb5a3f --- /dev/null +++ b/vllm/model_executor/models/qwen2_audio.py @@ -0,0 +1,422 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" +from collections.abc import Iterable, Mapping, Sequence +from functools import cached_property +from typing import Any, Optional, Set, Tuple, TypedDict, Union + +import torch +import torch.nn as nn +from transformers import BatchFeature +from transformers.models.qwen2_audio import (Qwen2AudioConfig, + Qwen2AudioEncoder, + Qwen2AudioProcessor) +from transformers.models.whisper import WhisperFeatureExtractor + +from vllm.config import VllmConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, + MultiModalDataParser) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .utils import (AutoWeightsLoader, init_vllm_registered_model, + maybe_prefix, merge_multimodal_embeddings) + + +# # === Audio Inputs === # +class Qwen2AudioInputs(TypedDict): + input_features: torch.Tensor + """Shape: `(num_audios, num_mel_bins, 3000)`""" + + feature_attention_mask: torch.Tensor + """Shape: `(num_audios, 3000)`""" + + +# === Audio Encoder === # + + +class Qwen2AudioMultiModalProjector(nn.Module): + + def __init__(self, audio_hidden_size: int, text_hidden_size: int): + super().__init__() + self.linear = nn.Linear(audio_hidden_size, text_hidden_size, bias=True) + + def forward(self, audio_features): + hidden_states = self.linear(audio_features) + return hidden_states + + +# From Qwen2AudioEncoder._get_feat_extract_output_lengths +def _get_feat_extract_output_lengths(input_lengths: torch.Tensor): + feat_lengths = (input_lengths - 1) // 2 + 1 + output_lengths = (feat_lengths - 2) // 2 + 1 + return feat_lengths, output_lengths + + +class Qwen2AudioProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(Qwen2AudioConfig) + + def get_hf_processor( + self, + *, + # Ignored in initialization + sampling_rate: Optional[int] = None, + **kwargs: object, + ) -> Qwen2AudioProcessor: + return self.ctx.get_hf_processor(Qwen2AudioProcessor, **kwargs) + + def get_feature_extractor( + self, + *, + # Ignored in initialization + sampling_rate: Optional[int] = None, + ) -> WhisperFeatureExtractor: + hf_processor = self.get_hf_processor(sampling_rate=sampling_rate) + feature_extractor = hf_processor.feature_extractor # type: ignore + assert isinstance(feature_extractor, WhisperFeatureExtractor) + return feature_extractor + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"audio": None} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + hf_config = self.get_hf_config() + max_source_positions = hf_config.audio_config.max_source_positions + max_output_lengths = (max_source_positions - 2) // 2 + 1 + + return {"audio": max_output_lengths} + + +class Qwen2AudioDummyInputsBuilder( + BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + feature_extractor = self.info.get_feature_extractor() + + sampling_rate = feature_extractor.sampling_rate + audio_len = feature_extractor.chunk_length * sampling_rate + num_audios = mm_counts.get("audio", 0) + + mm_data = { + "audio": + self._get_dummy_audios(length=audio_len, num_audios=num_audios) + } + + return ProcessorInputs( + prompt_text="<|AUDIO|>" * num_audios, + mm_data=mm_data, + ) + + +class Qwen2AudioMultiModalProcessor( + BaseMultiModalProcessor[Qwen2AudioProcessingInfo]): + + def _get_data_parser(self) -> MultiModalDataParser: + feature_extractor = self.info.get_feature_extractor() + return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, Any], + ) -> BatchFeature: + # Text-only input not supported in composite processor + if not mm_data.get("audios", []): + prompt_ids = self.info.get_tokenizer().encode(prompt) + prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + feature_extractor = self.info.get_feature_extractor(**mm_kwargs) + mm_kwargs = dict( + **mm_kwargs, + sampling_rate=feature_extractor.sampling_rate, + ) + + return super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + input_features=MultiModalFieldConfig.batched("audio"), + feature_attention_mask=MultiModalFieldConfig.batched("audio"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + + # Use getattr with default to be compatible with transformers<4.48 + audio_token = getattr(processor, "audio_token", "<|AUDIO|>") + audio_bos_token = getattr(processor, "audio_bos_token", + "<|audio_bos|>") + audio_eos_token = getattr(processor, "audio_eos_token", + "<|audio_eos|>") + + audio_token_id = vocab[audio_token] + audio_bos_id = vocab[audio_bos_token] + audio_eos_id = vocab[audio_eos_token] + + feature_attention_mask = out_mm_kwargs.get("feature_attention_mask") + if feature_attention_mask is None: + audio_output_lengths = [] + else: + assert isinstance(feature_attention_mask, torch.Tensor) + _, audio_output_lens = _get_feat_extract_output_lengths( + feature_attention_mask.sum(-1)) + + audio_output_lengths = audio_output_lens.tolist() + + def get_replacement_qwen2_audio(item_idx: int): + num_features = audio_output_lengths[item_idx] + if num_features == 0: + audios = mm_items.get_items("audio", AudioProcessorItems) + audio_len = audios.get_audio_length(item_idx) + + raise ValueError(f"The audio (len={audio_len}) is too short " + "to be represented inside the model") + + audio_tokens = [audio_token_id] * num_features + + return PromptUpdateDetails( + full=[audio_bos_id] + audio_tokens + [audio_eos_id], + features=audio_tokens, + ) + + return [ + PromptReplacement( + modality="audio", + target=audio_token, + replacement=get_replacement_qwen2_audio, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor( + Qwen2AudioMultiModalProcessor, + info=Qwen2AudioProcessingInfo, + dummy_inputs=Qwen2AudioDummyInputsBuilder) +class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.multimodal_config = multimodal_config + + self.audio_tower = Qwen2AudioEncoder(config.audio_config) + self.multi_modal_projector = Qwen2AudioMultiModalProjector( + config.audio_config.d_model, config.text_config.hidden_size) + + self.quant_config = quant_config + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=["Qwen2ForCausalLM"], + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return get_sampler() + + def _validate_and_reshape_mm_tensor(self, mm_input: object, + name: str) -> torch.Tensor: + if not isinstance(mm_input, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of {name}. " + f"Got type: {type(mm_input)}") + if isinstance(mm_input, torch.Tensor): + return torch.concat(list(mm_input)) + else: + return torch.concat(mm_input) + + def _parse_and_validate_audio_input( + self, **kwargs: object) -> Optional[Qwen2AudioInputs]: + input_features = kwargs.pop('input_features', None) + feature_attention_mask = kwargs.pop('feature_attention_mask', None) + if input_features is None: + return None + input_features = self._validate_and_reshape_mm_tensor( + input_features, 'input_features') + feature_attention_mask = self._validate_and_reshape_mm_tensor( + feature_attention_mask, 'feature_attention_mask') + if not isinstance(input_features, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio input features. " + f"Got type: {type(input_features)}") + return Qwen2AudioInputs(input_features=input_features, + feature_attention_mask=feature_attention_mask) + + def _process_audio_input(self, + audio_input: Qwen2AudioInputs) -> torch.Tensor: + + input_features = audio_input["input_features"] + feature_attention_mask = audio_input["feature_attention_mask"] + + audio_feat_lengths, audio_output_lengths = ( + self.audio_tower._get_feat_extract_output_lengths( + feature_attention_mask.sum(-1))) + + batch_size, _, max_mel_seq_len = input_features.shape + max_seq_len = (max_mel_seq_len - 2) // 2 + 1 + # Create a sequence tensor of shape (batch_size, max_seq_len) + seq_range = (torch.arange( + 0, + max_seq_len, + dtype=audio_feat_lengths.dtype, + device=audio_feat_lengths.device).unsqueeze(0).expand( + batch_size, max_seq_len)) + lengths_expand = audio_feat_lengths.unsqueeze(-1).expand( + batch_size, max_seq_len) + # Create mask + padding_mask = seq_range >= lengths_expand + + audio_attention_mask_ = padding_mask.view( + batch_size, 1, 1, max_seq_len).expand(batch_size, 1, max_seq_len, + max_seq_len) + audio_attention_mask = audio_attention_mask_.to( + dtype=self.audio_tower.conv1.weight.dtype, + device=self.audio_tower.conv1.weight.device) + audio_attention_mask[audio_attention_mask_] = float("-inf") + + audio_outputs = self.audio_tower(input_features, + attention_mask=audio_attention_mask) + selected_audio_feature = audio_outputs.last_hidden_state + audio_features = self.multi_modal_projector(selected_audio_feature) + num_audios, max_audio_tokens, embed_dim = audio_features.shape + audio_output_lengths = audio_output_lengths.unsqueeze(1) + audio_features_mask = torch.arange(max_audio_tokens).expand( + num_audios, max_audio_tokens).to( + audio_output_lengths.device) < audio_output_lengths + masked_audio_features = audio_features[audio_features_mask].view( + -1, embed_dim) + + # Split to tuple of embeddings for individual audio input. + return torch.split(masked_audio_features, + audio_output_lengths.flatten().tolist()) + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + audio_input = self._parse_and_validate_audio_input(**kwargs) + if audio_input is None: + return None + masked_audio_features = self._process_audio_input(audio_input) + return masked_audio_features + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.config.audio_token_index) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + multimodal_embeddings) + input_ids = None + + hidden_states = self.language_model.model(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py new file mode 100644 index 0000000..21855ba --- /dev/null +++ b/vllm/model_executor/models/qwen2_moe.py @@ -0,0 +1,533 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen2MoE model compatible with HuggingFace weights.""" +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (extract_layer_index, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + + +class Qwen2MoeMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Qwen2MoeSparseMoeBlock(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}.") + + self.experts = FusedMoE(num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts") + + self.gate = ReplicatedLinear(config.hidden_size, + config.num_experts, + bias=False, + quant_config=None) + if config.shared_expert_intermediate_size > 0: + self.shared_expert = Qwen2MoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.shared_expert_intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + ) + else: + self.shared_expert = None + self.shared_expert_gate = torch.nn.Linear(config.hidden_size, + 1, + bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + shared_output = None + if self.shared_expert is not None: + shared_output = self.shared_expert(hidden_states) + if self.shared_expert_gate is not None: + shared_output = F.sigmoid( + self.shared_expert_gate(hidden_states)) * shared_output + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + if shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(orig_shape) + + +class Qwen2MoeAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=True, + quant_config=quant_config, + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class Qwen2MoeDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.self_attn = Qwen2MoeAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + # Note: Qwen/Qwen2-57B-A14B-Instruct does not have + # `mlp_only_layers` in the config. + layer_idx = extract_layer_index(prefix) + mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else + config.mlp_only_layers) + if (layer_idx not in mlp_only_layers) and ( + config.num_experts > 0 and + (layer_idx + 1) % config.decoder_sparse_step == 0): + self.mlp = Qwen2MoeSparseMoeBlock(config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + else: + self.mlp = Qwen2MoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class Qwen2MoeModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Qwen2MoeDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer(positions, hidden_states, residual) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Qwen2MoeForCausalLM(nn.Module, SupportsPP): + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = Qwen2MoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + logger.warning_once( + "Found kv scale in the checkpoint " + f"(e.g. {name}), but not found the expected " + f"name in the model " + f"(e.g. {remapped_kv_scale_name}). " + "kv-scale is not loaded.") + continue + else: + name = remapped_kv_scale_name + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py new file mode 100644 index 0000000..90f799e --- /dev/null +++ b/vllm/model_executor/models/qwen2_rm.py @@ -0,0 +1,130 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://huggingface.co/Qwen/Qwen2.5-Math-RM-72B/blob/main/modeling_qwen2_rm.py +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +"""Inference-only Qwen2-RM model compatible with HuggingFace weights.""" +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn + +from vllm.config import VllmConfig +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.pooler import Pooler, PoolingType, SimplePooler +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.sequence import IntermediateTensors, PoolerOutput + +from .interfaces import SupportsLoRA, SupportsPP, SupportsV0Only +from .qwen2 import Qwen2Model +from .utils import AutoWeightsLoader, maybe_prefix + + +class ReLU(nn.Module): + + def __init__(self): + super().__init__() + self.activation = nn.ReLU() + + def forward(self, input): + input, _ = input + return self.activation(input) + + +class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP, + SupportsV0Only): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = Qwen2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + self.score = nn.Sequential( + ColumnParallelLinear(config.hidden_size, + config.hidden_size, + quant_config=quant_config), + ReLU(), + RowParallelLinear(config.hidden_size, + config.num_labels, + quant_config=quant_config), + ) + self._pooler: SimplePooler + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + logits, _ = self.score(hidden_states) + return logits + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader(self, + ignore_unexpected_prefixes=["lm_head."]) + return loader.load_weights(weights) + + +class Qwen2ForRewardModel(Qwen2RewardBaseModel): + + def __init__(self, *, vllm_config, prefix=""): + vllm_config.model_config.hf_config.num_labels = 1 + super().__init__(vllm_config=vllm_config, prefix=prefix) + pooler_config = vllm_config.model_config.pooler_config + self._pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.ALL, + normalize=False, + softmax=False) + + +class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel): + + def __init__(self, *, vllm_config, prefix=""): + vllm_config.model_config.hf_config.num_labels = 2 + super().__init__(vllm_config=vllm_config, prefix=prefix) + pooler_config = vllm_config.model_config.pooler_config + self._pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.STEP, + normalize=False, + softmax=True, + step_tag_id=151651, + ) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py new file mode 100644 index 0000000..4276981 --- /dev/null +++ b/vllm/model_executor/models/qwen2_vl.py @@ -0,0 +1,1428 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen2-VL model compatible with HuggingFace weights.""" +from collections.abc import Iterable, Mapping, Sequence +from functools import cached_property, partial +from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict, + Union) + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from transformers import BatchFeature +from transformers.models.qwen2_vl import (Qwen2VLImageProcessor, + Qwen2VLProcessor) +from transformers.models.qwen2_vl.configuration_qwen2_vl import ( + Qwen2VLConfig, Qwen2VLVisionConfig) +from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize + +from vllm.config import VllmConfig +from vllm.distributed import parallel_state, tensor_model_parallel_all_gather +from vllm.distributed import utils as dist_utils +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.activation import QuickGELU +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.gptq import GPTQConfig +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig) +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (ImageItem, ModalityData, + MultiModalFieldConfig, MultiModalKwargs, + VideoItem) +from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize, + ModalityDataItems, MultiModalDataItems, + MultiModalDataParser) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.platforms import _Backend +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.config import uses_mrope +from vllm.transformers_utils.processor import ( + cached_image_processor_from_config) + +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsPP) +from .utils import (AutoWeightsLoader, WeightsMapper, + init_vllm_registered_model, maybe_prefix, + merge_multimodal_embeddings) +from .vision import get_vit_attn_backend + +logger = init_logger(__name__) + +# For profile run +_MAX_FRAMES_PER_VIDEO = 16 + +# === Vision Inputs === # + + +class Qwen2VLImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + pixel_values: torch.Tensor + """Shape: + `(num_patches, num_channels * patch_size * patch_size)` + """ + + image_grid_thw: torch.Tensor + """Shape: `(num_images, 3)` + This should be in `(grid_t, grid_h, grid_w)` format. + """ + + +class Qwen2VLImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + image_embeds: torch.Tensor + """Supported types: + - List[`torch.Tensor`]: A list of tensors holding all images' features. + Each tensor holds an image's features. + - `torch.Tensor`: A tensor holding all images' features + (concatenation of all images' feature tensors). + + Tensor shape: `(num_image_features, hidden_size)` + - `num_image_features` varies based on + the number and resolution of the images. + - `hidden_size` must match the hidden size of language model backbone. + """ + + image_grid_thw: torch.Tensor + """Shape: `(num_images, 3)` + This should be in `(grid_t, grid_h, grid_w)` format. + """ + + +Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs, + Qwen2VLImageEmbeddingInputs] + + +class Qwen2VLVideoPixelInputs(TypedDict): + type: Literal["pixel_values_videos"] + pixel_values_videos: torch.Tensor + """Shape: + `(num_patches, + num_channels * temporal_patch_size * patch_size * patch_size)` + """ + + video_grid_thw: torch.Tensor + """Shape: `(num_videos, 3)` + + This should be in `(grid_t, grid_h, grid_w)` format. + """ + + +class Qwen2VLVideoEmbeddingInputs(TypedDict): + type: Literal["video_embeds"] + video_embeds: torch.Tensor + """Supported types: + - List[`torch.Tensor`]: A list of tensors holding all videos' features. + Each tensor holds an video's features. + - `torch.Tensor`: A tensor holding all videos' features + (concatenation of all videos' feature tensors). + + Tensor shape: `(num_image_features, hidden_size)` + - `num_image_features` varies based on + the number and resolution of the videos. + - `hidden_size` must match the hidden size of language model backbone. + """ + + video_grid_thw: torch.Tensor + """Shape: `(num_videos, 3)` + This should be in `(grid_t, grid_h, grid_w)` format. + """ + + +Qwen2VLVideoInputs = Union[Qwen2VLVideoPixelInputs, + Qwen2VLVideoEmbeddingInputs] + +# === Vision Encoder === # + + +class Qwen2VisionMLP(nn.Module): + + def __init__( + self, + in_features: int, + hidden_features: int, + act_layer: type[nn.Module] = QuickGELU, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.fc1 = ColumnParallelLinear(in_features, + hidden_features, + quant_config=quant_config, + prefix=f"{prefix}.fc1") + self.act = act_layer() + self.fc2 = RowParallelLinear(hidden_features, + in_features, + quant_config=quant_config, + prefix=f"{prefix}.fc2") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_parallel, _ = self.fc1(x) + x_parallel = self.act(x_parallel) + x, _ = self.fc2(x_parallel) + return x + + +def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange(torch.stack((-x2, x1), dim=-1), + "... d two -> ... (d two)", + two=2) + + +def apply_rotary_emb_torch(x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + interleaved: bool = False) -> torch.Tensor: + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat( + cos, + "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + sin = repeat( + sin, + "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + return torch.cat( + [ + x[..., :ro_dim] * cos + + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:] + ], + dim=-1, + ) + + +def apply_rotary_pos_emb_vision(t: torch.Tensor, + freqs: torch.Tensor, + use_flash_attn=False) -> torch.Tensor: + t_ = t.float() + cos = freqs.cos() + sin = freqs.sin() + apply_rotary_emb = apply_rotary_emb_torch + if use_flash_attn: + from flash_attn.layers.rotary import apply_rotary_emb + output = apply_rotary_emb(t_, cos, sin).type_as(t) + return output + + +class Qwen2VisionAttention(nn.Module): + + def __init__( + self, + embed_dim: int, + num_heads: int, + projection_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + # Per attention head and per partition values. + world_size = parallel_state.get_tensor_model_parallel_world_size() + self.tp_size = world_size + self.tp_rank = parallel_state.get_tensor_model_parallel_rank() + self.hidden_size_per_attention_head = dist_utils.divide( + projection_size, num_heads) + self.num_attention_heads_per_partition = dist_utils.divide( + num_heads, world_size) + + self.qkv = ColumnParallelLinear(input_size=embed_dim, + output_size=3 * projection_size, + quant_config=quant_config, + prefix=f"{prefix}.qkv") + self.proj = RowParallelLinear(input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj") + + # Detect attention implementation. + self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + if self.attn_backend not in { + _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS + }: + raise RuntimeError( + f"Qwen2-VL does not support {self.attn_backend} backend now.") + + def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: + # [s, b, 3 * head * head_dim] + seq_len, bs, _ = qkv.shape + if self.tp_size > 1: + qkv = tensor_model_parallel_all_gather(qkv) + + # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] + q, k, v = qkv.chunk(3, dim=2) + + # 3 * [s, b, head * head_dim] + if self.tp_size > 1: + splitter = partial(dist_utils.split_tensor_along_last_dim, + num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + v = splitter(v)[self.tp_rank] + + # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] + new_shape = (seq_len, bs, self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + q, k, v = (x.view(*new_shape) for x in (q, k, v)) + return q, k, v + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers + ) -> torch.Tensor: + + # [s, b, c] --> [s, b, 3 * head * head_dim] + x, _ = self.qkv(x) + + # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] + q, k, v = self.split_qkv(x) + batch_size = q.shape[1] + + q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() + for x in (q, k, v)) + if rotary_pos_emb is not None: + q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) + k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) + + if self.attn_backend == _Backend.FLASH_ATTN: + # from vllm_flash_attn.flash_attn_interface import ( + # flash_attn_varlen_func) + from flash_attn import flash_attn_varlen_func + + q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + + output = flash_attn_varlen_func(q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0, + causal=False) + + context_layer = rearrange(output, + "(b s) ... -> b s ...", + b=batch_size) + elif self.attn_backend == _Backend.TORCH_SDPA: + # Execute attention entry by entry for speed & less VRAM. + outputs = [] + for i in range(1, len(cu_seqlens)): + start_idx = cu_seqlens[i - 1] + end_idx = cu_seqlens[i] + q_i = q[:, start_idx:end_idx] + k_i = k[:, start_idx:end_idx] + v_i = v[:, start_idx:end_idx] + q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d") + for x in [q_i, k_i, v_i]) + output_i = F.scaled_dot_product_attention(q_i, + k_i, + v_i, + dropout_p=0.0) + output_i = rearrange(output_i, "b h s d -> b s h d ") + outputs.append(output_i) + context_layer = torch.cat(outputs, dim=1) + elif self.attn_backend == _Backend.XFORMERS: + from xformers import ops as xops + from xformers.ops.fmha.attn_bias import BlockDiagonalMask + + attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, + kv_seqlen=None) + + context_layer = xops.memory_efficient_attention_forward( + q, k, v, attn_bias=attn_bias, p=0, scale=None) + context_layer = rearrange(context_layer, + "b s h d -> s b (h d)").contiguous() + + output, _ = self.proj(context_layer) + return output + + +class Qwen2VisionBlock(nn.Module): + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float, + act_layer: type[nn.Module] = QuickGELU, + norm_layer: Optional[Callable[[int], nn.Module]] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + + self.attn = Qwen2VisionAttention(embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn") + self.mlp = Qwen2VisionMLP(dim, + mlp_hidden_dim, + act_layer=act_layer, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers + ) -> torch.Tensor: + x = x + self.attn( + self.norm1(x), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + + x = x + self.mlp(self.norm2(x)) + return x + + +class Qwen2VisionPatchEmbed(nn.Module): + + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + embed_dim: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.embed_dim = embed_dim + + kernel_size = (temporal_patch_size, patch_size, patch_size) + self.proj = nn.Conv3d(in_channels, + embed_dim, + kernel_size=kernel_size, + stride=kernel_size, + bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + L, C = x.shape + x = x.view(L, -1, self.temporal_patch_size, self.patch_size, + self.patch_size) + x = self.proj(x).view(L, self.embed_dim) + return x + + +class Qwen2VisionPatchMerger(nn.Module): + + def __init__( + self, + d_model: int, + context_dim: int, + norm_layer: Optional[Callable[[int], nn.Module]] = None, + spatial_merge_size: int = 2, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.ln_q = norm_layer(context_dim) + self.mlp = nn.ModuleList([ + ColumnParallelLinear(self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.0"), + nn.GELU(), + RowParallelLinear(self.hidden_size, + d_model, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.2"), + ]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.ln_q(x) + x = x.view(-1, self.hidden_size) + + mlp_fc1, mlp_act, mlp_fc2 = self.mlp + x_parallel, _ = mlp_fc1(x) + x_parallel = mlp_act(x_parallel) + out, _ = mlp_fc2(x_parallel) + return out + + +class Qwen2VisionRotaryEmbedding(nn.Module): + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + self.dim = dim + self.theta = theta + inv_freq = 1.0 / (theta + **(torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._freqs_cached = None + + def update_freqs_cache(self, seqlen: int) -> None: + if seqlen > self._seq_len_cached: + seqlen *= 2 + self._seq_len_cached = seqlen + self.inv_freq = 1.0 / (self.theta**(torch.arange( + 0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device) + / self.dim)) + seq = torch.arange(seqlen, + device=self.inv_freq.device, + dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + self._freqs_cached = freqs + + def forward(self, seqlen: int) -> torch.Tensor: + self.update_freqs_cache(seqlen) + return self._freqs_cached[:seqlen] + + +class Qwen2VisionTransformer(nn.Module): + + def __init__( + self, + vision_config: Qwen2VLVisionConfig, + norm_eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + patch_size = vision_config.patch_size + temporal_patch_size = vision_config.temporal_patch_size + spatial_merge_size = vision_config.spatial_merge_size + in_channels = vision_config.in_channels + hidden_size = vision_config.hidden_size + embed_dim = vision_config.embed_dim + depth = vision_config.depth + num_heads = vision_config.num_heads + mlp_ratio = vision_config.mlp_ratio + + self.spatial_merge_size = spatial_merge_size + self.num_heads = num_heads + self.embed_dim = embed_dim + + self.patch_embed = Qwen2VisionPatchEmbed( + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + in_channels=in_channels, + embed_dim=embed_dim, + ) + + norm_layer = partial(nn.LayerNorm, eps=norm_eps) + head_dim = embed_dim // num_heads + self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList([ + Qwen2VisionBlock(dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}") + for layer_idx in range(depth) + ]) + self.merger = Qwen2VisionPatchMerger( + d_model=hidden_size, + context_dim=embed_dim, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.merger", + ) + self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + + @property + def dtype(self) -> torch.dtype: + return self.patch_embed.proj.weight.dtype + + @property + def device(self) -> torch.device: + return self.patch_embed.proj.weight.device + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + pos_ids.append( + torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def compute_attn_mask_seqlen( + self, cu_seqlens: torch.Tensor + ) -> tuple[Optional[int], Optional[list[int]]]: + max_seqlen, seqlens = None, None + if self.attn_backend == _Backend.FLASH_ATTN: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + elif self.attn_backend == _Backend.XFORMERS: + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + return max_seqlen, seqlens + + def forward( + self, + x: torch.Tensor, + grid_thw: torch.Tensor, + ) -> torch.Tensor: + # patchify + x = x.to(device=self.device, dtype=self.dtype) + x = self.patch_embed(x) + + # compute position embedding + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + # compute cu_seqlens + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], + grid_thw[:, 0]).cumsum( + dim=0, dtype=torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) + + # transformers + x = x.unsqueeze(1) + + # pre-compute seqlens for attn mask to reduce cuMemcpy operations + max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) + for blk in self.blocks: + x = blk( + x, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + + # adapter + x = self.merger(x) + + return x + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() + + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]): + image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) + image_grid_sizes = image_grid_thw.prod(-1) + + video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) + video_grid_sizes = video_grid_thw.prod(-1) + + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_grid_sizes), + image_embeds=MultiModalFieldConfig.flat_from_sizes( + "image", image_grid_sizes), + image_grid_thw=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes), + video_embeds=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes), + video_grid_thw=MultiModalFieldConfig.batched("video"), + ) + + +class Qwen2VLMultiModalDataParser(MultiModalDataParser): + + def _parse_image_data( + self, + data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], + ) -> ModalityDataItems[Any, Any]: + if isinstance(data, dict): + return DictEmbeddingItems( + data, + modality="image", + required_fields={"image_embeds", "image_grid_thw"}, + fields_factory=_qwen2vl_field_config, + ) + + return super()._parse_image_data(data) + + def _parse_video_data( + self, + data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]], + ) -> ModalityDataItems[Any, Any]: + if isinstance(data, dict): + return DictEmbeddingItems( + data, + modality="video", + required_fields={"video_embeds", "video_grid_thw"}, + fields_factory=_qwen2vl_field_config, + ) + + return super()._parse_video_data(data) + + +class Qwen2VLProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(Qwen2VLConfig) + + def get_hf_processor( + self, + *, + min_pixels: Optional[int] = None, + max_pixels: Optional[int] = None, + size: Optional[dict[str, int]] = None, + **kwargs: object, + ) -> Qwen2VLProcessor: + return self.ctx.get_hf_processor( + Qwen2VLProcessor, + image_processor=self.get_image_processor(min_pixels=min_pixels, + max_pixels=max_pixels, + size=size), + **kwargs, + ) + + def _get_image_processor_kwargs( + self, + *, + min_pixels: Optional[int] = None, + max_pixels: Optional[int] = None, + size: Optional[dict[str, int]] = None, + **kwargs: object, + ): + if self.ctx.model_config.mm_processor_kwargs: + kwargs.update(self.ctx.model_config.mm_processor_kwargs) + + if min_pixels is not None: + kwargs["min_pixels"] = min_pixels + + if size is None: + size = {"shortest_edge": min_pixels} + else: + size["shortest_edge"] = min_pixels + + if max_pixels is not None: + kwargs["max_pixels"] = max_pixels + + if size is None: + size = {"longest_edge": max_pixels} + else: + size["longest_edge"] = max_pixels + + if size is not None: + kwargs["size"] = size + + return kwargs + + def get_image_processor( + self, + *, + min_pixels: Optional[int] = None, + max_pixels: Optional[int] = None, + size: Optional[dict[str, int]] = None, + **kwargs: object, + ) -> Qwen2VLImageProcessor: + return cached_image_processor_from_config( + self.ctx.model_config, + **self._get_image_processor_kwargs(min_pixels=min_pixels, + max_pixels=max_pixels, + size=size, + **kwargs), + ) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None, "video": None} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return { + "image": self.get_max_image_tokens(), + "video": self.get_max_video_tokens(seq_len, mm_counts), + } + + def _get_vision_info( + self, + *, + image_width: int, + image_height: int, + num_frames: int = 1, + do_resize: bool = True, + image_processor: Optional[Qwen2VLImageProcessor], + ) -> tuple[ImageSize, int]: + if image_processor is None: + image_processor = self.get_image_processor() + + hf_config = self.get_hf_config() + vision_config = hf_config.vision_config + patch_size = vision_config.patch_size + merge_size = vision_config.spatial_merge_size + temporal_patch_size = vision_config.temporal_patch_size + + if do_resize: + resized_height, resized_width = smart_resize( + height=image_height, + width=image_width, + factor=patch_size * merge_size, + min_pixels=image_processor.min_pixels, + max_pixels=image_processor.max_pixels, + ) + preprocessed_size = ImageSize(width=resized_width, + height=resized_height) + else: + preprocessed_size = ImageSize(width=image_width, + height=image_height) + + # NOTE: Frames are padded to be divisible by `temporal_patch_size` + # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294 + padded_num_frames = num_frames + num_frames % temporal_patch_size + + grid_t = max(padded_num_frames // temporal_patch_size, 1) + grid_h = preprocessed_size.height // patch_size + grid_w = preprocessed_size.width // patch_size + + num_patches = grid_t * grid_h * grid_w + num_vision_tokens = num_patches // (merge_size**2) + + return preprocessed_size, num_vision_tokens + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + image_processor: Optional[Qwen2VLImageProcessor], + ) -> int: + _, num_image_tokens = self._get_vision_info( + image_width=image_width, + image_height=image_height, + image_processor=image_processor, + ) + return num_image_tokens + + def get_num_video_tokens( + self, + *, + image_width: int, + image_height: int, + num_frames: int, + image_processor: Optional[Qwen2VLImageProcessor], + ) -> int: + _, num_video_tokens = self._get_vision_info( + image_width=image_width, + image_height=image_height, + num_frames=num_frames, + image_processor=image_processor, + ) + return num_video_tokens + + def get_image_size_with_most_features(self) -> ImageSize: + max_image_size, _ = self._get_vision_info( + image_width=9999999, + image_height=9999999, + image_processor=None, + ) + return max_image_size + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + return self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + image_processor=None, + ) + + def _get_max_video_frames(self, max_tokens: int) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + num_frames = 0 + + while True: + next_num_frames = num_frames + 1 + next_max_tokens = self.get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=next_num_frames, + image_processor=None, + ) + + if next_max_tokens > max_tokens: + break + + num_frames = next_num_frames + + return num_frames + + def get_num_frames_with_most_features( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + max_images = mm_counts.get("image", 0) + max_videos = mm_counts.get("video", 0) + + max_image_tokens = self.get_max_image_tokens() * max_images + max_total_frames = self._get_max_video_frames(seq_len - + max_image_tokens) + max_frames_per_video = min(max_total_frames // max(max_videos, 1), + _MAX_FRAMES_PER_VIDEO) + + return max(max_frames_per_video, 1) + + def get_max_video_tokens( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + return self.get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=self.get_num_frames_with_most_features( + seq_len, mm_counts), + image_processor=None, + ) + + +class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + hf_processor = self.info.get_hf_processor() + image_token: str = hf_processor.image_token + video_token: str = hf_processor.video_token + + target_width, target_height = \ + self.info.get_image_size_with_most_features() + target_num_frames = \ + self.info.get_num_frames_with_most_features(seq_len, mm_counts) + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + "video": + self._get_dummy_videos( + width=target_width, + height=target_height, + num_frames=target_num_frames, + num_videos=num_videos, + ) + } + + return ProcessorInputs( + prompt_text=image_token * num_images + video_token * num_videos, + mm_data=mm_data, + ) + + +class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] + ): + + def _get_data_parser(self) -> MultiModalDataParser: + return Qwen2VLMultiModalDataParser() + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + return self.info.ctx.call_hf_processor( + self.info.get_hf_processor(**mm_kwargs), + dict(text=prompt, **mm_data), + self.info._get_image_processor_kwargs(**mm_kwargs), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_processor = self.info.get_image_processor( + **hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + + placeholder = { + "image": vocab[hf_processor.image_token], + "video": vocab[hf_processor.video_token], + } + + merge_length = image_processor.merge_size**2 + + def get_replacement_qwen2vl(item_idx: int, modality: str): + grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx] + assert isinstance(grid_thw, torch.Tensor) + + num_tokens = int(grid_thw.prod()) // merge_length + return [placeholder[modality]] * num_tokens + + return [ + PromptReplacement( + modality=modality, + target=[placeholder[modality]], + replacement=partial(get_replacement_qwen2vl, + modality=modality), + ) for modality in ("image", "video") + ] + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return _qwen2vl_field_config(hf_inputs) + + +@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor, + info=Qwen2VLProcessingInfo, + dummy_inputs=Qwen2VLDummyInputsBuilder) +class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # To ensure correct weight loading and mapping. + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ + "lm_head.": "language_model.lm_head.", + "model.": "language_model.model.", + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config: Qwen2VLConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + + self.visual = Qwen2VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=self._maybe_ignore_quant_config(quant_config), + prefix=maybe_prefix(prefix, "visual"), + ) + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=["Qwen2ForCausalLM"], + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return get_sampler() + + def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + # GPTQ configs do not have a list of ignored modules, however AutoGPTQ + # seems to avoid vision encoder sections for some models. + # See: https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4 + if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): + return None + return quant_config + + def _validate_and_reshape_mm_tensor(self, mm_input: object, + name: str) -> torch.Tensor: + if not isinstance(mm_input, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of {name}. " + f"Got type: {type(mm_input)}") + if isinstance(mm_input, torch.Tensor): + if mm_input.ndim == 2: + return mm_input + if mm_input.ndim != 3: + raise ValueError(f"{name} should be 2D or batched 3D tensor. " + f"Got ndim: {mm_input.ndim} " + f"(shape={mm_input.shape})") + return torch.concat(list(mm_input)) + else: + return torch.concat(mm_input) + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Qwen2VLImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + pixel_values = self._validate_and_reshape_mm_tensor( + pixel_values, "image pixel values") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of image pixel values. " + f"Got type: {type(pixel_values)}") + + return Qwen2VLImagePixelInputs(type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw) + + if image_embeds is not None: + image_embeds = self._validate_and_reshape_mm_tensor( + image_embeds, "image embeds") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(image_embeds, torch.Tensor): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + return Qwen2VLImageEmbeddingInputs(type="image_embeds", + image_embeds=image_embeds, + image_grid_thw=image_grid_thw) + + def _parse_and_validate_video_input( + self, **kwargs: object) -> Optional[Qwen2VLVideoInputs]: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + video_embeds = kwargs.pop("video_embeds", None) + video_grid_thw = kwargs.pop("video_grid_thw", None) + + if pixel_values_videos is None and video_embeds is None: + return None + + if pixel_values_videos is not None: + pixel_values_videos = self._validate_and_reshape_mm_tensor( + pixel_values_videos, "video pixel values") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + return Qwen2VLVideoPixelInputs( + type="pixel_values_videos", + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + ) + + if video_embeds is not None: + video_embeds = self._validate_and_reshape_mm_tensor( + video_embeds, "video embeds") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + if not isinstance(video_embeds, torch.Tensor): + raise ValueError("Incorrect type of video embeddings. " + f"Got type: {type(video_embeds)}") + return Qwen2VLVideoEmbeddingInputs(type="video_embeds", + video_embeds=video_embeds, + video_grid_thw=video_grid_thw) + + def _process_image_input( + self, image_input: Qwen2VLImageInputs) -> tuple[torch.Tensor, ...]: + + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + + if image_input["type"] == "image_embeds": + image_embeds = image_input["image_embeds"].type(self.visual.dtype) + else: + pixel_values = image_input["pixel_values"].type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + + # Split concatenated embeddings for each image item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return image_embeds.split(sizes.tolist()) + + def _process_video_input( + self, video_input: Qwen2VLVideoInputs) -> tuple[torch.Tensor, ...]: + + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + + if video_input["type"] == "video_embeds": + video_embeds = video_input["video_embeds"].type(self.visual.dtype) + else: + pixel_values_videos = video_input["pixel_values_videos"].type( + self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) + + # Split concatenated embeddings for each video item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return video_embeds.split(sizes.tolist()) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("pixel_values", + "image_embeds") and "images" not in modalities: + modalities["images"] = self._parse_and_validate_image_input( + **kwargs) + if input_key in ("pixel_values_videos", + "video_embeds") and "videos" not in modalities: + modalities["videos"] = self._parse_and_validate_video_input( + **kwargs) + + return modalities + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: + return None + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in modalities: + if modality == "images": + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + multimodal_embeddings += vision_embeddings + if modality == "videos": + video_input = modalities["videos"] + video_embeddings = self._process_video_input(video_input) + multimodal_embeddings += video_embeddings + + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + [self.config.image_token_id, self.config.video_token_id]) + return inputs_embeds + + def get_input_embeddings_v0( + self, + input_ids: torch.Tensor, + image_input: Optional[Qwen2VLImagePixelInputs] = None, + video_input: Optional[Qwen2VLVideoPixelInputs] = None, + ) -> torch.Tensor: + inputs_embeds = self.get_input_embeddings(input_ids) + if image_input is not None: + image_embeds = self._process_image_input(image_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + image_embeds, + placeholder_token_id=self.config.image_token_id, + ) + + if video_input is not None: + video_embeds = self._process_video_input(video_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + video_embeds, + placeholder_token_id=self.config.video_token_id, + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + """Run forward pass for Qwen2-VL. + + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + positions: Flattened (concatenated) position ids corresponding to a + batch. + **NOTE**: If mrope is enabled (default setting for Qwen2-VL + opensource models), the shape will be `(3, seq_len)`, + otherwise it will be `(seq_len,). + pixel_values: Pixel values to be fed to a model. + `None` if no images are passed. + image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. + `None` if no images are passed. + pixel_values_videos: Pixel values of videos to be fed to a model. + `None` if no videos are passed. + video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. + `None` if no videos are passed. + """ + + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner from + # `get_multimodal_embeddings` and `get_input_embeddings`, this + # condition is only for v0 compatibility. + elif inputs_embeds is None: + image_input = self._parse_and_validate_image_input(**kwargs) + video_input = self._parse_and_validate_video_input(**kwargs) + + if image_input is None and video_input is None: + inputs_embeds = None + else: + if uses_mrope(self.config): + assert positions.ndim == 2 and positions.size(0) == 3, ( + "multimodal section rotary embedding requires " + f"(3, seq_len) positions, but got {positions.size()}") + inputs_embeds = self.get_input_embeddings_v0( + input_ids, + image_input=image_input, + video_input=video_input) + input_ids = None + + hidden_states = self.language_model.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="visual.", + tower_model="visual.merger.") diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py new file mode 100644 index 0000000..9c14038 --- /dev/null +++ b/vllm/model_executor/models/qwen3.py @@ -0,0 +1,329 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen3 model compatible with HuggingFace weights.""" +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import Qwen3Config + +from vllm.attention import Attention, AttentionType +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .qwen2 import Qwen2MLP as Qwen3MLP +from .qwen2 import Qwen2Model +from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix + +logger = init_logger(__name__) + + +class Qwen3Attention(nn.Module): + + def __init__(self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + head_dim: Optional[int] = None, + rms_norm_eps: float = 1e-06, + qkv_bias: bool = False, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[Tuple] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=self.rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=attn_type) + self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + # Add qk-norm + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, + self.head_dim) + q_by_head = self.q_norm.forward_native(q_by_head) + q = q_by_head.view(q.shape) + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, + self.head_dim) + k_by_head = self.k_norm.forward_native(k_by_head) + k = k_by_head.view(k.shape) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class Qwen3DecoderLayer(nn.Module): + + def __init__( + self, + config: Qwen3Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 1000000) + rope_scaling = getattr(config, "rope_scaling", None) + + # By default, Qwen3 uses causal attention as it is a decoder-only model. + # You can override the HF config with `is_causal=False` to enable + # bidirectional attention, which is used in some embedding models + # (e.g. Alibaba-NLP/gte-Qwen3-7B-instruct) + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + attn_type = AttentionType.ENCODER_ONLY + + self.self_attn = Qwen3Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rms_norm_eps=config.rms_norm_eps, + qkv_bias=getattr(config, 'attention_bias', False), + head_dim=getattr(config, 'head_dim', None), + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=rope_scaling, + prefix=f"{prefix}.self_attn", + attn_type=attn_type, + ) + self.mlp = Qwen3MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +ALL_DECODER_LAYER_TYPES = { + "attention": Qwen3DecoderLayer, +} + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, + # otherwise (seq_len, ). + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }) +class Qwen3Model(Qwen2Model): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, + prefix=prefix, + decoder_layer_type=Qwen3DecoderLayer) + + +class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = Qwen3Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py new file mode 100644 index 0000000..390bb7a --- /dev/null +++ b/vllm/model_executor/models/qwen3_moe.py @@ -0,0 +1,531 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen3MoE model compatible with HuggingFace weights.""" +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (extract_layer_index, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + + +class Qwen3MoeMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj") + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Qwen3MoeSparseMoeBlock(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}.") + + self.experts = FusedMoE(num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts") + + self.gate = ReplicatedLinear(config.hidden_size, + config.num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + final_hidden_states = final_hidden_states + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(orig_shape) + + +class Qwen3MoeAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + head_dim: Optional[int] = None, + rms_norm_eps: float = 1e-06, + qkv_bias: bool = False, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or (hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear(hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj") + + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + # Add qk-norm + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, + self.head_dim) + q_by_head = self.q_norm.forward_native(q_by_head) + q = q_by_head.view(q.shape) + + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, + self.head_dim) + k_by_head = self.k_norm.forward_native(k_by_head) + k = k_by_head.view(k.shape) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class Qwen3MoeDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.self_attn = Qwen3MoeAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + rms_norm_eps=config.rms_norm_eps, + qkv_bias=getattr(config, 'attention_bias', False), + head_dim=getattr(config, 'head_dim', None), + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + # `mlp_only_layers` in the config. + layer_idx = extract_layer_index(prefix) + mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else + config.mlp_only_layers) + if (layer_idx not in mlp_only_layers) and ( + config.num_experts > 0 and + (layer_idx + 1) % config.decoder_sparse_step == 0): + self.mlp = Qwen3MoeSparseMoeBlock(config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + else: + self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class Qwen3MoeModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + prefix=f"{prefix}.embed_tokens") + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Qwen3MoeDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, residual) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Qwen3MoeForCausalLM(nn.Module, SupportsPP): + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = Qwen3MoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + logger.warning_once( + "Found kv scale in the checkpoint " + f"(e.g. {name}), but not found the expected " + f"name in the model " + f"(e.g. {remapped_kv_scale_name}). " + "kv-scale is not loaded.") + continue + else: + name = remapped_kv_scale_name + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py new file mode 100644 index 0000000..4e9d02a --- /dev/null +++ b/vllm/model_executor/models/qwen_vl.py @@ -0,0 +1,787 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://huggingface.co/Qwen/Qwen-VL/blob/main/modeling_qwen.py +# Copyright (c) Alibaba Cloud. +"""Inference-only Qwen-VL model compatible with HuggingFace weights.""" + +import copy +import math +import re +import unicodedata +from collections.abc import Collection, Mapping, Sequence +from collections.abc import Set as AbstractSet +from functools import lru_cache, partial +from typing import Callable, List, Literal, Optional, TypedDict, Union + +import torch +from torch import nn +from torchvision import transforms +from torchvision.transforms import InterpolationMode +from transformers import (BatchFeature, PretrainedConfig, PreTrainedTokenizer, + TensorType) +from transformers.image_utils import ImageInput +from transformers.tokenization_utils_base import TextInput + +from vllm.config import VllmConfig +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors + +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsPP) +from .qwen import QWenBaseModel, QWenModel +from .utils import flatten_bn, merge_multimodal_embeddings + + +class QwenImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """ + Shape: `(batch_size * num_images, 3, image_size, image_size)` + + Note that image_size is the value in the vision config to which we resize + the image to in the normalization transform. Currently multi-image support + can only be leveraged by passing image embeddings directly. + """ + + +class QwenImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: torch.Tensor + """Shape: `(batch_size * num_images, 256, hidden_size)` + + `hidden_size` must match the hidden size of the language model backbone + and is stored in the visual config of the model if we have one. + """ + + +QwenImageInputs = Union[QwenImagePixelInputs, QwenImageEmbeddingInputs] + + +class VisualAttention(nn.Module): + """self-attention layer class. + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + kdim: Optional[int] = None, + vdim: Optional[int] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim \ + and self.vdim == embed_dim + + self.num_heads = num_heads + + # Per attention head and per partition values. + assert embed_dim % num_heads == 0 + self.hidden_size_per_attention_head = embed_dim // num_heads + self.num_attention_heads_per_partition = num_heads + self.hidden_size_per_partition = embed_dim + + # Strided linear layer. + assert self._qkv_same_embed_dim, \ + 'Visual Attention implementation only supports self-attention' + self.in_proj = ReplicatedLinear(embed_dim, 3 * embed_dim) + self.out_proj = ReplicatedLinear(embed_dim, embed_dim) + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # query/key/value: [sq, b, h] + sq, b, _ = x.size() + mixed_x_layer, _ = self.in_proj(x) + + # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] + new_tensor_shape = mixed_x_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + query_layer, key_layer, value_layer = mixed_x_layer.split( + self.hidden_size_per_attention_head, dim=-1) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view( + sq, b * self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head).transpose(0, 1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view( + sq, b * self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head).transpose(0, 1) + + q_scaled = query_layer / self.norm_factor + if attn_mask is not None: + attention_probs = torch.baddbmm(attn_mask, q_scaled, + key_layer.transpose(-2, -1)) + else: + attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1)) + attention_probs = attention_probs.softmax(dim=-1) + + value_layer = value_layer.view( + sq, b * self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head).transpose(0, 1) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer) + + # change view [b, np, sq, hn] + context_layer = context_layer.view( + b, self.num_attention_heads_per_partition, sq, + self.hidden_size_per_attention_head) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + \ + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + output, _ = self.out_proj(context_layer) + + return output + + +class QwenVLMLP(nn.Module): + """MLP for the visual component of the Qwen model.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.c_fc = ColumnParallelLinear(hidden_size, + intermediate_size, + bias=True, + quant_config=quant_config) + self.act_fn = get_act_fn("gelu") + self.c_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=True, + quant_config=quant_config, + ) + + def forward(self, x): + x, _ = self.c_fc(x) + x = self.act_fn(x) + x, _ = self.c_proj(x) + return x + + +class VisualAttentionBlock(nn.Module): + + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + norm_layer: Callable[[int], nn.Module] = nn.LayerNorm, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.attn = VisualAttention(d_model, n_head) + self.mlp = QwenVLMLP( + hidden_size=d_model, + intermediate_size=mlp_width, + quant_config=quant_config, + ) + + def attention( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None + return self.attn(x, attn_mask=attn_mask) + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + x = x + self.attention(self.ln_1(x), attn_mask=attn_mask) + x = x + self.mlp(self.ln_2(x)) + return x + + +class TransformerBlock(nn.Module): + + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + norm_layer: Callable[[int], nn.Module] = nn.LayerNorm, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.width = width + self.layers = layers + + self.resblocks = nn.ModuleList([ + VisualAttentionBlock(width, + heads, + mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config) + for _ in range(layers) + ]) + + def get_cast_dtype(self) -> torch.dtype: + return self.resblocks[0].mlp.c_fc.weight.dtype + + def get_cast_device(self) -> torch.device: + return self.resblocks[0].mlp.c_fc.weight.device + + def forward(self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + for r in self.resblocks: + x = r(x, attn_mask=attn_mask) + return x + + +class VisionTransformer(nn.Module): + + def __init__(self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + n_queries: int = 256, + output_dim: int = 512, + image_start_id: int = 151857, + quant_config: Optional[QuantizationConfig] = None, + **kwargs): + super().__init__() + image_height, image_width = self.image_size = (image_size, image_size) + patch_height, patch_width = self.patch_size = (patch_size, patch_size) + self.grid_size = (image_height // patch_height, + image_width // patch_width) + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False) + + # class embeddings and positional embeddings + scale = width**-0.5 + self.positional_embedding = nn.Parameter(scale * + torch.randn(256, width)) + + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.ln_pre = norm_layer(width) + self.transformer = TransformerBlock(width, + layers, + heads, + mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config) + + self.attn_pool = Resampler2( + grid_size=int(math.sqrt(n_queries)), + embed_dim=output_dim, + num_heads=output_dim // 128, + kv_dim=width, + norm_layer=norm_layer, + adaptive=False, + do_post_projection=False, + ).to( + device=self.positional_embedding.device, + dtype=self.positional_embedding.dtype, + ) + + self.ln_post = norm_layer(output_dim) + self.proj = nn.Parameter( + (output_dim**-0.5) * torch.randn(output_dim, output_dim)) + + self.image_start_id = image_start_id + self.image_end_id = image_start_id + 1 + self.image_pad_id = image_start_id + 2 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.to( + dtype=self.transformer.get_cast_dtype(), + device=self.transformer.get_cast_device(), + ) + + # to patches + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], + -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + x = x + get_abs_pos(self.positional_embedding, int(math.sqrt( + x.size(1)))) + + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.attn_pool(x) + x = self.ln_post(x) + x = x @ self.proj + + return x + + +class QwenVLModel(QWenModel): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.visual = VisionTransformer(**config.visual, + quant_config=quant_config) + + +@lru_cache(maxsize=1) +def _get_tokenizer_without_image_pad( + tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer: + """ + The logic of adding image pad tokens should only be applied in + :class:`QwenVLProcessor`, so they are patched out here. + + The definition of the wrapped tokenizer can be found here: + https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py + """ + new_tokenizer = copy.deepcopy(tokenizer) + + class TokenizerWithoutImagePad(tokenizer.__class__): # type: ignore + + def tokenize( + self, + text: str, + allowed_special: Union[AbstractSet[str], str] = "all", + disallowed_special: Union[Collection[str], str] = (), + **kwargs, + ) -> list[Union[bytes, str]]: + text = unicodedata.normalize("NFC", text) + + return [ + self.decoder[t] for t in self.tokenizer.encode( + text, + allowed_special=allowed_special, + disallowed_special=disallowed_special, + ) + ] + + def _decode( + self, + token_ids: Union[int, List[int]], + skip_special_tokens: bool = False, + errors: Optional[str] = None, + **kwargs, + ) -> str: + if isinstance(token_ids, int): + token_ids = [token_ids] + + return self.tokenizer.decode( + token_ids, + errors=errors or self.errors, + ) + + TokenizerWithoutImagePad.__name__ = \ + f"{tokenizer.__class__.__name__}WithoutImagePad" + + new_tokenizer.__class__ = TokenizerWithoutImagePad + return new_tokenizer + + +class QwenVLProcessor: + """ + This model doesn't define its own HF processor, + so we implement our own one here. + + We call the wrapped tokenizer to automatically insert image pad tokens: + https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py#L245 + + The image processor is defined here: + https://huggingface.co/Qwen/Qwen-VL/blob/main/visual.py#L354 + """ + + def __init__( + self, + config: PretrainedConfig, + tokenizer: PreTrainedTokenizer, + ) -> None: + super().__init__() + + self.config = config + self.tokenizer = tokenizer + + vision_config = config.visual + image_size = vision_config["image_size"] + + self.image_transform = transforms.Compose([ + transforms.Resize( + (image_size, image_size), + interpolation=InterpolationMode.BICUBIC, + ), + transforms.ToTensor(), + transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + ]) + + @property + def image_start_tag(self) -> str: + return self.tokenizer.image_start_tag # type: ignore + + @property + def image_end_tag(self) -> str: + return self.tokenizer.image_end_tag # type: ignore + + @property + def image_pad_tag(self) -> str: + return self.tokenizer.image_pad_tag # type: ignore + + def __call__( + self, + text: Optional[Union[TextInput, list[TextInput]]] = None, + images: Optional[Union[ImageInput, list[ImageInput]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + ) -> BatchFeature: + if text is None: + text = [] + if not isinstance(text, list): + text = [text] + if images is None: + images = [] + if not isinstance(images, list): + images = [images] + + text_inputs = self.tokenizer(text) + + if len(images) == 0: + image_inputs = {} + else: + pixel_values = [self.image_transform(image) for image in images] + image_inputs = {"pixel_values": torch.stack(pixel_values)} + + return BatchFeature( + { + **text_inputs, + **image_inputs, + }, + tensor_type=return_tensors, + ) + + +class QwenVLProcessingInfo(BaseProcessingInfo): + + def get_tokenizer(self) -> PreTrainedTokenizer: + tokenizer = self.ctx.tokenizer + assert isinstance(tokenizer, PreTrainedTokenizer) + + return _get_tokenizer_without_image_pad(tokenizer) + + def get_hf_processor(self, **kwargs: object) -> QwenVLProcessor: + return self.ctx.init_processor( + QwenVLProcessor, + config=self.get_hf_config(), + tokenizer=self.get_tokenizer(), + **kwargs, + ) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"image": self.get_num_image_tokens()} + + def get_num_image_tokens(self) -> int: + hf_config = self.get_hf_config() + vision_config = hf_config.visual + + image_size = vision_config["image_size"] + patch_size = vision_config["patch_size"] + grid_length = image_size // patch_size // 2 + return grid_length * grid_length + + +class QwenVLDummyInputsBuilder(BaseDummyInputsBuilder[QwenVLProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + hf_config = self.info.get_hf_config() + vision_config = hf_config.visual + + processor = self.info.get_hf_processor() + img_start = processor.image_start_tag + img_end = processor.image_end_tag + + target_width = target_height = vision_config["image_size"] + num_images = mm_counts.get("image", 0) + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text="".join(f"Picture {i}: {img_start}{img_end}\n" + for i in range(1, num_images + 1)), + mm_data=mm_data, + ) + + +class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + # Drops anything between / tags; encoding with the tokenizer + # will automatically add the image pads for the context. + prompt, num_matched_images = re.subn( + r"(Picture \d*: ).*?(<\/img>\n)", + r"\1\2", + prompt, + ) + + image_data = mm_data.get("images") + if image_data is not None: + assert isinstance(image_data, list) + + num_images = len(image_data) + assert num_matched_images == num_images + + return super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> bool: + return False + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + tokenizer = self.info.get_tokenizer() + special_tokens: dict[str, + int] = tokenizer.special_tokens # type: ignore + + processor = self.info.get_hf_processor() + img_start_id = special_tokens[processor.image_start_tag] + img_end_id = special_tokens[processor.image_end_tag] + img_pad_id = special_tokens[processor.image_pad_tag] + + num_image_tokens = self.info.get_num_image_tokens() + image_tokens = [img_pad_id] * num_image_tokens + + return [ + PromptReplacement( + modality="image", + target=[img_start_id, img_end_id], + replacement=PromptUpdateDetails( + full=[img_start_id] + image_tokens + [img_end_id], + features=image_tokens, + ), + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor(QwenVLMultiModalProcessor, + info=QwenVLProcessingInfo, + dummy_inputs=QwenVLDummyInputsBuilder) +class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, + SupportsMultiModal): + packed_modules_mapping = { + "c_attn": ["c_attn"], + "gate_up_proj": [ + "w2", + "w1", + ], + } + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="transformer.h", + connector="transformer.visual.attn_pool", + tower_model="transformer.visual.transformer") + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + transformer_type: type[QwenVLModel] = QwenVLModel, + ) -> None: + super().__init__( + vllm_config=vllm_config, + prefix=prefix, + transformer_type=transformer_type, + ) + + self.transformer: QwenVLModel + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + h = w = self.config.visual["image_size"] + expected_dims = (3, h, w) + actual_dims = tuple(data.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("batch_size", *map(str, expected_dims)) + raise ValueError( + f"The expected shape of pixel values is {expected_expr}. " + f"You supplied {tuple(data.shape)}.") + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[QwenImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + return QwenImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values( + flatten_bn(pixel_values, concat=True)), + ) + + if image_embeds is not None: + if not isinstance(image_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + + return QwenImageEmbeddingInputs( + type="image_embeds", + data=flatten_bn(image_embeds, concat=True), + ) + + return None + + def _process_image_input(self, + image_input: QwenImageInputs) -> torch.Tensor: + if image_input["type"] == "image_embeds": + return image_input["data"] + + return self.transformer.visual(image_input["data"]) + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.transformer.get_input_embeddings(input_ids) + + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.transformer.visual.image_pad_id) + + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) + return hidden_states diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py new file mode 100644 index 0000000..261feef --- /dev/null +++ b/vllm/model_executor/models/registry.py @@ -0,0 +1,601 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Whenever you add an architecture to this page, please also update +`tests/models/registry.py` with example HuggingFace models for it. +""" +import importlib +import os +import pickle +import subprocess +import sys +import tempfile +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from functools import lru_cache +from typing import (AbstractSet, Callable, Dict, List, Optional, Tuple, Type, + TypeVar, Union) + +import cloudpickle +import torch.nn as nn + +from vllm.logger import init_logger +from vllm.utils import is_in_doc_build + +from .interfaces import (has_inner_state, has_noops, is_attention_free, + is_hybrid, supports_cross_encoding, + supports_multimodal, supports_pp, + supports_transcription, supports_v0_only) +from .interfaces_base import is_text_generation_model + +logger = init_logger(__name__) + +# yapf: disable +_TEXT_GENERATION_MODELS = { + # [Decoder-only] + "AquilaModel": ("llama", "LlamaForCausalLM"), + "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 + "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), + "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"), + # baichuan-7b, upper case 'C' in the class name + "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), + # baichuan-13b, lower case 'c' in the class name + "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), + "BambaForCausalLM": ("bamba", "BambaForCausalLM"), + "BloomForCausalLM": ("bloom", "BloomForCausalLM"), + "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), + "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), + "CohereForCausalLM": ("commandr", "CohereForCausalLM"), + "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"), + "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"), + "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"), + "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), + "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"), + "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"), + "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"), + "FalconForCausalLM": ("falcon", "FalconForCausalLM"), + "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"), + "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), + "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), + "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"), + "GlmForCausalLM": ("glm", "GlmForCausalLM"), + "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"), + "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), + "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), + "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"), + "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), + "GraniteForCausalLM": ("granite", "GraniteForCausalLM"), + "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"), + "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"), # noqa: E501 + "GritLM": ("gritlm", "GritLM"), + "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"), + "InternLMForCausalLM": ("llama", "LlamaForCausalLM"), + "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), + "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"), + "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"), + "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), + "JambaForCausalLM": ("jamba", "JambaForCausalLM"), + "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), + "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"), + # For decapoda-research/llama-* + "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), + "MambaForCausalLM": ("mamba", "MambaForCausalLM"), + "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"), + "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"), + "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), + "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"), + "MistralForCausalLM": ("llama", "LlamaForCausalLM"), + "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), + "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), + # transformers's mpt class has lower case + "MptForCausalLM": ("mpt", "MPTForCausalLM"), + "MPTForCausalLM": ("mpt", "MPTForCausalLM"), + "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"), + "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), + "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"), + "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"), + "OPTForCausalLM": ("opt", "OPTForCausalLM"), + "OrionForCausalLM": ("orion", "OrionForCausalLM"), + "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), + "PhiForCausalLM": ("phi", "PhiForCausalLM"), + "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), + "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), + "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), + "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), + "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), + "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), + "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"), + "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"), + "RWForCausalLM": ("falcon", "FalconForCausalLM"), + "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), + "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), + "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), + "SolarForCausalLM": ("solar", "SolarForCausalLM"), + "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), + "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"), + "XverseForCausalLM": ("llama", "LlamaForCausalLM"), + "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"), + # [Encoder-decoder] + "BartModel": ("bart", "BartForConditionalGeneration"), + "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), +} + +_EMBEDDING_MODELS = { + # [Text-only] + "BertModel": ("bert", "BertEmbeddingModel"), + "RobertaModel": ("roberta", "RobertaEmbeddingModel"), + "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"), + "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"), + "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"), + "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), + "GlmForCausalLM": ("glm", "GlmForCausalLM"), + "GritLM": ("gritlm", "GritLM"), + "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"), + "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501 + "LlamaModel": ("llama", "LlamaForCausalLM"), + **{ + # Multiple models share the same architecture, so we include them all + k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items() + if arch == "LlamaForCausalLM" + }, + "MistralModel": ("llama", "LlamaForCausalLM"), + "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), + "Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"), + "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), + "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), + "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"), + "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), + # [Multimodal] + "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 + "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), + "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 + # [Auto-converted (see adapters.py)] + "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"), + # Technically PrithviGeoSpatialMAE is a model that works on images, both in + # input and output. I am adding it here because it piggy-backs on embedding + # models for the time being. + "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"), +} + +_CROSS_ENCODER_MODELS = { + "BertForSequenceClassification": ("bert", "BertForSequenceClassification"), + "RobertaForSequenceClassification": ("roberta", + "RobertaForSequenceClassification"), + "XLMRobertaForSequenceClassification": ("roberta", + "RobertaForSequenceClassification"), +} + +_MULTIMODAL_MODELS = { + # [Decoder-only] + "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"), + "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"), # noqa: E501 + "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"), + "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501 + "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"), + "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), + "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501 + "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"), + "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"), + "InternVLChatModel": ("internvl", "InternVLChatModel"), + "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"), + "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"), + "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 + "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501 + "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), # noqa: E501 + "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"), # noqa: E501 + "MiniCPMO": ("minicpmo", "MiniCPMO"), + "MiniCPMV": ("minicpmv", "MiniCPMV"), + "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"), # noqa: E501 + "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"), + "NVLM_D": ("nvlm_d", "NVLM_D_Model"), + "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501 + "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), + "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501 + "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"), # noqa: E501 + "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 + "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), # noqa: E501 + "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501 + "UltravoxModel": ("ultravox", "UltravoxModel"), + "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"), + # [Encoder-decoder] + "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501 + "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 + "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"), # noqa: E501 + "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"), + "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501 +} + +_SPECULATIVE_DECODING_MODELS = { + "EAGLEModel": ("eagle", "EAGLE"), + "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), + "MedusaModel": ("medusa", "Medusa"), + "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), +} + +_TRANSFORMERS_MODELS = { + "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"), +} +# yapf: enable + +_VLLM_MODELS = { + **_TEXT_GENERATION_MODELS, + **_EMBEDDING_MODELS, + **_CROSS_ENCODER_MODELS, + **_MULTIMODAL_MODELS, + **_SPECULATIVE_DECODING_MODELS, + **_TRANSFORMERS_MODELS, +} + +# This variable is used as the args for subprocess.run(). We +# can modify this variable to alter the args if needed. e.g. +# when we use par format to pack things together, sys.executable +# might not be the target we want to run. +_SUBPROCESS_COMMAND = [ + sys.executable, "-m", "vllm.model_executor.models.registry" +] + + +@dataclass(frozen=True) +class _ModelInfo: + architecture: str + is_text_generation_model: bool + is_pooling_model: bool + supports_cross_encoding: bool + supports_multimodal: bool + supports_pp: bool + has_inner_state: bool + is_attention_free: bool + is_hybrid: bool + has_noops: bool + supports_transcription: bool + supports_v0_only: bool + + @staticmethod + def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": + return _ModelInfo( + architecture=model.__name__, + is_text_generation_model=is_text_generation_model(model), + is_pooling_model=True, # Can convert any model into a pooling model + supports_cross_encoding=supports_cross_encoding(model), + supports_multimodal=supports_multimodal(model), + supports_pp=supports_pp(model), + has_inner_state=has_inner_state(model), + is_attention_free=is_attention_free(model), + is_hybrid=is_hybrid(model), + supports_transcription=supports_transcription(model), + supports_v0_only=supports_v0_only(model), + has_noops=has_noops(model), + ) + + +class _BaseRegisteredModel(ABC): + + @abstractmethod + def inspect_model_cls(self) -> _ModelInfo: + raise NotImplementedError + + @abstractmethod + def load_model_cls(self) -> Type[nn.Module]: + raise NotImplementedError + + +@dataclass(frozen=True) +class _RegisteredModel(_BaseRegisteredModel): + """ + Represents a model that has already been imported in the main process. + """ + + interfaces: _ModelInfo + model_cls: Type[nn.Module] + + @staticmethod + def from_model_cls(model_cls: Type[nn.Module]): + return _RegisteredModel( + interfaces=_ModelInfo.from_model_cls(model_cls), + model_cls=model_cls, + ) + + def inspect_model_cls(self) -> _ModelInfo: + return self.interfaces + + def load_model_cls(self) -> Type[nn.Module]: + return self.model_cls + + +@dataclass(frozen=True) +class _LazyRegisteredModel(_BaseRegisteredModel): + """ + Represents a model that has not been imported in the main process. + """ + module_name: str + class_name: str + + # Performed in another process to avoid initializing CUDA + def inspect_model_cls(self) -> _ModelInfo: + return _run_in_subprocess( + lambda: _ModelInfo.from_model_cls(self.load_model_cls())) + + def load_model_cls(self) -> Type[nn.Module]: + mod = importlib.import_module(self.module_name) + return getattr(mod, self.class_name) + + +@lru_cache(maxsize=128) +def _try_load_model_cls( + model_arch: str, + model: _BaseRegisteredModel, +) -> Optional[Type[nn.Module]]: + from vllm.platforms import current_platform + current_platform.verify_model_arch(model_arch) + try: + return model.load_model_cls() + except Exception: + logger.exception("Error in loading model architecture '%s'", + model_arch) + return None + + +@lru_cache(maxsize=128) +def _try_inspect_model_cls( + model_arch: str, + model: _BaseRegisteredModel, +) -> Optional[_ModelInfo]: + try: + return model.inspect_model_cls() + except Exception: + logger.exception("Error in inspecting model architecture '%s'", + model_arch) + return None + + +@dataclass +class _ModelRegistry: + # Keyed by model_arch + models: Dict[str, _BaseRegisteredModel] = field(default_factory=dict) + + def get_supported_archs(self) -> AbstractSet[str]: + return self.models.keys() + + def register_model( + self, + model_arch: str, + model_cls: Union[Type[nn.Module], str], + ) -> None: + """ + Register an external model to be used in vLLM. + + :code:`model_cls` can be either: + + - A :class:`torch.nn.Module` class directly referencing the model. + - A string in the format :code:`:` which can be used to + lazily import the model. This is useful to avoid initializing CUDA + when importing the model and thus the related error + :code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`. + """ + if not isinstance(model_arch, str): + msg = f"`model_arch` should be a string, not a {type(model_arch)}" + raise TypeError(msg) + + if model_arch in self.models: + logger.warning( + "Model architecture %s is already registered, and will be " + "overwritten by the new model class %s.", model_arch, + model_cls) + + if isinstance(model_cls, str): + split_str = model_cls.split(":") + if len(split_str) != 2: + msg = "Expected a string in the format `:`" + raise ValueError(msg) + + model = _LazyRegisteredModel(*split_str) + elif isinstance(model_cls, type) and (is_in_doc_build() or issubclass( + model_cls, nn.Module)): + model = _RegisteredModel.from_model_cls(model_cls) + else: + msg = ("`model_cls` should be a string or PyTorch model class, " + f"not a {type(model_arch)}") + raise TypeError(msg) + + self.models[model_arch] = model + + def _raise_for_unsupported(self, architectures: List[str]): + all_supported_archs = self.get_supported_archs() + + if any(arch in all_supported_archs for arch in architectures): + raise ValueError( + f"Model architectures {architectures} failed " + "to be inspected. Please check the logs for more details.") + + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {all_supported_archs}") + + def _try_load_model_cls(self, + model_arch: str) -> Optional[Type[nn.Module]]: + if model_arch not in self.models: + return None + + return _try_load_model_cls(model_arch, self.models[model_arch]) + + def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]: + if model_arch not in self.models: + return None + + return _try_inspect_model_cls(model_arch, self.models[model_arch]) + + def _normalize_archs( + self, + architectures: Union[str, List[str]], + ) -> List[str]: + if isinstance(architectures, str): + architectures = [architectures] + if not architectures: + logger.warning("No model architectures are specified") + + # filter out support architectures + normalized_arch = list( + filter(lambda model: model in self.models, architectures)) + + # make sure Transformers backend is put at the last as a fallback + if len(normalized_arch) != len(architectures): + normalized_arch.append("TransformersForCausalLM") + return normalized_arch + + def inspect_model_cls( + self, + architectures: Union[str, List[str]], + ) -> Tuple[_ModelInfo, str]: + architectures = self._normalize_archs(architectures) + + for arch in architectures: + model_info = self._try_inspect_model_cls(arch) + if model_info is not None: + return (model_info, arch) + + return self._raise_for_unsupported(architectures) + + def resolve_model_cls( + self, + architectures: Union[str, List[str]], + ) -> Tuple[Type[nn.Module], str]: + architectures = self._normalize_archs(architectures) + + for arch in architectures: + model_cls = self._try_load_model_cls(arch) + if model_cls is not None: + return (model_cls, arch) + + return self._raise_for_unsupported(architectures) + + def is_text_generation_model( + self, + architectures: Union[str, List[str]], + ) -> bool: + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.is_text_generation_model + + def is_pooling_model( + self, + architectures: Union[str, List[str]], + ) -> bool: + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.is_pooling_model + + def is_cross_encoder_model( + self, + architectures: Union[str, List[str]], + ) -> bool: + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.supports_cross_encoding + + def is_multimodal_model( + self, + architectures: Union[str, List[str]], + ) -> bool: + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.supports_multimodal + + def is_pp_supported_model( + self, + architectures: Union[str, List[str]], + ) -> bool: + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.supports_pp + + def model_has_inner_state( + self, + architectures: Union[str, List[str]], + ) -> bool: + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.has_inner_state + + def is_attention_free_model( + self, + architectures: Union[str, List[str]], + ) -> bool: + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.is_attention_free + + def is_hybrid_model( + self, + architectures: Union[str, List[str]], + ) -> bool: + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.is_hybrid + + def is_noops_model( + self, + architectures: Union[str, List[str]], + ) -> bool: + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.has_noops + + def is_transcription_model( + self, + architectures: Union[str, List[str]], + ) -> bool: + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.supports_transcription + + def is_v1_compatible( + self, + architectures: Union[str, List[str]], + ) -> bool: + model_cls, _ = self.inspect_model_cls(architectures) + return not model_cls.supports_v0_only + + +ModelRegistry = _ModelRegistry({ + model_arch: + _LazyRegisteredModel( + module_name=f"vllm.model_executor.models.{mod_relname}", + class_name=cls_name, + ) + for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items() +}) + +_T = TypeVar("_T") + + +def _run_in_subprocess(fn: Callable[[], _T]) -> _T: + # NOTE: We use a temporary directory instead of a temporary file to avoid + # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file + with tempfile.TemporaryDirectory() as tempdir: + output_filepath = os.path.join(tempdir, "registry_output.tmp") + + # `cloudpickle` allows pickling lambda functions directly + input_bytes = cloudpickle.dumps((fn, output_filepath)) + + # cannot use `sys.executable __file__` here because the script + # contains relative imports + returned = subprocess.run(_SUBPROCESS_COMMAND, + input=input_bytes, + capture_output=True) + + # check if the subprocess is successful + try: + returned.check_returncode() + except Exception as e: + # wrap raised exception to provide more information + raise RuntimeError(f"Error raised in subprocess:\n" + f"{returned.stderr.decode()}") from e + + with open(output_filepath, "rb") as f: + return pickle.load(f) + + +def _run() -> None: + # Setup plugins + from vllm.plugins import load_general_plugins + load_general_plugins() + + fn, output_file = pickle.loads(sys.stdin.buffer.read()) + + result = fn() + + with open(output_file, "wb") as f: + f.write(pickle.dumps(result)) + + +if __name__ == "__main__": + _run() diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py new file mode 100644 index 0000000..a09741a --- /dev/null +++ b/vllm/model_executor/models/roberta.py @@ -0,0 +1,267 @@ +# SPDX-License-Identifier: Apache-2.0 + +import itertools +from typing import Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import RobertaConfig + +from vllm.config import VllmConfig +from vllm.model_executor.layers.pooler import CrossEncodingPooler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel +from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.transformers_utils.config import ( + get_cross_encoder_activation_function) + +from .interfaces import SupportsCrossEncoding, SupportsV0Only + + +def roberta_task_weights_filter( + all_weights: Iterable[Tuple[str, torch.Tensor]] +) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[str, + torch.Tensor]]]: + """ + Separate task-specific weights that are applied on top + of the encoder-decoder bert base. + To do so, return two generators over the original iterator. + Also, remove the "roberta." prefix to make it loadable + from vanilla BertModel. + """ + # Copy of a lazy iterator without in-memory overhead so both + # iterators can be iterated upon independently. + all_weights1, all_weights2 = itertools.tee(all_weights) + + def encoder_decoder_weights(): + for name, weight in all_weights1: + if name.startswith("roberta."): + yield (name[len("roberta."):], weight) + + return encoder_decoder_weights(), ((n, w) for n, w in all_weights2 + if not n.startswith("roberta.")) + + +class RobertaEmbedding(nn.Module): + + def __init__(self, config: RobertaConfig): + super().__init__() + self.size = config.hidden_size + self.word_embeddings = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.padding_idx = config.pad_token_id + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size, + padding_idx=self.padding_idx) + + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.position_ids = nn.Parameter( + torch.empty((1, config.max_position_embeddings)), ) + + self.position_embedding_type = config.position_embedding_type + if self.position_embedding_type != "absolute": + raise ValueError("Only 'absolute' position_embedding_type" + + " is supported") + + def forward( + self, + input_ids: torch.Tensor, + seq_lens: torch.Tensor, + position_ids: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + input_shape = input_ids.size() + inputs_embeds = self.word_embeddings(input_ids) + + # Replace position ids because in RoBERTa models + # they have to start at padding_idx + 1 and ignore + # existing padding tokens + # References: + # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133 + # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669 + pos_list = [] + token_list = [] + offset = 0 + for seq_len in seq_lens: + pos_list.append(position_ids[offset:offset + seq_len]) + token_list.append(input_ids[offset:offset + seq_len]) + offset += seq_len + + new_pos_list = [] + for positions, tokens in zip(pos_list, token_list): + # Verify assumption that incoming position are + # always a sequence from 0 to N. + expected_pos = torch.arange(positions.size()[0], + dtype=torch.long, + device=inputs_embeds.device) + assert torch.equal(positions, expected_pos) + new_pos_list.append( + create_position_ids_from_input_ids(tokens, self.padding_idx)) + position_ids = torch.cat(new_pos_list) + + # Position embeddings. + position_embeddings = self.position_embeddings(position_ids) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, + dtype=torch.long, + device=inputs_embeds.device) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings = inputs_embeds + token_type_embeddings + position_embeddings + embeddings = self.LayerNorm(embeddings) + return embeddings + + +# Adapted from transformers +def create_position_ids_from_input_ids(input_ids, + padding_idx, + past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. + Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully + # balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + + incremental_indices = (torch.cumsum(mask, dim=0).type_as(mask) + + past_key_values_length) * mask + + return incremental_indices.long() + padding_idx + + +# Adapted from transformers +class RobertaClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config: RobertaConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + def forward(self, features, **kwargs): + x = features[0, :] # take token (equiv. to [CLS]) + x = self.dense(x) + x = torch.tanh(x) + x = self.out_proj(x) + return x + + +class RobertaEmbeddingModel(BertEmbeddingModel): + """A model that uses Roberta to provide embedding functionalities. + + This class encapsulates the BertModel and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + model: An instance of BertModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + def _build_model(self, + vllm_config: VllmConfig, + prefix: str = "") -> BertModel: + return BertModel(vllm_config=vllm_config, + prefix=prefix, + embedding_class=RobertaEmbedding) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + weights = self.hf_to_vllm_mapper.apply(weights) + # Separate weights in "roberta"-prefixed and all else (not in memory). + # For use with models like FacebookAI/roberta-base. + bert_weights, task_weights = roberta_task_weights_filter(weights) + loaded = self.model.load_weights(bert_weights) + if not len(loaded): + # Fix for models like `sentence-transformers/stsb-roberta-base-v2` + # which use the same architecture, but have no "roberta" prefix. + loaded = self.model.load_weights(task_weights) + assert len(loaded), "Unable to load RobertaEmbeddingModel" + + +class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, + SupportsV0Only): + """A model that uses Roberta to provide embedding functionalities. + + This class encapsulates the BertModel and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + roberta: An instance of BertModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + jina_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + 'emb_ln': "embeddings.LayerNorm", + 'layers': "layer", + 'mixer.Wqkv': "attention.self.qkv_proj", + 'mixer.out_proj': "attention.output.dense", + 'norm1': "attention.output.LayerNorm", + 'mlp.fc1': "intermediate.dense", + 'mlp.fc2': "output.dense", + 'norm2': "output.LayerNorm", + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + + self.default_activation_function = \ + get_cross_encoder_activation_function(config) + + self.num_labels = config.num_labels + self.roberta = BertModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "bert"), + embedding_class=RobertaEmbedding, + add_pooling_layer=False) + self.classifier = RobertaClassificationHead(config) + self._pooler = CrossEncodingPooler(config, self.classifier) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + bert_weights, task_weights = roberta_task_weights_filter(weights) + bert_weights = self.jina_to_vllm_mapper.apply(bert_weights) + + self.roberta.load_weights(bert_weights) + + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in task_weights: + if name.startswith("classifier"): + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.roberta(input_ids=input_ids, + position_ids=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors, + token_type_ids=token_type_ids) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py new file mode 100644 index 0000000..cecad9e --- /dev/null +++ b/vllm/model_executor/models/siglip.py @@ -0,0 +1,525 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Implementation of SiglipVisionModel intended to be only used +within a vision language model.""" + +import math +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import SiglipVisionConfig + +from vllm.attention.layer import MultiHeadAttention +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs + + +class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]): + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + return self.get_patch_grid_length()**2 + + def get_max_image_tokens(self) -> int: + return self.get_patch_grid_length()**2 + + def get_image_size(self) -> int: + return self.vision_config.image_size + + def get_patch_size(self) -> int: + return self.vision_config.patch_size + + def get_patch_grid_length(self) -> int: + image_size, patch_size = self.get_image_size(), self.get_patch_size() + return image_size // patch_size + + +# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa +class SiglipVisionEmbeddings(nn.Module): + + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches = (self.image_size // self.patch_size)**2 + self.num_positions = self.num_patches + self.position_embedding = VocabParallelEmbedding( + self.num_positions, self.embed_dim) + self.register_buffer( + "position_ids", + torch.arange(self.num_positions, dtype=torch.int64).expand( + (1, -1)), + persistent=False, + ) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, + width: int) -> torch.Tensor: + """ + This method is an adapted method for SigLIP (due to SigLIP not having + class embedding unlike other ViTs) that allows the model to interpolate + the pre-trained position encodings such that it can be usable on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + position_embeddings = self.position_embedding.weight.unsqueeze(0) + num_patches = embeddings.shape[1] + num_positions = position_embeddings.shape[1] + if num_patches == num_positions and height == width: + return position_embeddings + + dim = embeddings.shape[-1] + height = height // self.patch_size + width = width // self.patch_size + # we add a small number to avoid floating point error + # in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + height, width = height + 0.1, width + 0.1 + + patch_pos_embed = position_embeddings.reshape( + 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), + dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=( + height / math.sqrt(num_positions), + width / math.sqrt(num_positions), + ), + mode="bicubic", + align_corners=False, + ) + if (int(height) != patch_pos_embed.shape[-2] + or int(width) != patch_pos_embed.shape[-1]): + raise ValueError("Width or height does not match with " + "the interpolated position embeddings") + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return patch_pos_embed + + def forward(self, + pixel_values: torch.Tensor, + interpolate_pos_encoding: bool = False) -> torch.Tensor: + _, _, height, width = pixel_values.shape + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to( + dtype=target_dtype)) # shape = [*, width, grid, grid] + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding( + embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding( + self.position_ids) + return embeddings + + +class SiglipAttention(nn.Module): + + def __init__( + self, + config: SiglipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError(f"embed_dim must be divisible by num_heads (got " + "`embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads}).") + + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.qkv_proj = QKVParallelLinear( + hidden_size=self.embed_dim, + head_size=self.head_dim, + total_num_heads=self.num_heads, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.out_proj = RowParallelLinear( + input_size=self.embed_dim, + output_size=self.embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + + self.attn = MultiHeadAttention(self.num_heads_per_partition, + self.head_dim, self.scale) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + """Input shape: Batch x Time x Channel""" + qkv_states, _ = self.qkv_proj(hidden_states) + query_states, key_states, value_states = qkv_states.chunk(3, dim=-1) + + out = self.attn(query_states, key_states, value_states) + attn_output, _ = self.out_proj(out) + + return attn_output, None + + +class SiglipMLP(nn.Module): + + def __init__( + self, + config: SiglipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + # Special handling for BNB and torchao quantization + if quant_config and quant_config.get_name() in [ + "bitsandbytes", "torchao" + ]: + quantizable = True + else: + # For other quantization, we require the hidden size to be a + # multiple of 64 + quantizable = (config.hidden_size % 64 == 0 + and config.intermediate_size % 64 == 0) + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + quant_config=quant_config if quantizable else None, + prefix=f"{prefix}.fc1", + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + quant_config=quant_config if quantizable else None, + prefix=f"{prefix}.fc2", + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + return hidden_states + + +class SiglipEncoderLayer(nn.Module): + + def __init__( + self, + config: SiglipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.embed_dim = config.hidden_size + + self.self_attn = SiglipAttention( + config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + self.mlp = SiglipMLP( + config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> Tuple[torch.Tensor, None]: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, _ = self.self_attn(hidden_states=hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, None + + +class SiglipEncoder(nn.Module): + + def __init__( + self, + config: SiglipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + num_hidden_layers_override: Optional[int] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + + if num_hidden_layers_override is None: + num_hidden_layers = config.num_hidden_layers + else: + num_hidden_layers = num_hidden_layers_override + + self.layers = nn.ModuleList([ + SiglipEncoderLayer(config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}") + for layer_idx in range(num_hidden_layers) + ]) + + def forward( + self, + inputs_embeds: torch.Tensor, + return_all_hidden_states: bool, + ) -> Union[torch.Tensor, list[torch.Tensor]]: + hidden_states_pool = [inputs_embeds] + hidden_states = inputs_embeds + + for encoder_layer in self.layers: + hidden_states, _ = encoder_layer(hidden_states) + if return_all_hidden_states: + hidden_states_pool.append(hidden_states) + # If we have multiple feature sample layers, we return all hidden + # states in order and grab the ones we need by index. + if return_all_hidden_states: + return hidden_states_pool + return hidden_states + + +class SiglipMultiheadAttentionPoolingHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__( + self, + config: SiglipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + # TODO(ChristopherCho): Implement vLLM version of MultiheadAttention + self.attention = torch.nn.MultiheadAttention( + config.hidden_size, config.num_attention_heads, batch_first=True) + self.layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + hidden_state = self.attention(probe, hidden_state, hidden_state)[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +class SiglipVisionTransformer(nn.Module): + + def __init__( + self, + config: SiglipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + num_hidden_layers_override: Optional[int] = None, + require_post_norm: Optional[bool] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SiglipVisionEmbeddings(config) + + self.encoder = SiglipEncoder( + config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_override, + prefix=f"{prefix}.encoder", + ) + + num_hidden_layers = config.num_hidden_layers + if len(self.encoder.layers) > config.num_hidden_layers: + raise ValueError( + f"The original encoder only has {num_hidden_layers} " + f"layers, but you requested {len(self.encoder.layers)} layers." + ) + + # If possible, skip post_layernorm to conserve memory + if require_post_norm is None: + require_post_norm = len(self.encoder.layers) == num_hidden_layers + + if require_post_norm: + self.post_layernorm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps) + else: + self.post_layernorm = None + + self.use_head = (True if not hasattr(config, "vision_use_head") else + config.vision_use_head) + if self.use_head: + self.head = SiglipMultiheadAttentionPoolingHead( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.head", + ) + + def forward( + self, + pixel_values: torch.Tensor, + interpolate_pos_encoding: bool = True, + feature_sample_layers: Optional[list[int]] = None, + ) -> torch.Tensor: + + hidden_states = self.embeddings( + pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + return_all_hidden_states = feature_sample_layers is not None + + # Produces either the last layer output or all of the hidden states, + # depending on if we have feature_sample_layers or not + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + return_all_hidden_states=return_all_hidden_states, + ) + + # Handle post-norm (if applicable) and stacks feature layers if needed + encoder_outputs = resolve_visual_encoder_outputs( + encoder_outputs, feature_sample_layers, self.post_layernorm, + self.config.num_hidden_layers) + + # TODO: add this back when pooled_output is used in inference. + # if self.use_head: + # pooled_output = self.head(encoder_outputs) + + return encoder_outputs + + +class SiglipVisionModel(nn.Module): + config_class = SiglipVisionConfig + main_input_name = "pixel_values" + + def __init__( + self, + config: SiglipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + num_hidden_layers_override: Optional[int] = None, + require_post_norm: Optional[bool] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.vision_model = SiglipVisionTransformer( + config, + quant_config, + num_hidden_layers_override=num_hidden_layers_override, + require_post_norm=require_post_norm, + prefix=f"{prefix}.vision_model", + ) + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + def forward( + self, + pixel_values: torch.Tensor, + interpolate_pos_encoding: bool = False, + feature_sample_layers: Optional[list[int]] = None, + ) -> torch.Tensor: + return self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + feature_sample_layers=feature_sample_layers, + ) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + layer_count = len(self.vision_model.encoder.layers) + + for name, loaded_weight in weights: + # post_layernorm is optional in SiglipVisionModel + if (name.startswith("vision_model.post_layernorm") + and self.vision_model.post_layernorm is None): + continue + + # omit layers when num_hidden_layers_override is set + if name.startswith("vision_model.encoder.layers"): + layer_idx = int(name.split(".")[3]) + if layer_idx >= layer_count: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py new file mode 100644 index 0000000..ac5de0e --- /dev/null +++ b/vllm/model_executor/models/skyworkr1v.py @@ -0,0 +1,1014 @@ +# SPDX-License-Identifier: Apache-2.0 + +# adapted from https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/modeling_skywork_chat.py +# -------------------------------------------------------- +# SkyworkR1V +# Copyright (c) 2025 Skywork +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- +from abc import ABC, abstractmethod +from collections.abc import Iterable, Mapping, Sequence +from functools import cached_property +from typing import Literal, Optional, Set, Tuple, TypedDict, TypeVar, Union + +import torch +import torch.nn as nn +import torchvision.transforms as T +from PIL import Image +from transformers import BatchEncoding, PretrainedConfig, TensorType + +from vllm.config import VllmConfig +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.awq import AWQConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.models.intern_vit import (InternVisionModel, + InternVisionPatchModel) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, + NestedTensors) +from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, + ImageSize, MultiModalDataItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.tokenizer import AnyTokenizer + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, + maybe_prefix, merge_multimodal_embeddings) +from .vision import scatter_patch_features, select_patch_features + +IMG_START = '' +IMG_END = '' +IMG_CONTEXT = '' + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) + + +class SkyworkR1VImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + pixel_values_flat: torch.Tensor + """ + Shape: + `(batch_size * num_images * (1 + num_patches), num_channels, height, width)` + """ + + num_patches: torch.Tensor + """Shape: `(batch_size * num_images)`""" + + embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] + """ + A boolean mask indicating which image embeddings correspond + to patch tokens. + + Shape: `(batch_size * num_images, num_embeds)` + """ + + +class SkyworkR1VImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: Union[torch.Tensor, list[torch.Tensor]] + """ + A tensor of shape `(num_images, total_image_feature_size, hidden_size)` + or a list of tensors of shape `(total_image_feature_size, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + """ + + +SkyworkR1VImageInputs = Union[SkyworkR1VImagePixelInputs, + SkyworkR1VImageEmbeddingInputs] + + +# adapted from https://huggingface.co/Skywork/Skywork-R1V-38B/ +def build_transform(input_size: int): + MEAN, STD = IMAGENET_MEAN, IMAGENET_STD + return T.Compose([ + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.Resize((input_size, input_size), + interpolation=T.InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=MEAN, std=STD) + ]) + + +# adapted from https://huggingface.co/Skywork/Skywork-R1V-38B/ +def find_closest_aspect_ratio( + aspect_ratio: float, + target_ratios: list[tuple[int, int]], + *, + width: int, + height: int, + image_size: int, +) -> tuple[int, int]: + best_ratio_diff = float('inf') + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + +def resolve_skyworkr1v_min_max_num( + *, + min_dynamic_patch: int, + max_dynamic_patch: int, + dynamic_image_size: bool, + use_thumbnail: bool, +) -> tuple[int, int]: + min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1 + max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1 + + if use_thumbnail and max_dynamic_patch != 1: + max_dynamic_patch += 1 + + return min_dynamic_patch, max_dynamic_patch + + +def get_skyworkr1v_target_ratios( + min_num: int, + max_num: int, +) -> list[tuple[int, int]]: + target_ratios = {(i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) if min_num <= i * j <= max_num} + return sorted(target_ratios, key=lambda x: x[0] * x[1]) + + +def calculate_skyworkr1v_targets( + *, + orig_width: int, + orig_height: int, + target_ratios: list[tuple[int, int]], + image_size: int, + use_thumbnail: bool, +) -> tuple[int, int, int]: + aspect_ratio = orig_width / orig_height + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio( + aspect_ratio, + target_ratios, + width=orig_width, + height=orig_height, + image_size=image_size, + ) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # add thumbnail image if num_blocks != 1 + if use_thumbnail and blocks != 1: + blocks += 1 + + return blocks, target_width, target_height + + +def dynamic_preprocess_skyworkr1v( + image: Image.Image, + *, + target_ratios: list[tuple[int, int]], + image_size: int, + use_thumbnail: bool, +) -> list[Image.Image]: + orig_width, orig_height = image.size + + # calculate the number of blocks without thumbnail + blocks, target_width, target_height = calculate_skyworkr1v_targets( + orig_width=orig_width, + orig_height=orig_height, + target_ratios=target_ratios, + image_size=image_size, + use_thumbnail=False, + ) + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ((i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + + assert len(processed_images) == blocks + + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + + return processed_images + + +# adapted from https://huggingface.co/Skywork/Skywork-R1V-38B +def image_to_pixel_values_skyworkr1v( + image: Image.Image, + *, + input_size: int, + min_num: int, + max_num: int, + use_thumbnail: bool, +) -> torch.Tensor: + target_ratios = get_skyworkr1v_target_ratios(min_num, max_num) + + transform = build_transform(input_size=input_size) + images = dynamic_preprocess_skyworkr1v( + image, + target_ratios=target_ratios, + image_size=input_size, + use_thumbnail=use_thumbnail, + ) + + pixel_values = torch.stack([transform(image) for image in images]) + return pixel_values + + +class BaseSkyworkR1VProcessor(ABC): + """ + This model doesn't define its own HF processor, + so we implement our own one here. + + The code to insert image tokens is based on: + https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/modeling_skywork_chat.py#L252 + """ + + def __init__( + self, + config: PretrainedConfig, + tokenizer: AnyTokenizer, + *, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + ) -> None: + super().__init__() + + self.config = config + self.tokenizer = tokenizer + + image_size: int = config.vision_config.image_size + patch_size: int = config.vision_config.patch_size + + if min_dynamic_patch is None: + min_dynamic_patch = config.min_dynamic_patch + assert isinstance(min_dynamic_patch, int) + + if max_dynamic_patch is None: + max_dynamic_patch = config.max_dynamic_patch + assert isinstance(max_dynamic_patch, int) + + if dynamic_image_size is None: + dynamic_image_size = config.dynamic_image_size + assert isinstance(dynamic_image_size, bool) + + self.num_image_token = int( + (image_size // patch_size)**2 * (config.downsample_ratio**2)) + self.image_size = image_size + self.min_dynamic_patch = min_dynamic_patch + self.max_dynamic_patch = max_dynamic_patch + self.dynamic_image_size = dynamic_image_size + self.use_thumbnail: bool = config.use_thumbnail + + @property + @abstractmethod + def image_token_id(self) -> int: + raise NotImplementedError + + @abstractmethod + def get_image_repl( + self, + feature_size: int, + num_patches: Optional[int], + ) -> PromptUpdateDetails[str]: + raise NotImplementedError + + def resolve_min_max_num( + self, + *, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + use_thumbnail: Optional[bool] = None, + ) -> tuple[int, int]: + min_dynamic_patch = (self.min_dynamic_patch if min_dynamic_patch + is None else min_dynamic_patch) + max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch + is None else max_dynamic_patch) + dynamic_image_size = (self.dynamic_image_size if dynamic_image_size + is None else dynamic_image_size) + use_thumbnail = (self.use_thumbnail + if use_thumbnail is None else use_thumbnail) + + return resolve_skyworkr1v_min_max_num( + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + use_thumbnail=use_thumbnail, + ) + + def resolve_target_ratios( + self, + *, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + use_thumbnail: Optional[bool] = None, + ) -> list[tuple[int, int]]: + min_num, max_num = self.resolve_min_max_num( + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + use_thumbnail=use_thumbnail, + ) + + return get_skyworkr1v_target_ratios(min_num, max_num) + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + target_ratios = self.resolve_target_ratios( + use_thumbnail=False, # Applied in calculate_targets + ) + + num_patches, _, _ = calculate_skyworkr1v_targets( + orig_width=image_width, + orig_height=image_height, + image_size=self.image_size, + target_ratios=target_ratios, + use_thumbnail=self.use_thumbnail, + ) + + return num_patches * self.num_image_token + + def _images_to_pixel_values_lst( + self, + images: list[Image.Image], + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + ) -> list[torch.Tensor]: + min_num, max_num = self.resolve_min_max_num( + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + use_thumbnail=False, # Applied in image_to_pixel_values + ) + + return [ + image_to_pixel_values_skyworkr1v( + image, + input_size=self.image_size, + min_num=min_num, + max_num=max_num, + use_thumbnail=self.use_thumbnail, + ) for image in images + ] + + def __call__( + self, + text: Optional[Union[str, list[str]]] = None, + images: Optional[Union[Image.Image, list[Image.Image]]] = None, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + ) -> Mapping[str, NestedTensors]: + if text is None: + text = [] + if not isinstance(text, list): + text = [text] + if images is None: + images = [] + if not isinstance(images, list): + images = [images] + + if len(images) == 0: + image_inputs = {} + else: + pixel_values_lst = self._images_to_pixel_values_lst( + images, + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + ) + image_inputs: dict[str, NestedTensors] = { + "pixel_values_flat": + torch.cat(pixel_values_lst), + "image_num_patches": + torch.tensor([len(item) for item in pixel_values_lst]), + } + + tokenizer = self.tokenizer + image_token_id = self.image_token_id + + embed_is_patch = list[torch.Tensor]() + + for pixel_values in pixel_values_lst: + num_patches = pixel_values.shape[0] + feature_size = num_patches * self.num_image_token + + image_repl = self.get_image_repl(feature_size, num_patches) + feature_tokens = tokenizer.encode(image_repl.features, + add_special_tokens=False) + + text = [t.replace('', image_repl.full, 1) for t in text] + embed_is_patch.append( + torch.tensor(feature_tokens) == image_token_id) + + image_inputs["embed_is_patch"] = embed_is_patch + + text_inputs = self.tokenizer(text) + + return { + **BatchEncoding(text_inputs, tensor_type=return_tensors), + **image_inputs, + } + + +class SkyworkR1VProcessor(BaseSkyworkR1VProcessor): + + @property + def image_token_id(self) -> int: + return self.tokenizer.get_vocab()[IMG_CONTEXT] + + def get_image_repl( + self, + feature_size: int, + num_patches: Optional[int], + ) -> PromptUpdateDetails[str]: + repl_features = IMG_CONTEXT * feature_size + repl_full = IMG_START + repl_features + IMG_END + + return PromptUpdateDetails(full=repl_full, features=repl_features) + + +class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo): + + @abstractmethod + def get_hf_processor( + self, + *, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + **kwargs: object, + ) -> BaseSkyworkR1VProcessor: + raise NotImplementedError + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"image": self.get_max_image_tokens()} + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + processor: Optional[BaseSkyworkR1VProcessor], + ) -> int: + if processor is None: + processor = self.get_hf_processor() + + return processor.get_num_image_tokens( + image_width=image_width, + image_height=image_height, + ) + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + return self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + processor=None, + ) + + def get_image_size_with_most_features(self) -> ImageSize: + processor = self.get_hf_processor() + + base_size = processor.image_size + target_ratios = processor.resolve_target_ratios() + + largest_feature_size, largest_feature_pinpoint = 0, None + for wr, hr in target_ratios: + width, height = base_size * wr, base_size * hr + + feat_size = self.get_num_image_tokens( + image_width=width, + image_height=height, + processor=processor, + ) + if feat_size > largest_feature_size: + largest_feature_size = feat_size + largest_feature_pinpoint = ImageSize(width=width, + height=height) + + if largest_feature_size == 0 or largest_feature_pinpoint is None: + raise ValueError("Cannot have a largest feature size of 0!") + + return largest_feature_pinpoint + + +_I = TypeVar("_I", bound=BaseSkyworkR1VProcessingInfo) + + +class SkyworkR1VDummyInputsBuilder(BaseDummyInputsBuilder[_I]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + target_width, target_height = \ + self.info.get_image_size_with_most_features() + num_images = mm_counts.get("image", 0) + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text="" * num_images, + mm_data=mm_data, + ) + + +class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[_I]): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> Mapping[str, NestedTensors]: + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + hf_processor = self.info.get_hf_processor(**mm_kwargs) + image_token_id = hf_processor.image_token_id + + # Since there may be extra tokens in the feature placeholders, + # we need to pass the image token ID to the model to select the + # tokens to merge from the vision encoder outputs + processed_outputs["image_token_id"] = torch.tensor(image_token_id) + + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: Mapping[str, NestedTensors], + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0)) + num_images = len(image_num_patches) + + return dict( + pixel_values_flat=MultiModalFieldConfig.flat_from_sizes( + "image", image_num_patches), + image_num_patches=MultiModalFieldConfig.batched("image"), + embed_is_patch=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + image_token_id=MultiModalFieldConfig.shared("image", num_images), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + if "image_num_patches" in out_mm_kwargs: + image_num_patches = out_mm_kwargs["image_num_patches"] + assert isinstance(image_num_patches, torch.Tensor) + image_num_patches = image_num_patches.tolist() + elif "image_embeds" in out_mm_kwargs: + # TODO: Use image size information in dictionary embedding inputs + # to compute num_patches (similar to Qwen2-VL) + image_num_patches = [None] * len(out_mm_kwargs["image_embeds"]) + else: + image_num_patches = [] + + def get_replacement_skyworkr1v(item_idx: int): + images = mm_items.get_items( + "image", (ImageEmbeddingItems, ImageProcessorItems)) + + if isinstance(images, ImageEmbeddingItems): + feature_size = images.get_feature_size(item_idx) + else: + image_size = images.get_image_size(item_idx) + feature_size = self.info.get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + processor=hf_processor, + ) + + num_patches = image_num_patches[item_idx] + if num_patches is not None: + assert isinstance(num_patches, int) + + return hf_processor.get_image_repl(feature_size, num_patches) + + return [ + PromptReplacement( + modality="image", + target="", + replacement=get_replacement_skyworkr1v, + ) + ] + + +class SkyworkR1VProcessingInfo(BaseSkyworkR1VProcessingInfo): + + def get_hf_processor( + self, + *, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + **kwargs: object, + ) -> SkyworkR1VProcessor: + if min_dynamic_patch is not None: + kwargs["min_dynamic_patch"] = min_dynamic_patch + if max_dynamic_patch is not None: + kwargs["max_dynamic_patch"] = max_dynamic_patch + if dynamic_image_size is not None: + kwargs["dynamic_image_size"] = dynamic_image_size + + return self.ctx.init_processor( + SkyworkR1VProcessor, + config=self.get_hf_config(), + tokenizer=self.get_tokenizer(), + **kwargs, + ) + + +@MULTIMODAL_REGISTRY.register_processor( + SkyworkR1VMultiModalProcessor, + info=SkyworkR1VProcessingInfo, + dummy_inputs=SkyworkR1VDummyInputsBuilder) +class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + self._patch_quant_config(config, quant_config) + + image_size = config.force_image_size or config.vision_config.image_size + patch_size = config.vision_config.patch_size + self.patch_size = patch_size + self.num_image_token = int( + (image_size // patch_size)**2 * (config.downsample_ratio**2)) + self.downsample_ratio = config.downsample_ratio + self.ps_version = config.ps_version + + self.llm_arch_name = config.text_config.architectures[0] + self.is_mono = self.llm_arch_name == 'SkyworkLM2VEForCausalLM' + self.vision_model = self._init_vision_model( + config, + quant_config=quant_config, + is_mono=self.is_mono, + prefix=maybe_prefix(prefix, "vision_model"), + ) + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + + self.mlp1 = self._init_mlp1(config) + + self.img_context_token_id = None + self.visual_token_mask = None + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + def _patch_quant_config(self, config: PretrainedConfig, + quant_config: QuantizationConfig): + # the awq models from OpenGVLab missing `modules_to_not_convert` + # patch the quant_config to add `modules_to_not_convert` back + if isinstance(quant_config, AWQConfig): + text_config = config.text_config + llm_quant_config = getattr(text_config, "quantization_config", + None) + if (not quant_config.modules_to_not_convert) and \ + (llm_quant_config is not None): + quant_config.modules_to_not_convert.append("vision_model") + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return get_sampler() + + def _init_vision_model( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig], + *, + is_mono: bool, + prefix: str, + ): + if not is_mono: + vision_feature_layer = config.select_layer + if vision_feature_layer < 0: + num_hidden_layers = config.vision_config.num_hidden_layers \ + + vision_feature_layer + 1 + else: + num_hidden_layers = vision_feature_layer + 1 + + return InternVisionModel( + config.vision_config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers, + prefix=prefix, + ) + else: + return InternVisionPatchModel(config.vision_config) + + def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential: + vit_hidden_size = config.vision_config.hidden_size + llm_hidden_size = config.text_config.hidden_size + + return nn.Sequential( + nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2), + ReplicatedLinear(vit_hidden_size * + int(1 / self.downsample_ratio)**2, + llm_hidden_size, + return_bias=False), + nn.GELU(), + ReplicatedLinear(llm_hidden_size, + llm_hidden_size, + return_bias=False), + ) + + def pixel_shuffle(self, x, scale_factor=0.5): + n, w, h, c = x.size() + # N, W, H, C --> N, W, H * scale, C // scale + x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) + # N, W, H * scale, C // scale --> N, H * scale, W, C // scale + x = x.permute(0, 2, 1, 3).contiguous() + x = x.view(n, int(h * scale_factor), int(w * scale_factor), + int(c / (scale_factor * scale_factor))) + if self.ps_version == 'v1': + pass + else: + x = x.permute(0, 2, 1, 3).contiguous() + return x + + def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor: + vit_embeds = self.vision_model(pixel_values=pixel_values) + vit_embeds = vit_embeds[:, 1:, :] + + h = w = int(vit_embeds.shape[1]**0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = self.pixel_shuffle(vit_embeds, + scale_factor=self.downsample_ratio) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, + vit_embeds.shape[-1]) + vit_embeds = self.mlp1(vit_embeds) + return vit_embeds + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape) + + if actual_dims != expected_dims: + expected_expr = str(expected_dims) + raise ValueError( + "The expected shape of pixel values per image per batch " + f" per patch is {expected_expr}. " + f"You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[SkyworkR1VImageInputs]: + pixel_values_flat = kwargs.pop("pixel_values_flat", None) + image_num_patches = kwargs.pop("image_num_patches", None) + embed_is_patch = kwargs.pop("embed_is_patch", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values_flat is None and image_embeds is None: + return None + + if image_embeds is not None: + if not isinstance(image_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + + return SkyworkR1VImageEmbeddingInputs( + type="image_embeds", + data=flatten_bn(image_embeds), + ) + + image_token_id = kwargs["image_token_id"] + assert isinstance(image_token_id, torch.Tensor) + self.img_context_token_id = image_token_id.flatten().unique().item() + + if pixel_values_flat is not None: + if not isinstance(pixel_values_flat, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values_flat)}") + + if not isinstance(image_num_patches, (torch.Tensor, list)): + raise ValueError("Incorrect type of image_num_patches. " + f"Got type: {type(image_num_patches)}") + + if not isinstance(embed_is_patch, (torch.Tensor, list)): + raise ValueError("Incorrect type of embed_is_patch. " + f"Got type: {type(embed_is_patch)}") + + pixel_values_flat = flatten_bn(pixel_values_flat, concat=True) + image_num_patches = flatten_bn(image_num_patches, concat=True) + embed_is_patch = flatten_bn(embed_is_patch) + + return SkyworkR1VImagePixelInputs( + type="pixel_values", + pixel_values_flat=self._validate_pixel_values( + pixel_values_flat), + num_patches=image_num_patches, + embed_is_patch=embed_is_patch, + ) + + raise AssertionError("This line should be unreachable.") + + def _process_image_input( + self, + image_input: SkyworkR1VImageInputs, + ) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]: + if image_input["type"] == "image_embeds": + return image_input["data"] + + assert self.vision_model is not None + + image_embeds = self.extract_feature(image_input["pixel_values_flat"]) + + num_patches = image_input["num_patches"] + + # Only one image in the current batch + if len(num_patches) == 1: + return image_embeds.view( + -1, self.config.text_config.hidden_size).unsqueeze(0) + + # NOTE: Image embeddings are split into separate tensors for each image + # by the size of each embedding. + feature_size = image_embeds.shape[1] + image_embeds = image_embeds.view(-1, + self.config.text_config.hidden_size) + image_feature_sizes = [ + num_patches * feature_size for num_patches in num_patches + ] + return image_embeds.split(image_feature_sizes) + + def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None: + if self.is_mono: + self.visual_token_mask = ( + input_ids == self.img_context_token_id).reshape(-1, 1) + else: + self.visual_token_mask = None + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + + image_features = self._process_image_input(image_input) + + if image_input["type"] != "pixel_values": + return image_features + + return scatter_patch_features( + image_features, + image_input["embed_is_patch"], + ) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + assert self.img_context_token_id is not None + self._set_visual_token_mask(input_ids) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + select_patch_features(multimodal_embeddings), + self.img_context_token_id, + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[SamplerOutput, IntermediateTensors]: + + if intermediate_tensors is not None: + input_ids = None + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + forward_kwargs = { + "input_ids": input_ids, + "positions": positions, + "intermediate_tensors": intermediate_tensors, + "inputs_embeds": inputs_embeds, + } + + # Only required if the model is mono-architecture + if self.visual_token_mask is not None: + forward_kwargs.update( + {"visual_token_mask": self.visual_token_mask}) + self.visual_token_mask = None + + hidden_states = self.language_model.model(**forward_kwargs) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + skip_prefixes = [ + "action_embed", "temporal_embed", "track_embed", + "track_embed_decoder", "box_token", "cg_criterion", "cg_model", + "loc_encoder", "loc_decoder", "sam", "temporal_token", + "track_token" + ] + loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py new file mode 100644 index 0000000..1cae0a7 --- /dev/null +++ b/vllm/model_executor/models/solar.py @@ -0,0 +1,515 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Solar model compatible with HuggingFace weights.""" + +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class SolarMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class SolarAttention(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr(config, "head_dim", + self.hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class SolarDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] \ + = config.original_max_position_embeddings + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) + self.self_attn = SolarAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = SolarMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class SolarModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + else: + self.embed_tokens = PPMissingLayer() + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: SolarDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + bskcn_h_1 = None + bskcn_h_2 = None + bskcn_r_1 = None + bskcn_r_2 = None + bskcn_tv = (self.config.bskcn_tv[0] + if self.training else self.config.bskcn_tv[1]) + + for i in range(self.start_layer, self.end_layer): + if i in self.config.bskcn_1: + bskcn_h_1 = hidden_states.clone() + bskcn_r_1 = residual.clone() + if i in self.config.bskcn_2: + bskcn_h_2 = hidden_states.clone() + bskcn_r_2 = residual.clone() + if i in self.config.bskcn_3: + hidden_states = bskcn_h_1 * bskcn_tv + hidden_states * ( + 1 - bskcn_tv) + residual = bskcn_r_1 * bskcn_tv + residual * (1 - bskcn_tv) + if i in self.config.bskcn_4: + hidden_states = bskcn_h_2 * bskcn_tv + hidden_states * ( + 1 - bskcn_tv) + residual = bskcn_r_2 * bskcn_tv + residual * (1 - bskcn_tv) + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config + self.lora_config = lora_config + self.quant_config = quant_config + + self.model = SolarModel( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + ) + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + else: + self.lm_head = PPMissingLayer() + + self.sampler = get_sampler() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return model_output + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py new file mode 100644 index 0000000..a15faec --- /dev/null +++ b/vllm/model_executor/models/stablelm.py @@ -0,0 +1,353 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This code is based off the following work: +# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py +# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json +"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM) +model compatible with HuggingFace weights.""" +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import StableLmConfig + +from vllm.attention import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class StablelmMLP(nn.Module): + + def __init__(self, + config: StableLmConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_up_proj = MergedColumnParallelLinear( + config.hidden_size, [config.intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(config.intermediate_size, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj") + self.act_fn = SiluAndMul() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class StablelmAttention(nn.Module): + + def __init__(self, + config: StableLmConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + self.num_heads = self.total_num_heads // tp_size + + self.total_num_key_value_heads = config.num_key_value_heads + if self.total_num_key_value_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_key_value_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_key_value_heads == 0 + self.num_key_value_heads = max( + 1, self.total_num_key_value_heads // tp_size) + self.head_dim = self.hidden_size // self.total_num_heads + self.max_position_embeddings = config.max_position_embeddings + rope_pct = getattr(config, "rope_pct", + getattr(config, "partial_rotary_factor", 1)) + self.rotary_ndims = int(self.head_dim * rope_pct) + self.scaling = self.head_dim**-0.5 + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_key_value_heads * self.head_dim + self.qkv_bias = getattr(config, "use_qkv_bias", False) + if (self.head_dim * self.num_heads * tp_size) != self.hidden_size: + raise ValueError(f"hidden_size must be divisible by num_heads " + f"(got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads}).") + + self.qkv_proj = QKVParallelLinear(self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_key_value_heads, + self.qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj") + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.rotary_ndims, + max_position=self.config.max_position_embeddings, + base=self.config.rope_theta, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_key_value_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class StablelmDecoderLayer(nn.Module): + + def __init__( + self, + config: StableLmConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.self_attn = StablelmAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.self_attn") + self.mlp = StablelmMLP(config, quant_config, prefix=f"{prefix}.mlp") + norm_eps = getattr(config, "norm_eps", + getattr(config, "layer_norm_eps", 1e-05)) + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + eps=norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, residual + + +class StableLMEpochModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: StablelmDecoderLayer( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers", + ) + norm_eps = getattr(config, "norm_eps", + getattr(config, "layer_norm_eps", 1e-05)) + self.norm = nn.LayerNorm(config.hidden_size, eps=norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer(positions, hidden_states) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class StablelmForCausalLM(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = StableLMEpochModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.lm_head") + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py new file mode 100644 index 0000000..3d11dfd --- /dev/null +++ b/vllm/model_executor/models/starcoder2.py @@ -0,0 +1,359 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2024 BigCode and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Starcoder2 model.""" +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import Starcoder2Config + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class Starcoder2Attention(nn.Module): + + def __init__(self, + config: Starcoder2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.config = config + + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = self.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = config.rope_theta + self.max_position_embeddings = config.max_position_embeddings + self.use_bias = config.use_bias + + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=self.use_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=self.use_bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=int(self.rope_theta), + is_neox_style=True, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class Starcoder2MLP(nn.Module): + + def __init__(self, + config: Starcoder2Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.c_fc = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=config.use_bias, + quant_config=quant_config, + prefix=f"{prefix}.c_fc", + ) + self.c_proj = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=config.use_bias, + quant_config=quant_config, + prefix=f"{prefix}.c_proj", + ) + self.act = get_act_fn(config.hidden_act) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.c_proj(hidden_states) + return hidden_states + + +class Starcoder2DecoderLayer(nn.Module): + + def __init__(self, + config: Starcoder2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Starcoder2Attention(config, + cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn") + self.mlp = Starcoder2MLP(config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.norm_epsilon) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.norm_epsilon) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +@support_torch_compile +class Starcoder2Model(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens") + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Starcoder2DecoderLayer( + config, cache_config, quant_config=quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers", + ) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(positions, hidden_states) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class Starcoder2ForCausalLM(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.model = Starcoder2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.vocab_size = config.vocab_size + self.unpadded_vocab_size = config.vocab_size + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + quant_config=quant_config, + prefix=f"{prefix}.lm_head", + ) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/telechat2.py b/vllm/model_executor/models/telechat2.py new file mode 100644 index 0000000..2627c20 --- /dev/null +++ b/vllm/model_executor/models/telechat2.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Iterable, Set, Tuple + +import torch + +from vllm.config import VllmConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaModel + +from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, + is_pp_missing_parameter) + + +class TeleChat2Model(LlamaModel): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + # 1. Initialize the LlamaModel with bias + vllm_config.model_config.hf_config.bias = True + vllm_config.model_config.hf_config.mlp_bias = True + super().__init__(vllm_config=vllm_config, prefix=prefix) + # 2. Remove the bias from the qkv_proj and gate_up_proj based on config + # Telechat2's gate_up_proj and qkv_proj don't have bias + # see: https://github.com/vllm-project/vllm/pull/10311#issuecomment-2490297566 + for layer in self.layers: + if not isinstance(layer, PPMissingLayer): + layer.self_attn.qkv_proj.bias = None + layer.self_attn.qkv_proj.skip_bias_add = True + layer.mlp.gate_up_proj.bias = None + layer.mlp.gate_up_proj.skip_bias_add = True + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + ('gate_up_proj', 'gate_proj', 0), + ('gate_up_proj', 'up_proj', 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + total_num_heads = self.config.n_head + head_dim = self.config.hidden_size // total_num_heads + for name, loaded_weight in weights: + if "self_attn.key_value" in name: + k_weight = [] + v_weight = [] + for i in range(total_num_heads): + start = i * head_dim * 2 + k_weight.append(loaded_weight[start:start + head_dim, :]) + v_weight.append(loaded_weight[start + head_dim:start + + 2 * head_dim:]) + k_weight = torch.cat(k_weight, dim=0) + v_weight = torch.cat(v_weight, dim=0) + name = name.replace("key_value", "qkv_proj") + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, k_weight, "k") + weight_loader(param, v_weight, "v") + elif "query" in name: + name = name.replace("query", "qkv_proj") + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, "q") + else: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class TeleChat2ForCausalLM(LlamaForCausalLM): + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "transformer.": "model.", + }, + orig_to_new_substr={ + ".h.": ".layers.", + ".self_attention.": ".self_attn.", + ".word_embeddings.": ".embed_tokens.", + ".dense.": ".o_proj.", + ".ln_f.": ".norm.", + }, + ) + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "transformer.": "model.", + }, + orig_to_new_substr={ + ".h.": ".layers.", + ".self_attention.": ".self_attn.", + ".word_embeddings.": ".embed_tokens.", + ".dense.": ".o_proj.", + ".ln_f.": ".norm.", + }, + ) + + def _init_model(self, vllm_config: VllmConfig, prefix: str = ""): + return TeleChat2Model(vllm_config=vllm_config, prefix=prefix) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/teleflm.py b/vllm/model_executor/models/teleflm.py new file mode 100644 index 0000000..e670b1d --- /dev/null +++ b/vllm/model_executor/models/teleflm.py @@ -0,0 +1,79 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Type + +import torch + +from vllm.config import VllmConfig +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.models.llama import (LlamaDecoderLayer, + LlamaForCausalLM, LlamaModel) + + +class TeleFLMModel(LlamaModel): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer, + ): + super().__init__(vllm_config=vllm_config, + prefix=prefix, + layer_type=layer_type) + """ + This implementation is based on the µScaling paper presented at + the ICLR 2025 Workshop: + NanoLM: An Affordable LLM Study Benchmark \ + via Accurate Loss Prediction across Scales + by Yiqun Yao et al. + Available at: https://openreview.net/forum?id=IwaPYg1SCA + arXiv preprint: https://arxiv.org/abs/2304.06875 + """ + self.use_mup = self.config.use_mup + if self.use_mup: + self.input_mult = self.config.input_mult + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + embedding = self.embed_tokens(input_ids) + if self.use_mup: + embedding = embedding * self.input_mult + return embedding + + +class TeleFLMForCausalLM(LlamaForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + # mup + self.use_mup = self.config.use_mup + if self.use_mup: + self.mup_scale_factor = self.config.mup_scale_factor + self.output_mult = self.config.output_mult / self.mup_scale_factor + logit_scale = self.output_mult + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + self.config.vocab_size, + logit_scale) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py new file mode 100644 index 0000000..a1f233e --- /dev/null +++ b/vllm/model_executor/models/transformers.py @@ -0,0 +1,451 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Wrapper around `transformers` models""" +import re +from itertools import chain +from typing import Iterable, Literal, Optional, Union + +import torch +from torch import nn +from transformers import AutoModel, PretrainedConfig, PreTrainedModel +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, VllmConfig) +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.distributed.utils import get_pp_indices +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant +from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, maybe_prefix) + +logger = init_logger(__name__) + + +def vllm_flash_attention_forward( + # Transformers args + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor, + # Transformers kwargs + scaling: Optional[float] = None, + # vLLM kwargs + attention_instances: Optional[dict[Attention]] = None, + **kwargs): + self_attn = attention_instances[module.layer_idx] + if scaling is not None: + self_attn.impl.scale = float(scaling) + hidden = query.shape[-2] + query, key, value = (x.transpose(1, 2) for x in (query, key, value)) + query, key, value = (x.reshape(hidden, -1) for x in (query, key, value)) + return self_attn.forward(query, key, value), None + + +ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward + + +def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module): + logger.debug("%s: %s -> %s", name, old_module, new_module) + + +def replace_linear_class( + linear: nn.Linear, style: Literal["colwise", "rowwise"], + quant_config: QuantizationConfig +) -> Union[ColumnParallelLinear, RowParallelLinear]: + """ + Replace nn.Linear with one of vLLM's tensor parallel linear classes. + + Args: + linear (nn.Linear): `nn.Linear` to be replaced. + style (str): Tensor parallel style of the new linear, e.g. "colwise". + quant_config (QuantConfig): Quantization config for the new linear. + Returns: + Union[ColumnParallelLinear, RowParallelLinear]: The new linear. + """ + + if not isinstance(style, str): + raise ValueError( + f"Unsupported parallel style type {type(style)}, expected str") + + vllm_linear_cls = { + "colwise": ColumnParallelLinear, + "rowwise": RowParallelLinear, + }.get(style, ReplicatedLinear) + + return vllm_linear_cls( + input_size=linear.in_features, + output_size=linear.out_features, + bias=linear.bias is not None, + quant_config=quant_config, + return_bias=False, + ) + + +class TransformersModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + logger.info("Using Transformers backend.") + + config: PretrainedConfig = vllm_config.model_config.hf_config + cache_config: CacheConfig = vllm_config.cache_config + device_config: DeviceConfig = vllm_config.device_config + model_config: ModelConfig = vllm_config.model_config + parallel_config: ParallelConfig = vllm_config.parallel_config + quant_config: QuantizationConfig = vllm_config.quant_config + + self.config = config + self.cache_config = cache_config + self.device_config = device_config + self.model_config = model_config + self.parallel_config = parallel_config + self.quant_config = quant_config + + self.pp_group = get_pp_group() + self.pp_size = self.pp_group.world_size + self.pp_rank = self.pp_group.rank_in_group + self.tp_size = get_tensor_model_parallel_world_size() + + # Use meta device to delay allocating GPU tensors + with torch.device("meta"): + # FIXME(Isotr0py): We need to refactor this part in the future to + # avoid registering an extra model layer, otherwise we will need a + # weights mapper to rename weights. + self.model: PreTrainedModel = AutoModel.from_config( + config, + attn_implementation="vllm", + torch_dtype=model_config.dtype, + trust_remote_code=model_config.trust_remote_code, + ) + + self.pipeline_parallel() + self.tensor_parallel() + + # Input embeddings + if not isinstance(self.model.get_input_embeddings(), PPMissingLayer): + self.model.set_input_embeddings( + VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + )) + + # Attention layers + self.attention_instances = self.create_attention_instances() + + # Initialize buffers (e.g. rotary embedding inverse frequency) + self.init_buffers(self.model) + + # Move remaining meta tensors to device (should happen last) + self.meta_to_empty(self.model) + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + def pipeline_parallel(self): + """ + Apply the model's pipeline parallelization plan. + """ + if self.pp_size <= 1: + return + + if not self.model.supports_pp_plan: + raise ValueError( + f"{type(self.model)} does not support pipeline parallel yet!") + + module_lists = [] + module_list_idx = None + pp_plan = list(self.model._pp_plan.keys()) + for i, name in enumerate(pp_plan): + if isinstance(getattr(self.model, name), nn.ModuleList): + module_lists.append(name) + module_list_idx = i + + if len(module_lists) > 1: + raise ValueError( + "Pipeline parallel of models with multiple `ModuleList`s " + "in the base model are not supported yet!") + if module_list_idx is None: + raise ValueError( + f"Could not find `ModuleList` in {type(self.model)}") + + # Layers before module list + for name in pp_plan[:module_list_idx]: + if self.pp_group.is_first_rank or (self.config.tie_word_embeddings + and self.pp_group.is_last_rank): + continue + setattr(self.model, name, PPMissingLayer()) + + # Module list + start_layer, end_layer = get_pp_indices(self.config.num_hidden_layers, + self.pp_rank, self.pp_size) + layers_name = pp_plan[module_list_idx] + layers = getattr(self.model, layers_name) + for i in range(len(layers)): + if start_layer <= i and i < end_layer: + continue + layers[i] = PPMissingLayer(return_tuple=True) + + # Layers after module list + for name in pp_plan[module_list_idx + 1:]: + # Modules that should be on last rank + if not self.pp_group.is_last_rank: + setattr(self.model, name, PPMissingLayer()) + + def tensor_parallel(self): + """ + Apply the model's tensor parallelization plan. + Currently only supports linear layers. + """ + if not self.model.supports_tp_plan: + if self.tp_size <= 1: + return + + raise ValueError( + f"{type(self.model)} does not support tensor parallel yet!") + + tp_plan = self.model._tp_plan + + def _tensor_parallel(module: nn.Module, prefix: str = ""): + for child_name, child_module in module.named_children(): + qual_name = maybe_prefix(prefix, child_name) + for pattern, style in tp_plan.items(): + if re.match(pattern, qual_name) and isinstance( + child_module, nn.Linear): + new_module = replace_linear_class( + child_module, style, self.quant_config) + setattr(module, child_name, new_module) + log_replacement(qual_name, child_module, new_module) + else: + _tensor_parallel(child_module, prefix=qual_name) + + _tensor_parallel(self.model) + + def create_attention_instances(self) -> dict[int, Attention]: + """ + Create `Attention` instances to inform KV cache allocation. + """ + num_heads = self.model_config.get_num_attention_heads( + self.parallel_config) + head_size = self.model_config.get_head_size() + num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) + start, end = get_pp_indices(self.config.num_hidden_layers, + self.pp_rank, self.pp_size) + return { + i: + Attention( + num_heads=num_heads, + head_size=head_size, + # NOTE: We use Llama scale as default, if it's set by + # Transformers, it's updated in vllm_flash_attention_forward + scale=head_size**-0.5, + num_kv_heads=num_kv_heads, + cache_config=self.cache_config, + quant_config=self.quant_config, + prefix=f"{i}.attn") + for i in range(start, end) + } + + def init_buffers(self, module: nn.Module): + """ + If a `buffer` is on the `meta` device, then its parent + `module` is the original module created by: + + ```python + with torch.device("meta"): + self.model: PreTrainedModel = AutoModel.from_config(...) + ``` + + This means that: + - `type(module)` is a class from `transformers` + - This class is constructed using a `PretrainedConfig` + """ + for name, buffer in module.named_buffers(recurse=False): + if buffer.device == torch.device("meta"): + new_buffer = getattr(type(module)(self.config), name) + setattr(module, name, new_buffer) + for child in module.children(): + self.init_buffers(child) + + def meta_to_empty(self, module: nn.Module): + tensors = list(chain(module.buffers(), module.parameters())) + if tensors and all(t.device == torch.device("meta") for t in tensors): + module.to_empty(device=self.device_config.device) + return # We can stop recursing because to_empty is recursive + for child in module.children(): + self.meta_to_empty(child) + + def get_input_embeddings(self) -> nn.Module: + return self.model.get_input_embeddings() + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if not get_pp_group().is_first_rank: + assert intermediate_tensors is not None + input_ids = None + inputs_embeds = intermediate_tensors["hidden_states"] + + if input_ids is not None: + input_ids = input_ids[None, ...] + if inputs_embeds is not None: + inputs_embeds = inputs_embeds[None, ...] + + hidden_states = self.model( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + use_cache=False, + position_ids=positions[None, ...], + attention_instances=self.attention_instances, + return_dict=False)[0][0, ...] # we remove batch dimension for now + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + params_dict = dict(self.named_parameters()) + loaded_params = set[str]() + for name, loaded_weight in weights: + # Use "model" instead of base_model_prefix because + # the base model attribute in vLLM is always `model` + if not name.startswith(prefix := "model."): + name = prefix + name + + if is_pp_missing_parameter(name, self): + continue + if name in params_dict: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +@support_torch_compile +class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA, + SupportsPP): + embedding_padding_modules = ["lm_head"] + embedding_modules = ["embed_tokens" + ] # TODO transformers will have a util to get it + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config: PretrainedConfig = vllm_config.model_config.hf_config + quant_config: QuantizationConfig = vllm_config.quant_config + + self.config = config + + self.model = TransformersModel(vllm_config=vllm_config, prefix=prefix) + + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights( + self.model.get_input_embeddings()) + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + else: + self.lm_head = PPMissingLayer() + + self.sampler = get_sampler() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + # FIXME(Isotr0py): Don't use any weights mapper for Transformers backend, + # this makes thing complicated. We need to remove this mapper after refactor + # `TransformersModel` in the future. + @property + def hf_to_vllm_mapper(self): + prefix_mapper = { + name: "model." + name + for name, _ in self.model.model.named_children() + } + return WeightsMapper( + orig_to_new_substr={"model.": "model.model."}, + orig_to_new_prefix=prefix_mapper, + ) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample(self, logits: torch.Tensor, + sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: + + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py new file mode 100644 index 0000000..6e73a2a --- /dev/null +++ b/vllm/model_executor/models/ultravox.py @@ -0,0 +1,680 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py +"""PyTorch Ultravox model.""" +import math +from collections.abc import Iterable, Mapping, Sequence +from functools import cached_property +from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union + +import torch +from torch import nn +from torch.nn import functional as F +from transformers import BatchFeature, ProcessorMixin +from transformers.models.whisper import WhisperFeatureExtractor +from transformers.models.whisper.modeling_whisper import WhisperEncoder + +from vllm import envs +from vllm.config import VllmConfig +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.model_loader.loader import DefaultModelLoader +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, + NestedTensors) +from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.ultravox import UltravoxConfig + +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsPP) +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, + init_vllm_registered_model, maybe_prefix, + merge_multimodal_embeddings, + merge_multimodal_embeddings_from_map) + +_AUDIO_PLACEHOLDER_OVERRIDE = "<|reserved_special_token_0|>" +_AUDIO_PLACEHOLDER_TOKEN = 128002 +_AUDIO_TOKENS_PER_SECOND = 6.25 +_MAX_ENCODER_BATCH_SIZE = 16 + + +class UltravoxAudioFeatureInputs(TypedDict): + type: Literal["audio_features"] + data: Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]] + """Shape: `(batch_size, num_chunks, 80, M)`""" + lens: Union[torch.Tensor, list[torch.Tensor]] + """ + Length of the audio frames. Used for attention mask in WhisperEncoder. + Shape: `(batch_size, num_chunks)` + """ + token_len: Union[torch.Tensor, list[torch.Tensor]] + """ + Length of the audio tokens. Used for flattening the audio features. + Shape: `(batch_size, num_chunks)` + """ + + +class UltravoxAudioEmbeddingInputs(TypedDict): + type: Literal["audio_embeds"] + data: NestedTensors + """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)`""" + + +UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs, + UltravoxAudioEmbeddingInputs] + + +class UltravoxProcessingInfo(BaseProcessingInfo): + + def get_hf_processor( + self, + *, + # Ignored in initialization + sampling_rate: Optional[int] = None, + **kwargs: object, + ) -> ProcessorMixin: + hf_processor = self.ctx.get_hf_processor(**kwargs) + + # NOTE: Ultravox processing definition uses '<|eot_id|>' as the + # placeholder that will cause confusion with the actual end of turn + # token, thus we override placeholder with a reserved special + # token. + hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE + hf_processor.audio_replacement_token_id = _AUDIO_PLACEHOLDER_TOKEN + return hf_processor + + def get_feature_extractor( + self, + *, + # Ignored in initialization + sampling_rate: Optional[int] = None, + ) -> WhisperFeatureExtractor: + hf_processor = self.get_hf_processor(sampling_rate=sampling_rate) + audio_processor = hf_processor.audio_processor # type: ignore + feature_extractor = audio_processor.feature_extractor # type: ignore + assert isinstance(feature_extractor, WhisperFeatureExtractor) + return feature_extractor + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"audio": None} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + feature_extractor = self.get_feature_extractor() + max_audio_tokens = math.ceil(feature_extractor.chunk_length * + _AUDIO_TOKENS_PER_SECOND) + + return {"audio": max_audio_tokens * _MAX_ENCODER_BATCH_SIZE} + + +class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo] + ): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + feature_extractor = self.info.get_feature_extractor() + + sampling_rate = feature_extractor.sampling_rate + audio_len = (feature_extractor.chunk_length * sampling_rate * + _MAX_ENCODER_BATCH_SIZE) + num_audios = mm_counts.get("audio", 0) + + mm_data = { + "audio": + self._get_dummy_audios(length=audio_len, num_audios=num_audios) + } + + return ProcessorInputs( + prompt_text="<|audio|>" * num_audios, + mm_data=mm_data, + ) + + +class UltravoxMultiModalProcessor( + BaseMultiModalProcessor[UltravoxProcessingInfo]): + + def _get_data_parser(self) -> MultiModalDataParser: + feature_extractor = self.info.get_feature_extractor() + return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + # Text-only input not supported in composite processor + if not mm_data.get("audios", []): + prompt_ids = self.info.get_tokenizer().encode( + prompt, add_special_tokens=False) + prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + mm_data = dict(mm_data) + audios = mm_data.pop("audios", []) + assert isinstance(audios, list) + + feature_extractor = self.info.get_feature_extractor() + mm_kwargs = dict( + **mm_kwargs, + sampling_rate=feature_extractor.sampling_rate, + include_audio_num_chunks=True, + ) + + item_processor_data = dict(**mm_data, audios=audios) + + output = super()._call_hf_processor( + prompt=prompt, + mm_data=item_processor_data, + mm_kwargs=mm_kwargs, + ) + output['audio_features'] = output.pop('audio_values') + + return output + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + num_chunks = hf_inputs.get('audio_num_chunks', torch.zeros(0)) + return dict( + # to handle longer than 30s audio, each audio might be split + # into multiple chunks as such, their batch dimension can be + # higher than the number of audio samples + audio_features=MultiModalFieldConfig.flat_from_sizes( + "audio", num_chunks), + audio_token_len=MultiModalFieldConfig.flat_from_sizes( + "audio", num_chunks), + audio_lens=MultiModalFieldConfig.flat_from_sizes( + "audio", num_chunks), + # num_chunks can convert audio_chunked to audio batch dimension + audio_num_chunks=MultiModalFieldConfig.batched("audio"), + audio_embeds=MultiModalFieldConfig.batched("audio"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + replacement_id = hf_processor.audio_replacement_token_id # type: ignore + + # Each audio can be split into multiple chunks. + # chunks_start_idx[i] indicates the start index of the chunks + # belonging to the i-th audio. + num_chunks = out_mm_kwargs.get("audio_num_chunks", torch.zeros(0)) + chunks_start_idx: torch.Tensor = torch.cumsum(num_chunks, + dim=0, + dtype=torch.int32) + chunks_start_idx = torch.cat( + [torch.tensor([0], dtype=torch.int32), chunks_start_idx]) + + def get_replacement_ultravox(item_idx: int): + start = chunks_start_idx[item_idx] + end = chunks_start_idx[item_idx + 1] + audio_token_len = out_mm_kwargs["audio_token_len"][start:end].sum() + return [replacement_id] * int(audio_token_len) # type: ignore + + return [ + PromptReplacement( + modality="audio", + target="<|audio|>", + replacement=get_replacement_ultravox, + ) + ] + + +class StackAudioFrames(nn.Module): + """ + Stack the audio embedding frames to reduce the sequence length by a factor + of `stack_factor`. + """ + + def __init__(self, stack_factor: int = 8): + super().__init__() + self.stack_factor = stack_factor + + def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor: + B, T, C = audio_embeds.shape + T_pad = (T + self.stack_factor - + 1) // self.stack_factor * self.stack_factor + audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T)) + B, T, C = audio_embeds.shape + audio_embeds = audio_embeds.view(B, T // self.stack_factor, + C * self.stack_factor) + return audio_embeds + + +class UltravoxProjector(nn.Module): + + def __init__(self, config: UltravoxConfig): + super().__init__() + self.hidden_dim = config.hidden_size + self._pad_and_stack = StackAudioFrames(config.stack_factor) + dim_in = config.audio_config.hidden_size * config.stack_factor + self.ln_pre = RMSNorm(dim_in) + self.linear_1 = nn.Linear(dim_in, self.hidden_dim, bias=False) + dim_mid = self.hidden_dim + + if config.projector_act == "swiglu": + self.act = MulAndSilu() + dim_mid = dim_mid // 2 + else: + self.act = get_act_fn(config.projector_act) + + dim_out = config.text_config.hidden_size + self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False) + + # Ultravox v0.4.1 and below use layer_norm after the second linear layer + # while v0.5.0 and above uses layer_norm after the first linear layer. + if config.projector_ln_mid: + self.ln_mid: nn.Module = RMSNorm(dim_mid) + self.ln_post = nn.Identity() + else: + self.ln_mid = nn.Identity() + self.ln_post = RMSNorm(dim_out) + + def forward(self, audio_features: torch.Tensor) -> torch.Tensor: + audio_features = self._pad_and_stack(audio_features) + audio_features = self.ln_pre(audio_features) + hidden_states = self.linear_1(audio_features) + hidden_states = self.act(hidden_states) + hidden_states = self.ln_mid(hidden_states) + hidden_states = self.linear_2(hidden_states) + hidden_states = self.ln_post(hidden_states) + return hidden_states + + +class ModifiedWhisperEncoder(WhisperEncoder): + """ + Encoder portion of OpenAI's Whisper model. + + This implementation is a slightly modified version of HF Transformers' + Whisper Encoder, with only a few fixes: + 1. base_model_prefix updated to allow for doing `.from_pretrained` + directly on the encoder + 2. allow less than 30 second of audio padding to be passed in: + - relaxed ValueError check for `input_features` length to be less + than or equal to `expected_seq_length` instead of strictly equal + - embed_pos is now sliced to match the length of `inputs_embeds` + + Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py + See commentary: https://github.com/huggingface/transformers/issues/25744 + """ + + base_model_prefix = "model.encoder" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.config.is_decoder = False + + @property + def max_context_length(self): + return (self.config.max_source_positions * self.conv1.stride[0] * + self.conv2.stride[0]) + + def get_attention_mask_by_audio_len(self, + audio_lens: Optional[torch.Tensor], + hidden_states: torch.Tensor): + """ + Create attention mask based on audio lengths to mask out padding tokens + For each sample in batch: + - Convert raw audio length to feature length after convolutions + - Create bool mask: True for valid positions and False for padding + - Convert to attention mask format expected by transformer layers + (1.0 for positions to attend to, large negative for positions to ignore) + This masking ensures consistent behavior between training and inference + by preventing the model from attending to padding tokens in both cases + """ + if audio_lens is None: + return None + + audio_feature_len = self._get_feat_extract_output_lengths(audio_lens) + max_seq_len = hidden_states.shape[1] + attention_mask = torch.arange(max_seq_len, + device=hidden_states.device)[None, :].lt( + audio_feature_len.view(-1, 1)) + attention_mask = self.get_extended_attention_mask( + attention_mask, + None, + dtype=hidden_states.dtype, + ) + return attention_mask + + def forward( + self, + input_features: torch.Tensor, + audio_lens: Optional[torch.Tensor] = None, + ): + expected_seq_length = self.max_context_length + if input_features.shape[-1] > expected_seq_length: + raise ValueError( + f"Whisper expects the mel input features to be of length " + f"{expected_seq_length} or less, but found " + f"{input_features.shape[-1]}. Make sure to pad the input mel " + f"features to {expected_seq_length}.") + + inputs_embeds = nn.functional.gelu(self.conv1(input_features)) + inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) + + inputs_embeds = inputs_embeds.permute(0, 2, 1) + embed_pos = self.embed_positions.weight[:inputs_embeds.size(-2)] + + hidden_states = inputs_embeds + embed_pos + hidden_states = nn.functional.dropout(hidden_states, + p=self.dropout, + training=self.training) + + attention_mask = self.get_attention_mask_by_audio_len( + audio_lens, hidden_states) + + for encoder_layer in self.layers: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=None, + ) + + hidden_states = layer_outputs[0] + + hidden_states = self.layer_norm(hidden_states) + return hidden_states + + +@MULTIMODAL_REGISTRY.register_processor( + UltravoxMultiModalProcessor, + info=UltravoxProcessingInfo, + dummy_inputs=UltravoxDummyInputsBuilder) +class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): + + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"] + } + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."}) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.multi_modal_config = multimodal_config + assert self.multi_modal_config + + self.secondary_weights = [] + self.audio_tower = ModifiedWhisperEncoder(config.audio_config) + if config.audio_model_id is not None: + # this prefix is not for initialization, but for loading weights + # note the trailing dot + self.secondary_weights.append( + DefaultModelLoader.Source( + model_or_path=config.audio_model_id, + revision=None, + prefix="audio_tower.", + )) + self.multi_modal_projector = UltravoxProjector(config) + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + if config.text_model_id is not None: + # this prefix is not for initialization, but for loading weights + # note the trailing dot + self.secondary_weights.append( + DefaultModelLoader.Source(model_or_path=config.text_model_id, + revision=None, + prefix="language_model.")) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return get_sampler() + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model.", + connector="multi_modal_projector.", + tower_model="audio_tower.", + ) + + def _audio_features_to_embeddings( + self, input_features: torch.Tensor, + audio_lens: torch.Tensor) -> torch.Tensor: + audio_features = input_features.to(self.audio_tower.dtype) + batch_size = audio_features.size(0) + audio_embeddings = [] + + # Process audio features in batches to keep memory usage predictable + for start in range(0, batch_size, _MAX_ENCODER_BATCH_SIZE): + end = min(start + _MAX_ENCODER_BATCH_SIZE, batch_size) + # Process through audio tower + batch_features = self.audio_tower(audio_features[start:end], + audio_lens[start:end]) + batch_features = batch_features.to(self.audio_tower.dtype) + + # Process through projector + batch_embeddings = self.multi_modal_projector(batch_features) + audio_embeddings.append(batch_embeddings) + + # Concatenate results + audio_embeddings = torch.cat(audio_embeddings, dim=0) + return audio_embeddings + + def _parse_and_validate_audio_input( + self, **kwargs: object) -> Optional[UltravoxAudioInputs]: + audio_features = kwargs.pop("audio_features", None) + audio_embeds = kwargs.pop("audio_embeds", None) + audio_lens = kwargs.pop("audio_lens", None) + audio_token_len = kwargs.pop("audio_token_len", None) + + if audio_features is None and audio_embeds is None: + return None + + if audio_features is not None: + if not isinstance(audio_features, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio features. " + f"Got type: {type(audio_features)}") + if not isinstance(audio_lens, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio_lens. " + f"Got type: {type(audio_features)}") + if not isinstance(audio_token_len, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio_token_len. " + f"Got type: {type(audio_features)}") + + return UltravoxAudioFeatureInputs(type="audio_features", + data=audio_features, + lens=audio_lens, + token_len=audio_token_len) + + if audio_embeds is not None: + if not isinstance(audio_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio embeds. " + f"Got type: {type(audio_embeds)}") + + return UltravoxAudioEmbeddingInputs(type="audio_embeds", + data=audio_embeds) + + raise AssertionError("This line should be unreachable.") + + def _process_audio_input( + self, + audio_input: UltravoxAudioInputs, + ) -> Union[NestedTensors, tuple[torch.Tensor, ...]]: + if audio_input["type"] == "audio_embeds": + return audio_input["data"] + + # Pad and concatenate audio features + # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)] + audio_features = pad_and_concat_to_dim3(audio_input["data"]) + + # [B1, B2] -> [B1+B2] + audio_lens = flatten_bn(audio_input['lens'], concat=True) + audio_token_len = flatten_bn(audio_input['token_len'], concat=True) + + embeddings = self._audio_features_to_embeddings( + audio_features, audio_lens) + + # We should flatten and concatenate embeddings based on token lengths + # For example, with token_len = [4, 2, 3], flattened_embeddings will be + # concat(embeddings[0][:4], embeddings[1][:2], embeddings[2][:3]) + + # Create a mask of valid indices based on token lengths + max_len = embeddings.shape[1] + indices = torch.arange(max_len, device=embeddings.device).expand( + embeddings.shape[0], -1) + mask = indices < audio_token_len[:, None] + # Apply mask and flatten + flattened_embeddings = embeddings[mask] + + # Return one tensor per input audio + embed_lens = [ + token_len_item.sum().item() + for token_len_item in audio_input['token_len'] + ] + return flattened_embeddings.split(embed_lens) + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + audio_input = self._parse_and_validate_audio_input(**kwargs) + if audio_input is None: + return None + audio_embeddings = self._process_audio_input(audio_input) + return audio_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + + # TODO(ywang96): remove this block after v0 is deprecated. + if not envs.VLLM_USE_V1: + attn_metadata = get_forward_context().attn_metadata + merge_multimodal_embeddings_from_map( + inputs_embeds, multimodal_embeddings, + attn_metadata.multi_modal_placeholder_index_maps["audio"]) + else: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + _AUDIO_PLACEHOLDER_TOKEN) + return inputs_embeds + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs) -> Union[torch.Tensor, IntermediateTensors]: + """Run forward pass for Ultravox + + One key thing to understand is the `input_ids` already accounts for the + positions of the to-be-inserted audio embeddings. The to-be-inserted + audio has a size that is essentially 6.25 tokens per second of audio. + + This way, the `positions` and `attn_metadata` are consistent + with the `input_ids`. + + Args: + audio_features: A batch of audio input chunks [B, N, 80, M]. + audio_lens: Length of audio frames for each audio chunk [B]. + audio_token_len: Length of audio tokens for each audio chunk [B']. + Note: batch dim is different from batch dim in audio chunks. + + """ + + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) + + inputs_embeds = self.get_input_embeddings(input_ids, + multimodal_embeddings) + input_ids = None + + hidden_states = self.language_model.model(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + + loader = AutoWeightsLoader(self, + ignore_unexpected_prefixes=["audio_tower."]) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + +def pad_and_concat_to_dim3( + features: Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]] +) -> torch.Tensor: + """ + Pad and concatenate a list of tensors. + + output: + Tensor of shape [B, C, M] where M is the maximum length of the input + tensors, B is the sum of the batch sizes of the input tensors. + C must be the same for all input tensors. + """ + if isinstance(features, torch.Tensor): + if features.ndim > 3: + # Flatten [B, N, 80, M] -> [B * N, 80, M] + features = flatten_bn(features) + return features + + features = [pad_and_concat_to_dim3(f) for f in features] + + max_len = max(f.shape[-1] for f in features) + # Ensure all features have dim=3 + features = [f.view(-1, *f.shape[-2:]) for f in features] + # Pad and oncatenate: + # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)] + features = [F.pad(f, (0, max_len - f.shape[-1])) for f in features] + return torch.cat(features) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py new file mode 100644 index 0000000..f197434 --- /dev/null +++ b/vllm/model_executor/models/utils.py @@ -0,0 +1,705 @@ +# SPDX-License-Identifier: Apache-2.0 + +import itertools +from dataclasses import dataclass, field +from typing import (Callable, Dict, Iterable, List, Literal, Mapping, Optional, + Protocol, Set, Tuple, Union, overload) + +import torch +import torch.nn as nn +from torch.func import functional_call +from transformers import PretrainedConfig + +import vllm.envs as envs +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors +from vllm.sequence import IntermediateTensors +from vllm.utils import (get_cuda_view_from_cpu_tensor, is_pin_memory_available, + is_uva_available) + +logger = init_logger(__name__) + +WeightsMapping = Mapping[str, Optional[str]] +"""If a key maps to a value of `None`, the corresponding weight is ignored.""" + + +@dataclass +class WeightsMapper: + """Maps the name of each weight if they match the following patterns.""" + + orig_to_new_substr: WeightsMapping = field(default_factory=dict) + orig_to_new_prefix: WeightsMapping = field(default_factory=dict) + orig_to_new_suffix: WeightsMapping = field(default_factory=dict) + + def _map_name(self, key: str) -> Optional[str]: + for substr, new_key in self.orig_to_new_substr.items(): + if substr in key: + if new_key is None: + return None + + key = key.replace(substr, new_key, 1) + + for prefix, new_key in self.orig_to_new_prefix.items(): + if key.startswith(prefix): + if new_key is None: + return None + + key = key.replace(prefix, new_key, 1) + + for suffix, new_key in self.orig_to_new_suffix.items(): + if key.endswith(suffix): + if new_key is None: + return None + + key = new_key.join(key.rsplit(suffix, 1)) + + return key + + def apply( + self, weights: Iterable[Tuple[str, torch.Tensor]] + ) -> Iterable[Tuple[str, torch.Tensor]]: + return ((out_name, data) for name, data in weights + if (out_name := self._map_name(name)) is not None) + + +class AutoWeightsLoader: + """ + Helper class to load weights into a :class:`torch.nn.Module`. It is able + to automatically detect child modules and parameters while iterating over + the weights only once. + + The weight loading logic for individual modules can be overridden + by defining a ``load_weights`` method. + + Similarly, the weight loading logic for individual parameters can be + overridden by defining a ``weight_loader`` method. + + Detailed weight loading information can be viewed by setting the + environment variable ``VLLM_LOGGING_LEVEL=DEBUG``. + """ + + def __init__( + self, + module: nn.Module, + *, + skip_prefixes: Optional[List[str]] = None, + ignore_unexpected_prefixes: Optional[List[str]] = None, + ) -> None: + super().__init__() + + self.module = module + self.skip_prefixes = skip_prefixes or [] + self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or [] + + def _groupby_prefix( + self, + weights: Iterable[Tuple[str, torch.Tensor]], + ) -> Iterable[Tuple[str, Iterable[Tuple[str, torch.Tensor]]]]: + weights_by_parts = ((weight_name.split(".", 1), weight_data) + for weight_name, weight_data in weights) + + for prefix, group in itertools.groupby(weights_by_parts, + key=lambda x: x[0][0]): + yield ( + prefix, + # Because maxsplit=1 in weight_name.split(...), + # the length of `parts` must either be 1 or 2 + (("" if len(parts) == 1 else parts[1], weights_data) + for parts, weights_data in group), + ) + + def _get_qualname(self, prefix: str, rest: str) -> str: + if prefix == "": + return rest + if rest == "": + return prefix + + return ".".join((prefix, rest)) + + def _can_skip(self, qualname: str) -> bool: + return any(qualname.startswith(p) for p in self.skip_prefixes) + + def _can_ignore_unexpected(self, qualname: str) -> bool: + return any( + qualname.startswith(p) for p in self.ignore_unexpected_prefixes) + + def _load_param( + self, + base_prefix: str, + param: nn.Parameter, + weights: Iterable[Tuple[str, torch.Tensor]], + ) -> Iterable[str]: + for weight_name, weight_data in weights: + weight_qualname = self._get_qualname(base_prefix, weight_name) + + if self._can_skip(weight_qualname): + logger.debug("Skipping weight %s", weight_qualname) + + continue + + if weight_name != "": + if self._can_ignore_unexpected(weight_qualname): + logger.debug("Ignoring weight %s", weight_qualname) + + continue + + raise ValueError( + f"Attempted to load nested weight '{weight_qualname}' " + f"into a single parameter '{base_prefix}'") + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, weight_data) + + logger.debug("Loaded weight %s with shape %s", weight_qualname, + param.shape) + + yield weight_qualname + + def _add_loadable_non_param_tensors(self, module: nn.Module, + child_params: Dict[str, torch.Tensor]): + """ + Add tensor names that are not in the model params that may be in the + safetensors, e.g., batch normalization stats. + """ + if isinstance(module, ( + nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d, + nn.LazyBatchNorm1d, + nn.LazyBatchNorm2d, + nn.LazyBatchNorm3d, + nn.SyncBatchNorm, + )): + module_state_dict = module.state_dict() + for stat_name in ("running_mean", "running_var", + "num_batches_tracked"): + child_params[stat_name] = module_state_dict[stat_name] + + def _load_module( + self, + base_prefix: str, + module: nn.Module, + weights: Iterable[Tuple[str, torch.Tensor]], + ) -> Iterable[str]: + if isinstance(module, PPMissingLayer): + return + + # Avoid infinite recursion since this function is typically + # called inside load_weights of the module itself + if module != self.module: + module_load_weights = getattr(module, "load_weights", None) + if callable(module_load_weights): + loaded_params = module_load_weights(weights) + if loaded_params is None: + logger.warning( + "Unable to collect loaded parameters " + "for module %s", module) + else: + yield from map( + lambda x: self._get_qualname(base_prefix, x), + loaded_params, + ) + + child_modules = dict(module.named_children()) + child_params = dict(module.named_parameters(recurse=False)) + + # Add missing tensors the weight loader needs to be able to load + # that aren't registered as params, e.g., batchnorm statistics. + self._add_loadable_non_param_tensors(module, child_params) + + for child_prefix, child_weights in self._groupby_prefix(weights): + prefix = self._get_qualname(base_prefix, child_prefix) + + if child_prefix in child_modules: + if self._can_skip(prefix + "."): + logger.debug("Skipping module %s", prefix) + + continue + + yield from self._load_module(prefix, + child_modules[child_prefix], + child_weights) + elif child_prefix in child_params: + if self._can_skip(prefix): + logger.debug("Skipping param %s", prefix) + + continue + + yield from self._load_param(prefix, child_params[child_prefix], + child_weights) + else: + can_skip_module = self._can_skip(prefix + ".") + can_skip_param = self._can_skip(prefix) + if can_skip_module or can_skip_param: + logger.debug("Skipping missing %s", prefix) + + continue + + can_ignore_module = self._can_ignore_unexpected(prefix + ".") + can_ignore_param = self._can_ignore_unexpected(prefix) + if can_ignore_module or can_ignore_param: + logger.debug("Ignoring missing %s", prefix) + + continue + + msg = (f"There is no module or parameter named '{prefix}' " + f"in {type(self.module).__name__}") + raise ValueError(msg) + + def load_weights( + self, + weights: Iterable[Tuple[str, torch.Tensor]], + *, + mapper: Optional[WeightsMapper] = None, + ) -> Set[str]: + if mapper is not None: + weights = mapper.apply(weights) + + autoloaded_weights = set(self._load_module("", self.module, weights)) + return autoloaded_weights + + +def init_vllm_registered_model( + vllm_config: VllmConfig, + *, + prefix: str = "", + hf_config: Optional[PretrainedConfig] = None, + architectures: Optional[list[str]] = None, +) -> nn.Module: + """ + Helper function to initialize an inner model registered to vLLM, + based on the arguments passed to the outer vLLM model. + """ + from vllm.model_executor.model_loader.loader import _initialize_model + + if hf_config is None and architectures is not None: + # So that the architectures field is overridden + hf_config = vllm_config.model_config.hf_config + + if hf_config is not None: + vllm_config = vllm_config.with_hf_config(hf_config, + architectures=architectures) + + return _initialize_model(vllm_config=vllm_config, prefix=prefix) + + +@overload +def flatten_bn(x: torch.Tensor) -> torch.Tensor: + ... + + +@overload +def flatten_bn(x: List[torch.Tensor]) -> List[torch.Tensor]: + ... + + +@overload +def flatten_bn( + x: Union[List[torch.Tensor], torch.Tensor], + *, + concat: Literal[True], +) -> torch.Tensor: + ... + + +@overload +def flatten_bn( + x: Union[List[torch.Tensor], torch.Tensor], + *, + concat: bool = False, +) -> Union[List[torch.Tensor], torch.Tensor]: + ... + + +def flatten_bn( + x: Union[List[torch.Tensor], torch.Tensor], + *, + concat: bool = False, +) -> Union[List[torch.Tensor], torch.Tensor]: + """ + Flatten the ``B`` and ``N`` dimensions of batched multimodal inputs. + + The input tensor should have shape ``(B, N, ...)```. + """ + if isinstance(x, torch.Tensor): + return x.flatten(0, 1) + + if concat: + return torch.cat(x) + + return [x_n for x_b in x for x_n in x_b] + + +def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor: + """ + Recursively flattens and concatenates NestedTensors on all but the last + dimension. + """ + + if isinstance(embeddings, torch.Tensor): + # Flatten all but the last dimension. + return embeddings.flatten(0, -2) + + return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings)) + + +def _embedding_count_expression(embeddings: NestedTensors) -> str: + """ + Constructs a debugging representation of the number of embeddings in the + NestedTensors. + """ + + if isinstance(embeddings, torch.Tensor): + return " x ".join([str(dim) for dim in embeddings.shape[:-1]]) + + return " + ".join( + _embedding_count_expression(inner) for inner in embeddings) + + +def merge_multimodal_embeddings_from_map( + inputs_embeds: torch.Tensor, multimodal_embeddings: NestedTensors, + placeholder_map: MultiModalPlaceholderMap.IndexMap) -> torch.Tensor: + """ + Merge ``multimodal_embeddings`` into ``inputs_embeds`` using the provided + placeholder map . + + Note: + This updates ``inputs_embeds`` in place. + """ + flattened_embeddings = _flatten_embeddings(multimodal_embeddings) + inputs_embeds[placeholder_map.dest] = flattened_embeddings[ + placeholder_map.src] + return inputs_embeds + + +def _merge_multimodal_embeddings( + inputs_embeds: torch.Tensor, + is_multimodal: torch.Tensor, + multimodal_embeddings: NestedTensors, +) -> torch.Tensor: + """ + Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the + positions in ``inputs_embeds`` corresponding to placeholder tokens in + ``input_ids``. + + Note: + This updates ``inputs_embeds`` in place. + """ + num_expected_tokens = is_multimodal.sum().item() + assert isinstance(num_expected_tokens, int) + + flattened = _flatten_embeddings(multimodal_embeddings) + if flattened.shape[0] != num_expected_tokens: + expr = _embedding_count_expression(multimodal_embeddings) + raise ValueError( + f"Attempted to assign {expr} = {flattened.shape[0]} " + f"multimodal tokens to {num_expected_tokens} placeholders") + + inputs_embeds[is_multimodal] = flattened + return inputs_embeds + + +def embed_multimodal( + input_ids: torch.Tensor, + multimodal_token_id: int, + get_text_embeds: Callable[[torch.Tensor], torch.Tensor], + multimodal_embeds: NestedTensors, +) -> torch.Tensor: + """ + Embed token IDs and multimodal inputs and combine their embeddings. + + ``multimodal_token_id`` is used to determine whether a token ID should + be embedded using ``get_text_embeds`` or ``get_multimodal_embeds``. + + Compared to ``merge_multimodal_embeddings`, this avoids running + ``get_text_embeds`` on ``input_ids[input_ids == multimodal_token_id]`` + which causes issues when the placeholder token ID exceeds the + vocabulary size of the language model. + """ + is_multimodal = input_ids == multimodal_token_id + is_text = ~is_multimodal + + text_embeds = get_text_embeds(input_ids[is_text]) + merged_embeds = torch.empty( + (input_ids.shape[0], text_embeds.shape[1]), + dtype=text_embeds.dtype, + device=text_embeds.device, + ) + + merged_embeds[is_text] = text_embeds + + return _merge_multimodal_embeddings( + merged_embeds, + is_multimodal, + multimodal_embeds, + ) + + +def merge_multimodal_embeddings( + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + multimodal_embeddings: NestedTensors, + placeholder_token_id: Union[int, List[int]], +) -> torch.Tensor: + """ + Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the + positions in ``inputs_embeds`` corresponding to placeholder tokens in + ``input_ids``. + + ``placeholder_token_id`` can be a list of token ids (e.g, token ids + of img_start, img_break, and img_end tokens) when needed: This means + the order of these tokens in the ``input_ids`` MUST MATCH the order of + their embeddings in ``multimodal_embeddings`` since we need to + slice-merge instead of individually scattering. + + For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where + - T is text token + - S is image start token + - I is image embedding token + - B is image break token + - E is image end token. + + Then the image embeddings (that correspond to I's) from vision encoder + must be padded with embeddings of S, B, and E in the same order of + input_ids for a correct embedding merge. + + Note: + This updates ``inputs_embeds`` in place. + """ + if isinstance(placeholder_token_id, list): + placeholder_token_id = torch.tensor(placeholder_token_id, + device=input_ids.device) + return _merge_multimodal_embeddings( + inputs_embeds, + torch.isin(input_ids, placeholder_token_id), + multimodal_embeddings, + ) + + return _merge_multimodal_embeddings( + inputs_embeds, + (input_ids == placeholder_token_id), + multimodal_embeddings, + ) + + +class LayerFn(Protocol): + + def __call__(self, prefix: str) -> torch.nn.Module: + ... + + +class PPMissingLayer(torch.nn.Identity): + """ + A placeholder layer for missing layers in a pipeline parallel model. + """ + + def __init__(self, *args, **kwargs): + super().__init__() + self.return_tuple = kwargs.get("return_tuple", False) + + def forward(self, *args, **kwargs): + """ + Return the first arg from args or the first value from kwargs. + + Wraps the input in a tuple if `self.return_tuple` is True. + """ + input = args[0] if args else next(iter(kwargs.values())) + return (input, ) if self.return_tuple else input + + +_CPU_OFFLOAD_BYTES = 0 +_CPU_OFFLOAD_MAX_BYTES = 0 + + +def set_cpu_offload_max_bytes(max_bytes: int) -> None: + global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES + _CPU_OFFLOAD_BYTES = 0 + _CPU_OFFLOAD_MAX_BYTES = max_bytes + + +def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: + if (params := next(module.parameters(), None)) is None: + return module + + device = params.device + + if device == torch.device("cpu"): + return module + + global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES + if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES: + return module + + pin_memory = is_pin_memory_available() + uva_available = is_uva_available() + + if envs.VLLM_USE_V1: + assert uva_available, ("V1 CPU offloading requires" + " uva (pin memory) support") + uva_offloading = True + else: + uva_offloading = False + + # offload parameters to CPU + # use pin_memory if possible, which helps cudagraph capture speed + offloaded_parameters = False + for p in module.parameters(): + if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES: + # we use per-parameter offloading + # one module might have some parameters offloaded and some not + break + + # `torch.empty_like` does not support `pin_memory` argument + cpu_data = torch.empty_strided(size=p.data.size(), + stride=p.data.stride(), + dtype=p.data.dtype, + layout=p.data.layout, + device='cpu', + pin_memory=pin_memory) + cpu_data.copy_(p.data) + if not uva_offloading: + p.data = cpu_data + else: + # keep the cpu data alive + p._vllm_offloaded_cpu_data = cpu_data + p.data = get_cuda_view_from_cpu_tensor(cpu_data) + _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size() + offloaded_parameters = True + + if offloaded_parameters and not uva_offloading: + original_forward = module.forward + + def forward(*args, **kwargs): + module.forward = original_forward + device_state = { + # here we blindly call `to(device)` + # if the parameter is already on the device, it will be a no-op + k: v.to(device, non_blocking=True) + for k, v in module.state_dict().items() + } + output = functional_call(module, + device_state, + args=args, + kwargs=kwargs) + module.forward = forward + return output + + module.forward = forward + + return module + + +def make_layers( + num_hidden_layers: int, + layer_fn: LayerFn, + prefix: str, +) -> Tuple[int, int, torch.nn.ModuleList]: + """Make a list of layers with the given layer function, taking + pipeline parallelism into account. + """ + from vllm.distributed.parallel_state import get_pp_group + from vllm.distributed.utils import get_pp_indices + start_layer, end_layer = get_pp_indices(num_hidden_layers, + get_pp_group().rank_in_group, + get_pp_group().world_size) + modules = torch.nn.ModuleList( + [PPMissingLayer() for _ in range(start_layer)] + [ + maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}")) + for idx in range(start_layer, end_layer) + ] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]) + return start_layer, end_layer, modules + + +# NOTE: don't use lru_cache here because it can prevent garbage collection +_model_to_pp_missing_layer_names: Dict[int, List[str]] = {} + + +def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]: + """Get the names of the missing layers in a pipeline parallel model.""" + model_id = id(model) + if model_id in _model_to_pp_missing_layer_names: + return _model_to_pp_missing_layer_names[model_id] + + missing_layer_names = [] + for name, module in model.named_modules(): + if isinstance(module, PPMissingLayer): + # NOTE: the trailing dot is used to match the prefix of the layer. + # without the dot, we could match a layer that is not missing, + # e.g., 'encoder.layer.1' would match 'encoder.layer.11' + missing_layer_names.append(name + '.') + _model_to_pp_missing_layer_names[model_id] = missing_layer_names + + return missing_layer_names + + +def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: + """Check if a parameter is missing in a pipeline parallel model.""" + if isinstance(model, PPMissingLayer): + return True + + return any( + name.startswith(missing_layer_name) + for missing_layer_name in get_pp_missing_layer_names(model)) + + +def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int): + + def make_empty_intermediate_tensors( + batch_size: int, + dtype: torch.dtype, + device: torch.device, + ) -> IntermediateTensors: + return IntermediateTensors({ + key: + torch.zeros((batch_size, hidden_size), dtype=dtype, device=device) + for key in keys + }) + + return make_empty_intermediate_tensors + + +def maybe_prefix(prefix: str, name: str) -> str: + """Add a prefix to a name if the prefix is non-empty. + + Args: + prefix: The prefix to add. If empty, no prefix will be added. + name: The name to potentially prefix. + + Returns: + The string "prefix.name" if prefix was non-empty, otherwise just "name". + """ + return name if not prefix else f"{prefix}.{name}" + + +def extract_layer_index(layer_name: str) -> int: + """ + Extract the layer index from the module name. + Examples: + - "encoder.layers.0" -> 0 + - "encoder.layers.1.self_attn" -> 1 + - "2.self_attn" -> 2 + - "model.encoder.layers.0.sub.1" -> ValueError + """ + subnames = layer_name.split(".") + int_vals: List[int] = [] + for subname in subnames: + try: + int_vals.append(int(subname)) + except ValueError: + continue + assert len(int_vals) == 1, (f"layer name {layer_name} should" + " only contain one integer") + return int_vals[0] + + +def cast_overflow_tensors( + tensors: torch.Tensor, + offset: float = 1000, +) -> torch.Tensor: + if tensors.isinf().any() or tensors.isnan().any(): + clamp_value = torch.finfo(tensors.dtype).max - offset + tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value) + return tensors diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py new file mode 100644 index 0000000..9e00da6 --- /dev/null +++ b/vllm/model_executor/models/vision.py @@ -0,0 +1,228 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from collections.abc import Sequence +from typing import Final, Generic, Optional, Protocol, TypeVar, Union, cast + +import torch +from transformers import PretrainedConfig + +import vllm.envs as envs +from vllm.attention.selector import (backend_name_to_enum, + get_global_forced_attn_backend) +from vllm.jsontree import JSONTree, json_map_leaves +from vllm.logger import init_logger +from vllm.platforms import _Backend, current_platform + +from .interfaces import MultiModalEmbeddings + +logger = init_logger(__name__) + +_C = TypeVar("_C", bound=PretrainedConfig) + + +class VisionEncoderInfo(ABC, Generic[_C]): + + def __init__(self, vision_config: _C) -> None: + super().__init__() + + self.vision_config = vision_config + + @abstractmethod + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + raise NotImplementedError + + @abstractmethod + def get_max_image_tokens(self) -> int: + raise NotImplementedError + + @abstractmethod + def get_image_size(self) -> int: + raise NotImplementedError + + @abstractmethod + def get_patch_size(self) -> int: + raise NotImplementedError + + @abstractmethod + def get_patch_grid_length(self) -> int: + raise NotImplementedError + + +class VisionLanguageConfig(Protocol): + vision_config: Final[PretrainedConfig] + + +def get_vision_encoder_info( + hf_config: VisionLanguageConfig) -> VisionEncoderInfo: + # Avoid circular imports + from .clip import CLIPEncoderInfo, CLIPVisionConfig + from .pixtral import PixtralHFEncoderInfo, PixtralVisionConfig + from .siglip import SiglipEncoderInfo, SiglipVisionConfig + + vision_config = hf_config.vision_config + if isinstance(vision_config, CLIPVisionConfig): + return CLIPEncoderInfo(vision_config) + if isinstance(vision_config, PixtralVisionConfig): + # Need to sneak in spatial_merge_size for Mistral3 + vision_config.spatial_merge_size = getattr(hf_config, + "spatial_merge_size", 1) + return PixtralHFEncoderInfo(vision_config) + if isinstance(vision_config, SiglipVisionConfig): + return SiglipEncoderInfo(vision_config) + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + +def get_vit_attn_backend(support_fa: bool = False) -> _Backend: + """ + Get the available attention backend for Vision Transformer. + """ + # TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn. + selected_backend: Optional[_Backend] = get_global_forced_attn_backend() + if selected_backend is None: + backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND + if backend_by_env_var is not None: + selected_backend = backend_name_to_enum(backend_by_env_var) + if selected_backend is None: + if current_platform.is_cuda(): + device_available = current_platform.has_device_capability(80) + if device_available and support_fa: + from transformers.utils import is_flash_attn_2_available + if is_flash_attn_2_available(): + selected_backend = _Backend.FLASH_ATTN + else: + logger.warning_once( + "Current `vllm-flash-attn` has a bug inside vision " + "module, so we use xformers backend instead. You can " + "run `pip install flash-attn` to use flash-attention " + "backend.") + selected_backend = _Backend.XFORMERS + else: + # For Volta and Turing GPUs, use xformers instead. + selected_backend = _Backend.XFORMERS + else: + # Default to torch SDPA for other non-GPU platforms. + selected_backend = _Backend.TORCH_SDPA + return selected_backend + + +def resolve_visual_encoder_outputs( + encoder_outputs: Union[torch.Tensor, list[torch.Tensor]], + feature_sample_layers: Optional[list[int]], + post_layer_norm: Optional[torch.nn.LayerNorm], + max_possible_layers: int, +) -> torch.Tensor: + """Given the outputs a visual encoder module that may correspond to the + output of the last layer, or a list of hidden states to be stacked, + handle post normalization and resolve it into a single output tensor. + + Args: + encoder_outputs: Output of encoder's last layer or all hidden states. + feature_sample_layers: Optional layer indices to grab from the encoder + outputs; if provided, encoder outputs must be a list. + post_layer_norm: Post norm to apply to the output of the encoder. + max_possible_layers: Total layers in the fully loaded visual encoder. + + """ + if feature_sample_layers is None: + if post_layer_norm is not None: + return post_layer_norm(encoder_outputs) + return encoder_outputs + + # Get the hidden states corresponding to the layer indices. + # Negative values are relative to the full visual encoder, + # so offset them depending on how many layers were loaded. + # NOTE: this assumes that encoder_outputs is a list containing + # the inputs to the visual encoder, followed by the hidden states + # of each layer. + num_loaded_layers = len(encoder_outputs) - 1 + offset = max_possible_layers - num_loaded_layers + hs_pool = [ + encoder_outputs[layer_idx] + if layer_idx >= 0 else encoder_outputs[layer_idx + offset] + for layer_idx in feature_sample_layers + ] + + # Apply post-norm on the final hidden state if we are using it + uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1) + if post_layer_norm is not None and uses_last_layer: + hs_pool[-1] = post_layer_norm(encoder_outputs) + return torch.cat(hs_pool, dim=-1) + + +def scatter_patch_features( + patches: Union[torch.Tensor, Sequence[torch.Tensor]], + embed_is_patch: Union[torch.Tensor, Sequence[torch.Tensor]], +) -> tuple[torch.Tensor, ...]: + """ + Scatter the patch features into a contiguous tensor that corresponds + to the embedding tokens defined by the multimodal processor. + + The rest of the values in the tensor are set to NaN so that they + can be filtered out by :func`select_patch_features`. + + Args: + patches: The patch features for each image. + Shape: `(num_images, , feature_depth)` + embed_is_patch: A boolean mask indicating which image embeddings + correspond to patch tokens for each image. + Shape: `(num_images, num_embeds)` + + Note: + The original code only considers patch tokens as feature + tokens, but our processor considers all image-related tokens + as feature tokens because the feature tokens need to be + consecutive in `input_ids`. + + Example: + A simplified example for one image: + + .. code-block:: + + Embedding tokens (from HF processor): + [ ] + + embed_is_patch (from HF processor): + [ False True True False True True False False ] + + Encoder outputs (from model): + [ p1 p2 p3 p4 ] + + The resulting embedding tensor is: + [ nan p1 p2 nan p3 p4 nan nan ] + """ + if len(patches) != len(embed_is_patch): + raise ValueError(f"Inconsistent num_images: {len(patches)=} vs. " + f"{len(embed_is_patch)=}") + + def get_embed_one(patches_one: torch.Tensor, e_is_patch: torch.Tensor): + embed_one = patches_one.new_full( + (e_is_patch.shape[0], patches_one.shape[-1]), + fill_value=torch.nan, + ) + embed_one[e_is_patch] = patches_one + return embed_one + + return tuple( + get_embed_one(patches_one, e_is_patch) + for patches_one, e_is_patch in zip(patches, embed_is_patch)) + + +def select_patch_features( + multimodal_embeddings: MultiModalEmbeddings) -> MultiModalEmbeddings: + """ + Given the outputs of :func:`scatter_patch_features`, return only + the values that correspond to patch features. + """ + selected_features = json_map_leaves( + lambda x: x[~x.isnan()].view(-1, *x.shape[1:]), + cast(JSONTree[torch.Tensor], multimodal_embeddings), + ) + return cast(MultiModalEmbeddings, selected_features) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py new file mode 100644 index 0000000..eb64049 --- /dev/null +++ b/vllm/model_executor/models/whisper.py @@ -0,0 +1,756 @@ +# SPDX-License-Identifier: Apache-2.0 + +import math +from collections.abc import Iterable, Mapping, Sequence +from typing import List, Optional, Set, Tuple, TypedDict, Union + +import torch +from torch import nn +from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor, + WhisperProcessor) +from transformers.models.whisper.modeling_whisper import sinusoids + +from vllm.attention import Attention, AttentionType +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.parse import (MultiModalDataDict, MultiModalDataItems, + MultiModalDataParser) +from vllm.multimodal.processing import (BaseProcessingInfo, + EncDecMultiModalProcessor, + PromptReplacement, PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs + +from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, + SupportsTranscription, SupportsV0Only) +from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors, + make_layers) + +logger = init_logger(__name__) + + +class WhisperAudioInputs(TypedDict): + input_features: NestedTensors + """Shape: `(batch_size, 128, M)`""" + + +class WhisperPositionalEmbedding(nn.Embedding): + + def __init__(self, num_positions: int, embedding_dim: int): + super().__init__(num_positions, embedding_dim) + + def forward(self, position_ids): + return self.weight[position_ids] + + +class WhisperAttention(nn.Module): + + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + attn_type: AttentionType = AttentionType.DECODER, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.embed_dim = embed_dim + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + if self.total_num_heads >= tp_size: + # Number of heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_heads % tp_size == 0 + else: + # Number of heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_heads == 0 + self.num_kv_heads = max(1, self.total_num_heads // tp_size) + self.head_dim = self.embed_dim // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.attn_type = attn_type + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: " + f"{self.embed_dim} and `num_heads`: {num_heads}).") + self.scaling = self.head_dim**-0.5 + + self._init_qkv(embed_dim, bias, quant_config, prefix=prefix) + self.out_proj = RowParallelLinear( + input_size=embed_dim, + output_size=embed_dim, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=self.attn_type, + ) + + def _init_qkv( + self, + embed_dim: int, + bias: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + self.qkv_proj = QKVParallelLinear( + hidden_size=embed_dim, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + def forward( + self, + hidden_states: torch.Tensor, + ): + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + attn_output = self.attn(q, k, v) + + output, _ = self.out_proj(attn_output) + + return output + + +class WhisperCrossAttention(WhisperAttention): + + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__( + embed_dim=embed_dim, + num_heads=num_heads, + bias=bias, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + attn_type=AttentionType.ENCODER_DECODER, + ) + + def _init_qkv( + self, + embed_dim: int, + bias: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + self.q_proj = ColumnParallelLinear( + input_size=embed_dim, + output_size=embed_dim, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.q_proj", + ) + self.kv_proj = QKVParallelLinear( + hidden_size=embed_dim, + head_size=self.head_dim, + total_num_heads=0, + total_num_kv_heads=self.total_num_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.kv_proj", + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor], + ): + q, _ = self.q_proj(hidden_states) + + # Encoder hidden states are only computed once during prefill phase. + # Afterwards, the keys and values should be available in the kv-cache. + if encoder_hidden_states is not None: + kv, _ = self.kv_proj(encoder_hidden_states) + k, v = kv.split([self.kv_size, self.kv_size], dim=-1) + else: + k = v = None + + attn_output = self.attn(q, k, v) + + output, _ = self.out_proj(attn_output) + + return output + + +class WhisperMLP(nn.Module): + + def __init__( + self, + embed_dim: int, + ffn_dim: int, + act_fn: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + self.activation_fn = get_act_fn(act_fn) + self.fc1 = ColumnParallelLinear( + input_size=embed_dim, + output_size=ffn_dim, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + self.fc2 = RowParallelLinear( + input_size=ffn_dim, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) + + def forward(self, hidden_states: torch.Tensor): + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + return hidden_states + + +class WhisperEncoderLayer(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.embed_dim = config.d_model + self.self_attn = WhisperAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + attn_type=AttentionType.ENCODER, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.mlp = WhisperMLP( + embed_dim=config.d_model, + ffn_dim=config.encoder_ffn_dim, + act_fn=config.activation_function, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + ): + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn(hidden_states=hidden_states) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + hidden_states = cast_overflow_tensors(hidden_states) + + return hidden_states + + +class WhisperDecoderLayer(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.self_attn = WhisperAttention( + embed_dim=config.d_model, + num_heads=config.decoder_attention_heads, + attn_type=AttentionType.DECODER, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.self_attn_layer_norm = nn.LayerNorm(config.d_model) + self.encoder_attn = WhisperCrossAttention( + embed_dim=config.d_model, + num_heads=config.decoder_attention_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.encoder_attn", + ) + self.encoder_attn_layer_norm = nn.LayerNorm(config.d_model) + self.mlp = WhisperMLP( + embed_dim=config.d_model, + ffn_dim=config.decoder_ffn_dim, + act_fn=config.activation_function, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.final_layer_norm = nn.LayerNorm(config.d_model) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor], + ): + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn(hidden_states=hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + hidden_states = self.encoder_attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class WhisperEncoder(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + embed_dim = config.d_model + self.num_mel_bins = config.num_mel_bins + self.max_source_positions = config.max_source_positions + self.embed_scale = (math.sqrt(embed_dim) + if config.scale_embedding else 1.0) + + self.conv1 = nn.Conv1d(self.num_mel_bins, + embed_dim, + kernel_size=3, + padding=1) + self.conv2 = nn.Conv1d(embed_dim, + embed_dim, + kernel_size=3, + stride=2, + padding=1) + self.embed_positions = nn.Embedding(self.max_source_positions, + embed_dim) + self.start_layer, self.end_layer, self.layers = make_layers( + config.encoder_layers, + lambda prefix: WhisperEncoderLayer(vllm_config=vllm_config, + prefix=f"{prefix}.layers"), + prefix=f"{prefix}.layers", + ) + self.layer_norm = nn.LayerNorm(config.d_model) + + with torch.no_grad(): + self.embed_positions.weight.copy_( + sinusoids(*self.embed_positions.weight.shape)) + + def forward(self, input_features: Union[torch.Tensor, List[torch.Tensor]]): + hidden_states = [] + for features in input_features: + embeds = nn.functional.gelu(self.conv1(features)) + embeds = nn.functional.gelu(self.conv2(embeds)) + embeds = embeds.permute(1, 0) + embeds = embeds + self.embed_positions.weight[:embeds.size(0), :] + hidden_states.append(embeds) + hidden_states = torch.cat(hidden_states) + + for encoder_layer in self.layers: + hidden_states = encoder_layer(hidden_states) + + hidden_states = self.layer_norm(hidden_states) + return hidden_states + + +class WhisperDecoder(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_target_positions + self.max_source_positions = config.max_source_positions + self.embed_scale = (math.sqrt(config.d_model) + if config.scale_embedding else 1.0) + + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, + self.padding_idx) + self.embed_positions = WhisperPositionalEmbedding( + self.max_target_positions, config.d_model) + self.start_layer, self.end_layer, self.layers = make_layers( + config.decoder_layers, + lambda prefix: WhisperDecoderLayer(vllm_config=vllm_config, + prefix=f"{prefix}.layers"), + prefix=f"{prefix}.layers", + ) + self.layer_norm = nn.LayerNorm(config.d_model) + + def forward( + self, + input_ids, + positions: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor], + ): + inputs_embeds = self.get_input_embeddings(input_ids) + positions = self.embed_positions(positions) + hidden_states = inputs_embeds + positions + + for decoder_layer in self.layers: + hidden_states = decoder_layer( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + ) + + hidden_states = self.layer_norm(hidden_states) + return hidden_states + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + ) -> torch.Tensor: + return self.embed_tokens(input_ids) + + +class WhisperModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.encoder = WhisperEncoder(vllm_config=vllm_config, + prefix=f"{prefix}.encoder") + self.decoder = WhisperDecoder(vllm_config=vllm_config, + prefix=f"{prefix}.decoder") + + def forward( + self, + input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]], + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + ) -> torch.Tensor: + encoder_outputs = self.get_encoder_outputs(input_features) + decoder_outputs = self.decoder( + input_ids=input_ids, + positions=positions, + encoder_hidden_states=encoder_outputs, + ) + return decoder_outputs + + def get_encoder_outputs( + self, + input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]], + ) -> Optional[torch.Tensor]: + if input_features is None: + return None + return self.encoder(input_features) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".self_attn.qkv_proj", ".self_attn.q_proj", "q"), + (".self_attn.qkv_proj", ".self_attn.k_proj", "k"), + (".self_attn.qkv_proj", ".self_attn.v_proj", "v"), + (".encoder_attn.kv_proj", ".encoder_attn.k_proj", "k"), + (".encoder_attn.kv_proj", ".encoder_attn.v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class WhisperProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self) -> WhisperConfig: + return self.ctx.get_hf_config(WhisperConfig) + + def get_hf_processor(self, + sampling_rate: Optional[int] = None + ) -> WhisperProcessor: + return self.ctx.get_hf_processor(WhisperProcessor) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"audio": 1} + + def get_feature_extractor(self) -> WhisperFeatureExtractor: + hf_processor = self.get_hf_processor() + feature_extractor = hf_processor.feature_extractor # type: ignore + assert isinstance(feature_extractor, WhisperFeatureExtractor) + return feature_extractor + + def get_max_audio_tokens(self) -> int: + return self.get_hf_config().max_source_positions + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"audio": self.get_max_audio_tokens()} + + +class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + feature_extractor = self.info.get_feature_extractor() + + sampling_rate = feature_extractor.sampling_rate + audio_len = feature_extractor.chunk_length * sampling_rate + num_audios = mm_counts.get("audio", 0) + + mm_data = { + "audio": + self._get_dummy_audios(length=audio_len, num_audios=num_audios) + } + + return ProcessorInputs( + prompt_text="<|startoftranscript|>" * num_audios, + mm_data=mm_data, + ) + + +class WhisperMultiModalProcessor( + EncDecMultiModalProcessor[WhisperProcessingInfo]): + + def _get_data_parser(self) -> MultiModalDataParser: + feature_extractor = self.info.get_feature_extractor() + return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) + + def create_encoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + ) -> Union[str, list[int]]: + # Strictly speaking, whisper encoder only accept audio features. + # We create a dummy encoder prompt here which will be padded to + # num_audio_tokens. So that we can create dummy data from this + # for encoder profiling. + return [0] + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + if mm_data: + feature_extractor = self.info.get_feature_extractor(**mm_kwargs) + mm_data = dict(audio=mm_data.pop("audios")) + mm_kwargs = dict( + **mm_kwargs, + sampling_rate=feature_extractor.sampling_rate, + ) + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + if "labels" in processed_outputs: + processed_outputs["input_ids"] = processed_outputs.pop("labels") + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(input_features=MultiModalFieldConfig.batched("audio")) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + num_tokens = self.info.get_max_audio_tokens() + return [ + PromptReplacement( + modality="audio", + target=[0], + replacement=[0] * num_tokens, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor(WhisperMultiModalProcessor, + info=WhisperProcessingInfo, + dummy_inputs=WhisperDummyInputsBuilder) +class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, + SupportsMultiModal, SupportsV0Only): + packed_modules_mapping = { + "self_attn.qkv_proj": [ + "self_attn.q_proj", + "self_attn.k_proj", + "self_attn.v_proj", + ], + "encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"], + } + + hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={ + ".fc1.": ".mlp.fc1.", + ".fc2.": ".mlp.fc2." + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.dtype = vllm_config.model_config.dtype + + self.model = WhisperModel(vllm_config=vllm_config, prefix=prefix) + self.unpadded_vocab_size = config.vocab_size + self.proj_out = ParallelLMHead(config.vocab_size, + config.d_model, + quant_config=quant_config) + self.proj_out = self.proj_out.tie_weights( + self.model.decoder.embed_tokens) + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, logit_scale) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + audio_input = self._parse_and_validate_audio_input(**kwargs) + decoder_outputs = self.model( + input_features=audio_input["input_features"], + input_ids=input_ids, + positions=positions, + ) + return decoder_outputs + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + # TODO: This method does not obey the interface for SupportsMultiModal. + # Refactor this once encoder/decoder support is implemented in V1. + audio_input = self._parse_and_validate_audio_input(**kwargs) + return self.model.get_encoder_outputs(audio_input["input_features"]) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + # TODO: This method just returns the decoder sequence embeddings since + # Whisper does not have encoder text tokens. Refactor this once + # encoder/decoder support is implemented in V1. + return self.model.decoder.get_input_embeddings(input_ids) + + def _parse_and_validate_audio_input( + self, **kwargs: object) -> WhisperAudioInputs: + input_features = kwargs.pop("input_features", None) + + if input_features is not None: + if not isinstance(input_features, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio features. " + f"Got type: {type(input_features)}") + input_features = torch.cat( + [feat.to(self.dtype) for feat in input_features]) + + return WhisperAudioInputs(input_features=input_features) + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.proj_out, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."]) + + # add fake zeros bias for k_proj to state_dict + weights = _create_fake_bias_for_k_proj(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + +def _create_fake_bias_for_k_proj( + weights: Iterable[Tuple[str, torch.Tensor]] +) -> Iterable[Tuple[str, torch.Tensor]]: + """ + Create full zeros bias for k_proj weight in self-attn and x-attn layers. + So that the bias for k_proj in qkv_proj can be initialized with zeros. + """ + for name, weight in weights: + if name.endswith(".k_proj.weight"): + bias = torch.zeros(weight.size(0)) + bias_name = name.replace("weight", "bias") + yield from [(name, weight), (bias_name, bias)] + yield name, weight diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py new file mode 100644 index 0000000..7e21024 --- /dev/null +++ b/vllm/model_executor/models/zamba2.py @@ -0,0 +1,1031 @@ +# SPDX-License-Identifier: Apache-2.0 +"""PyTorch Zamba2 model implementation for vLLM. + +This module implements the Zamba2 architecture from +https://arxiv.org/abs/2411.15242, which combines Mamba and Transformer +architectures in a hybrid model optimized for efficient sequence modeling. The +model alternates between state space model layers and attention-based layers. +""" +from itertools import cycle +from typing import Dict, Iterable, List, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import Zamba2Config + +from vllm.attention.layer import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.activation import GeluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba_mixer2 import ( + MambaMixer2, extra_groups_for_head_shards) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.mamba_cache import (MambaCacheManager, + MambaCacheParams) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import HasInnerState, IsHybrid, SupportsV0Only +from .utils import maybe_prefix + + +class Zamba2LoRA(nn.Module): + """LoRA layer for the Zamba2 model. + + Implements a LoRA layer that is used in shared attention and gated MLP + blocks. + """ + + def __init__( + self, + input_dim: int, + rank: int, + output_dim: Union[int, List[int]], + quant_config: Optional[QuantizationConfig] = None, + ): + """Initialize the attention layer. + + Args: + input_dim: input dimension + rank: LoRA rank + output_dim: output dimension + quant_config: Configuration for model quantization + """ + super().__init__() + + self.A = ColumnParallelLinear(input_dim, + rank, + bias=False, + quant_config=quant_config, + gather_output=True) + + if isinstance(output_dim, list): + B_class = MergedColumnParallelLinear + else: + B_class = ColumnParallelLinear + self.B = B_class(rank, + output_dim, + bias=False, + quant_config=quant_config) + + def forward( + self, + hidden_states: torch.Tensor, + ): + lora_output, _ = self.A(hidden_states) + lora_output, _ = self.B(lora_output) + return lora_output + + +class Zamba2Attention(nn.Module): + """Multi-head attention mechanism for the Zamba2 model. + + Implements attention with parallel computation, QKV projections, optional + adapters and rotary position embeddings. The attention is computed across + distributed blocks for efficient processing. + """ + + def __init__( + self, + config: Zamba2Config, + bare_block_idx: int, + num_hybrid_layers: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + """Initialize the attention layer. + + Args: + config: The Zamba2 model configuration + bare_block_idx: Index of the bare attention block + num_hybrid_layers: Total number of hybrid layers + cache_config: Configuration for key-value caching + quant_config: Configuration for model quantization + prefix: Optional prefix for parameter names + """ + super().__init__() + tp_size = get_tensor_model_parallel_world_size() + self.config = config + self.num_hybrid_layers = num_hybrid_layers + self.rope_theta = config.rope_theta + + self.attention_hidden_size = config.attention_hidden_size + self.total_num_attention_heads = config.num_attention_heads + assert self.total_num_attention_heads % tp_size == 0 + self.num_attention_heads = config.num_attention_heads // tp_size + self.attention_head_dim = config.attention_head_dim + self.qkv_size = self.attention_hidden_size // tp_size + self.scale = (self.attention_head_dim / 2)**-0.5 + + if (self.attention_head_dim * + self.total_num_attention_heads) != self.attention_hidden_size: + raise ValueError( + f"attention_hidden_size must be divisible by" + f" num_attention_heads" + f" (got `attention_hidden_size`: {self.attention_hidden_size}" + f" and `num_heads`: {self.num_attention_heads}).") + + self.qkv_proj = QKVParallelLinear( + self.attention_hidden_size, + self.attention_head_dim, + self.total_num_attention_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear(self.attention_hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config) + + # Even though in Zamba2 weights are shared between attention layers, KV + # cache is unique for every attention layer. Hence, we need to define + # separate Attention objects, because in recent vLLM KV cache tensors + # are tied to specific Attention objects. + + # Initialize attention blocks with proper indexing + self.dpa_list = nn.ModuleList([]) + j = bare_block_idx * (self.num_hybrid_layers + config.num_mem_blocks - + 1) // config.num_mem_blocks + for block_idx in range(self.num_hybrid_layers): + if block_idx % config.num_mem_blocks == bare_block_idx: + dpa = Attention( + self.num_attention_heads, + self.attention_head_dim, + self.scale, + cache_config=cache_config, + prefix=f"{prefix}.attn.{j}", + ) + j += 1 + else: + dpa = nn.Identity() + self.dpa_list.append(dpa) + + # Initialize adapter layers if enabled + if config.use_shared_attention_adapter: + self.linear_q_adapter_list = nn.ModuleList([]) + self.linear_k_adapter_list = nn.ModuleList([]) + self.linear_v_adapter_list = nn.ModuleList([]) + + for block_idx in range(self.num_hybrid_layers): + if block_idx % config.num_mem_blocks == bare_block_idx: + linear_q_adapter = Zamba2LoRA( + self.attention_hidden_size, + config.adapter_rank, + self.attention_hidden_size, + quant_config=quant_config, + ) + linear_k_adapter = Zamba2LoRA( + self.attention_hidden_size, + config.adapter_rank, + self.attention_hidden_size, + quant_config=quant_config, + ) + linear_v_adapter = Zamba2LoRA( + self.attention_hidden_size, + config.adapter_rank, + self.attention_hidden_size, + quant_config=quant_config, + ) + else: + linear_q_adapter = nn.Identity() + linear_k_adapter = nn.Identity() + linear_v_adapter = nn.Identity() + + self.linear_q_adapter_list.append(linear_q_adapter) + self.linear_k_adapter_list.append(linear_k_adapter) + self.linear_v_adapter_list.append(linear_v_adapter) + + if config.use_mem_rope: + self.rotary_emb = get_rope( + head_size=self.attention_head_dim, + rotary_dim=self.attention_head_dim, + max_position=config.max_position_embeddings, + base=self.rope_theta, + rope_scaling=None, + is_neox_style=True, + ) + + def forward( + self, + hidden_states: torch.Tensor, + block_idx: int, + position_ids: torch.Tensor, + ) -> torch.Tensor: + """Forward pass through the attention layer. + + Args: + hidden_states: Input tensor [batch_size, seq_len, hidden_size] + position_ids: Position IDs for positional embeddings + block_idx: Current shared transformer block index + + Returns: + Output tensor [batch_size, seq_len, hidden_size] + """ + qkv, _ = self.qkv_proj(hidden_states) + query_states, key_states, value_states = qkv.split([self.qkv_size] * 3, + dim=-1) + + if self.config.use_shared_attention_adapter: + # Apply adapter transformations to Q, K, V if enabled + q_adapter = self.linear_q_adapter_list[block_idx] + assert not isinstance(q_adapter, nn.Identity) + q_lora_output = q_adapter(hidden_states) + query_states = query_states + q_lora_output + + k_adapter = self.linear_k_adapter_list[block_idx] + assert not isinstance(k_adapter, nn.Identity) + k_lora_output = k_adapter(hidden_states) + key_states = key_states + k_lora_output + + v_adapter = self.linear_v_adapter_list[block_idx] + assert not isinstance(v_adapter, nn.Identity) + v_lora_output = v_adapter(hidden_states) + value_states = value_states + v_lora_output + + if self.config.use_mem_rope: + query_states, key_states = self.rotary_emb(position_ids, + query_states, + key_states) + + y = self.dpa_list[block_idx](query_states, key_states, value_states) + y, _ = self.o_proj(y) + return y + + +class Zamba2MLP(nn.Module): + """Feed-forward MLP layer for the Zamba2 model. + + Implements a gated feed-forward network that projects inputs to a larger + intermediate size, applies GELU activation with gating, then projects back + to the original size. Includes optional adapter layers for model adaptation. + """ + + def __init__( + self, + config: Zamba2Config, + bare_block_idx: int, + num_hybrid_layers: Dict[int, int], + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + """Initialize the MLP layer. + + Args: + config: The Zamba2 model configuration + bare_block_idx: Index of the bare block in the model + num_hybrid_layers: Total number of hybrid layers + quant_config: Configuration for model quantization + """ + super().__init__() + self.config = config + self.tp_size = get_tensor_model_parallel_world_size() + self.num_hybrid_layers = num_hybrid_layers + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + # Main projection layers with gating + self.gate_up_proj = MergedColumnParallelLinear( + self.hidden_size, + 2 * [self.intermediate_size], # 2x for gate and input projections + bias=self.config.add_bias_linear, + quant_config=quant_config) + + self.down_proj = RowParallelLinear(self.intermediate_size, + self.hidden_size, + bias=self.config.add_bias_linear, + quant_config=quant_config) + + # Only allow GELU activations + if config.hidden_act != "gelu": + raise ValueError(f"Only GELU activation is supported " + f"(got `hidden_act`: {config.hidden_act})") + self.act_fn = GeluAndMul() + + # Initialize adapter layers + self.gate_up_proj_adapter_list = nn.ModuleList([]) + for block_idx in range(self.num_hybrid_layers): + if block_idx % config.num_mem_blocks == bare_block_idx: + gate_up_proj_adapter = Zamba2LoRA( + config.hidden_size, + config.adapter_rank, + 2 * [self.intermediate_size], + quant_config, + ) + else: + gate_up_proj_adapter = nn.Identity() + self.gate_up_proj_adapter_list.append(gate_up_proj_adapter) + + def forward(self, hidden_states: torch.Tensor, + block_idx: int) -> torch.Tensor: + """Forward pass through the MLP layer. + + Args: + hidden_states: Input tensor [batch_size, seq_len, hidden_size] + block_idx: Current shared transformer block index + + Returns: + Output tensor [batch_size, seq_len, hidden_size] after applying + gated feed-forward transformation + """ + # Project input to intermediate size with gating + gate_up_states, _ = self.gate_up_proj(hidden_states) + + # Apply adapter transformation if present + adapter = self.gate_up_proj_adapter_list[block_idx] + assert not isinstance(adapter, nn.Identity) + lora_output = adapter(hidden_states) + gate_up_states = gate_up_states + lora_output + + # Apply GELU activation with gating + hidden_states = self.act_fn(gate_up_states) + + # Project back to hidden size + output, _ = self.down_proj(hidden_states) + return output + + +class Zamba2AttentionDecoderLayer(nn.Module): + """Single decoder layer combining attention and feed-forward networks. + + This layer implements a standard transformer block with: + - Input layer normalization + - Multi-head self-attention + - Pre-feed-forward layer normalization + - Feed-forward network (MLP) + """ + + def __init__( + self, + config: Zamba2Config, + bare_block_idx: int, + num_hybrid_layers: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + """Initialize the decoder layer. + + Args: + config: The Zamba2 model configuration + bare_block_idx: Index of the bare block + num_hybrid_layers: Total number of hybrid layers + cache_config: Configuration for key-value caching + quant_config: Configuration for model quantization + prefix: Optional prefix for parameter names + """ + super().__init__() + + # Initialize attention sublayer + self.self_attn = Zamba2Attention( + config, + bare_block_idx=bare_block_idx, + num_hybrid_layers=num_hybrid_layers, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ) + + # Initialize feed-forward sublayer + self.feed_forward = Zamba2MLP( + config, + bare_block_idx=bare_block_idx, + num_hybrid_layers=num_hybrid_layers, + quant_config=quant_config, + ) + + # Initialize layer normalizations + # Input normalization operates on concatenated states + self.input_layernorm = RMSNorm(2 * config.hidden_size, + eps=config.rms_norm_eps) + # Pre-FF normalization operates on attention output + self.pre_ff_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + original_hidden_states: torch.Tensor, + block_idx: int, + positions: torch.Tensor, + ) -> torch.Tensor: + """Forward pass through the decoder layer. + + Args: + hidden_states: Input tensor from previous layer + original_hidden_states: Original input tensor for residual + connection + block_idx: Current shared transformer block index + positions: IDs for positional embeddings + + Returns: + Transformed hidden states after attention and feed-forward + """ + + # The argument original_hidden_states is concatenated with hidden_states + # (which is the output of the previous (mamba) layer). + # The concatenated tensor is then used as input of the pre-attention + # RMSNorm (see fig. 2 in https://arxiv.org/pdf/2405.16712). + hidden_states = torch.concatenate( + [hidden_states, original_hidden_states], dim=-1) + + # Layer norm before attention + hidden_states = self.input_layernorm(hidden_states) + + # Self attention + hidden_states = self.self_attn( + hidden_states, + position_ids=positions, + block_idx=block_idx, + ) + + # Layer norm before feed-forward + hidden_states = self.pre_ff_layernorm(hidden_states) + + # Feed-forward network + hidden_states = self.feed_forward(hidden_states, block_idx=block_idx) + + return hidden_states + + +class Zamba2MambaDecoderLayer(nn.Module): + """Single Mamba decoder layer with normalization. + + This implements a Mamba block. It includes input normalization + and can process sequences using either chunked or full + computation depending on configuration. + """ + + def __init__( + self, + config: Zamba2Config, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + """Initialize the Mamba decoder layer. + + Args: + config: The Zamba2 model configuration + quant_config: Configuration for model quantization + """ + super().__init__() + + # Initialize Mamba mixer with expanded intermediate size + intermediate_size = config.mamba_expand * config.hidden_size + self.mamba = MambaMixer2( + hidden_size=config.hidden_size, + ssm_state_size=config.mamba_d_state, + conv_kernel_size=config.mamba_d_conv, + intermediate_size=intermediate_size, + use_conv_bias=config.use_conv_bias, + use_bias=config.add_bias_linear, + n_groups=config.mamba_ngroups, + num_heads=config.n_mamba_heads, + head_dim=intermediate_size // config.n_mamba_heads, + rms_norm_eps=config.rms_norm_eps, + activation="silu", + chunk_size=config.chunk_size, + quant_config=quant_config, + ) + + # Input normalization + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + mamba_cache_params: MambaCacheParams, + sequence_idx: Optional[torch.Tensor] = None, + transformer_hidden_states: Optional[torch.Tensor] = None, + positions: Optional[torch.Tensor] = None, + original_hidden_states: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass through the Mamba decoder layer. + + Args: + hidden_states: Input tensor [batch_size, seq_len, hidden_size] + mamba_cache_params: Parameters for Mamba's state caches + (one for conv, one for ssm) + sequence_idx: Index tensor for identifying sequences in batch + Required for proper chunked processing in prefill + transformer_hidden_states: Optional output from transformer path + Added to input if provided (used in hybrid architecture) + positions: Optional position IDs (unused in Mamba) + original_hidden_states: Optional original inputs (unused in Mamba) + + Returns: + Transformed hidden states with residual connection applied + """ + # Store input for residual connection + residual = hidden_states + + # `transformer_hidden_states` is the output from shared + # transformer + linear layer (see fig. 2 in + # https://arxiv.org/pdf/2405.16712). + # `transformer_hidden_states` is then added to the input to the mamba + # layer below (as described in eq. (6) of + # https://arxiv.org/pdf/2405.16712). + if transformer_hidden_states is not None: + hidden_states = hidden_states + transformer_hidden_states + + # Apply input normalization + hidden_states = self.input_layernorm(hidden_states) + + # Process through Mamba mixer + hidden_states = self.mamba( + hidden_states, + mamba_cache_params=mamba_cache_params, + sequence_idx=sequence_idx, + ) + + # residual connection after mamba + hidden_states = residual + hidden_states + + return hidden_states + + +class Zamba2HybridLayer(nn.Module): + """Hybrid layer combining Transformer and Mamba architectures. + + This layer implements the hybrid architecture described in the Zamba paper, + where a shared transformer pathway processes input in parallel with a Mamba + pathway. The transformer output is projected and added to the Mamba input + for enhanced representation learning. + """ + + def __init__( + self, + shared_transformer: Zamba2AttentionDecoderLayer, + config: Zamba2Config, + block_idx: int, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + """Initialize the hybrid layer. + + Args: + shared_transformer: Transformer decoder layer for attention pathway + linear: Linear projection for transformer output before Mamba + mamba: Mamba decoder layer for state space pathway + """ + super().__init__() + self.block_idx = block_idx + self.shared_transformer = shared_transformer + self.linear = ReplicatedLinear(config.hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config) + self.mamba_decoder = Zamba2MambaDecoderLayer(config, + quant_config=quant_config) + + def forward( + self, + hidden_states: torch.Tensor, + original_hidden_states: torch.Tensor, + positions: torch.Tensor, + mamba_cache_params: Optional[MambaCacheParams] = None, + sequence_idx: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass through the hybrid layer. + + Processes input through parallel transformer and Mamba paths: + 1. Transformer path processes input with attention + 2. Transformer output is projected to match hidden size + 3. Projected output is added to Mamba path input + 4. Final output combines both paths' representations + + Args: + hidden_states: Input tensor [batch_size, seq_len, hidden_size] + original_hidden_states: Original input for transformer residual + connection + positions: Position IDs for positional embeddings + mamba_cache_params: Parameters for Mamba's state caches + (one for conv, one for ssm) + sequence_idx: Indices for identifying sequences in batch, + required for proper chunked processing in prefill + + Returns: + Output tensor combining transformer and Mamba representations + """ + # Process through transformer pathway + transformer_hidden_states = self.shared_transformer( + hidden_states, + original_hidden_states=original_hidden_states, + block_idx=self.block_idx, + positions=positions, + ) + + # Project transformer output + transformer_hidden_states, _ = self.linear(transformer_hidden_states) + + # Process through Mamba pathway with transformer injection + layer_outputs = self.mamba_decoder( + hidden_states, + transformer_hidden_states=transformer_hidden_states, + mamba_cache_params=mamba_cache_params, + sequence_idx=sequence_idx, + ) + + return layer_outputs + + +class Zamba2Model(nn.Module): + """Core Zamba2 model combining transformer and Mamba architectures. + + The model processes input through a sequence of hybrid and Mamba-only + layers, using token embeddings and final layer normalization. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + """Initialize the Zamba2 model. + + Args: + vllm_config: Configuration object containing model, cache, + quantization and LoRA settings + prefix: Optional prefix for parameter names in state dict + """ + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + is_lora_enabled = bool(lora_config) + assert not is_lora_enabled + + self.config = config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + # Initialize token embeddings + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + # Map hybrid layer indices to block indices + layer2block_map = { + layer_idx: block_idx + for block_idx, layer_idx in enumerate(config.hybrid_layer_ids) + } + + # Create cyclic iterator of transformer blocks + blocks = cycle([ + Zamba2AttentionDecoderLayer(config, + bare_block_idx=idx, + num_hybrid_layers=len(layer2block_map), + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}") + for idx in range(config.num_mem_blocks) + ]) + + # Initialize layers according to block type configuration + layers = [] + for layer_idx, layer_type in enumerate(config.layers_block_type): + if layer_type == "hybrid": + block = next(blocks) + block_idx = layer2block_map[layer_idx] + layers.append( + Zamba2HybridLayer(block, config, block_idx, quant_config)) + else: + layers.append( + Zamba2MambaDecoderLayer(config, quant_config=quant_config)) + self.layers = nn.ModuleList(layers) + + # Final layer normalization + self.final_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + """Convert input token IDs to embeddings. + + Args: + input_ids: Tensor of input token IDs + + Returns: + Embedded representation of the input tokens + """ + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + mamba_cache_params: MambaCacheParams, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + """Forward pass through the model. + + Args: + input_ids: Input token IDs + positions: Position IDs for embeddings + mamba_cache_params: Parameters for Mamba's state caches + (one for conv, one for ssm) + inputs_embeds: Optional pre-computed input embeddings + + Returns: + Either final hidden states or intermediate tensors for pipeline + parallelism + """ + # Handle pipeline parallelism for first rank + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) + hidden_states = inputs_embeds + + # pass a sequence index tensor, that is required for + # proper continuous batching computation including + # chunked prefill + seq_idx = None + attn_metadata = get_forward_context().attn_metadata + if attn_metadata.num_prefills > 0: + seq_idx = torch.zeros_like(input_ids, dtype=torch.int32) + for i, (srt, end) in enumerate( + zip( + attn_metadata.query_start_loc, + attn_metadata.query_start_loc[1:], + )): + seq_idx[srt:end] = i + seq_idx.unsqueeze_(0) + + # Process through layers + original_hidden_states = torch.clone(hidden_states) + for layer_idx, layer in enumerate(self.layers): + layer_outputs = layer( + hidden_states, + original_hidden_states=original_hidden_states, + positions=positions, + mamba_cache_params=mamba_cache_params.at_layer_idx(layer_idx), + sequence_idx=seq_idx, + ) + hidden_states = layer_outputs + + hidden_states = self.final_layernorm(hidden_states) + return hidden_states + + +class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only): + """Zamba2 model with causal language modeling head. + + This class wraps the core Zamba2 model and adds: + - A language modeling head for next token prediction + - Mamba state caching functionality + - Support for model parallelism and quantization + - Sampling capabilities for text generation + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + """Initialize the Zamba2 model for causal language modeling. + + Args: + vllm_config: Configuration containing model, cache, quantization, + LoRA and scheduler settings + prefix: Optional prefix for parameter names + + Raises: + AssertionError: If prefix caching is enabled (not supported by + Mamba) + """ + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config + assert not cache_config.enable_prefix_caching, \ + "Mamba does not support prefix caching" + + super().__init__() + self.config = config + self.vllm_config = vllm_config + self.scheduler_config = scheduler_config + self.model_config = vllm_config.model_config + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + # Initialize core model + self.model = Zamba2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + # Initialize language modeling head + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + # Tie weights with input embeddings if using same dimensions + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) + + # Used to track and store by the Mamba cache between steps. + self.mamba_cache: Optional[MambaCacheManager] = None + + # Initialize logits processing and sampling + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = get_sampler() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + """Convert input token IDs to embeddings. + Args: + input_ids: Tensor of input token IDs + Returns: + Embedded representation of the input tokens + """ + return self.model.get_input_embeddings(input_ids) + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs) -> torch.Tensor: + """Forward pass through the model. + + Args: + input_ids: Input token IDs + positions: Position IDs for embeddings + inputs_embeds: Optional pre-computed input embeddings + **kwargs: Additional arguments passed to cache manager + + Returns: + Output hidden states + """ + # Initialize Mamba cache if needed + if self.mamba_cache is None: + num_mamba_layers = self.config.num_hidden_layers + self.mamba_cache = MambaCacheManager( + self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, + *self._get_mamba_cache_shape()) + + # Get cache parameters for current run + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + + # Forward pass through model + hidden_states = self.model( + input_ids, + positions, + mamba_cache_params, + inputs_embeds, + ) + + return hidden_states + + def copy_inputs_before_cuda_graphs(self, input_buffers: Dict[str, + torch.Tensor], + **kwargs) -> Dict[str, torch.Tensor]: + """Copy inputs before CUDA graph capture. + + Args: + input_buffers: Dictionary of input tensors + **kwargs: Additional arguments passed to cache manager + + Returns: + Updated input buffers + """ + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs( + self, batch_size: int) -> Dict[str, torch.Tensor]: + """Get inputs for sequence-length-agnostic graph capture. + + Args: + batch_size: Size of batch to capture + Returns: + Dictionary of capture inputs + """ + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def _get_mamba_cache_shape( + self) -> Tuple[Tuple[int, int], Tuple[int, int]]: + """Calculate shapes for Mamba's convolutional and state caches. + + Returns: + Tuple containing: + - conv_state_shape: Shape for convolutional state cache + - temporal_state_shape: Shape for state space model cache + """ + world_size = get_tensor_model_parallel_world_size() + + intermediate_size = self.config.mamba_expand * self.config.hidden_size + + # Extend groups if needed to ensure all groups needed by a head + # are sharded together + + # if n_groups is not divisible by world_size, need to extend the shards + # to ensure all groups needed by a head is sharded along with it + n_groups = (self.config.mamba_ngroups + extra_groups_for_head_shards( + self.config.mamba_ngroups, world_size)) + + # Calculate conv state shape (includes groups) + # - heads and n_groups are TP-ed + conv_dim = (intermediate_size + + 2 * n_groups * self.config.mamba_d_state) + conv_state_shape = ( + divide(conv_dim, world_size), + self.config.mamba_d_conv - 1, + ) + + # Calculate temporal state shape (per-head states) + # These are not TP-ed as they depend on A, dt_bias, D + # - they are typically small + # e.g., (h_heads, d_head, d_state) = (128, 64, 128) + temporal_state_shape = ( + divide(divide(intermediate_size, self.config.mamba_headdim), + world_size), + self.config.mamba_headdim, + self.config.mamba_d_state, + ) + + return conv_state_shape, temporal_state_shape + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + """Compute logits for next token prediction. + + Args: + hidden_states: Hidden states from model forward pass + sampling_metadata: Metadata for sampling process + + Returns: + Logits for next token prediction + """ + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + """Sample next tokens from computed logits. + + Args: + logits: Computed logits for next token prediction + sampling_metadata: Metadata for sampling process + + Returns: + Sampled tokens and related sampling information + """ + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + weights_dict = {} + for key, loaded_weight in weights: + if "A_log" in key: + key = key.replace("A_log", "A") + elif "adapter_list" in key: + key = key.replace("0.weight", "A.weight") + key = key.replace("1.weight", "B.weight") + weights_dict[key] = loaded_weight + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for chkpt_weight_name, loaded_weight in weights_dict.items(): + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in chkpt_weight_name: + continue + chkpt_weight_name = chkpt_weight_name.replace( + weight_name, param_name) + param = params_dict[chkpt_weight_name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if chkpt_weight_name not in params_dict: + continue + param = params_dict[chkpt_weight_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(chkpt_weight_name) + return loaded_params diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py new file mode 100644 index 0000000..2b1294b --- /dev/null +++ b/vllm/model_executor/parameter.py @@ -0,0 +1,433 @@ +# SPDX-License-Identifier: Apache-2.0 + +from fractions import Fraction +from typing import Callable, Optional, Union + +import torch +from torch.nn import Parameter + +from vllm.distributed import get_tensor_model_parallel_rank +from vllm.logger import init_logger +from vllm.model_executor.utils import _make_synced_weight_loader + +__all__ = [ + "BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter", + "ModelWeightParameter", "ChannelQuantScaleParameter", + "GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter" +] + +logger = init_logger(__name__) + + +class BasevLLMParameter(Parameter): + """ + Base parameter for vLLM linear layers. Extends the torch.nn.parameter + by taking in a linear weight loader. Will copy the loaded weight + into the parameter when the provided weight loader is called. + """ + + def __new__(cls, data: torch.Tensor, **kwargs): + + return super().__new__(cls, data=data, requires_grad=False) + + def __init__(self, data: torch.Tensor, weight_loader: Callable): + """ + Initialize the BasevLLMParameter + + :param data: torch tensor with the parameter data + :param weight_loader: weight loader callable + + :returns: a torch.nn.parameter + """ + + # During weight loading, we often do something like: + # narrowed_tensor = param.data.narrow(0, offset, len) + # narrowed_tensor.copy_(real_weight) + # expecting narrowed_tensor and param.data to share the same storage. + # However, on TPUs, narrowed_tensor will lazily propagate to the base + # tensor, which is param.data, leading to the redundant memory usage. + # This sometimes causes OOM errors during model loading. To avoid this, + # we sync the param tensor after its weight loader is called. + from vllm.platforms import current_platform + if current_platform.is_tpu(): + weight_loader = _make_synced_weight_loader(weight_loader) + + self._weight_loader = weight_loader + + @property + def weight_loader(self): + return self._weight_loader + + def _is_1d_and_scalar(self, loaded_weight: torch.Tensor): + cond1 = self.data.ndim == 1 and self.data.numel() == 1 + cond2 = loaded_weight.ndim == 0 and loaded_weight.numel() == 1 + return (cond1 and cond2) + + def _assert_and_load(self, loaded_weight: torch.Tensor): + assert (self.data.shape == loaded_weight.shape + or self._is_1d_and_scalar(loaded_weight)) + self.data.copy_(loaded_weight) + + def load_column_parallel_weight(self, loaded_weight: torch.Tensor): + self._assert_and_load(loaded_weight) + + def load_row_parallel_weight(self, loaded_weight: torch.Tensor): + self._assert_and_load(loaded_weight) + + def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): + self._assert_and_load(loaded_weight) + + def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): + self._assert_and_load(loaded_weight) + + +class _ColumnvLLMParameter(BasevLLMParameter): + """ + Private class defining weight loading functionality + (load_merged_column_weight, load_qkv_weight) + for parameters being loaded into linear layers with column + parallelism. This includes QKV and MLP layers which are + not already fused on disk. Requires an output dimension + to be defined. Called within the weight loader of + each of the column parallel linear layers. + """ + + def __init__(self, output_dim: int, **kwargs): + self._output_dim = output_dim + super().__init__(**kwargs) + + @property + def output_dim(self): + return self._output_dim + + def load_column_parallel_weight(self, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.data.shape[self.output_dim] + loaded_weight = loaded_weight.narrow(self.output_dim, + tp_rank * shard_size, shard_size) + assert self.data.shape == loaded_weight.shape + self.data.copy_(loaded_weight) + + def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): + + shard_offset = kwargs.get("shard_offset") + shard_size = kwargs.get("shard_size") + if isinstance( + self, + (PackedColumnParameter, + PackedvLLMParameter)) and self.packed_dim == self.output_dim: + shard_size, shard_offset = self.adjust_shard_indexes_for_packing( + shard_offset=shard_offset, shard_size=shard_size) + + param_data = self.data + + tp_rank = get_tensor_model_parallel_rank() + param_data = param_data.narrow(self.output_dim, shard_offset, + shard_size) + loaded_weight = loaded_weight.narrow(self.output_dim, + tp_rank * shard_size, shard_size) + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): + + shard_offset = kwargs.get("shard_offset") + shard_size = kwargs.get("shard_size") + shard_id = kwargs.get("shard_id") + num_heads = kwargs.get("num_heads") + + if isinstance( + self, + (PackedColumnParameter, + PackedvLLMParameter)) and self.output_dim == self.packed_dim: + shard_size, shard_offset = self.adjust_shard_indexes_for_packing( + shard_offset=shard_offset, shard_size=shard_size) + + param_data = self.data + tp_rank = get_tensor_model_parallel_rank() + shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads + param_data = param_data.narrow(self.output_dim, shard_offset, + shard_size) + loaded_weight = loaded_weight.narrow(self.output_dim, + shard_id * shard_size, shard_size) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +class RowvLLMParameter(BasevLLMParameter): + """ + Parameter class defining weight_loading functionality + (load_row_parallel_weight) for parameters being loaded + into linear layers with row parallel functionality. + Requires an input_dim to be defined. + """ + + def __init__(self, input_dim: int, **kwargs): + self._input_dim = input_dim + super().__init__(**kwargs) + + @property + def input_dim(self): + return self._input_dim + + def load_row_parallel_weight(self, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.data.shape[self.input_dim] + loaded_weight = loaded_weight.narrow(self.input_dim, + tp_rank * shard_size, shard_size) + + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert self.data.shape == loaded_weight.shape + self.data.copy_(loaded_weight) + + +class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter): + """ + Parameter class for linear layer weights. Uses both column and + row parallelism. + """ + pass + + +class GroupQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter): + """ + Parameter class for weight scales loaded for weights with + grouped quantization. Uses both column and row parallelism. + """ + pass + + +class ChannelQuantScaleParameter(_ColumnvLLMParameter): + """ + Parameter class for weight scales loaded for weights with + channel-wise quantization. Equivalent to _ColumnvLLMParameter. + """ + pass + + +class PerTensorScaleParameter(BasevLLMParameter): + """ + Parameter class for scales where the number of scales is + equivalent to the number of logical matrices in fused linear + layers (e.g. for QKV, there are 3 scales loaded from disk). + This is relevant to weights with per-tensor quantization. + Adds functionality to map the scalers to a shard during + weight loading. + + Note: additional parameter manipulation may be handled + for each quantization config specifically, within + process_weights_after_loading + """ + + def __init__(self, **kwargs): + self.qkv_idxs = {"q": 0, "k": 1, "v": 2} + super().__init__(**kwargs) + + def _shard_id_as_int(self, shard_id: Union[str, int]) -> int: + if isinstance(shard_id, int): + return shard_id + + # if not int, assume shard_id for qkv + # map to int and return + assert isinstance(shard_id, str) + assert shard_id in self.qkv_idxs + return self.qkv_idxs[shard_id] + + # For row parallel layers, no sharding needed + # load weight into parameter as is + def load_row_parallel_weight(self, *args, **kwargs): + super().load_row_parallel_weight(*args, **kwargs) + + def load_merged_column_weight(self, *args, **kwargs): + self._load_into_shard_id(*args, **kwargs) + + def load_qkv_weight(self, *args, **kwargs): + self._load_into_shard_id(*args, **kwargs) + + def load_column_parallel_weight(self, *args, **kwargs): + super().load_row_parallel_weight(*args, **kwargs) + + def _load_into_shard_id(self, loaded_weight: torch.Tensor, + shard_id: Union[str, int], **kwargs): + """ + Slice the parameter data based on the shard id for + loading. + """ + + param_data = self.data + shard_id = self._shard_id_as_int(shard_id) + + # AutoFP8 scales do not have a shape + # compressed-tensors scales do have a shape + if len(loaded_weight.shape) != 0: + assert loaded_weight.shape[0] == 1 + loaded_weight = loaded_weight[0] + + param_data = param_data[shard_id] + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +class PackedColumnParameter(_ColumnvLLMParameter): + """ + Parameter for model parameters which are packed on disk + and support column parallelism only. See PackedvLLMParameter + for more details on the packed properties. + """ + + def __init__(self, + packed_factor: Union[int, Fraction], + packed_dim: int, + marlin_tile_size: Optional[int] = None, + **kwargs): + self._packed_factor = packed_factor + self._packed_dim = packed_dim + self._marlin_tile_size = marlin_tile_size + super().__init__(**kwargs) + + @property + def packed_dim(self): + return self._packed_dim + + @property + def packed_factor(self): + return self._packed_factor + + @property + def marlin_tile_size(self): + return self._marlin_tile_size + + def adjust_shard_indexes_for_packing(self, shard_size, shard_offset): + return _adjust_shard_indexes_for_packing( + shard_size=shard_size, + shard_offset=shard_offset, + packed_factor=self.packed_factor, + marlin_tile_size=self.marlin_tile_size) + + +class PackedvLLMParameter(ModelWeightParameter): + """ + Parameter for model weights which are packed on disk. + Example: GPTQ Marlin weights are int4 or int8, packed into int32. + Extends the ModelWeightParameter to take in the + packed factor, the packed dimension, and optionally, marlin + tile size for marlin kernels. Adjusts the shard_size and + shard_offset for fused linear layers model weight loading + by accounting for packing and optionally, marlin tile size. + """ + + def __init__(self, + packed_factor: Union[int, Fraction], + packed_dim: int, + marlin_tile_size: Optional[int] = None, + **kwargs): + self._packed_factor = packed_factor + self._packed_dim = packed_dim + self._marlin_tile_size = marlin_tile_size + super().__init__(**kwargs) + + @property + def packed_dim(self): + return self._packed_dim + + @property + def packed_factor(self): + return self._packed_factor + + @property + def marlin_tile_size(self): + return self._marlin_tile_size + + def adjust_shard_indexes_for_packing(self, shard_size, shard_offset): + return _adjust_shard_indexes_for_packing( + shard_size=shard_size, + shard_offset=shard_offset, + packed_factor=self.packed_factor, + marlin_tile_size=self.marlin_tile_size) + + +class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter): + """ + Parameter class for weight scales loaded for weights with + block-wise quantization. Uses both column and row parallelism. + """ + + pass + + +def permute_param_layout_(param: BasevLLMParameter, input_dim: int, + output_dim: int, **kwargs) -> BasevLLMParameter: + """ + Permute a parameter's layout to the specified input and output dimensions, + useful for forcing the parameter into a known layout, for example, if I need + a packed (quantized) weight matrix to be in the layout + {input_dim = 0, output_dim = 1, packed_dim = 0} + then I can call: + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + to ensure x is in the correct layout (permuting it to the correct layout if + required, asserting if it cannot get it to the correct layout) + """ + + curr_input_dim = getattr(param, "input_dim", None) + curr_output_dim = getattr(param, "output_dim", None) + + if curr_input_dim is None or curr_output_dim is None: + assert param.data.dim() == 2,\ + "permute_param_layout_ only supports 2D parameters when either "\ + "input_dim or output_dim is not set" + + # if one of the dimensions is not set, set it to the opposite of the other + # we can only do this since we asserted the parameter is 2D above + if curr_input_dim is None: + assert curr_output_dim is not None,\ + "either input or output dim must be set" + curr_input_dim = (curr_output_dim + 1) % 2 + if curr_output_dim is None: + assert curr_input_dim is not None,\ + "either input or output dim must be set" + curr_output_dim = (curr_input_dim + 1) % 2 + + # create permutation from the current layout to the layout with + # self.input_dim at input_dim and self.output_dim at output_dim preserving + # other dimensions + perm = [ + i for i in range(param.data.dim()) + if i not in [curr_input_dim, curr_output_dim] + ] + perm.insert(input_dim, curr_input_dim) + perm.insert(output_dim, curr_output_dim) + + if "packed_dim" in kwargs: + assert hasattr(param, "packed_dim") and\ + param.packed_dim == perm[kwargs["packed_dim"]],\ + "permute_param_layout_ currently doesn't support repacking" + + param.data = param.data.permute(*perm) + if hasattr(param, "_input_dim"): + param._input_dim = input_dim + if hasattr(param, "_output_dim"): + param._output_dim = output_dim + if "packed_dim" in kwargs and hasattr(param, "_packed_dim"): + param._packed_dim = kwargs["packed_dim"] + + return param + + +def _adjust_shard_indexes_for_marlin(shard_size, shard_offset, + marlin_tile_size): + return shard_size * marlin_tile_size, shard_offset * marlin_tile_size + + +def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor, + marlin_tile_size): + shard_size = shard_size // packed_factor + shard_offset = shard_offset // packed_factor + if marlin_tile_size is not None: + return _adjust_shard_indexes_for_marlin( + shard_size=shard_size, + shard_offset=shard_offset, + marlin_tile_size=marlin_tile_size) + return shard_size, shard_offset diff --git a/vllm/model_executor/pooling_metadata.py b/vllm/model_executor/pooling_metadata.py new file mode 100644 index 0000000..dea8b0e --- /dev/null +++ b/vllm/model_executor/pooling_metadata.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple + +import torch + +from vllm.pooling_params import PoolingParams +from vllm.utils import is_pin_memory_available + + +class PoolingMetadata: + """Metadata for pooling operations in the Pooler layer. + + This class holds the necessary information for pooling operations, + providing context for how to perform pooling and other related operations. + + Attributes: + seq_groups: List of (seq_ids, pooling_params). + seq_data: A mapping of sequence ID to additional sequence data. + prompt_lens: List of the lengths of each prompt. + """ + + def __init__( + self, + seq_groups: List[Tuple[List[int], PoolingParams]], + seq_data: Dict[int, Any], # Specific data related to sequences + prompt_lens: List[int], + ) -> None: + self.seq_groups = seq_groups + self.seq_data = seq_data + self.prompt_lens = prompt_lens + + def __repr__(self) -> str: + return ("PoolingMetadata(" + f"seq_groups={self.seq_groups}, " + f"seq_data={self.seq_data}, " + f"prompt_lens={self.prompt_lens})") + + +@dataclass +class PoolingTensors: + """Tensors for pooling.""" + + prompt_lens: torch.Tensor + + @classmethod + def from_pooling_metadata( + cls, + pooling_metadata: "PoolingMetadata", + device: torch.device, + ) -> "PoolingTensors": + """ + Create PoolingTensors from PoolingMetadata. + + Args: + pooling_metadata: PoolingMetadata instance to convert. + device: Device to store the tensors. + """ + # Convert prompt lengths to tensor + pin_memory = is_pin_memory_available() + + prompt_lens_t = torch.tensor( + pooling_metadata.prompt_lens, + device="cpu", + dtype=torch.long, + pin_memory=pin_memory, + ) + + return cls(prompt_lens=prompt_lens_t.to(device=device, + non_blocking=True), ) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py new file mode 100644 index 0000000..d76c75d --- /dev/null +++ b/vllm/model_executor/sampling_metadata.py @@ -0,0 +1,596 @@ +# SPDX-License-Identifier: Apache-2.0 + +from array import array +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import torch + +from vllm.sampling_params import SamplingParams, SamplingType +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData, + SequenceGroupMetadata) +from vllm.utils import (PyObjectCache, async_tensor_h2d, + is_pin_memory_available, make_tensor_with_pad) + +_SAMPLING_EPS = 1e-5 + + +@dataclass +class SequenceGroupToSample: + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ----------------------| + # |-- query_len ---| + + # Sequence ids for the sequence group in a previous step. + seq_ids: List[int] + sampling_params: SamplingParams + # seq_id -> sequence data. + seq_data: Dict[int, SequenceData] + # The length of the sequence (all tokens seen in the past + new token to + # compute attention) of the sequence group. None if it is in a decode + # stage. + seq_len: Optional[int] + # The length of new query tokens to compute in the current step. None if it + # is in a decode stage. The length of query_len <= seq_len if chunked + # prefill is enabled. + query_len: Optional[int] + # A random number generator for sampling. + generator: Optional[torch.Generator] + # True if the sequence group is in prefill stage. False if it is in a + # decode stage. + is_prompt: bool + # Query token indices from logits. to compute prompt logprob. Empty if + # prompt logprob is not required. + prompt_logprob_indices: List[int] + # Sample token indices from logits. Empty if sampling is not required. + sample_indices: List[int] + + @property + def do_sample(self): + return len(self.sample_indices) > 0 + + def __post_init__(self): + if len(self.prompt_logprob_indices) > 0: + assert self.sampling_params.prompt_logprobs is not None + if self.is_prompt: + assert self.seq_len is not None + assert self.query_len is not None + + +def gen_seq_group_to_sample_builder(num_seqs: int): + return lambda: SequenceGroupToSample( + seq_ids=[0] * num_seqs, + sampling_params=None, + seq_data=None, # type: ignore + seq_len=0, + query_len=0, + generator=None, + is_prompt=True, + prompt_logprob_indices=[], + sample_indices=[], + ) + + +class SamplingMetadataCache: + """Used to cache SamplingMetadata objects between scheduler iterations""" + + def __init__(self): + self._seq_group_to_sample_cache: Dict[int, PyObjectCache] = {} + + def get_cached_seq_group_to_sample(self, num_seqs): + if num_seqs not in self._seq_group_to_sample_cache: + self._seq_group_to_sample_cache[num_seqs] = PyObjectCache( + gen_seq_group_to_sample_builder(num_seqs)) + + obj = self._seq_group_to_sample_cache[num_seqs].get_object() + return obj + + def reset(self): + for cache in self._seq_group_to_sample_cache.values(): + cache.reset() + + +class SamplingMetadata: + """Metadata for input sequences. Used in sampler. + + The usage is as follow; + ``` + hidden_states = execute_model(...) + logits = hidden_states[sampling_metadata.selected_token_indices] + sample(logits) + + def sample(logits): + # Use categorized_sample_indices for sampling.... + ``` + + Args: + seq_groups: List of batched sequence groups. + selected_token_indices: (num_query_tokens_to_logprob). Indices to find + logits from the initial model output hidden states. + categorized_sample_indices: SamplingType -> token indices to sample. + Each token indices is 2D tensor of (num_indices, num_indices) where + the first item means the sample index within the returned logit + (before pruning padding), and the second item means the sample + index after pruning using selected_token_indices. + For example, if the returned logit is [1, 2, 3], and we select + [1, 2] for sampling, the pruned logit will be [2, 3]. In this case, + The first tuple is [1, 2] (sampled index within original logit), + and the second tuple is [0, 1] (sampled index within pruned logit). + num_prompts: Number of prompt sequence groups in seq_groups. + skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU + serialization of token outputs. + reuse_sampling_tensors: Indicates if we want to reuse sampling + tensors that are part of the sampler forward pass. Currently, + it is mainly used for multi-step decode. + + """ + + def __init__( + self, + seq_groups: List[SequenceGroupToSample], + selected_token_indices: torch.Tensor, + categorized_sample_indices: Dict[SamplingType, torch.Tensor], + num_prompts: int, + skip_sampler_cpu_output: bool = False, + reuse_sampling_tensors: bool = False, + ) -> None: + self.seq_groups = seq_groups + self.selected_token_indices = selected_token_indices + self.categorized_sample_indices = categorized_sample_indices + self.num_prompts = num_prompts + self.skip_sampler_cpu_output = skip_sampler_cpu_output + self.reuse_sampling_tensors = reuse_sampling_tensors + + @staticmethod + def prepare( + seq_group_metadata_list: List[SequenceGroupMetadata], + seq_lens: List[int], + query_lens: List[int], + device: str, + pin_memory: bool, + generators: Optional[Dict[str, torch.Generator]] = None, + cache: Optional[SamplingMetadataCache] = None, + ) -> "SamplingMetadata": + ( + seq_groups, + selected_token_indices, + categorized_sample_indices, + num_prompts, + ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens, + device, generators, cache) + selected_token_indices = async_tensor_h2d( + selected_token_indices, + dtype=torch.long, + target_device=device, + pin_memory=pin_memory, + ) + categorized_sample_indices = { + t: + async_tensor_h2d( + seq_ids, + dtype=torch.int, + target_device=device, + pin_memory=pin_memory, + ) + for t, seq_ids in categorized_sample_indices.items() + } + + sampling_metadata = SamplingMetadata( + seq_groups=seq_groups, + selected_token_indices=selected_token_indices, + categorized_sample_indices=categorized_sample_indices, + num_prompts=num_prompts, + ) + return sampling_metadata + + def __repr__(self) -> str: + return ( + "SamplingMetadata(" + f"seq_groups={self.seq_groups}, " + f"selected_token_indices={self.selected_token_indices}, " + f"categorized_sample_indices={self.categorized_sample_indices})") + + +def _prepare_seq_groups( + seq_group_metadata_list: List[SequenceGroupMetadata], + seq_lens: List[int], + query_lens: List[int], + device: str, + generators: Optional[Dict[str, torch.Generator]] = None, + cache: Optional[SamplingMetadataCache] = None, +) -> Tuple[ + List[SequenceGroupToSample], + List[int], + Dict[SamplingType, List[int]], + int, +]: + """Prepare sequence groups and indices for sampling. + + Args: + seq_group_metadata_list: A list of sequence group to batch. + seq_lens: A list of sequence lens per sequence group. + Index of prompt len should match with seq_group_metadata_list. + query_lens: A list of query lengths. Prompt lens include the length + of entire prompt tokens, and it could be shorter. + device: A device to use for random number generators, + `SequenceGroupToSample.generator`. + generators: A store of per-request random number generators used + for seeded requests. + + Returns: + seq_groups: A list of sequence group to sample. + selected_token_indices: See the definition from `SamplingMetadata`. + categorized_sample_indices: See the definition from `SamplingMetadata`. + num_prompts: Total number of prompts from `seq_group_metadata_list`. + """ + # Batched sequence groups for the current model forward stsep. + seq_groups: List[SequenceGroupToSample] = [] + # A list of token indices to sample/compute logprob. It is used to + # prune the outcome logits from the model for the performance. + selected_token_indices: List[int] = [] + # Used for selected_token_indices. + model_output_idx = 0 + + # Sampling type -> ( + # indices to sample/prompt logprob within pruned output logits, + # indices to sample within pruned logits) + categorized_sample_indices: Dict[SamplingType, List[int]] = { + t: [] + for t in SamplingType + } + # Index of logits to compute logprob. Logits include both prompt logprob + # and sample logprob indices. + logit_idx = 0 + # Total number of prompts from given sequence groups. + num_prompts = 0 + + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + seq_ids = seq_group_metadata.seq_data.keys() + + if cache is not None: + sample_obj = cache.get_cached_seq_group_to_sample(len(seq_ids)) + + for j, seq_id in enumerate(seq_ids): + sample_obj.seq_ids[j] = seq_id + + sample_obj.prompt_logprob_indices.clear() + sample_obj.sample_indices.clear() + + sampling_params = seq_group_metadata.sampling_params + is_prompt = seq_group_metadata.is_prompt + generator: Optional[torch.Generator] = None + # If the current seq group is in decode stage, it is None. + seq_len: Optional[int] = None + query_len: Optional[int] = None + prompt_logprob_indices: List[int] = (sample_obj.prompt_logprob_indices + if cache is not None else []) + sample_indices: List[int] = (sample_obj.sample_indices + if cache is not None else []) + do_sample = seq_group_metadata.do_sample + + if seq_group_metadata.is_prompt: + if sampling_params.seed is not None: + generator = torch.Generator(device=device).manual_seed( + sampling_params.seed) + if generators is not None: + generators[seq_group_metadata.request_id] = generator + + num_prompts += 1 + num_prefill_sample = len(seq_ids) + assert num_prefill_sample == 1 + assert query_lens is not None and seq_lens is not None + query_len, seq_len = query_lens[i], seq_lens[i] + # If we need sampling, exclude num_prefill_sample tokens from + # prompt logprob. + prompt_logprob_len = (query_len - num_prefill_sample + if do_sample else query_len) + sample_len = num_prefill_sample if do_sample else 0 + else: + # Decode + prompt_logprob_len = 0 + query_len = query_lens[i] if query_lens is not None and len( + query_lens) > 0 else 1 + sample_len = len(seq_ids) * query_len if do_sample else 0 + + if sampling_params.seed is not None and generators is not None: + generator = generators.get(seq_group_metadata.request_id) + + # Update indices to select from the model output. + """ + This blocks computes selected_token_indices which is used in the + following way. + + hidden_states = model(...) + logits = hidden_states[selected_token_indices] + """ + + if sampling_params.prompt_logprobs is not None: + selected_token_indices.extend( + range(model_output_idx, model_output_idx + prompt_logprob_len)) + model_output_idx += prompt_logprob_len + if do_sample: + selected_token_indices.extend( + range(model_output_idx, model_output_idx + sample_len)) + model_output_idx += sample_len + + # We now find indices for logprob computation and sampling. + """ + This block computes categorized_sample_indices which is used in the + following way. + + hidden_states = model(...) + logits = hidden_states[selected_token_indices] + def sample(logits): + # Use categorized_sample_indices for sampling. + # prompt_logprob_indices to find prompt logprob indices. + # sample_indices to find sample indices. + """ + + if sampling_params.prompt_logprobs is not None: + prompt_logprob_indices.extend( + range(logit_idx, logit_idx + prompt_logprob_len)) + logit_idx += prompt_logprob_len + if do_sample: + sample_indices.extend(range(logit_idx, logit_idx + sample_len)) + categorized_sample_indices[sampling_params.sampling_type].extend( + list(range(logit_idx, logit_idx + sample_len))) + logit_idx += sample_len + + if cache is not None: + sample_obj.sampling_params = sampling_params + sample_obj.seq_data = seq_group_metadata.seq_data + sample_obj.seq_len = seq_len + sample_obj.query_len = query_len + sample_obj.generator = generator + sample_obj.is_prompt = is_prompt + else: + sample_obj = SequenceGroupToSample( + seq_ids=list(seq_ids), + sampling_params=sampling_params, + seq_data=seq_group_metadata.seq_data, + seq_len=seq_len, + query_len=query_len, + generator=generator, + is_prompt=is_prompt, + prompt_logprob_indices=list(prompt_logprob_indices), + sample_indices=list(sample_indices), + ) + + seq_groups.append(sample_obj) + + if cache is not None: + cache.reset() + + return (seq_groups, selected_token_indices, categorized_sample_indices, + num_prompts) + + +@dataclass +class SamplingTensors: + """Tensors for sampling.""" + + temperatures: torch.Tensor + top_ps: torch.Tensor + top_ks: torch.Tensor + min_ps: torch.Tensor + presence_penalties: torch.Tensor + frequency_penalties: torch.Tensor + repetition_penalties: torch.Tensor + prompt_tokens: torch.Tensor + output_tokens: torch.Tensor + + @classmethod + def from_sampling_metadata( + cls, + sampling_metadata: "SamplingMetadata", + vocab_size: int, + device: torch.device, + dtype: torch.dtype, + ) -> Tuple["SamplingTensors", bool, bool, bool]: + prompt_tokens: List[array] = [] + output_tokens: List[array] = [] + top_ks: List[int] = [] + temperatures: List[float] = [] + top_ps: List[float] = [] + min_ps: List[float] = [] + presence_penalties: List[float] = [] + frequency_penalties: List[float] = [] + repetition_penalties: List[float] = [] + do_penalties = False + do_top_p_top_k = False + do_min_p = False + + assert sampling_metadata.seq_groups is not None + for seq_group in sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + sampling_params = seq_group.sampling_params + temperature = sampling_params.temperature + p = sampling_params.presence_penalty + f = sampling_params.frequency_penalty + r = sampling_params.repetition_penalty + top_p = sampling_params.top_p + min_p = sampling_params.min_p + + # k should not be greater than the vocab size. + top_k = min(sampling_params.top_k, vocab_size) + top_k = vocab_size if top_k == -1 else top_k + if temperature < _SAMPLING_EPS: + # NOTE: Zero temperature means deterministic sampling + # (i.e., greedy sampling or beam search). + # Set the temperature to 1 to avoid division by zero. + temperature = 1.0 + if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS + or top_k != vocab_size): + do_top_p_top_k = True + if not do_min_p and min_p > _SAMPLING_EPS: + do_min_p = True + if not do_penalties and (abs(p) >= _SAMPLING_EPS + or abs(f) >= _SAMPLING_EPS + or abs(r - 1.0) >= _SAMPLING_EPS): + do_penalties = True + + is_prompt = seq_group.is_prompt + if is_prompt and sampling_params.prompt_logprobs is not None: + # For tokens in the prompt that we only need to get + # their logprobs + query_len = seq_group.query_len + assert query_len is not None + prefill_len = len(seq_group.prompt_logprob_indices) + temperatures += [temperature] * prefill_len + top_ps += [top_p] * prefill_len + top_ks += [top_k] * prefill_len + min_ps += [min_p] * prefill_len + presence_penalties += [0] * prefill_len + frequency_penalties += [0] * prefill_len + repetition_penalties += [1] * prefill_len + + if seq_group.do_sample: + sample_lens = len(seq_group.sample_indices) + assert sample_lens >= len(seq_ids) + temperatures += [temperature] * sample_lens + top_ps += [top_p] * sample_lens + top_ks += [top_k] * sample_lens + min_ps += [min_p] * sample_lens + presence_penalties += [p] * sample_lens + frequency_penalties += [f] * sample_lens + repetition_penalties += [r] * sample_lens + + if do_penalties: + for seq_group in sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + sampling_params = seq_group.sampling_params + if (seq_group.is_prompt + and sampling_params.prompt_logprobs is not None): + prefill_len = len(seq_group.prompt_logprob_indices) + prompt_tokens.extend( + array(VLLM_TOKEN_ID_ARRAY_TYPE) + for _ in range(prefill_len)) + output_tokens.extend( + array(VLLM_TOKEN_ID_ARRAY_TYPE) + for _ in range(prefill_len)) + if seq_group.do_sample: + for seq_id in seq_ids: + seq_data = seq_group.seq_data[seq_id] + prompt_tokens.append(seq_data.prompt_token_ids_array) + output_tokens.append(seq_data.output_token_ids_array) + + sampling_tensors = SamplingTensors.from_lists( + temperatures, + top_ps, + top_ks, + min_ps, + presence_penalties, + frequency_penalties, + repetition_penalties, + prompt_tokens, + output_tokens, + vocab_size, + device, + dtype, + ) + return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) + + @classmethod + def from_lists( + cls, + temperatures: List[float], + top_ps: List[float], + top_ks: List[int], + min_ps: List[float], + presence_penalties: List[float], + frequency_penalties: List[float], + repetition_penalties: List[float], + prompt_tokens: List[array], + output_tokens: List[array], + vocab_size: int, + device: torch.device, + dtype: torch.dtype, + ) -> "SamplingTensors": + # Note that the performance will be very bad without + # pinned memory. + pin_memory = is_pin_memory_available() + + do_penalties = prompt_tokens or output_tokens + + if do_penalties: + prompt_t = make_tensor_with_pad( + prompt_tokens, + vocab_size, + device="cpu", + dtype=torch.int64, + pin_memory=pin_memory, + ) + output_t = make_tensor_with_pad( + output_tokens, + vocab_size, + device="cpu", + dtype=torch.int64, + pin_memory=pin_memory, + ) + else: + empty_tensor = torch.empty(0, device=device, dtype=torch.long) + prompt_t = empty_tensor + output_t = empty_tensor + + temperatures_t = torch.tensor( + temperatures, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + top_ps_t = torch.tensor( + top_ps, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + min_ps_t = torch.tensor( + min_ps, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + presence_penalties_t = torch.tensor( + presence_penalties, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + frequency_penalties_t = torch.tensor( + frequency_penalties, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + repetition_penalties_t = torch.tensor( + repetition_penalties, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + top_ks_t = torch.tensor( + top_ks, + device="cpu", + dtype=torch.int, + pin_memory=pin_memory, + ) + # Because the memory is pinned, we can do non-blocking + # transfer to device. + + return cls( + temperatures=temperatures_t.to(device=device, non_blocking=True), + top_ps=top_ps_t.to(device=device, non_blocking=True), + top_ks=top_ks_t.to(device=device, non_blocking=True), + min_ps=min_ps_t.to(device=device, non_blocking=True), + presence_penalties=presence_penalties_t.to(device=device, + non_blocking=True), + frequency_penalties=frequency_penalties_t.to(device=device, + non_blocking=True), + repetition_penalties=repetition_penalties_t.to(device=device, + non_blocking=True), + prompt_tokens=prompt_t.to(device=device, non_blocking=True), + output_tokens=output_t.to(device=device, non_blocking=True), + ) diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py new file mode 100644 index 0000000..04f922d --- /dev/null +++ b/vllm/model_executor/utils.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Utils for model executor.""" +from typing import Any, Dict, Optional + +import torch + + +def set_random_seed(seed: int) -> None: + from vllm.platforms import current_platform + current_platform.seed_everything(seed) + + +def set_weight_attrs( + weight: torch.Tensor, + weight_attrs: Optional[Dict[str, Any]], +): + """Set attributes on a weight tensor. + + This method is used to set attributes on a weight tensor. This method + will not overwrite existing attributes. + + Args: + weight: The weight tensor. + weight_attrs: A dictionary of attributes to set on the weight tensor. + """ + if weight_attrs is None: + return + for key, value in weight_attrs.items(): + assert not hasattr( + weight, key), (f"Overwriting existing tensor attribute: {key}") + + # NOTE(woosuk): During weight loading, we often do something like: + # narrowed_tensor = param.data.narrow(0, offset, len) + # narrowed_tensor.copy_(real_weight) + # expecting narrowed_tensor and param.data to share the same storage. + # However, on TPUs, narrowed_tensor will lazily propagate to the base + # tensor, which is param.data, leading to the redundant memory usage. + # This sometimes causes OOM errors during model loading. To avoid this, + # we sync the param tensor after its weight loader is called. + # TODO(woosuk): Remove this hack once we have a better solution. + from vllm.platforms import current_platform + if current_platform.is_tpu() and key == "weight_loader": + value = _make_synced_weight_loader(value) + setattr(weight, key, value) + + +def _make_synced_weight_loader(original_weight_loader): + + def _synced_weight_loader(param, *args, **kwargs): + original_weight_loader(param, *args, **kwargs) + torch._sync(param) + + return _synced_weight_loader diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py new file mode 100644 index 0000000..741bd1a --- /dev/null +++ b/vllm/multimodal/__init__.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 + +from .base import MultiModalPlaceholderMap, MultiModalPlugin +from .hasher import MultiModalHashDict, MultiModalHasher +from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins, + MultiModalDataDict, MultiModalKwargs, + MultiModalPlaceholderDict, NestedTensors) +from .registry import MultiModalRegistry + +MULTIMODAL_REGISTRY = MultiModalRegistry() +""" +The global :class:`~MultiModalRegistry` is used by model runners to +dispatch data processing according to the target model. + +See also: + :ref:`mm-processing` +""" + +__all__ = [ + "BatchedTensorInputs", + "ModalityData", + "MultiModalDataBuiltins", + "MultiModalDataDict", + "MultiModalHashDict", + "MultiModalHasher", + "MultiModalKwargs", + "MultiModalPlaceholderDict", + "MultiModalPlaceholderMap", + "MultiModalPlugin", + "NestedTensors", + "MULTIMODAL_REGISTRY", + "MultiModalRegistry", +] diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py new file mode 100644 index 0000000..f379ec1 --- /dev/null +++ b/vllm/multimodal/audio.py @@ -0,0 +1,77 @@ +# SPDX-License-Identifier: Apache-2.0 + +import base64 +from io import BytesIO +from pathlib import Path + +import numpy as np +import numpy.typing as npt + +from vllm.inputs.registry import InputContext +from vllm.utils import PlaceholderModule + +from .base import MediaIO, MultiModalPlugin +from .inputs import AudioItem, ModalityData, MultiModalKwargs + +try: + import librosa +except ImportError: + librosa = PlaceholderModule("librosa") # type: ignore[assignment] + +try: + import soundfile +except ImportError: + soundfile = PlaceholderModule("soundfile") # type: ignore[assignment] + + +class AudioPlugin(MultiModalPlugin): + """Plugin for audio data.""" + + def get_data_key(self) -> str: + return "audio" + + def _default_input_mapper( + self, + ctx: InputContext, + data: ModalityData[AudioItem], + **mm_processor_kwargs, + ) -> MultiModalKwargs: + raise NotImplementedError("There is no default audio input mapper") + + def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: + raise NotImplementedError( + "There is no default maximum multimodal tokens") + + +def resample_audio( + audio: npt.NDArray[np.floating], + *, + orig_sr: float, + target_sr: float, +) -> npt.NDArray[np.floating]: + return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr) + + +class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]): + + def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]: + return librosa.load(BytesIO(data), sr=None) + + def load_base64( + self, + media_type: str, + data: str, + ) -> tuple[npt.NDArray, float]: + return self.load_bytes(base64.b64decode(data)) + + def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]: + return librosa.load(filepath, sr=None) + + def encode_base64(self, media: tuple[npt.NDArray, float]) -> str: + audio, sr = media + + with BytesIO() as buffer: + soundfile.write(buffer, audio, sr, format="WAV") + data = buffer.getvalue() + + return base64.b64encode(data).decode('utf-8') diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py new file mode 100644 index 0000000..5159b0b --- /dev/null +++ b/vllm/multimodal/base.py @@ -0,0 +1,468 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from collections import defaultdict +from collections.abc import Sequence +from pathlib import Path +from typing import (TYPE_CHECKING, Any, Callable, Generic, NamedTuple, + Optional, TypeVar, Union) + +from torch import nn + +from vllm.inputs import InputContext +from vllm.logger import init_logger +from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides, + resolve_mm_processor_kwargs) + +if TYPE_CHECKING: + from vllm.config import ModelConfig + from vllm.sequence import SequenceGroupMetadata + +from .inputs import (ModalityData, MultiModalDataDict, MultiModalKwargs, + PlaceholderRange) + +logger = init_logger(__name__) + +MultiModalInputMapper = Callable[[InputContext, ModalityData[object]], + MultiModalKwargs] +""" +Return a dictionary to be passed as keyword arguments to +:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers +and processors in HuggingFace Transformers. + +If the data is not supported, throw :exc:`TypeError`. +""" + +MultiModalTokensCalc = Union[int, Callable[[InputContext], int]] +""" +Calculate the maximum number of multimodal tokens input to the language +model. This does not include tokens that correspond to the input text. +""" + +_T = TypeVar("_T") +N = TypeVar("N", bound=type[nn.Module]) + + +class MultiModalPlugin(ABC): + """ + Base class that defines data processing logic for a specific modality. + + In particular, we adopt a registry pattern to dispatch data processing + according to the model being used (considering that different models may + process the same data differently). This registry is in turn used by + :class:`~MultiModalRegistry` which acts at a higher level + (i.e., the modality of the data). + """ + + def __init__(self) -> None: + self._input_mappers = ClassRegistry[nn.Module, MultiModalInputMapper]() + self._max_mm_tokens = ClassRegistry[nn.Module, MultiModalTokensCalc]() + + @abstractmethod + def get_data_key(self) -> str: + """ + Get the data key corresponding to the modality. + """ + raise NotImplementedError + + @abstractmethod + def _default_input_mapper( + self, + ctx: InputContext, + data: ModalityData[Any], + **mm_processor_kwargs, + ) -> MultiModalKwargs: + """ + Return a dictionary to be passed as keyword arguments to + :meth:`~torch.nn.Module.forward`. This is similar in concept to + tokenizers and processors in HuggingFace Transformers. + + If the data is not supported, throw :exc:`TypeError`. + """ + raise NotImplementedError + + def register_input_mapper( + self, + mapper: Optional[MultiModalInputMapper] = None, + ): + """ + Register an input mapper to a model class. + + When the model receives input data that matches the modality served by + this plugin (see :meth:`get_data_key`), the provided function is + invoked to transform the data into a dictionary of model inputs. + + If `None` is provided, then the default input mapper is used instead. + """ + + def wrapper(model_cls: N) -> N: + if self._input_mappers.contains(model_cls, strict=True): + logger.warning( + "Model class %s already has an input mapper " + "registered to %s. It is overwritten by the new one.", + model_cls, + self, + ) + + self._input_mappers[model_cls] = (mapper + or self._default_input_mapper) + + return model_cls + + return wrapper + + def map_input( + self, + model_config: "ModelConfig", + data: ModalityData[Any], + mm_processor_kwargs: Optional[dict[str, Any]], + ) -> MultiModalKwargs: + """ + Transform the data into a dictionary of model inputs using the + input mapper registered for that model. + + The model is identified by ``model_config``. + + Raises: + TypeError: If the data type is not supported. + """ + + # Avoid circular import + from vllm.model_executor.model_loader import get_model_architecture + + model_cls, _ = get_model_architecture(model_config) + + mapper = self._input_mappers.get(model_cls) + + if mapper is None: + raise KeyError(f"No input mapper in {self} is registered for " + f"model class {model_cls.__name__}.") + + if mm_processor_kwargs is None: + mm_processor_kwargs = {} + + # In the case of the default mapper, we have to get resource + # processor through its HuggingFace autoclass; since this goes + # through **kwargs, we can't inspect it the same way, so we allow + # drop mm_processor_kwargs based on signature inspection + # if we're using the default mapper. + # + # This should be safe in general due to the sanitation, since the + # transformers resource should filter unused kwargs anyway. + uses_default_mapper = mapper == self._default_input_mapper + mm_processor_kwargs = resolve_mm_processor_kwargs( + model_config.mm_processor_kwargs, + mm_processor_kwargs, + callable=mapper, + allow_var_kwargs=uses_default_mapper, + ) + return mapper(InputContext(model_config), data, **mm_processor_kwargs) + + @abstractmethod + def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: + """ + Calculate the maximum number of tokens, corresponding to a single + instance of multimodal data, that are passed to the language model. + """ + raise NotImplementedError + + def _validate_max_multimodal_tokens(self, max_mm_tokens: int): + if max_mm_tokens < 1: + raise ValueError("You should set the number of tokens to a " + f"positive integer. Found: {max_mm_tokens}") + + def register_max_multimodal_tokens( + self, + max_mm_tokens: Optional[MultiModalTokensCalc] = None, + ): + """ + Register the maximum number of tokens, corresponding to a single + instance of multimodal data, that are passed to the language model + for a model class. + + If `None` is provided, then the default calculation is used instead. + """ + + def wrapper(model_cls: N) -> N: + if self._max_mm_tokens.contains(model_cls, strict=True): + logger.warning( + "Model class %s already calculates maximum number of " + "tokens in %s. It is overwritten by the new one.", + model_cls, + self, + ) + + if isinstance(max_mm_tokens, int): + self._validate_max_multimodal_tokens(max_mm_tokens) + + self._max_mm_tokens[model_cls] = ( + max_mm_tokens or self._default_max_multimodal_tokens) + + return model_cls + + return wrapper + + def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: + """ + Get the maximum number of multi-modal tokens + for profiling the memory usage of a model. + + If this registry is not applicable to the model, `0` is returned. + + The model is identified by ``model_config``. + """ + # Avoid circular import + from vllm.model_executor.model_loader import get_model_architecture + from vllm.model_executor.models import supports_multimodal + + model_cls, _ = get_model_architecture(model_config) + + if not supports_multimodal(model_cls): + return 0 + + max_mm_tokens = self._max_mm_tokens.get(model_cls) + if max_mm_tokens is None: + return 0 + + if callable(max_mm_tokens): + mm_processor_kwargs = get_allowed_kwarg_only_overrides( + max_mm_tokens, + overrides=model_config.mm_processor_kwargs, + requires_kw_only=False, + allow_var_kwargs=True, + ) + max_mm_tokens = max_mm_tokens(InputContext(model_config), + **mm_processor_kwargs) + + self._validate_max_multimodal_tokens(max_mm_tokens) + + return max_mm_tokens + + +class MultiModalPlaceholderMap: + """ + Relates multi-modal embeddings to their corresponding placeholders. + """ + + class IndexMap(NamedTuple): + src: list[int] + dest: list[int] + + src_ranges: list[range] + """ + The indices of the multi-modal embeddings that will replace the + corresponding placeholder embeddings pointed to by ``dest_ranges``. + """ + + src_len: int + """ + The total number of flattened multi-modal embeddings. + """ + + dest_ranges: list[range] + """ + The indices of the placeholder embeddings that will be replaced by the + multimodal embeddings. + """ + + dest_len: int + """ + The total number of embeddings in the destination tensor. + """ + + def __init__(self): + self.src_ranges = [] + self.src_len = 0 + self.dest_ranges = [] + self.dest_len = 0 + + @classmethod + def from_seq_group( + cls, seq_group: "SequenceGroupMetadata", positions: range + ) -> tuple[Optional[MultiModalDataDict], dict[str, + "MultiModalPlaceholderMap"]]: + """ + Returns the multi-modal items that intersect with the portion of a + prompt (``seq_group``) represented by ``positions``, as well as a + ``MultiModalPlaceholderMap`` that relates the multi-modal embedding + vectors to their corresponding placeholders. + + Examples: + + .. code-block:: + + Prompt: |AAAA BBBB What's in these images?| + Positions: |.................................| + + images = [A, B] + src_ranges = [(0, 4), (4, 8)] + dest_ranges = [(0, 4), (5, 9)] + + Prompt: |AAAA BBBB What's in these images?| + Positions: | ..... | + + images = [A, B] + src_ranges = [(2, 4), (4, 6)] + dest_ranges = [(0, 2), (3, 5)] + + Prompt: |AAAA BBBB What's in these images?| + Positions: | ......... | + + images = [B] + src_ranges = [(0, 4)] + dest_ranges = [(0, 4)] + + Prompt: |AAAA BBBB What's in these images?| + Positions: | .......................| + + images = [] + src_ranges = [] + dest_ranges = [] + """ + seq_mm_data = seq_group.multi_modal_data + seq_mm_placeholders = seq_group.multi_modal_placeholders + + if not seq_mm_data or not seq_mm_placeholders: + return seq_mm_data, {} + + # For merged processor, we directly use mm_kwargs as mm_data + if isinstance(seq_mm_data, MultiModalKwargs): + placeholder_maps = dict[str, MultiModalPlaceholderMap]() + + for modality, placeholders in seq_mm_placeholders.items(): + placeholder_map = MultiModalPlaceholderMap() + + if positions: + placeholder_map.append_items_from_seq_group( + positions, + # Dummy, since we don't care about intersecting items + [None] * len(placeholders), + placeholders, + ) + + placeholder_maps[modality] = placeholder_map + + return seq_mm_data, placeholder_maps + + mm_data = {**seq_mm_data} + placeholder_maps = defaultdict[str, MultiModalPlaceholderMap]( + MultiModalPlaceholderMap) + + for modality, placeholders in seq_mm_placeholders.items(): + mm_items = mm_data.pop(modality) + if not isinstance(mm_items, list): + mm_items = [mm_items] + + if positions: + intersecting_items = placeholder_maps[modality] \ + .append_items_from_seq_group( + positions, + mm_items, + placeholders, + ) + + if intersecting_items: + mm_data[modality] = intersecting_items + + return mm_data, placeholder_maps + + def append_items_from_seq_group( + self, + positions: range, + multi_modal_items: list[_T], + multi_modal_placeholders: Sequence[PlaceholderRange], + ) -> list[_T]: + """ + Adds the multi-modal items that intersect ```positions`` to this + placeholder map and returns the intersecting items. + """ + intersecting_items = [] + + if len(multi_modal_items) != len(multi_modal_placeholders): + raise ValueError( + "Multi-modal placeholders and items must have the same length." + ) + for placeholder_dict, mm_item in zip(multi_modal_placeholders, + multi_modal_items): + placeholder = range( + placeholder_dict["offset"], + placeholder_dict["offset"] + placeholder_dict["length"], + ) + intersection = range( + max(positions.start, placeholder.start), + min(positions.stop, placeholder.stop), + ) + + if not intersection: + # Skip this multi-modal item. + continue + + token_embedding_range = range( + intersection.start - positions.start, + intersection.stop - positions.start, + ) + + multimodal_embedding_range = range( + intersection.start - placeholder.start + self.src_len, + intersection.stop - placeholder.start + self.src_len, + ) + + intersecting_items.append(mm_item) + self.dest_ranges.append(token_embedding_range) + self.src_ranges.append(multimodal_embedding_range) + self.src_len += len(placeholder) + + self.dest_len += len(positions) + return intersecting_items + + def extend(self, other: "MultiModalPlaceholderMap"): + """ + Adds the placeholders from another ``MultiModalPlaceholderMap`` to this + instance based on the source and destination tensors being + concatenated. + """ + + self.src_ranges.extend( + range(self.src_len + r.start, self.src_len + r.stop) + for r in other.src_ranges) + self.src_len += other.src_len + self.dest_ranges.extend( + range(self.dest_len + r.start, self.dest_len + r.stop) + for r in other.dest_ranges) + self.dest_len += other.dest_len + + def index_map(self) -> "IndexMap": + """ + Finalizes the placeholder map into lists of indices that can be used to + index the source and destination tensors. + """ + + src_indices = [i for r in self.src_ranges for i in r] + dest_indices = [i for r in self.dest_ranges for i in r] + + if len(src_indices) != len(dest_indices): + raise ValueError( + f"The number of source ({len(src_indices)}) and destination " + f"indices ({len(dest_indices)}) must be the same.") + + return MultiModalPlaceholderMap.IndexMap(src=src_indices, + dest=dest_indices) + + +class MediaIO(ABC, Generic[_T]): + + @abstractmethod + def load_bytes(self, data: bytes) -> _T: + raise NotImplementedError + + @abstractmethod + def load_base64(self, media_type: str, data: str) -> _T: + """ + List of media types: + https://www.iana.org/assignments/media-types/media-types.xhtml + """ + raise NotImplementedError + + @abstractmethod + def load_file(self, filepath: Path) -> _T: + raise NotImplementedError diff --git a/vllm/multimodal/hasher.py b/vllm/multimodal/hasher.py new file mode 100644 index 0000000..11665ef --- /dev/null +++ b/vllm/multimodal/hasher.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pickle +from collections.abc import Iterable, Mapping +from typing import TYPE_CHECKING, Optional + +import numpy as np +import torch +from blake3 import blake3 +from PIL import Image + +from vllm.logger import init_logger + +if TYPE_CHECKING: + from vllm.inputs import TokensPrompt + +logger = init_logger(__name__) + +MultiModalHashDict = Mapping[str, list[str]] +""" +A dictionary containing hashes for items in each modality. +""" + + +class MultiModalHasher: + + @classmethod + def serialize_item(cls, obj: object) -> bytes: + # Simple cases + if isinstance(obj, str): + return obj.encode("utf-8") + if isinstance(obj, bytes): + return obj + if isinstance(obj, Image.Image): + return obj.tobytes() + + # Convertible to NumPy arrays + if isinstance(obj, torch.Tensor): + obj = obj.numpy() + if isinstance(obj, (int, float)): + obj = np.array(obj) + if isinstance(obj, np.ndarray): + return obj.tobytes() + + logger.warning( + "No serialization method found for %s. " + "Falling back to pickle.", type(obj)) + + return pickle.dumps(obj) + + @classmethod + def item_to_bytes( + cls, + key: str, + obj: object, + ) -> Iterable[tuple[bytes, bytes]]: + # Recursive cases + if isinstance(obj, (list, tuple)): + for i, elem in enumerate(obj): + yield from cls.item_to_bytes(f"{key}.{i}", elem) + elif isinstance(obj, dict): + for k, v in obj.items(): + yield from cls.item_to_bytes(f"{key}.{k}", v) + else: + key_bytes = cls.serialize_item(key) + value_bytes = cls.serialize_item(obj) + yield key_bytes, value_bytes + + @classmethod + def hash_kwargs(cls, **kwargs: object) -> str: + hasher = blake3() + + for k, v in kwargs.items(): + for k_bytes, v_bytes in cls.item_to_bytes(k, v): + hasher.update(k_bytes) + hasher.update(v_bytes) + + return hasher.hexdigest() + + @classmethod + def hash_prompt_mm_data( + cls, prompt: "TokensPrompt") -> Optional["MultiModalHashDict"]: + """Hash multimodal data in the user input prompt if they exist.""" + + if "multi_modal_data" not in prompt: + return None + + mm_data = prompt["multi_modal_data"] + if not mm_data: + # mm_data can be None or an empty dict. + return None + + mm_items = { + modality: items if isinstance(items, list) else [items] + for modality, items in mm_data.items() + } + + mm_hashes = { + modality: [cls.hash_kwargs(**{modality: item}) for item in items] + for modality, items in mm_items.items() + } + + return mm_hashes diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py new file mode 100644 index 0000000..0c5a84c --- /dev/null +++ b/vllm/multimodal/image.py @@ -0,0 +1,155 @@ +# SPDX-License-Identifier: Apache-2.0 + +import base64 +from io import BytesIO +from pathlib import Path +from typing import TYPE_CHECKING, Any, Optional + +import torch +from PIL import Image + +from vllm.inputs.registry import InputContext +from vllm.logger import init_logger +from vllm.transformers_utils.processor import cached_get_image_processor +from vllm.utils import is_list_of + +from .base import MediaIO, MultiModalPlugin +from .inputs import ImageItem, ModalityData, MultiModalKwargs + +if TYPE_CHECKING: + from vllm.config import ModelConfig + +logger = init_logger(__name__) + + +class ImagePlugin(MultiModalPlugin): + """Plugin for image data.""" + + def get_data_key(self) -> str: + return "image" + + def _get_hf_image_processor( + self, + model_config: "ModelConfig", + mm_processor_kwargs: Optional[dict[str, Any]] = None, + ): + if mm_processor_kwargs is None: + mm_processor_kwargs = {} + return cached_get_image_processor( + model_config.model, + trust_remote_code=model_config.trust_remote_code, + **mm_processor_kwargs) + + def _default_input_mapper( + self, + ctx: InputContext, + data: ModalityData[ImageItem], + **mm_processor_kwargs, + ) -> MultiModalKwargs: + model_config = ctx.model_config + + # PIL image + if isinstance(data, Image.Image) or is_list_of(data, Image.Image): + image_processor = self._get_hf_image_processor( + model_config, + mm_processor_kwargs, + ) + + if image_processor is None: + raise RuntimeError("No HuggingFace processor is available " + "to process the image object") + try: + # NOTE: It may make sense to forward the mm_processor_kwargs + # here too. For now, to keep it simple, we only allow it be + # used for the initialization call though, just in case the + # signatures of the preprocessor initializer don't match + # preprocess() + batch_data = image_processor \ + .preprocess(data, return_tensors="pt") \ + .data + except Exception: + logger.error( + "Failed to process image (%s) with the default mapper. " + "This is most likely an edge-case with this model's image " + "processor in transformers (type: %s), and not vLLM.", + data, + type(image_processor).__name__) + raise + + return MultiModalKwargs(batch_data) + + # Image embedding + elif isinstance(data, torch.Tensor) or is_list_of(data, torch.Tensor): + return MultiModalKwargs({"image_embeds": data}) + + raise TypeError(f"Invalid image type: {type(data)}") + + def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: + return 3000 + + +def rescale_image_size(image: Image.Image, + size_factor: float, + transpose: int = -1) -> Image.Image: + """Rescale the dimensions of an image by a constant factor.""" + new_width = int(image.width * size_factor) + new_height = int(image.height * size_factor) + image = image.resize((new_width, new_height)) + if transpose >= 0: + image = image.transpose(Image.Transpose(transpose)) + return image + + +class ImageMediaIO(MediaIO[Image.Image]): + + def __init__(self, *, image_mode: str = "RGB") -> None: + super().__init__() + + self.image_mode = image_mode + + def load_bytes(self, data: bytes) -> Image.Image: + image = Image.open(BytesIO(data)) + image.load() + return image.convert(self.image_mode) + + def load_base64(self, media_type: str, data: str) -> Image.Image: + return self.load_bytes(base64.b64decode(data)) + + def load_file(self, filepath: Path) -> Image.Image: + image = Image.open(filepath) + image.load() + return image.convert(self.image_mode) + + def encode_base64( + self, + media: Image.Image, + *, + image_format: str = "JPEG", + ) -> str: + image = media + + with BytesIO() as buffer: + image = image.convert(self.image_mode) + image.save(buffer, image_format) + data = buffer.getvalue() + + return base64.b64encode(data).decode('utf-8') + + +class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]): + + def __init__(self) -> None: + super().__init__() + + def load_bytes(self, data: bytes) -> torch.Tensor: + buffer = BytesIO(data) + return torch.load(buffer, weights_only=True) + + def load_base64(self, media_type: str, data: str) -> torch.Tensor: + return self.load_bytes(base64.b64decode(data)) + + def load_file(self, filepath: Path) -> torch.Tensor: + return torch.load(filepath, weights_only=True) + + def encode_base64(self, media: torch.Tensor) -> str: + return base64.b64encode(media.numpy()).decode('utf-8') diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py new file mode 100644 index 0000000..81d72ff --- /dev/null +++ b/vllm/multimodal/inputs.py @@ -0,0 +1,769 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from collections import UserDict, defaultdict +from collections.abc import Mapping, Sequence +from dataclasses import dataclass +from functools import partial +from itertools import accumulate +from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar, + Union, cast, final) + +import numpy as np +import torch +import torch.types +from PIL.Image import Image +from transformers import BatchFeature +from typing_extensions import NotRequired, TypeAlias + +from vllm.jsontree import JSONTree, json_map_leaves +from vllm.utils import full_groupby, is_list_of + +if TYPE_CHECKING: + from .hasher import MultiModalHashDict + +_T = TypeVar("_T") + +HfImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor] +""" +A :class:`transformers.image_utils.ImageInput` representing a single image +item, which can be passed to a HuggingFace :code:`ImageProcessor`. +""" + +HfVideoItem: TypeAlias = Union[list[Image], np.ndarray, torch.Tensor, + list[np.ndarray], list[torch.Tensor]] +""" +A :class:`transformers.image_utils.VideoInput` representing a single video +item, which can be passed to a HuggingFace :code:`VideoProcessor`. +""" + +HfAudioItem: TypeAlias = Union[list[float], np.ndarray, torch.Tensor] +""" +Represents a single audio +item, which can be passed to a HuggingFace :code:`AudioProcessor`. +""" + +ImageItem: TypeAlias = Union[HfImageItem, torch.Tensor] +""" +A :class:`transformers.image_utils.ImageInput` representing a single image +item, which can be passed to a HuggingFace :code:`ImageProcessor`. + +Alternatively, a 3-D tensor or batch of 2-D tensors, +which are treated as image embeddings; +these are directly passed to the model without HF processing. +""" + +VideoItem: TypeAlias = Union[HfVideoItem, torch.Tensor] +""" +A :class:`transformers.image_utils.VideoInput` representing a single video +item, which can be passed to a HuggingFace :code:`VideoProcessor`. + +Alternatively, a 3-D tensor or batch of 2-D tensors, +which are treated as video embeddings; +these are directly passed to the model without HF processing. +""" + +AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float], + torch.Tensor] +""" +Represents a single audio +item, which can be passed to a HuggingFace :code:`AudioProcessor`. + +Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate +is different from that expected by the model; +these are resampled to the model's sampling rate before being processed by HF. + +Alternatively, a 3-D tensor or batch of 2-D tensors, +which are treated as audio embeddings; +these are directly passed to the model without HF processing. +""" + +ModalityData: TypeAlias = Union[_T, list[_T]] +""" +Either a single data item, or a list of data items. + +The number of data items allowed per modality is restricted by +:code:`--limit-mm-per-prompt`. +""" + + +@final +class MultiModalDataBuiltins(TypedDict, total=False): + """Type annotations for modality types predefined by vLLM.""" + + image: ModalityData[ImageItem] + """The input image(s).""" + + video: ModalityData[VideoItem] + """The input video(s).""" + + audio: ModalityData[AudioItem] + """The input audio(s).""" + + +MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]] +""" +A dictionary containing an entry for each modality type to input. + +The built-in modalities are defined by :class:`MultiModalDataBuiltins`. +""" + + +class PlaceholderRange(TypedDict): + """ + Placeholder location information for multi-modal data. + + Example: + + Prompt: :code:`AAAA BBBB What is in these images?` + + Images A and B will have: + + .. code-block:: + + A: { "offset": 0, "length": 4 } + B: { "offset": 5, "length": 4 } + """ + + offset: int + """The start index of the placeholder in the prompt.""" + + length: int + """The length of the placeholder.""" + + +NestedTensors = Union[list["NestedTensors"], list[torch.Tensor], torch.Tensor, + tuple[torch.Tensor, ...]] +""" +Uses a list instead of a tensor if the dimensions of each element do not match. +""" + + +def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool: + """Equality check between :data:`NestedTensors` objects.""" + if isinstance(a, torch.Tensor): + return isinstance(b, torch.Tensor) and torch.equal(a, b) + elif isinstance(b, torch.Tensor): + return isinstance(a, torch.Tensor) and torch.equal(b, a) + + if isinstance(a, list): + return (isinstance(b, list) + and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b))) + if isinstance(b, list): + return (isinstance(a, list) + and all(nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a))) + + # Both a and b are scalars + return a == b + + +BatchedTensorInputs: TypeAlias = Mapping[str, NestedTensors] +""" +A dictionary containing nested tensors which have been batched via +:meth:`MultiModalKwargs.batch`. +""" + + +@dataclass(frozen=True) +class MultiModalFieldElem: + """ + Represents a keyword argument corresponding to a multi-modal item + in :class:`MultiModalKwargs`. + """ + + modality: str + """ + The modality of the corresponding multi-modal item. + Each multi-modal item can consist of multiple keyword arguments. + """ + + key: str + """ + The key of this field in :class:`MultiModalKwargs`, + i.e. the name of the keyword argument to be passed to the model. + """ + + data: NestedTensors + """ + The tensor data of this field in :class:`MultiModalKwargs`, + i.e. the value of the keyword argument to be passed to the model. + """ + + field: "BaseMultiModalField" + """ + Defines how to combine the tensor data of this field with others + in order to batch multi-modal items together for model inference. + """ + + def __eq__(self, other: object) -> bool: + if not isinstance(other, self.__class__): + return False + + return ((self.modality, self.key) == (other.modality, other.key) + and nested_tensors_equal(self.data, other.data) + and type(self.field) == type(other.field)) # noqa: E721 + + +@dataclass(frozen=True) +class BaseMultiModalField(ABC): + """ + Defines how to interpret tensor data belonging to a keyword argument in + :class:`MultiModalKwargs` for multiple multi-modal items, and vice versa. + """ + + def _field_factory(self, *, modality: str, key: str): + f = partial( + MultiModalFieldElem, + modality=modality, + key=key, + field=self, + ) + + # Allow passing data as positional argument + def factory(data: NestedTensors) -> MultiModalFieldElem: + return f(data=data) + + return factory + + @abstractmethod + def build_elems( + self, + modality: str, + key: str, + data: NestedTensors, + ) -> Sequence[MultiModalFieldElem]: + """ + Construct :class:`MultiModalFieldElem` instances to represent + the provided data. + + This is the inverse of :meth:`reduce_data`. + """ + raise NotImplementedError + + @abstractmethod + def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: + raise NotImplementedError + + def reduce_data(self, elems: list[MultiModalFieldElem]) -> NestedTensors: + """ + Merge the data from multiple instances of :class:`MultiModalFieldElem`. + + This is the inverse of :meth:`build_elems`. + """ + field_types = [type(item.field) for item in elems] + if len(set(field_types)) > 1: + raise ValueError(f"Cannot merge different {field_types=}") + + return self._reduce_data([item.data for item in elems]) + + +@dataclass(frozen=True) +class MultiModalBatchedField(BaseMultiModalField): + """ + See also: + :func:`MultiModalFieldConfig.batched` + """ + + def build_elems( + self, + modality: str, + key: str, + data: NestedTensors, + ) -> Sequence[MultiModalFieldElem]: + field_factory = self._field_factory(modality=modality, key=key) + return [field_factory(item) for item in data] + + def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: + if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): + if len(batch) == 1: + # An optimization when `batch` contains only one tensor: + # - produce exactly same result as `torch.stack(batch)` + # - will achieve zero-copy if the tensor is contiguous + return batch[0].unsqueeze(0).contiguous() + first_shape = batch[0].shape + if all(elem.shape == first_shape for elem in batch): + return torch.stack(batch) + + return batch + + +@dataclass(frozen=True) +class MultiModalFlatField(BaseMultiModalField): + """ + See also: + :func:`MultiModalFieldConfig.flat` + :func:`MultiModalFieldConfig.flat_from_sizes` + """ + slices: Sequence[slice] + + def build_elems( + self, + modality: str, + key: str, + data: NestedTensors, + ) -> Sequence[MultiModalFieldElem]: + field_factory = self._field_factory(modality=modality, key=key) + return [field_factory(data[s]) for s in self.slices] + + def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: + if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): + if len(batch) == 1: + # An optimization when `batch` contains only one tensor: + # - produce exactly same result as `torch.concat(batch)` + # - will achieve zero-copy if the tensor is contiguous + return batch[0].contiguous() + first_shape = batch[0].shape + if all(elem.shape[1:] == first_shape[1:] for elem in batch): + return torch.concat(batch) + + return [e for elem in batch for e in elem] + + +@dataclass(frozen=True) +class MultiModalSharedField(BaseMultiModalField): + """ + See also: + :func:`MultiModalFieldConfig.shared` + """ + batch_size: int + + def build_elems( + self, + modality: str, + key: str, + data: NestedTensors, + ) -> Sequence[MultiModalFieldElem]: + field_factory = self._field_factory(modality=modality, key=key) + return [field_factory(data)] * self.batch_size + + def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: + return batch[0] + + +class MultiModalFieldConfig: + + @staticmethod + def batched(modality: str): + """ + Defines a field where an element in the batch is obtained by + indexing into the first dimension of the underlying data. + + Args: + modality: The modality of the multi-modal item that uses this + keyword argument. + + Example: + + .. code-block:: + + Input: + Data: [[AAAA] + [BBBB] + [CCCC]] + + Output: + Element 1: [AAAA] + Element 2: [BBBB] + Element 3: [CCCC] + """ + return MultiModalFieldConfig( + field=MultiModalBatchedField(), + modality=modality, + ) + + @staticmethod + def flat(modality: str, slices: Sequence[slice]): + """ + Defines a field where an element in the batch is obtained by + slicing along the first dimension of the underlying data. + + Args: + modality: The modality of the multi-modal item that uses this + keyword argument. + slices: For each multi-modal item, a slice that is used to extract + the data corresponding to it. + + Example: + + .. code-block:: + + Given: + slices: [slice(0, 3), slice(3, 7), slice(7, 9)] + + Input: + Data: [AAABBBBCC] + + Output: + Element 1: [AAA] + Element 2: [BBBB] + Element 3: [CC] + """ + return MultiModalFieldConfig( + field=MultiModalFlatField(slices=slices), + modality=modality, + ) + + @staticmethod + def flat_from_sizes(modality: str, size_per_item: torch.Tensor): + """ + Defines a field where an element in the batch is obtained by + slicing along the first dimension of the underlying data. + + Args: + modality: The modality of the multi-modal item that uses this + keyword argument. + slices: For each multi-modal item, the size of the slice that + is used to extract the data corresponding to it. + + Example: + + .. code-block:: + + Given: + size_per_item: [3, 4, 2] + + Input: + Data: [AAABBBBCC] + + Output: + Element 1: [AAA] + Element 2: [BBBB] + Element 3: [CC] + + See also: + :func:`MultiModalFieldConfig.flat` + """ + + if size_per_item.ndim != 1: + raise ValueError("size_per_item should be a 1-D tensor, " + f"but found shape: {size_per_item.shape}") + + slice_idxs = [0, *accumulate(size_per_item)] + slices = [ + slice(slice_idxs[i], slice_idxs[i + 1]) + for i in range(len(size_per_item)) + ] + + return MultiModalFieldConfig.flat(modality, slices) + + @staticmethod + def shared(modality: str, batch_size: int): + """ + Defines a field where an element in the batch is obtained by + taking the entirety of the underlying data. + + This means that the data is the same for each element in the batch. + + Args: + modality: The modality of the multi-modal item that uses this + keyword argument. + batch_size: The number of multi-modal items which share this data. + + Example: + + .. code-block:: + + Given: + batch_size: 4 + + Input: + Data: [XYZ] + + Output: + Element 1: [XYZ] + Element 2: [XYZ] + Element 3: [XYZ] + Element 4: [XYZ] + """ + return MultiModalFieldConfig( + field=MultiModalSharedField(batch_size), + modality=modality, + ) + + def __init__(self, field: BaseMultiModalField, modality: str) -> None: + super().__init__() + + self.field = field + self.modality = modality + + def build_elems( + self, + key: str, + batch: NestedTensors, + ) -> Sequence[MultiModalFieldElem]: + return self.field.build_elems(self.modality, key, batch) + + +class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]): + """ + A collection of :class:`MultiModalFieldElem` + corresponding to a data item in :class:`MultiModalDataItems`. + """ + + @staticmethod + def from_elems(elems: Sequence[MultiModalFieldElem]): + return MultiModalKwargsItem({elem.key: elem for elem in elems}) + + @property + def modality(self) -> str: + modalities = {elem.modality for elem in self.data.values()} + assert len(modalities) == 1, f"Found different modalities={modalities}" + return next(iter(modalities)) + + +# NOTE: UserDict is for V0 compatibility. +# V1 should access individual items via `get_item`. +class MultiModalKwargs(UserDict[str, NestedTensors]): + """ + A dictionary that represents the keyword arguments to + :meth:`~torch.nn.Module.forward`. + + The metadata :code:`items` enables us to obtain the keyword arguments + corresponding to each data item in :class:`MultiModalDataItems`, via + :meth:`get_item` and :meth:`get_items`. + """ + + @staticmethod + def from_hf_inputs( + hf_inputs: BatchFeature, + config_by_key: Mapping[str, MultiModalFieldConfig], + ): + # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key` + # We assume that those fields are not used in vLLM + elems_by_key = dict[str, Sequence[MultiModalFieldElem]]() + keys_by_modality = defaultdict[str, set[str]](set) + for key, config in config_by_key.items(): + batch = hf_inputs.get(key) + if batch is not None: + elems = config.build_elems(key, batch) + if len(elems) > 0: + elems_by_key[key] = elems + keys_by_modality[config.modality].add(key) + + items = list[MultiModalKwargsItem]() + for modality, keys in keys_by_modality.items(): + elems_in_modality = {k: elems_by_key[k] for k in keys} + batch_sizes = {k: len(v) for k, v in elems_in_modality.items()} + + if len(set(batch_sizes.values())) > 1: + raise ValueError( + f"Cannot merge different batch sizes for {modality=}! " + f"Found: {batch_sizes=}") + + batch_size = next(iter(batch_sizes.values())) + for item_idx in range(batch_size): + elems = [v[item_idx] for v in elems_in_modality.values()] + items.append(MultiModalKwargsItem.from_elems(elems)) + + return MultiModalKwargs.from_items(items) + + @staticmethod + def from_items(items: Sequence[MultiModalKwargsItem]): + """Construct a new :class:`MultiModalKwargs` from multiple items.""" + elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list) + for item in items: + for key, elem in item.items(): + elems_by_key[key].append(elem) + + data = { + key: elems[0].field.reduce_data(elems) + for key, elems in elems_by_key.items() if len(elems) > 0 + } + + return MultiModalKwargs(data, items=items) + + def __init__( + self, + data: Mapping[str, NestedTensors], + *, + items: Optional[Sequence[MultiModalKwargsItem]] = None, + ) -> None: + super().__init__(data) + + items_by_modality = full_groupby(items or [], key=lambda x: x.modality) + self._items_by_modality = dict(items_by_modality) + + @property + def modalities(self): + return self._items_by_modality.keys() + + @staticmethod + def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: + """ + Stack the inner dimensions that have the same shape in + a nested list of tensors. + + Thus, a dimension represented by a list means that the inner + dimensions are different for each element along that dimension. + """ + if isinstance(nested_tensors, torch.Tensor): + return nested_tensors + + # TODO: Remove these once all models have been migrated + if isinstance(nested_tensors, np.ndarray): + return torch.from_numpy(nested_tensors) + if isinstance(nested_tensors, (int, float)): + return torch.tensor(nested_tensors) + + stacked = [MultiModalKwargs._try_stack(t) for t in nested_tensors] + if not is_list_of(stacked, torch.Tensor, check="all"): + # Only tensors (not lists) can be stacked. + return stacked + + tensors_ = cast(list[torch.Tensor], stacked) + if len(tensors_) == 1: + # An optimization when `tensors_` contains only one tensor: + # - produce exactly same result as `torch.stack(tensors_)` + # - will achieve zero-copy if the tensor is contiguous + return tensors_[0].unsqueeze(0).contiguous() + + if any(t.shape != tensors_[0].shape for t in tensors_): + # The tensors have incompatible shapes and can't be stacked. + return tensors_ + + return torch.stack(tensors_) + + @staticmethod + def batch(inputs_list: list["MultiModalKwargs"]) -> BatchedTensorInputs: + """ + Batch multiple inputs together into a dictionary. + + The resulting dictionary has the same keys as the inputs. + If the corresponding value from each input is a tensor and they all + share the same shape, the output value is a single batched tensor; + otherwise, the output value is a list containing the original value + from each input. + """ + if len(inputs_list) == 0: + return {} + + # We need to consider the case where each item in the batch + # contains different modalities (i.e. different keys). + item_lists = defaultdict[str, list[NestedTensors]](list) + + for inputs in inputs_list: + for k, v in inputs.items(): + item_lists[k].append(v) + + return { + k: MultiModalKwargs._try_stack(item_list) + for k, item_list in item_lists.items() + } + + @staticmethod + def as_kwargs( + batched_inputs: BatchedTensorInputs, + *, + device: torch.types.Device, + ) -> BatchedTensorInputs: + json_inputs = cast(JSONTree[torch.Tensor], batched_inputs) + + json_mapped = json_map_leaves( + lambda x: x.to(device, non_blocking=True), + json_inputs, + ) + + return cast(BatchedTensorInputs, json_mapped) + + def __delitem__(self, key: str) -> None: + super().__delitem__(key) + + for items in self._items_by_modality.values(): + for item in items: + item.pop(key, None) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, self.__class__): + return False + if self._items_by_modality != other._items_by_modality: + return False + + ks = self.keys() + return (ks == other.keys() + and all(nested_tensors_equal(self[k], other[k]) for k in ks)) + + def _validate_modality(self, method_name: str, modality: str) -> None: + if not self._items_by_modality: + raise RuntimeError( + f"`{method_name}` is not supported when " + "MultiModalKwargs is not initialized with `items`") + + if modality not in self._items_by_modality: + available_modalities = set(self._items_by_modality.keys()) + raise KeyError(f"Modality {modality!r} not found. " + f"Available modalities: {available_modalities}") + + def get_item_count(self, modality: str) -> int: + """Get the number of items belonging to a modality.""" + self._validate_modality("get_item_count", modality) + return len(self._items_by_modality[modality]) + + def get_item(self, modality: str, item_index: int) -> MultiModalKwargsItem: + """ + Get the keyword arguments corresponding to an item identified by + its modality and index. + """ + self._validate_modality("get_item", modality) + return self._items_by_modality[modality][item_index] + + def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]: + """ + Get the keyword arguments corresponding to each item belonging to + a modality. + """ + self._validate_modality("get_items", modality) + return self._items_by_modality[modality] + + +MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]] +""" +A dictionary containing placeholder ranges for each modality. +""" + + +class MultiModalInputs(TypedDict): + """ + Represents the outputs of + :class:`vllm.multimodal.processing.BaseMultiModalProcessor`, + ready to be passed to vLLM internals. + """ + + type: Literal["multimodal"] + """The type of inputs.""" + + prompt: str + """The processed prompt text.""" + + prompt_token_ids: list[int] + """The processed token IDs which includes placeholder tokens.""" + + token_type_ids: NotRequired[list[int]] + """The token type IDs of the prompt.""" + + mm_kwargs: MultiModalKwargs + """Keyword arguments to be directly passed to the model after batching.""" + + mm_hashes: Optional["MultiModalHashDict"] + """The hashes of the multi-modal data.""" + + mm_placeholders: MultiModalPlaceholderDict + """ + For each modality, information about the placeholder tokens in + :code:`prompt_token_ids`. + """ + + +class MultiModalEncDecInputs(MultiModalInputs): + """ + Represents the outputs of :class:`vllm.multimodal.EncDecMultiModalProcessor` + ready to be passed to vLLM internals. + """ + + encoder_prompt: str + """The processed encoder prompt text.""" + + encoder_prompt_token_ids: list[int] + """The processed token IDs of the encoder prompt.""" + + encoder_token_type_ids: NotRequired[list[int]] + """The token type IDs of the encoder prompt.""" diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py new file mode 100644 index 0000000..fc5a294 --- /dev/null +++ b/vllm/multimodal/parse.py @@ -0,0 +1,453 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from collections import UserDict +from collections.abc import Callable, Iterator, Mapping, Sequence +from typing import (TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar, + Union) + +import numpy as np +import torch +from PIL.Image import Image +from transformers import BatchFeature +from typing_extensions import TypeAlias, TypeGuard, assert_never + +from vllm.utils import is_list_of + +from .audio import resample_audio +from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem, + ImageItem, ModalityData, MultiModalDataDict, + MultiModalFieldConfig, MultiModalKwargs, VideoItem) + +_T = TypeVar("_T") +_I = TypeVar("_I") + + +class ModalityDataItems(ABC, Generic[_T, _I]): + """ + Represents data items for a modality in :class:`MultiModalDataItems`. + """ + + def __init__(self, data: _T, modality: str) -> None: + super().__init__() + + self.data = data + self.modality = modality + + def __repr__(self) -> str: + return (f"{type(self).__name__}(modality={self.modality!r}, " + f"len={len(self)})") + + def __len__(self) -> int: + return self.get_count() + + def __getitem__(self, index: int) -> _I: + return self.get(index) + + if TYPE_CHECKING: + # Auto-generated + def __iter__(self) -> Iterator[_I]: + ... + + @abstractmethod + def get_count(self) -> int: + """Get the number of data items.""" + raise NotImplementedError + + @abstractmethod + def get(self, index: int) -> _I: + """Get a data item by its index.""" + raise NotImplementedError + + def get_all(self) -> list[_I]: + """Get all data items.""" + return [self.get(idx) for idx in range(self.get_count())] + + @abstractmethod + def get_processor_data(self) -> Mapping[str, object]: + """Get the data to pass to the HF processor.""" + raise NotImplementedError + + @abstractmethod + def get_passthrough_data(self) -> Mapping[str, object]: + """Get the data to pass directly to the model.""" + raise NotImplementedError + + +class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]): + """Base class for data items that are arranged in a list.""" + + def get_count(self) -> int: + return len(self.data) + + def get(self, index: int) -> _T: + return self.data[index] + + def get_processor_data(self) -> Mapping[str, object]: + return {f"{self.modality}s": self.data} + + def get_passthrough_data(self) -> Mapping[str, object]: + return {} + + +class EmbeddingItems(ModalityDataItems[Union[torch.Tensor, list[torch.Tensor]], + torch.Tensor]): + """ + Base class for data items that are expressed as a batched embedding tensor, + or a list of embedding tensors (one per item). + """ + + def get_count(self) -> int: + return len(self.data) + + def get(self, index: int) -> torch.Tensor: + return self.data[index] + + def get_processor_data(self) -> Mapping[str, object]: + return {} + + def get_passthrough_data(self) -> Mapping[str, object]: + return {f"{self.modality}_embeds": self.data} + + def get_feature_size(self, item_idx: int) -> int: + return len(self.get(item_idx)) + + +class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor], + Mapping[str, torch.Tensor]]): + """ + Base class for data items that are expressed as a dictionary of tensors. + + Usually, the dictionary keys correspond to the outputs of HF processor. + """ + + def __init__( + self, + data: Mapping[str, torch.Tensor], + modality: str, + required_fields: set[str], + fields_factory: Callable[ + [Mapping[str, torch.Tensor]], + Mapping[str, MultiModalFieldConfig], + ], + ) -> None: + super().__init__(data, modality) + + missing_required_data_keys = required_fields - data.keys() + if missing_required_data_keys: + data_keys = set(data.keys()) + msg = (f"The data should contain the fields: {required_fields}, " + f"but only found the following keys: {data_keys}") + raise ValueError(msg) + + fields_config = fields_factory(data) + missing_required_fields = required_fields - fields_config.keys() + if missing_required_fields: + fields = set(fields_config.keys()) + msg = f"{required_fields=} should be a subset of {fields=}" + raise ValueError(msg) + + self.fields_config = fields_config + self.required_fields = required_fields + + self._kwargs = MultiModalKwargs.from_hf_inputs( + BatchFeature(dict(data)), + fields_config, + ) + + def get_count(self) -> int: + return self._kwargs.get_item_count(self.modality) + + def get(self, index: int) -> Mapping[str, torch.Tensor]: + return { + k: v.data + for k, v in self._kwargs.get_item(self.modality, index).items() + } + + def get_processor_data(self) -> Mapping[str, object]: + return {} + + def get_passthrough_data(self) -> Mapping[str, object]: + return self.data + + +class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]): + + def __init__(self, data: Sequence[HfAudioItem]) -> None: + super().__init__(data, "audio") + + def get_audio_length(self, item_idx: int) -> int: + audio = self.get(item_idx) + return len(audio) + + +class AudioEmbeddingItems(EmbeddingItems): + + def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None: + super().__init__(data, "audio") + + +class ImageSize(NamedTuple): + width: int + height: int + + +class ImageProcessorItems(ProcessorBatchItems[HfImageItem]): + + def __init__(self, data: Sequence[HfImageItem]) -> None: + super().__init__(data, "image") + + def get_image_size(self, item_idx: int) -> ImageSize: + image = self.get(item_idx) + + if isinstance(image, Image): + return ImageSize(*image.size) + if isinstance(image, (np.ndarray, torch.Tensor)): + _, h, w = image.shape + return ImageSize(w, h) + + assert_never(image) + + +class ImageEmbeddingItems(EmbeddingItems): + + def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None: + super().__init__(data, "image") + + +class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]): + + def __init__(self, data: Sequence[HfVideoItem]) -> None: + super().__init__(data, "video") + + def get_num_frames(self, item_idx: int) -> int: + return len(self.get(item_idx)) + + def get_frame_size(self, item_idx: int) -> ImageSize: + image = self.get(item_idx)[0] # Assume that the video isn't empty + + if isinstance(image, Image): + return ImageSize(*image.size) + if isinstance(image, (np.ndarray, torch.Tensor)): + _, h, w = image.shape + return ImageSize(w, h) + + assert_never(image) + + +class VideoEmbeddingItems(EmbeddingItems): + + def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None: + super().__init__(data, "video") + + +_D = TypeVar("_D", bound=ModalityDataItems[Any, Any]) + + +class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]): + """ + As :data:`~vllm.multimodal.inputs.MultiModalDataDict`, but normalized + such that each entry corresponds to a list. + """ + + def get_count(self, modality: str, *, strict: bool = True) -> int: + """ + Get the number of data items belonging to a modality. + + If `strict=False`, return `0` instead of raising :exc:`KeyError` + even if the modality is not found. + """ + if modality not in self: + if strict: + available_modalities = set(self.keys()) + raise KeyError(f"Modality {modality!r} not found. " + f"Available modalities: {available_modalities}") + + return 0 + + return self[modality].get_count() + + def get_all_counts(self) -> Mapping[str, int]: + """Get the number of items belonging to each modality.""" + return {m: items.get_count() for m, items in self.items()} + + def get_items( + self, + modality: str, + typ: Union[type[_D], tuple[type[_D], ...]], + ) -> _D: + """ + Get the data items belonging to a modality, + requiring that they belong to a certain type. + """ + if modality not in self: + available_modalities = set(self.keys()) + raise KeyError(f"Modality {modality!r} not found. " + f"Available modalities: {available_modalities}") + + items = self[modality] + if not isinstance(items, typ): + raise TypeError(f"Invalid type of data items for {modality=}. " + f"Expected type: {typ}, but " + f"found type: {type(items)}") + + return items # type: ignore[return-value] + + +ModalityDataParser: TypeAlias = Callable[[ModalityData[Any]], + Optional[ModalityDataItems[Any, Any]]] + + +class MultiModalDataParser: + """ + Parses :data:`~vllm.multimodal.inputs.MultiModalDataDict` into + :class:`MultiModalDataItems`. + + Args: + target_sr (float, optional): Enables automatic resampling of audio + items to the model's expected sampling rate. + """ + + def __init__(self, *, target_sr: Optional[float] = None) -> None: + super().__init__() + + self.target_sr = target_sr + + def _is_embeddings( + self, data: object + ) -> TypeGuard[Union[torch.Tensor, list[torch.Tensor]]]: + if isinstance(data, torch.Tensor): + return data.ndim == 3 + if is_list_of(data, torch.Tensor): + return data[0].ndim == 2 + + return False + + def _is_empty(self, data: object) -> TypeGuard[None]: + if isinstance(data, list): + return len(data) == 0 + if isinstance(data, (np.ndarray, torch.Tensor)): + return data.size == 0 + + return False + + def _get_audio_with_sr( + self, + audio: AudioItem, + ) -> tuple[np.ndarray, Optional[float]]: + if isinstance(audio, tuple): + return audio + if isinstance(audio, list): + return np.array(audio), None + if isinstance(audio, np.ndarray): + return audio, None + if isinstance(audio, torch.Tensor): + return audio.numpy(), None + + assert_never(audio) + + def _parse_audio_data( + self, + data: ModalityData[AudioItem], + ) -> Optional[ModalityDataItems[Any, Any]]: + # also check single audio item with sampling rate + if self._is_empty(data) or (isinstance(data, tuple) + and self._is_empty(data[0])): + return None + + if self._is_embeddings(data): + return AudioEmbeddingItems(data) + + if (is_list_of(data, float) + or isinstance(data, + (np.ndarray, torch.Tensor)) and data.ndim == 1 + or isinstance(data, tuple)): + data_items = [data] + elif isinstance(data, (np.ndarray, torch.Tensor)): + data_items = [elem for elem in data] + else: + data_items = data + + new_audios = list[np.ndarray]() + for data_item in data_items: + audio, orig_sr = self._get_audio_with_sr(data_item) + if orig_sr is None: + new_audio = audio + else: + target_sr = self.target_sr + if target_sr is None: + raise RuntimeError( + "Audio resampling is not supported when " + "`target_sr` is not provided") + + new_audio = resample_audio(audio, + orig_sr=orig_sr, + target_sr=target_sr) + + new_audios.append(new_audio) + + return AudioProcessorItems(new_audios) + + def _parse_image_data( + self, + data: ModalityData[ImageItem], + ) -> Optional[ModalityDataItems[Any, Any]]: + if self._is_empty(data): + return None + + if self._is_embeddings(data): + return ImageEmbeddingItems(data) + + if (isinstance(data, Image) + or isinstance(data, + (np.ndarray, torch.Tensor)) and data.ndim == 3): + data_items = [data] + elif isinstance(data, (np.ndarray, torch.Tensor)): + data_items = [elem for elem in data] + else: + data_items = data + + return ImageProcessorItems(data_items) + + def _parse_video_data( + self, + data: ModalityData[VideoItem], + ) -> Optional[ModalityDataItems[Any, Any]]: + if self._is_empty(data): + return None + + if self._is_embeddings(data): + return VideoEmbeddingItems(data) + + if (is_list_of(data, Image) + or isinstance(data, + (np.ndarray, torch.Tensor)) and data.ndim == 4): + data_items = [data] + elif isinstance(data, (np.ndarray, torch.Tensor)): + data_items = [elem for elem in data] + else: + data_items = data + + return VideoProcessorItems(data_items) + + def _get_subparsers(self) -> Mapping[str, ModalityDataParser]: + return { + "audio": self._parse_audio_data, + "image": self._parse_image_data, + "video": self._parse_video_data, + } + + def parse_mm_data(self, + mm_data: MultiModalDataDict) -> MultiModalDataItems: + subparsers = self._get_subparsers() + + mm_items = MultiModalDataItems() + for k, v in mm_data.items(): + if k not in subparsers: + raise ValueError(f"Unsupported modality: {k}") + + # ignore empty embedding data + if (parsed_data := subparsers[k](v)) is not None: + mm_items[k] = parsed_data + + return mm_items diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py new file mode 100644 index 0000000..c8864c3 --- /dev/null +++ b/vllm/multimodal/processing.py @@ -0,0 +1,1705 @@ +# SPDX-License-Identifier: Apache-2.0 +import re +import sys +from abc import ABC, abstractmethod +from collections import defaultdict +from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping, + Sequence) +from dataclasses import dataclass, field +from enum import Enum +from functools import lru_cache +from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol, + TypeVar, Union, cast) + +import torch +from transformers import BatchFeature, PretrainedConfig, ProcessorMixin +from typing_extensions import assert_never + +from vllm.inputs import InputProcessingContext +from vllm.jsontree import json_map_leaves, json_reduce_leaves +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens, + encode_tokens) +from vllm.utils import GiB_bytes, LRUCache, flatten_2d_lists, full_groupby + +from .hasher import MultiModalHasher +from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, + MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs, + MultiModalKwargsItem, NestedTensors, PlaceholderRange) +from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems, + MultiModalDataParser) + +if TYPE_CHECKING: + from .profiling import BaseDummyInputsBuilder + +logger = init_logger(__name__) + +_S = TypeVar("_S", str, list[int]) + +PromptSeq = Union[str, list[int]] +"""A token sequence (list of token IDs) or text.""" + + +@dataclass +class PromptIndex: + """Resolves to an index in the prompt.""" + get_match_index: Callable[[AnyTokenizer, PromptSeq], Optional[int]] + + +class PromptIndexTargets: + + @staticmethod + def start() -> PromptIndex: + """ + Resolves to the start of the prompt (before the first token). + + This results in a match even if the prompt is empty. + """ + return PromptIndex(lambda tok, prompt: 0) + + @staticmethod + def prefix(seq: PromptSeq) -> PromptIndex: + """ + Resolves to a location in the prompt after the given prefix. + """ + + def get_match_index( + tokenizer: AnyTokenizer, + prompt: PromptSeq, + ) -> Optional[int]: + prefix = seq + + if isinstance(prompt, str): + if not isinstance(prefix, str): + # Make both `str` + prefix = decode_tokens(tokenizer, prefix) + else: + if isinstance(prefix, str): + # Make both `list[int]` + prefix = encode_tokens(tokenizer, + prefix, + add_special_tokens=False) + + match_idx = len(prefix) + return match_idx if prompt[:match_idx] == prefix else None + + return PromptIndex(get_match_index) + + @staticmethod + def end() -> PromptIndex: + """ + Resolves to the end of the prompt (after the last token). + + This results in a match even if the prompt is empty. + """ + return PromptIndex(lambda tok, prompt: len(prompt)) + + +PromptTarget = Union[PromptSeq, PromptIndex] +""" +The token sequence or text to update. +""" + + +@dataclass +class PromptUpdateDetails(Generic[_S]): + """Details about the token sequence or text that are part of the update.""" + + full: _S + """The full content.""" + + features: _S + """ + The part of the content that corresponds to feature placeholders; + this will be replaced by the output of the vision encoder during model + inference. + """ + + @staticmethod + def from_seq(seq: _S) -> "PromptUpdateDetails[_S]": + return PromptUpdateDetails(full=seq, features=seq) + + +PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails] +""" +The token sequence or text that are part of the update. + +If only part of the content corresponds to feature placeholders, you can +use :class:`PromptUpdateDetails` to specify which part. +""" + +PromptUpdateContent = Union[Callable[[int], PromptUpdateInfo], + PromptUpdateInfo] +""" +Given the index of the processed item within :attr:`modality`, +output the corresponding token sequence (or text). + +For convenience, you can directly pass in the token sequence (or text) +instead of a function if it does not depend on the input. +""" + + +class UpdateMode(str, Enum): + INSERT = "insert" + REPLACE = "replace" + + +@dataclass +class PromptUpdate(ABC): + """ + Defines how to update a prompt with placeholder tokens. + """ + + modality: str + """The modality for which the update is made.""" + + target: PromptTarget + """The token sequence (or text) to update.""" + + @property + @abstractmethod + def content(self) -> PromptUpdateContent: + """The placeholder tokens that are part of the update.""" + raise NotImplementedError + + @property + @abstractmethod + def mode(self) -> UpdateMode: + """Defines how to update the prompt.""" + raise NotImplementedError + + def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptUpdate": + return BoundPromptUpdate( + _origin=self, + tokenizer=tokenizer, + ) + + +@dataclass +class PromptInsertion(PromptUpdate): + """ + Defines how to insert placeholder tokens into a prompt. + + Example: + + For each image, insert a number of ```` feature placeholders + equal to the feature size of the vision encoder after the ```` token: + + .. code-block:: python + + PromptInsertion( + modality="image", + target="", + insertion="" * image_feature_size, + ) + + Insert these tokens at the start of the prompt: + + .. code-block:: python + + PromptInsertion( + modality="image", + target=PromptIndexTargets.start(), + insertion="" * image_feature_size, + ) + + Insert these tokens after a prefix ``Images:``: + + .. code-block:: python + + PromptInsertion( + modality="image", + target=PromptIndexTargets.prefix("Images:"), + insertion="" * image_feature_size, + ) + + Insert these tokens at the end of the prompt: + + .. code-block:: python + + PromptInsertion( + modality="image", + target=PromptIndexTargets.end(), + insertion="" * image_feature_size, + ) + """ + + insertion: PromptUpdateContent = field(repr=False) + """ + Given the index of the processed item within :attr:`modality`, + output the token sequence (or text) to insert right after :attr:`target`. + + For convenience, you can directly pass in the token sequence (or text) + instead of a function if it does not depend on the input. + """ + + @property + def content(self) -> PromptUpdateContent: + return self.insertion + + @property + def mode(self) -> UpdateMode: + return UpdateMode.INSERT + + +@dataclass +class PromptReplacement(PromptUpdate): + """ + Defines how to replace portions of an input prompt with placeholder tokens. + + Example: + + For each image, replace one ```` input placeholder in the prompt + with a number of ```` feature placeholders + equal to the feature size of the vision encoder: + + .. code-block:: python + + PromptReplacement( + modality="image", + target="", + replacement="" * image_feature_size, + ) + + As above, but further pad the feature placeholders with ```` + and ```, which are not supposed to be passed to the vision + encoder: + + .. code-block:: python + + PromptReplacement( + modality="image", + target="", + replacement=PromptUpdateDetails( + full="".join([ + "", + "" * image_feature_size, + "", + ]), + features="" * image_feature_size, + ), + ) + + To avoid unnecessary tokenization during prompt replacement, + we recommended passing token sequences instead of text: + + .. code-block:: python + + PromptReplacement( + modality="image", + target=[image_token_id], + replacement=PromptUpdateDetails( + full=([image_bos_id] + [image_token_id] * image_feature_size + + [image_eos_id]), + features=[image_token_id] * image_feature_size, + ), + ) + """ + + replacement: PromptUpdateContent = field(repr=False) + """ + Given the index of the processed item within :attr:`modality`, + output the token sequence (or text) to replace :attr:`target`. + + For convenience, you can directly pass in the token sequence (or text) + instead of a function if it does not depend on the input. + """ + + @property + def content(self) -> PromptUpdateContent: + return self.replacement + + @property + def mode(self) -> UpdateMode: + return UpdateMode.REPLACE + + +@lru_cache(maxsize=2048) +def _cached_encode( + tokenizer: AnyTokenizer, + text: str, + *, + add_special_tokens: Optional[bool] = None, +) -> list[int]: + return encode_tokens(tokenizer, + text, + add_special_tokens=add_special_tokens) + + +@lru_cache(maxsize=2048) +def _cached_decode( + tokenizer: AnyTokenizer, + token_ids: tuple[int, ...], + *, + skip_special_tokens: Optional[bool] = None, +) -> str: + return decode_tokens(tokenizer, + list(token_ids), + skip_special_tokens=skip_special_tokens) + + +class _HasModalityAttr(Protocol): + modality: str + + +class _HasModalityProp(Protocol): + + @property + def modality(self) -> str: + ... + + +_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp]) + + +def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]: + """Convenience function to apply :func:`full_groupby` based on modality.""" + return full_groupby(values, key=lambda x: x.modality) + + +@dataclass +class _BoundPromptSequence: + """ + A :data:`_PromptSeq` bound to a tokenizer to automatically + convert between token sequence and text representations. + """ + tokenizer: AnyTokenizer = field(repr=False) + + _text: Optional[str] + _token_ids: Optional[list[int]] + + @staticmethod + def from_seq( + tokenizer: AnyTokenizer, + seq: PromptSeq, + ) -> "_BoundPromptSequence": + return _BoundPromptSequence( + tokenizer=tokenizer, + _text=seq if isinstance(seq, str) else None, + _token_ids=seq if isinstance(seq, list) else None, + ) + + def __post_init__(self) -> None: + if self._text is None and self._token_ids is None: + raise ValueError("At least one of 'text' and 'token_ids' must be " + "specified") + + @property + def text(self) -> str: + if self._text is None: + assert self._token_ids is not None + self._text = _cached_decode(self.tokenizer, tuple(self._token_ids)) + + return self._text + + @property + def token_ids(self) -> list[int]: + if self._token_ids is None: + assert self._text is not None + self._token_ids = _cached_encode(self.tokenizer, + self._text, + add_special_tokens=False) + + return self._token_ids + + +@dataclass +class _BoundPromptContent: + full: _BoundPromptSequence + features: _BoundPromptSequence + + +@dataclass +class BoundPromptUpdate: + """ + A :class:`PromptUpdate` bound to a tokenizer to automatically convert + :attr:`target` and the result of :meth:`get_content` between + token sequence and text representations. + """ + _origin: PromptUpdate + tokenizer: AnyTokenizer = field(repr=False) + + def __post_init__(self) -> None: + self._content_cache = dict[int, _BoundPromptContent]() + + @property + def modality(self) -> str: + return self._origin.modality + + @property + def target(self) -> Union[_BoundPromptSequence, PromptIndex]: + """The token sequence (or text) to update.""" + target = self._origin.target + + if isinstance(target, PromptIndex): + return target + + return _BoundPromptSequence.from_seq(self.tokenizer, target) + + @property + def content(self) -> PromptUpdateContent: + """The placeholder tokens that are part of the update.""" + return self._origin.content + + @property + def mode(self) -> UpdateMode: + """Defines how to update the prompt.""" + return self._origin.mode + + def get_content(self, item_idx: int) -> _BoundPromptContent: + """ + Given the index of the processed item within :attr:`modality`, + output the token sequence (or text) to update. + """ + content = self.content + if callable(content): + cache_key = item_idx + if cache_key in self._content_cache: + return self._content_cache[cache_key] + + content = content(item_idx) + else: + cache_key = None + + if not isinstance(content, PromptUpdateDetails): + content = PromptUpdateDetails.from_seq(content) + + bound_full = _BoundPromptSequence.from_seq(self.tokenizer, + content.full) + bound_features = _BoundPromptSequence.from_seq(self.tokenizer, + content.features) + bound_content = _BoundPromptContent(full=bound_full, + features=bound_features) + + if cache_key is not None: + self._content_cache[cache_key] = bound_content + + return bound_content + + +class _TokenMatch(NamedTuple): + start_idx: int + end_idx: int + + +def iter_token_matches( + token_ids: list[int], + match_ids: list[int], +) -> Generator[_TokenMatch]: + """ + Yield each occurrence of :code:`match_ids` in :code:`token_ids`. + + Note that empty matches are ignored. + """ + prompt_len = len(token_ids) + match_len = len(match_ids) + + if match_len == 0: + return + + start_idx = 0 + while start_idx < prompt_len - match_len + 1: + end_idx = start_idx + match_len + + if token_ids[start_idx:end_idx] == match_ids: + yield _TokenMatch(start_idx=start_idx, end_idx=end_idx) + + # Exclude overlapping matches + start_idx = end_idx + else: + start_idx += 1 + + +def replace_token_matches( + token_ids: list[int], + match_ids: list[int], + new_ids: list[int], +) -> list[int]: + """ + Replace each occurrence of :code:`match_ids` in :code:`token_ids` + with :code:`new_ids`. + + Note that empty matches are ignored. + """ + out_seqs = list[list[int]]() + prev_end_idx = 0 + + for match in iter_token_matches(token_ids, match_ids): + start_idx = match.start_idx + end_idx = match.end_idx + + out_seqs.append(token_ids[prev_end_idx:start_idx]) + out_seqs.append(new_ids) + prev_end_idx = end_idx + + out_seqs.append(token_ids[prev_end_idx:]) + + return flatten_2d_lists(out_seqs) + + +@dataclass(repr=False) +class PromptTargetMatch(ABC): + _origin: BoundPromptUpdate + + @property + def modality(self) -> str: + return self._origin.modality + + @property + @abstractmethod + def start_idx(self) -> int: + raise NotImplementedError + + @property + @abstractmethod + def end_idx(self) -> int: + raise NotImplementedError + + def __repr__(self) -> str: + return (f"{type(self).__name__}(modality={self.modality!r}, " + f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})") + + +@dataclass(repr=False) +class _PromptTargetIndexMatch(PromptTargetMatch): + match_idx: int + + @property + def start_idx(self) -> int: + return self.match_idx + + @property + def end_idx(self) -> int: + return self.match_idx + + +@dataclass(repr=False) +class _PromptTargetTokenMatch(PromptTargetMatch): + match: _TokenMatch + + @property + def start_idx(self) -> int: + return self.match.start_idx + + @property + def end_idx(self) -> int: + return self.match.end_idx + + +@dataclass(repr=False) +class _PromptTargetTextMatch(PromptTargetMatch): + match: re.Match[str] + + @property + def start_idx(self) -> int: + return self.match.start() + + @property + def end_idx(self) -> int: + return self.match.end() + + +@dataclass +class PlaceholderFeaturesInfo: + modality: str + item_idx: int + start_idx: int + tokens: list[int] + + @property + def length(self) -> int: + return len(self.tokens) + + def to_range(self) -> PlaceholderRange: + return PlaceholderRange( + offset=self.start_idx, + length=self.length, + ) + + +def find_token_matches( + prompt: list[int], + prompt_updates: Sequence[BoundPromptUpdate], +) -> Sequence[PromptTargetMatch]: + """Return each target of :code:`prompt_updates` found in :code:`prompt`.""" + + def get_matches(update: BoundPromptUpdate): + target = update.target + + if isinstance(target, PromptIndex): + match_idx = target.get_match_index(update.tokenizer, prompt) + if match_idx is None: + return [] + + return [_PromptTargetIndexMatch(update, match_idx)] + + return [ + _PromptTargetTokenMatch(update, match) + for match in iter_token_matches(prompt, target.token_ids) + ] + + return [ + match for update in prompt_updates for match in get_matches(update) + ] + + +def find_text_matches( + prompt: str, + prompt_updates: Sequence[BoundPromptUpdate], +) -> Sequence[PromptTargetMatch]: + """Return each target of :code:`prompt_updates` found in :code:`prompt`.""" + + def get_matches(update: BoundPromptUpdate): + target = update.target + + if isinstance(target, PromptIndex): + match_idx = target.get_match_index(update.tokenizer, prompt) + if match_idx is None: + return [] + + return [_PromptTargetIndexMatch(update, match_idx)] + + return [ + _PromptTargetTextMatch(update, match) + for match in re.finditer(re.escape(target.text), prompt) + ] + + return [ + match for update in prompt_updates for match in get_matches(update) + ] + + +def _resolve_matches( + prompt: PromptSeq, + mm_matches: Mapping[str, Sequence[PromptTargetMatch]], +) -> list[PromptTargetMatch]: + """ + Resolve :code:`mm_matches` to ensure that there are no overlapping matches, + and sort them such that earlier matches take priority over later ones. + """ + matches = [m for matches in mm_matches.values() for m in matches] + + seen_matches: list[Optional[PromptTargetMatch]] = [None] * len(prompt) + + for match in matches: + for idx in range(match.start_idx, match.end_idx): + if seen_matches[idx] is not None: + raise ValueError("Found overlapping matches " + f"({seen_matches[idx]} and {match}) " + f"at index={idx} of prompt={prompt}") + + seen_matches[idx] = match + + return sorted(matches, key=lambda x: x.start_idx) + + +def _apply_matches( + prompt: _S, + mm_matches: Mapping[str, Sequence[PromptTargetMatch]], + mm_item_counts: Mapping[str, int], +) -> list[_S]: + """Apply the updates in :code:`mm_matches` to :code:`prompt`.""" + out_seqs = list[Union[str, list[int]]]() + prev_end_idx = 0 + next_idx_by_modality = defaultdict[str, int](lambda: 0) + + for match in _resolve_matches(prompt, mm_matches): + modality = match.modality + + item_start_idx = next_idx_by_modality[modality] + max_item_count = mm_item_counts.get(modality, 0) + if item_start_idx >= max_item_count: + continue + + start_idx = match.start_idx + end_idx = match.end_idx + origin = match._origin + mode = origin.mode + + if mode == UpdateMode.INSERT: + out_seqs.append(prompt[prev_end_idx:end_idx]) + num_inserts = max_item_count + elif mode == UpdateMode.REPLACE: + out_seqs.append(prompt[prev_end_idx:start_idx]) + num_inserts = max_item_count if start_idx == end_idx else 1 + else: + assert_never(mode) + + item_end_idx = min(item_start_idx + num_inserts, max_item_count) + + for item_idx in range(item_start_idx, item_end_idx): + content = origin.get_content(item_idx) + insert_seq = (content.full.text if isinstance(prompt, str) else + content.full.token_ids) + + out_seqs.append(insert_seq) + + prev_end_idx = end_idx + next_idx_by_modality[modality] += item_end_idx - item_start_idx + + out_seqs.append(prompt[prev_end_idx:]) + + return cast(list[_S], out_seqs) + + +def apply_token_matches( + prompt: list[int], + mm_matches: Mapping[str, Sequence[PromptTargetMatch]], + mm_item_counts: Mapping[str, int], +) -> list[int]: + """Apply the updates in :code:`mm_matches` to :code:`prompt`.""" + if not mm_matches: + return prompt + + token_id_seqs = _apply_matches(prompt, mm_matches, mm_item_counts) + + return flatten_2d_lists(token_id_seqs) + + +def apply_text_matches( + prompt: str, + mm_matches: Mapping[str, Sequence[PromptTargetMatch]], + mm_item_counts: Mapping[str, int], +) -> str: + """Apply the updates in :code:`mm_matches` to :code:`prompt`.""" + if not mm_matches: + return prompt + + texts = _apply_matches(prompt, mm_matches, mm_item_counts) + + return "".join(texts) + + +def _iter_placeholders( + mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], + prompt: list[int], + mm_item_counts: Mapping[str, int], +) -> Iterable[PlaceholderFeaturesInfo]: + """ + Yield each set of placeholder tokens found in :code:`prompt`. + + Matches are exclusive even when multiple modalities share + the same placeholder tokens. In that case, the modality that + appears earlier in `mm_prompt_updates` takes priority. + + Note that empty matches are ignored. + """ + prompt_len = len(prompt) + item_idx_by_modality = defaultdict[str, int](lambda: 0) + + start_idx = 0 + while start_idx < prompt_len: + found = False + + for modality, modality_updates in mm_prompt_updates.items(): + item_idx = item_idx_by_modality[modality] + if item_idx >= mm_item_counts.get(modality, 0): + continue + + for update_info in modality_updates: + content = update_info.get_content(item_idx) + content_tokens_full = content.full.token_ids + content_len_full = len(content_tokens_full) + end_idx_full = start_idx + content_len_full + + if content_len_full == 0 or end_idx_full > prompt_len: + continue + + if prompt[start_idx:end_idx_full] == content_tokens_full: + content_tokens_feat = content.features.token_ids + + try: + match = next( + iter_token_matches(content_tokens_full, + content_tokens_feat)) + yield PlaceholderFeaturesInfo( + modality=modality, + item_idx=item_idx, + start_idx=start_idx + match.start_idx, + tokens=content_tokens_feat, + ) + except StopIteration: + raise AssertionError( + f"{content_tokens_feat=} should be a " + f"subsequence of {content_tokens_full=}") from None + + # Exclude overlapping matches + start_idx = end_idx_full + item_idx_by_modality[modality] += 1 + found = True + break + + if found: + break # Go back to the outer while loop + + if not found: + start_idx += 1 + + +def find_mm_placeholders( + mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], + prompt: list[int], + mm_item_counts: Mapping[str, int], +) -> Mapping[str, list[PlaceholderFeaturesInfo]]: + it = _iter_placeholders(mm_prompt_updates, prompt, mm_item_counts) + return dict(full_groupby_modality(it)) + + +_V = TypeVar("_V", bound="Union[MultiModalKwargs, MultiModalKwargsItem]") + + +class ProcessingCache: + + @staticmethod + def get_lru_cache( + capacity_gb: float, + value_type: type[_V], + *, + debug: bool = False, + ) -> LRUCache[str, _V]: + + def get_leaf_size(leaf: object) -> int: + # MultiModalKwargs is not a subclass of dict + if isinstance(leaf, MultiModalKwargs): + return get_item_size(leaf.data) + + # MultiModalKwargsItem is not a subclass of dict + if isinstance(leaf, MultiModalKwargsItem): + leaf_data = {k: v.data for k, v in leaf.items()} + return get_item_size(leaf_data) + + # sys.getsizeof doesn't work for tensors + if isinstance(leaf, torch.Tensor): + return leaf.nbytes + + return sys.getsizeof(leaf) + + def get_item_size( + value: Union[MultiModalKwargs, MultiModalKwargsItem, + Mapping[str, NestedTensors]] + ) -> int: + size = json_reduce_leaves( + lambda a, b: a + b, + json_map_leaves(get_leaf_size, value), + ) + + if debug: + logger.debug("Calculated size of %s to be %.2f GiB", + type(value), size / GiB_bytes) + + return size + + return LRUCache(GiB_bytes * capacity_gb, getsizeof=get_item_size) + + def __init__( + self, + capacity_gb: float, + *, + debug_cache_hit_ratio_steps: Optional[int] = None, + ) -> None: + super().__init__() + + self.debug_cache_hit_ratio_steps = debug_cache_hit_ratio_steps + self.debug_cache_hits = 0 + self.debug_cache_total = 0 + + self._cache = self.get_lru_cache( + capacity_gb, + MultiModalKwargsItem, + debug=bool(debug_cache_hit_ratio_steps), + ) + + def _maybe_log_cache_stats(self) -> None: + steps = self.debug_cache_hit_ratio_steps + if not steps: + return + + total = self.debug_cache_total + if total > 0 and total % steps == 0: + logger.debug("ProcessingCache: hit_ratio = %.2f", + self.debug_cache_hits / total) + logger.debug("ProcessingCache: size = %.2f / %.2f GiB", + self._cache.currsize / GiB_bytes, + self._cache.maxsize / GiB_bytes) + + def get( + self, + model_id: str, + modality: str, + input_item: object, + input_kwargs: Mapping[str, object], + ) -> Optional[MultiModalKwargsItem]: + """ + Get a processed multi-modal item from the cache + according to its dependencies, including: + + - The model ID + - The modality of the item + - The original data item passed to the HF processor + - The configuration options of the HF processor + """ + self._maybe_log_cache_stats() + + cache_key = MultiModalHasher.hash_kwargs(model_id=model_id, + **{modality: input_item}, + **input_kwargs) + + if self.debug_cache_hit_ratio_steps: + if cache_key in self._cache: + self.debug_cache_hits += 1 + + self.debug_cache_total += 1 + + return self._cache.get(cache_key) + + def put( + self, + model_id: str, + modality: str, + input_item: object, + input_kwargs: Mapping[str, object], + output_kwargs: MultiModalKwargsItem, + ) -> None: + """ + Put a processed multi-modal item into the cache + according to its dependencies (see :meth:`get`). + """ + cache_key = MultiModalHasher.hash_kwargs(model_id=model_id, + **{modality: input_item}, + **input_kwargs) + self._cache[cache_key] = output_kwargs + + +class BaseProcessingInfo: + """Base class to provide the information necessary for data processing.""" + + def __init__(self, ctx: InputProcessingContext) -> None: + super().__init__() + + self.ctx = ctx + + @property + def model_id(self) -> str: + return self.ctx.model_config.model + + def get_tokenizer(self) -> AnyTokenizer: + return self.ctx.tokenizer + + def get_hf_config(self) -> PretrainedConfig: + return self.ctx.get_hf_config() + + def get_hf_processor(self, **kwargs: object) -> ProcessorMixin: + """ + Subclasses can override this method to handle + specific kwargs from model config or user inputs. + """ + return self.ctx.get_hf_processor(**kwargs) + + @abstractmethod + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + """ + Return the maximum supported number of items for each modality. + + A value of `None` means unlimited number of items. + + Omitting a modality from the returned dictionary means that + it is not supported at all. + """ + raise NotImplementedError + + @abstractmethod + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + """ + Get the maximum possible number of tokens per data item + for each modality. + + The dictionary returned by this method should have the same + keys as that returned by :meth:`get_supported_mm_limits`. + """ + raise NotImplementedError + + +_I = TypeVar("_I", bound=BaseProcessingInfo) + + +class BaseMultiModalProcessor(ABC, Generic[_I]): + """ + Abstract base class to process multi-modal inputs to be used in vLLM. + + Not to be confused with :class:`transformers.ProcessorMixin`. + """ + + def __init__(self, + info: _I, + dummy_inputs: "BaseDummyInputsBuilder[_I]", + *, + cache: Optional[ProcessingCache] = None, + enable_sanity_checks: bool = True) -> None: + if get_repls := getattr(self, "_get_prompt_replacements", None): + logger.warning_once("`_get_prompt_replacements` has been renamed " + "to `_get_prompt_updates`. The old name will " + "be removed in an upcoming release.") + self._get_prompt_updates = get_repls # type: ignore[method-assign] + + super().__init__() + + self.info = info + self.dummy_inputs = dummy_inputs + self.cache = cache + self.enable_sanity_checks = enable_sanity_checks + + self.data_parser = self._get_data_parser() + + def __call__( + self, + prompt: str, + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> MultiModalInputs: + return self.apply(prompt, mm_data, hf_processor_mm_kwargs) + + def _get_data_parser(self) -> MultiModalDataParser: + """ + Construct a parser to preprocess multi-modal data items + before passing them to :meth:`_get_hf_mm_data`. + + You can support additional modalities by creating a subclass + of :class:`MultiModalDataParser` that has additional subparsers. + """ + return MultiModalDataParser() + + def _to_mm_items( + self, + mm_data: MultiModalDataDict, + ) -> MultiModalDataItems: + """ + Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems` + before passing them to :meth:`_get_hf_mm_data`. + """ + mm_items = self.data_parser.parse_mm_data(mm_data) + mm_config = self.info.ctx.get_mm_config() + + for modality, items in mm_items.items(): + limit = mm_config.get_limit_per_prompt(modality) + if len(items) > limit: + raise ValueError( + f"You set {modality}={limit} (or defaulted to 1) in " + f"`--limit-mm-per-prompt`, but passed {len(items)} " + f"{modality} items in the same prompt.") + + return mm_items + + @abstractmethod + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + """Given the HF-processed data, output the metadata of each field.""" + raise NotImplementedError + + @abstractmethod + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + """ + Given the original multi-modal items for this modality + and HF-processed data, output the updates to perform. + + The information returned by this method is used to update token inputs + which bypass the HF processor. It is also used to update the output of + HF processor if the HF process does not apply prompt updates to text + inputs. + + Moreover, this information is critical to determine the token positions + in order to construct :class:`~vllm-multimodal.input.PlaceholderRange` + for each multi-modal item. + """ + raise NotImplementedError + + def _find_mm_placeholders( + self, + mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], + new_token_ids: list[int], + mm_item_counts: Mapping[str, int], + ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: + return find_mm_placeholders(mm_prompt_updates, new_token_ids, + mm_item_counts) + + def _get_hf_mm_data( + self, + mm_items: MultiModalDataItems, + ) -> tuple[Mapping[str, object], Mapping[str, object]]: + processor_data = dict[str, object]() + passthrough_data = dict[str, object]() + + for items in mm_items.values(): + processor_data.update(items.get_processor_data()) + passthrough_data.update(items.get_passthrough_data()) + + return processor_data, passthrough_data + + def _call_hf_processor( + self, + prompt: str, + # Not to be confused with `mm_data` in `self.apply`. + # This refers to the data to be passed to HF processor. + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + """ + Call the HF processor on the prompt text and + associated multi-modal data. + """ + return self.info.ctx.call_hf_processor( + self.info.get_hf_processor(**mm_kwargs), + dict(text=prompt, **mm_data), + mm_kwargs, + ) + + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> bool: + """ + Return whether the HF processor applies prompt updates. + + For most HF processors, this should be :code:`True` when multi-modal + data items are passed, but :code:`False` when multi-modal embeddings + are passed. + """ + return not any( + isinstance(items, (EmbeddingItems, DictEmbeddingItems)) + for items in mm_items.values()) + + def _apply_hf_processor_text_mm( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> tuple[list[int], MultiModalKwargs, bool]: + """ + Apply the HF processor on the prompt text and multi-modal data + together. + + In addition, return whether prompt updates have been applied. + """ + processor_data, passthrough_data = self._get_hf_mm_data(mm_items) + + processed_data = self._call_hf_processor( + prompt=prompt_text, + mm_data=processor_data, + mm_kwargs=hf_processor_mm_kwargs, + ) + processed_data.update(passthrough_data) + + prompt_ids, = processed_data.pop("input_ids").tolist() + + mm_kwargs = MultiModalKwargs.from_hf_inputs( + processed_data, + self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs), + ) + + is_update_applied = self._hf_processor_applies_updates( + prompt_text=prompt_text, + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + + return prompt_ids, mm_kwargs, is_update_applied + + def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]: + """ + Apply the HF processor on the prompt text only. + + Since HF processor requires that text and multi-modal items + correspond to each other, we create dummy multi-modal items + to go along with the text. + """ + prompt_ids, _, _ = self._apply_hf_processor_text_mm( + prompt_text=prompt_text, + mm_items=MultiModalDataItems({}), + hf_processor_mm_kwargs={}, + ) + + return prompt_ids + + def _apply_hf_processor_tokens_only( + self, + prompt_tokens: list[int], + ) -> list[int]: + """ + Apply the HF processor on the prompt tokens only. + + Most HF processors accept prompt text but not prompt tokens. + If the HF processor adds or removes tokens that are not related to + multi-modal data, you should override this method so it is consistent + with the output of :meth:`_apply_hf_processor_text_only` on the + corresponding text. + """ + return prompt_tokens + + def _apply_hf_processor_mm_only( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> MultiModalKwargs: + """ + Apply the HF processor on the multi-modal data only. + + Since HF processor requires that text and multi-modal items + correspond to each other, we generate dummy text using + :class:`DummyInputsBuilder` to go along with the multi-modal data. + """ + mm_counts = mm_items.get_all_counts() + + dummy_inputs = self.dummy_inputs.get_dummy_processor_inputs( + self.info.ctx.model_config.max_model_len, + mm_counts, + ) + + _, mm_kwargs, _ = self._apply_hf_processor_text_mm( + prompt_text=dummy_inputs.prompt_text, + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + + return mm_kwargs + + def _apply_hf_processor_main( + self, + prompt: Union[str, list[int]], + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + *, + enable_hf_prompt_update: bool, + ) -> tuple[list[int], MultiModalKwargs, bool]: + """ + Apply the HF processor on the prompt text and multi-modal data. + + In addition, return whether prompt updates have been applied + (for most HF processors, this should be :code:`True`). + + Note: + If :code:`enable_hf_prompt_update=False`, we use HF processor + to perform prompt updates if available; HF processor requires + that the prompt corresponds to multi-modal items. + """ + if isinstance(prompt, str): + if enable_hf_prompt_update: + return self._apply_hf_processor_text_mm( + prompt_text=prompt, + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + + prompt_ids = self._apply_hf_processor_text_only(prompt) + else: + prompt_ids = self._apply_hf_processor_tokens_only(prompt) + + mm_kwargs = self._apply_hf_processor_mm_only( + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + + return prompt_ids, mm_kwargs, False + + def _cached_apply_hf_processor( + self, + prompt: Union[str, list[int]], + mm_data_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> tuple[list[int], MultiModalKwargs, bool]: + """ + Apply the HF processor on the full prompt text, + caching the results and reusing cached results. + """ + cache = self.cache + model_id = self.info.model_id + + _, passthrough_data = self._get_hf_mm_data(mm_data_items) + if cache is None or passthrough_data: + return self._apply_hf_processor_main( + prompt=prompt, + mm_items=mm_data_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + enable_hf_prompt_update=True, + ) + + mm_maybe_cached_kw_items = { + modality: [ + cache.get(model_id, modality, item, hf_processor_mm_kwargs) + for item in items + ] + for modality, items in mm_data_items.items() + } + + mm_missing_idxs = { + modality: + [idx for idx, item in enumerate(kw_items) if item is None] + for modality, kw_items in mm_maybe_cached_kw_items.items() + } + mm_missing_data = { + modality: [mm_data_items[modality][idx] for idx in idxs] + for modality, idxs in mm_missing_idxs.items() + } + mm_missing_data_items = self._to_mm_items(mm_missing_data) + + # NOTE: `prompt` does not correspond to `mm_missing_data_items`, + # so we can't apply prompt updates until the new multimodal + # items are combined with the cached multimodal items + ( + prompt_ids, + mm_missing_kwargs, + is_update_applied, + ) = self._apply_hf_processor_main( + prompt=prompt, + mm_items=mm_missing_data_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + enable_hf_prompt_update=False, + ) + + mm_missing_next_idx = { + modality: 0 + for modality in mm_missing_data_items + } + + merged_kw_items = list[MultiModalKwargsItem]() + for modality, kw_items in mm_maybe_cached_kw_items.items(): + for idx, kw_item in enumerate(kw_items): + if kw_item is None: + kw_item = mm_missing_kwargs.get_item( + modality, + mm_missing_next_idx[modality], + ) + + cache.put( + model_id, + modality, + mm_data_items[modality][idx], + hf_processor_mm_kwargs, + kw_item, + ) + + mm_missing_next_idx[modality] += 1 + + merged_kw_items.append(kw_item) + + if self.enable_sanity_checks: + mm_missing_counts = mm_missing_data_items.get_all_counts() + assert all( + item_count == mm_missing_counts[modality] + for modality, item_count in mm_missing_next_idx.items()), dict( + mm_missing_next_idx=mm_missing_next_idx, + mm_missing_counts=mm_missing_counts) + + mm_kwargs = MultiModalKwargs.from_items(merged_kw_items) + + return prompt_ids, mm_kwargs, is_update_applied + + def _bind_and_group_updates( + self, + prompt_updates: Sequence[PromptUpdate], + ) -> dict[str, Sequence[BoundPromptUpdate]]: + tokenizer = self.info.get_tokenizer() + + it = (update.bind(tokenizer) for update in prompt_updates) + return dict(full_groupby_modality(it)) + + def _apply_token_matches( + self, + prompt: list[int], + mm_matches: Mapping[str, Sequence[PromptTargetMatch]], + mm_item_counts: Mapping[str, int], + ) -> list[int]: + return apply_token_matches(prompt, mm_matches, mm_item_counts) + + def _apply_text_matches( + self, + prompt: str, + mm_matches: Mapping[str, Sequence[PromptTargetMatch]], + mm_item_counts: Mapping[str, int], + ) -> str: + return apply_text_matches(prompt, mm_matches, mm_item_counts) + + def _apply_prompt_updates( + self, + token_ids: list[int], + mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], + mm_item_counts: Mapping[str, int], + ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: + tokenizer = self.info.get_tokenizer() + + mm_token_matches = { + modality: find_token_matches(token_ids, updates) + for modality, updates in mm_prompt_updates.items() + } + mm_match_counts = { + modality: len(matches) + for modality, matches in mm_token_matches.items() + } + + # If the search text does not represent a special token, + # it may have different token IDs in the prompt, because + # the tokens may go across the boundaries of the search text. + # ---- + # e.g. when searching for "foo" in "food", if "food" itself makes + # up a token, then the token ID of "foo" will not appear at all + # ---- + # Since it is inefficient to search for all possible tokenizations + # of the search text in the prompt, we instead perform string-based + # updates on the decoded token IDs, then encode them back. + if all( + mm_match_counts.get(modality, 0) >= item_count + for modality, item_count in mm_item_counts.items() + ): # yapf: disable + token_ids = self._apply_token_matches( + token_ids, + mm_token_matches, + mm_item_counts, + ) + + text = decode_tokens(tokenizer, token_ids) + matched_updates = { + modality: [match._origin for match in token_matches] + for modality, token_matches in mm_token_matches.items() + } + else: + text = decode_tokens(tokenizer, token_ids) + + mm_text_matches = { + modality: find_text_matches(text, updates) + for modality, updates in mm_prompt_updates.items() + } + text = self._apply_text_matches( + text, + mm_text_matches, + mm_item_counts, + ) + + token_ids = encode_tokens(tokenizer, + text, + add_special_tokens=False) + matched_updates = { + modality: [match._origin for match in token_matches] + for modality, token_matches in mm_text_matches.items() + } + + placeholders = self._find_mm_placeholders( + matched_updates, + token_ids, + mm_item_counts, + ) + + return token_ids, text, placeholders + + def _validate_mm_kwargs( + self, + mm_kwargs: MultiModalKwargs, + mm_item_counts: Mapping[str, int], + ) -> None: + for modality, item_count in mm_item_counts.items(): + if modality in mm_kwargs.modalities: + items = mm_kwargs.get_items(modality) + else: + items = [] + + if len(items) != item_count: + raise RuntimeError( + f"Expected there to be {item_count} {modality} items in " + f"keyword arguments corresponding to {item_count} " + f"{modality} data items, but only found {len(items)}! " + "There is likely a problem with your " + "implementation of merged multi-modal processor for this " + "model (usually arising from an inconsistency between " + "`_call_hf_processor` and `_get_mm_fields_config`).") + + def _validate_mm_placeholders( + self, + mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]], + mm_item_counts: Mapping[str, int], + ) -> None: + for modality, item_count in mm_item_counts.items(): + placeholders = mm_placeholders.get(modality, []) + + if len(placeholders) != item_count: + raise RuntimeError( + f"Expected there to be {item_count} prompt updates " + f"corresponding to {item_count} {modality} items, but " + f"instead found {len(placeholders)} prompt updates! " + "Either the prompt text has missing/incorrect tokens for " + "multi-modal inputs, or there is a problem with your " + "implementation of merged multi-modal processor for this " + "model (usually arising from an inconsistency between " + "`_call_hf_processor` and `_get_prompt_updates`).") + + def apply( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + return_mm_hashes: bool = False, + ) -> MultiModalInputs: + """ + Process multi-modal inputs to be used in vLLM. + + The main steps are: + + 1. Apply HF Processor on prompt text and multi-modal data together, + outputting token IDs and processed tensors. + 2. Find and update sequences in the token IDs with placeholder tokens. + The number of placeholder tokens equals the feature size of the + multi-modal data outputted by the multi-modal encoder. + 3. Extract information about the placeholder tokens from the + processed token IDs. + """ + mm_items = self._to_mm_items(mm_data) + + # Create MM hashes to be returned (only used in V1) + # TODO: Use these hash keys for caching operations in apply_hf_processor + # instead of rehashing. + + if return_mm_hashes: + model_id = self.info.model_id + mm_hashes = { + modality: [ + MultiModalHasher.hash_kwargs(model_id=model_id, + **{modality: item}, + **hf_processor_mm_kwargs) + for item in items + ] + for modality, items in mm_items.items() + } + else: + mm_hashes = None + + ( + prompt_ids, + mm_kwargs, + is_update_applied, + ) = self._cached_apply_hf_processor( + prompt, + mm_items, + hf_processor_mm_kwargs, + ) + + unbound_prompt_updates = self._get_prompt_updates( + mm_items, + hf_processor_mm_kwargs, + mm_kwargs, + ) + mm_prompt_updates = self._bind_and_group_updates( + unbound_prompt_updates) + + mm_item_counts = mm_items.get_all_counts() + self._validate_mm_kwargs(mm_kwargs, mm_item_counts) + + if is_update_applied: + mm_placeholders = self._find_mm_placeholders( + mm_prompt_updates, + prompt_ids, + mm_item_counts, + ) + self._validate_mm_placeholders(mm_placeholders, mm_item_counts) + + tokenizer = self.info.get_tokenizer() + prompt = decode_tokens(tokenizer, prompt_ids) + else: + ( + prompt_ids, + prompt, + mm_placeholders, + ) = self._apply_prompt_updates( + prompt_ids, + mm_prompt_updates, + mm_item_counts, + ) + self._validate_mm_placeholders(mm_placeholders, mm_item_counts) + + mm_placeholder_ranges = { + modality: [item.to_range() for item in placeholders] + for modality, placeholders in mm_placeholders.items() + } + + return MultiModalInputs( + type="multimodal", + prompt=prompt, + prompt_token_ids=prompt_ids, + mm_kwargs=mm_kwargs, + mm_hashes=mm_hashes, + mm_placeholders=mm_placeholder_ranges, + ) + + +class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): + + @abstractmethod + def create_encoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + ) -> Union[str, list[int]]: + """ + Create input prompt for the encoder. HF processor will be applied on + this prompt during profiling and generation. + """ + raise NotImplementedError + + def create_decoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + ) -> Union[str, list[int]]: + """Create input prompt for the decoder.""" + return prompt + + def apply( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + return_mm_hashes: bool = False, + ) -> MultiModalEncDecInputs: + """ + Process multi-modal inputs to be used in vLLM. + The main processing steps are modified to fit encoder-decoder model: + 1. Create encoder prompt from input prompt text. + 2. Apply the HF processor on encoder prompt. + 3. Copy the input prompt text as decoder prompt inputs. + """ + encoder_prompt = self.create_encoder_prompt(prompt, mm_data) + encoder_inputs = super().apply( + encoder_prompt, + mm_data, + hf_processor_mm_kwargs, + return_mm_hashes, + ) + + tokenizer = self.info.get_tokenizer() + decoder_prompt = self.create_decoder_prompt(prompt, mm_data) + if isinstance(decoder_prompt, str): + decoder_prompt_ids = encode_tokens(tokenizer, + decoder_prompt, + add_special_tokens=False) + else: + decoder_prompt_ids = decoder_prompt + decoder_prompt = decode_tokens(tokenizer, decoder_prompt) + + mm_inputs = MultiModalEncDecInputs( + encoder_prompt=encoder_inputs["prompt"], + encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"], + **encoder_inputs) + mm_inputs.update({ + "prompt": decoder_prompt, + "prompt_token_ids": decoder_prompt_ids + }) + return mm_inputs diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py new file mode 100644 index 0000000..1df9a1f --- /dev/null +++ b/vllm/multimodal/profiling.py @@ -0,0 +1,251 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from collections.abc import Mapping +from dataclasses import dataclass, field +from typing import Generic, NamedTuple, Optional, TypeVar, cast + +import numpy as np +import numpy.typing as npt +from PIL import Image + +import vllm.envs as envs +from vllm.logger import init_logger + +from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, + MultiModalInputs, MultiModalKwargs, + MultiModalPlaceholderDict) +from .processing import BaseMultiModalProcessor, BaseProcessingInfo + +logger = init_logger(__name__) + + +@dataclass +class ProcessorInputs: + """ + Represents the keyword arguments to + :meth:`vllm.multimodal.processing.BaseMultiModalProcessor.apply`. + """ + prompt_text: str + mm_data: MultiModalDataDict + hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict) + + +class DummyEncoderData(NamedTuple): + """Dummy data used for profiling.""" + + prompt_token_ids: list[int] + + +class DummyDecoderData(NamedTuple): + """Dummy data used for profiling.""" + + prompt_token_ids: list[int] + multi_modal_data: MultiModalKwargs + multi_modal_placeholders: MultiModalPlaceholderDict + + +_I = TypeVar("_I", bound=BaseProcessingInfo) + + +class BaseDummyInputsBuilder(ABC, Generic[_I]): + """ + Abstract base class that constructs the dummy data to profile + multi-modal models. + """ + + def __init__(self, info: _I) -> None: + super().__init__() + + self.info = info + + @abstractmethod + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + """ + Build the input which, after processing, results in + :code:`self.info.get_mm_max_tokens_per_item()` placeholder tokens. + """ + raise NotImplementedError + + def _get_dummy_audios( + self, + *, + length: int, + num_audios: int, + ) -> list[npt.NDArray]: + audio = np.zeros((length, )) + return [audio] * num_audios + + def _get_dummy_images( + self, + *, + width: int, + height: int, + num_images: int, + ) -> list[Image.Image]: + image = Image.new("RGB", (width, height), color=255) + return [image] * num_images + + def _get_dummy_videos( + self, + *, + width: int, + height: int, + num_frames: int, + num_videos: int, + ) -> list[npt.NDArray]: + video = np.full((num_frames, width, height, 3), 255) + return [video] * num_videos + + +class MultiModalProfiler(Generic[_I]): + """ + Contains code for running memory profiling for multi-modal models. + """ + + def __init__( + self, + processor: BaseMultiModalProcessor[_I], + ) -> None: + super().__init__() + + self.processor = processor + + @property + def processing_info(self) -> BaseProcessingInfo: + return self.processor.info + + @property + def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]: + return self.processor.dummy_inputs + + def get_mm_limits(self) -> Mapping[str, int]: + mm_config = self.processing_info.ctx.get_mm_config() + supported_mm_limits = self.processing_info.get_supported_mm_limits() + + mm_limits = { + modality: mm_config.get_limit_per_prompt(modality) + for modality in supported_mm_limits + } + + for modality, supported_limit in supported_mm_limits.items(): + limit = mm_limits[modality] + if supported_limit is not None and supported_limit < limit: + raise ValueError( + f"You set {modality}={limit} (or defaulted to 1) in " + f"`--limit-mm-per-prompt`, but this model only supports " + f"at most {supported_limit} {modality} items.") + + return mm_limits + + def _get_dummy_mm_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalInputs: + factory = self.dummy_inputs + processor_inputs = factory.get_dummy_processor_inputs( + seq_len, mm_counts) + + return self.processor.apply( + prompt=processor_inputs.prompt_text, + mm_data=processor_inputs.mm_data, + hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, + ) + + def get_and_validate_mm_inputs( + self, + seq_len: int, + mm_counts: Optional[Mapping[str, int]] = None, + ) -> tuple[MultiModalInputs, Mapping[str, int]]: + if mm_counts is None: + mm_counts = self.get_mm_limits() + + info = self.processing_info + mm_max_tokens_per_item = info.get_mm_max_tokens_per_item( + seq_len, mm_counts) + + if mm_counts.keys() - mm_max_tokens_per_item.keys(): + raise AssertionError( + "The keys returned by `get_supported_mm_limits` " + f"({set(mm_counts.keys())}) should be a subset of those " + "returned by `get_mm_max_tokens_per_item` " + f"({set(mm_max_tokens_per_item.keys())})") + + mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) + placeholders_by_modality = mm_inputs["mm_placeholders"] + + total_placeholders_by_modality = { + modality: sum(item["length"] for item in placeholders) + for modality, placeholders in placeholders_by_modality.items() + } + expected_placeholders_by_modality = { + modality: mm_max_tokens_per_item[modality] * mm_counts[modality] + for modality in placeholders_by_modality + } + if total_placeholders_by_modality != expected_placeholders_by_modality: + raise AssertionError( + f"The processed dummy data has a total of " + f"{total_placeholders_by_modality} placeholder tokens, which " + f"is not the expected {expected_placeholders_by_modality} " + "tokens.") + return mm_inputs, total_placeholders_by_modality + + def get_encoder_dummy_data( + self, + seq_len: int, + mm_counts: Optional[Mapping[str, int]] = None, + ) -> DummyEncoderData: + mm_inputs, _ = self.get_and_validate_mm_inputs(seq_len, mm_counts) + mm_inputs = cast(MultiModalEncDecInputs, mm_inputs) + + # For encoder-decoder models, use encoder prompt token ids instead of + # decoder prompt to construct dummy seq_data for encoder profiling. + encoder_prompt_token_ids = mm_inputs["encoder_prompt_token_ids"] + + total_len = len(encoder_prompt_token_ids) + num_tokens_to_pad = max(total_len, seq_len) - total_len + encoder_prompt_token_ids.extend([0] * num_tokens_to_pad) + + return DummyEncoderData(encoder_prompt_token_ids) + + def get_decoder_dummy_data( + self, + seq_len: int, + mm_counts: Optional[Mapping[str, int]] = None, + ) -> DummyDecoderData: + ( + mm_inputs, + total_placeholders_by_modality, + ) = self.get_and_validate_mm_inputs(seq_len, mm_counts) + + prompt_token_ids = mm_inputs["prompt_token_ids"] + total_len = len(prompt_token_ids) + + # V0 does not support chunked prefill. + if total_len > seq_len and not envs.VLLM_USE_V1: + # `max_num_batched_tokens` is defined by `SchedulerConfig` + logger.warning( + "The sequence length used for profiling (" + "max_num_batched_tokens / max_num_seqs = %d) is too short " + "to hold the multi-modal embeddings in the worst case " + "(%d tokens in total, out of which %s are reserved for " + "multi-modal embeddings). This may cause certain " + "multi-modal inputs to fail during inference, even when " + "the input text is short. To avoid this, you should " + "increase `max_model_len`, reduce `max_num_seqs`, " + "and/or reduce `mm_counts`.", seq_len, total_len, + total_placeholders_by_modality) + + if total_len < seq_len: + prompt_token_ids.extend([0] * (seq_len - total_len)) + + return DummyDecoderData( + prompt_token_ids=prompt_token_ids, + multi_modal_data=mm_inputs["mm_kwargs"], + multi_modal_placeholders=mm_inputs["mm_placeholders"], + ) diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py new file mode 100644 index 0000000..4f41fa0 --- /dev/null +++ b/vllm/multimodal/registry.py @@ -0,0 +1,503 @@ +# SPDX-License-Identifier: Apache-2.0 + +import functools +from collections import UserDict +from collections.abc import Mapping, Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Generic, Optional, Protocol, TypeVar + +import torch.nn as nn + +from vllm.envs import VLLM_MM_INPUT_CACHE_GIB +from vllm.inputs import InputProcessingContext +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import (AnyTokenizer, + cached_tokenizer_from_config) +from vllm.utils import ClassRegistry + +from .audio import AudioPlugin +from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc +from .image import ImagePlugin +from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors +from .processing import (BaseMultiModalProcessor, BaseProcessingInfo, + ProcessingCache) +from .profiling import (BaseDummyInputsBuilder, DummyDecoderData, + DummyEncoderData, MultiModalProfiler) +from .video import VideoPlugin + +if TYPE_CHECKING: + from vllm.config import ModelConfig + +logger = init_logger(__name__) + +N = TypeVar("N", bound=type[nn.Module]) +_I = TypeVar("_I", bound=BaseProcessingInfo) +_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True) + + +class ProcessingInfoFactory(Protocol[_I_co]): + """Constructs a :class:`MultiModalProcessor` instance from the context.""" + + def __call__( + self, + ctx: InputProcessingContext, + ) -> _I_co: + ... + + +class DummyInputsBuilderFactory(Protocol[_I]): + """ + Constructs a :class:`BaseDummyInputsBuilder` instance from the context. + """ + + def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]: + ... + + +class MultiModalProcessorFactory(Protocol[_I]): + """Constructs a :class:`MultiModalProcessor` instance from the context.""" + + def __call__( + self, + info: _I, + dummy_inputs: BaseDummyInputsBuilder[_I], + *, + cache: Optional[ProcessingCache] = None, + ) -> BaseMultiModalProcessor[_I]: + ... + + +@dataclass(frozen=True) +class _ProcessorFactories(Generic[_I]): + info: ProcessingInfoFactory[_I] + processor: MultiModalProcessorFactory[_I] + dummy_inputs: DummyInputsBuilderFactory[_I] + + def build_processor( + self, + ctx: InputProcessingContext, + *, + cache: Optional[ProcessingCache] = None, + ): + info = self.info(ctx) + dummy_inputs_builder = self.dummy_inputs(info) + return self.processor(info, dummy_inputs_builder, cache=cache) + + +class _MultiModalLimits(UserDict["ModelConfig", dict[str, int]]): + """ + Wraps `_limits_by_model` for a more informative error message + when attempting to access a model that does not exist. + """ + + def __getitem__(self, key: "ModelConfig") -> dict[str, int]: + try: + return super().__getitem__(key) + except KeyError as exc: + msg = (f"Cannot find `mm_limits` for model={key.model}. Did you " + "forget to call `init_mm_limits_per_prompt`?") + raise KeyError(msg) from exc + + +class MultiModalRegistry: + """ + A registry that dispatches data processing according to the model. + """ + + DEFAULT_PLUGINS = (ImagePlugin(), AudioPlugin(), VideoPlugin()) + + def __init__( + self, + *, + plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None: + self._plugins = {p.get_data_key(): p for p in plugins} + + self._processor_factories = ClassRegistry[nn.Module, + _ProcessorFactories]() + + # This is used for non-multimodal models + self._disabled_limits_per_plugin = {k: 0 for k in self._plugins} + + self._limits_by_model = _MultiModalLimits() + + self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_GIB) + + def register_plugin(self, plugin: MultiModalPlugin) -> None: + """ + Register a multi-modal plugin so it can be recognized by vLLM. + """ + data_type_key = plugin.get_data_key() + + if data_type_key in self._plugins: + logger.warning( + "A plugin is already registered for data type %s, " + "and will be overwritten by the new plugin %s.", data_type_key, + plugin) + + self._plugins[data_type_key] = plugin + + def _get_plugin(self, data_type_key: str): + plugin = self._plugins.get(data_type_key) + if plugin is not None: + return plugin + + msg = f"Unknown multi-modal data type: {data_type_key}" + raise NotImplementedError(msg) + + def register_input_mapper( + self, + data_type_key: str, + mapper: Optional[MultiModalInputMapper] = None, + ): + """ + Register an input mapper for a specific modality to a model class. + + See :meth:`MultiModalPlugin.register_input_mapper` for more details. + """ + return self._get_plugin(data_type_key).register_input_mapper(mapper) + + def register_image_input_mapper( + self, + mapper: Optional[MultiModalInputMapper] = None, + ): + """ + Register an input mapper for image data to a model class. + + See :meth:`MultiModalPlugin.register_input_mapper` for more details. + """ + return self.register_input_mapper("image", mapper) + + def map_input( + self, + model_config: "ModelConfig", + data: MultiModalDataDict, + mm_processor_kwargs: Optional[dict[str, Any]] = None, + ) -> MultiModalKwargs: + """ + Apply an input mapper to the data passed to the model. + + The data belonging to each modality is passed to the corresponding + plugin which in turn converts the data into into keyword arguments + via the input mapper registered for that model. + + See :meth:`MultiModalPlugin.map_input` for more details. + + Note: + This should be called after :meth:`init_mm_limits_per_prompt`. + """ + merged_dict = dict[str, NestedTensors]() + + for data_key, data_value in data.items(): + plugin = self._get_plugin(data_key) + + num_items = len(data_value) if isinstance(data_value, list) else 1 + max_items = self._limits_by_model[model_config][data_key] + if num_items > max_items: + raise ValueError( + f"You set {data_key}={max_items} (or defaulted to 1) in " + f"`--limit-mm-per-prompt`, but found {num_items} items " + "in the same prompt.") + + input_dict = plugin.map_input(model_config, data_value, + mm_processor_kwargs) + for input_key, input_tensor in input_dict.items(): + if input_key in merged_dict: + raise ValueError(f"The input mappers (keys={set(data)}) " + f"resulted in a conflicting keyword " + f"argument to `forward()`: {input_key}") + + merged_dict[input_key] = input_tensor + + return MultiModalKwargs(merged_dict) + + def create_input_mapper(self, model_config: "ModelConfig"): + """ + Create an input mapper (see :meth:`map_input`) for a specific model. + """ + # NOTE - we currently make the assumption that if a model has multiple + # supported modalities, they take the same kwargs. For the default, + # this could be an issue in the future if it falls back to two HF + # resources and we can't inspect the signature easily since it's + # getting initialized through the autoclass. + # + # If this is a problem in the future, we should revisit it, but since + # it potentially introduces a lot of complexity for a currently + # uncommon case, we do not for simplicity of both use & implementation + return functools.partial(self.map_input, model_config) + + def register_max_multimodal_tokens( + self, + data_type_key: str, + max_mm_tokens: Optional[MultiModalTokensCalc] = None, + ): + """ + Register the maximum number of tokens, corresponding to a single + instance of multimodal data belonging to a specific modality, that are + passed to the language model for a model class. + """ + return self._get_plugin(data_type_key) \ + .register_max_multimodal_tokens(max_mm_tokens) + + def register_max_image_tokens( + self, + max_mm_tokens: Optional[MultiModalTokensCalc] = None, + ): + """ + Register the maximum number of image tokens, corresponding to a single + image, that are passed to the language model for a model class. + """ + return self.register_max_multimodal_tokens("image", max_mm_tokens) + + def get_max_tokens_per_item_by_modality( + self, + model_config: "ModelConfig", + ) -> Mapping[str, int]: + """ + Get the maximum number of tokens per data item from each modality based + on underlying model configuration. + """ + if self.has_processor(model_config): + processor = self.create_processor(model_config, disable_cache=True) + seq_len = model_config.max_model_len + mm_limits = self.get_mm_limits_per_prompt(model_config) + return processor.info.get_mm_max_tokens_per_item( + seq_len, mm_limits) + + return { + key: plugin.get_max_multimodal_tokens(model_config) + for key, plugin in self._plugins.items() + } + + def get_max_tokens_per_item_by_nonzero_modality( + self, + model_config: "ModelConfig", + ) -> Mapping[str, int]: + """ + Get the maximum number of tokens per data item from each modality based + on underlying model configuration, excluding modalities that user + explicitly disabled via `limit_mm_per_prompt`. + + Note: + This is currently directly used only in V1 for profiling the memory + usage of a model. + """ + mm_limits = self.get_mm_limits_per_prompt(model_config) + + return { + key: max_tokens_per_mm_item + for key, max_tokens_per_mm_item in + self.get_max_tokens_per_item_by_modality(model_config).items() + if mm_limits[key] > 0 + } + + def get_max_tokens_by_modality( + self, + model_config: "ModelConfig", + ) -> Mapping[str, int]: + """ + Get the maximum number of tokens from each modality + for profiling the memory usage of a model. + + See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details. + + Note: + This should be called after :meth:`init_mm_limits_per_prompt`. + """ + mm_limits = self.get_mm_limits_per_prompt(model_config) + + return { + key: mm_limits[key] * max_tokens_per_mm_item + for key, max_tokens_per_mm_item in + self.get_max_tokens_per_item_by_modality(model_config).items() + } + + def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: + """ + Get the maximum number of multi-modal tokens + for profiling the memory usage of a model. + + See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details. + + Note: + This should be called after :meth:`init_mm_limits_per_prompt`. + """ + return sum(self.get_max_tokens_by_modality(model_config).values()) + + def init_mm_limits_per_prompt( + self, + model_config: "ModelConfig", + ) -> None: + """ + Initialize the maximum number of multi-modal input instances for each + modality that are allowed per prompt for a model class. + """ + if model_config in self._limits_by_model: + logger.warning( + "`mm_limits` has already been set for model=%s, and will " + "be overwritten by the new values.", model_config.model) + + multimodal_config = model_config.multimodal_config + if multimodal_config is None: + limits_per_plugin = self._disabled_limits_per_plugin + else: + config_limits_per_plugin = multimodal_config.limit_per_prompt + + extra_keys = config_limits_per_plugin.keys() - self._plugins.keys() + if extra_keys: + logger.warning( + "Detected extra keys in `--limit-mm-per-prompt` which " + "are not registered as multi-modal plugins: %s. " + "They will be ignored.", extra_keys) + + # NOTE: Currently the default is set to 1 for each plugin + # TODO: Automatically determine the limits based on budget + # once more models support multi-image inputs + limits_per_plugin = { + key: multimodal_config.get_limit_per_prompt(key) + for key in self._plugins + } + + self._limits_by_model[model_config] = limits_per_plugin + + def get_mm_limits_per_prompt( + self, + model_config: "ModelConfig", + ) -> Mapping[str, int]: + """ + Get the maximum number of multi-modal input instances for each modality + that are allowed per prompt for a model class. + + Note: + This should be called after :meth:`init_mm_limits_per_prompt`. + """ + if self.has_processor(model_config): + processor = self.create_processor(model_config, disable_cache=True) + profiler = MultiModalProfiler(processor) + return profiler.get_mm_limits() + + return self._limits_by_model[model_config] + + def register_processor( + self, + processor: MultiModalProcessorFactory[_I], + *, + info: ProcessingInfoFactory[_I], + dummy_inputs: DummyInputsBuilderFactory[_I], + ): + """ + Register a multi-modal processor to a model class. The processor + is constructed lazily, hence a factory method should be passed. + + When the model receives multi-modal data, the provided function is + invoked to transform the data into a dictionary of model inputs. + + See also: + :ref:`mm-processing` + """ + + def wrapper(model_cls: N) -> N: + if self._processor_factories.contains(model_cls, strict=True): + logger.warning( + "Model class %s already has a multi-modal processor " + "registered to %s. It is overwritten by the new one.", + model_cls, self) + + self._processor_factories[model_cls] = _ProcessorFactories( + info=info, + dummy_inputs=dummy_inputs, + processor=processor, + ) + + return model_cls + + return wrapper + + def _get_model_cls(self, model_config: "ModelConfig"): + # Avoid circular import + from vllm.model_executor.model_loader import get_model_architecture + + model_cls, _ = get_model_architecture(model_config) + return model_cls + + def has_processor(self, model_config: "ModelConfig") -> bool: + """ + Test whether a multi-modal processor is defined for a specific model. + + See also: + :ref:`mm-processing` + """ + return self._get_model_cls(model_config) in self._processor_factories + + def create_processor( + self, + model_config: "ModelConfig", + *, + tokenizer: Optional[AnyTokenizer] = None, + disable_cache: Optional[bool] = None, + ) -> BaseMultiModalProcessor[BaseProcessingInfo]: + """ + Create a multi-modal processor for a specific model and tokenizer. + + See also: + :ref:`mm-processing` + """ + if tokenizer is None: + tokenizer = cached_tokenizer_from_config(model_config) + if disable_cache is None: + disable_cache = model_config.disable_mm_preprocessor_cache + + model_cls = self._get_model_cls(model_config) + factories = self._processor_factories[model_cls] + + ctx = InputProcessingContext(model_config, tokenizer) + cache = None if disable_cache else self._processing_cache + + return factories.build_processor(ctx, cache=cache) + + def get_decoder_dummy_data( + self, + model_config: "ModelConfig", + seq_len: int, + mm_counts: Optional[Mapping[str, int]] = None, + ) -> DummyDecoderData: + """ + Create dummy data for profiling the memory usage of a model. + + The model is identified by ``model_config``. + """ + processor = self.create_processor(model_config, disable_cache=True) + profiler = MultiModalProfiler(processor) + dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts) + + # Having more tokens is over-conservative but otherwise fine + token_ids = dummy_data.prompt_token_ids + if len(token_ids) < seq_len: + raise AssertionError( + f"Expected at least {seq_len} dummy tokens for profiling, " + f"but found {len(token_ids)} tokens instead.") + + return dummy_data + + def get_encoder_dummy_data( + self, + model_config: "ModelConfig", + seq_len: int, + mm_counts: Optional[Mapping[str, int]] = None, + ) -> DummyEncoderData: + """ + Create dummy data for profiling the memory usage of a model. + + The model is identified by ``model_config``. + """ + processor = self.create_processor(model_config, disable_cache=True) + profiler = MultiModalProfiler(processor) + dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts) + + # Having more tokens is over-conservative but otherwise fine + token_ids = dummy_data.prompt_token_ids + if len(token_ids) < seq_len: + logger.warning_once( + f"Expected at least {seq_len} dummy encoder tokens for " + f"profiling, but found {len(token_ids)} tokens instead.") + + return dummy_data diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py new file mode 100644 index 0000000..fc0fb89 --- /dev/null +++ b/vllm/multimodal/utils.py @@ -0,0 +1,386 @@ +# SPDX-License-Identifier: Apache-2.0 + +from itertools import groupby +from pathlib import Path +from typing import TYPE_CHECKING, Optional, TypeVar, Union +from urllib.parse import ParseResult, urlparse + +import numpy as np +import numpy.typing as npt +import torch +from PIL import Image + +import vllm.envs as envs +from vllm.connections import HTTPConnection, global_http_connection + +from .audio import AudioMediaIO +from .base import MediaIO +from .image import ImageEmbeddingMediaIO, ImageMediaIO +from .inputs import PlaceholderRange +from .video import VideoMediaIO + +_M = TypeVar("_M") + +if TYPE_CHECKING: + from .hasher import MultiModalHashDict + from .inputs import MultiModalKwargs, MultiModalPlaceholderDict + + +class MediaConnector: + + def __init__( + self, + connection: HTTPConnection = global_http_connection, + *, + allowed_local_media_path: str = "", + ) -> None: + super().__init__() + + self.connection = connection + + if allowed_local_media_path: + allowed_local_media_path_ = Path(allowed_local_media_path) + + if not allowed_local_media_path_.exists(): + raise ValueError( + "Invalid `--allowed-local-media-path`: The path " + f"{allowed_local_media_path_} does not exist.") + if not allowed_local_media_path_.is_dir(): + raise ValueError( + "Invalid `--allowed-local-media-path`: The path " + f"{allowed_local_media_path_} must be a directory.") + else: + allowed_local_media_path_ = None + + self.allowed_local_media_path = allowed_local_media_path_ + + def _load_data_url( + self, + url_spec: ParseResult, + media_io: MediaIO[_M], + ) -> _M: + data_spec, data = url_spec.path.split(",", 1) + media_type, data_type = data_spec.split(";", 1) + + if data_type != "base64": + msg = "Only base64 data URLs are supported for now." + raise NotImplementedError(msg) + + return media_io.load_base64(media_type, data) + + def _load_file_url( + self, + url_spec: ParseResult, + media_io: MediaIO[_M], + ) -> _M: + allowed_local_media_path = self.allowed_local_media_path + if allowed_local_media_path is None: + raise RuntimeError("Cannot load local files without " + "`--allowed-local-media-path`.") + + filepath = Path(url_spec.path) + if allowed_local_media_path not in filepath.resolve().parents: + raise ValueError( + f"The file path {filepath} must be a subpath " + f"of `--allowed-local-media-path` {allowed_local_media_path}.") + + return media_io.load_file(filepath) + + def load_from_url( + self, + url: str, + media_io: MediaIO[_M], + *, + fetch_timeout: Optional[int] = None, + ) -> _M: + url_spec = urlparse(url) + + if url_spec.scheme.startswith("http"): + connection = self.connection + data = connection.get_bytes(url, timeout=fetch_timeout) + + return media_io.load_bytes(data) + + if url_spec.scheme == "data": + return self._load_data_url(url_spec, media_io) + + if url_spec.scheme == "file": + return self._load_file_url(url_spec, media_io) + + msg = "The URL must be either a HTTP, data or file URL." + raise ValueError(msg) + + async def load_from_url_async( + self, + url: str, + media_io: MediaIO[_M], + *, + fetch_timeout: Optional[int] = None, + ) -> _M: + url_spec = urlparse(url) + + if url_spec.scheme.startswith("http"): + connection = self.connection + data = await connection.async_get_bytes(url, timeout=fetch_timeout) + + return media_io.load_bytes(data) + + if url_spec.scheme == "data": + return self._load_data_url(url_spec, media_io) + + if url_spec.scheme == "file": + return self._load_file_url(url_spec, media_io) + + msg = "The URL must be either a HTTP, data or file URL." + raise ValueError(msg) + + def fetch_audio( + self, + audio_url: str, + ) -> tuple[np.ndarray, Union[int, float]]: + """ + Load audio from a URL. + """ + audio_io = AudioMediaIO() + + return self.load_from_url( + audio_url, + audio_io, + fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT, + ) + + async def fetch_audio_async( + self, + audio_url: str, + ) -> tuple[np.ndarray, Union[int, float]]: + """ + Asynchronously fetch audio from a URL. + """ + audio_io = AudioMediaIO() + + return await self.load_from_url_async( + audio_url, + audio_io, + fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT, + ) + + def fetch_image( + self, + image_url: str, + *, + image_mode: str = "RGB", + ) -> Image.Image: + """ + Load a PIL image from a HTTP or base64 data URL. + + By default, the image is converted into RGB format. + """ + image_io = ImageMediaIO(image_mode=image_mode) + + return self.load_from_url( + image_url, + image_io, + fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT, + ) + + async def fetch_image_async( + self, + image_url: str, + *, + image_mode: str = "RGB", + ) -> Image.Image: + """ + Asynchronously load a PIL image from a HTTP or base64 data URL. + + By default, the image is converted into RGB format. + """ + image_io = ImageMediaIO(image_mode=image_mode) + + return await self.load_from_url_async( + image_url, + image_io, + fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT, + ) + + def fetch_video( + self, + video_url: str, + *, + image_mode: str = "RGB", + num_frames: int = 32, + ) -> npt.NDArray: + """ + Load video from a HTTP or base64 data URL. + """ + image_io = ImageMediaIO(image_mode=image_mode) + video_io = VideoMediaIO(image_io, num_frames=num_frames) + + return self.load_from_url( + video_url, + video_io, + fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT, + ) + + async def fetch_video_async( + self, + video_url: str, + *, + image_mode: str = "RGB", + num_frames: int = 32, + ) -> npt.NDArray: + """ + Asynchronously load video from a HTTP or base64 data URL. + + By default, the image is converted into RGB format. + """ + image_io = ImageMediaIO(image_mode=image_mode) + video_io = VideoMediaIO(image_io, num_frames=num_frames) + + return await self.load_from_url_async( + video_url, + video_io, + fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT, + ) + + def fetch_image_embedding( + self, + data: str, + ) -> torch.Tensor: + """ + Load image embedding from a URL. + """ + image_embedding_io = ImageEmbeddingMediaIO() + + return image_embedding_io.load_base64("", data) + + +global_media_connector = MediaConnector() +"""The global :class:`MediaConnector` instance used by vLLM.""" + +fetch_audio = global_media_connector.fetch_audio +fetch_image = global_media_connector.fetch_image +fetch_video = global_media_connector.fetch_video + + +def encode_audio_base64( + audio: np.ndarray, + sampling_rate: int, +) -> str: + """Encode audio as base64.""" + audio_io = AudioMediaIO() + return audio_io.encode_base64((audio, sampling_rate)) + + +def encode_image_base64( + image: Image.Image, + *, + image_mode: str = "RGB", + format: str = "JPEG", +) -> str: + """ + Encode a pillow image to base64 format. + + By default, the image is converted into RGB format before being encoded. + """ + image_io = ImageMediaIO(image_mode=image_mode) + return image_io.encode_base64(image, image_format=format) + + +def encode_video_base64(frames: npt.NDArray) -> str: + image_io = ImageMediaIO() + video_io = VideoMediaIO(image_io) + return video_io.encode_base64(frames) + + +def merge_and_sort_multimodal_metadata( + mm_positions: "MultiModalPlaceholderDict", + mm_hashes: Optional["MultiModalHashDict"], +) -> tuple[list[str], list[PlaceholderRange], Optional[list[str]]]: + """Given a MultiModalPlaceholderDict, merge all PlaceholderRange + objects from all available modalities into a single list of + PlaceholderRange, sorted by their offset (starting index in the input + sequence) in the ascending order. + + Optionally if a MultiModalHashDict is given, same operation will be + applied to the object and the sorted list of hashes will be returned. + + Returns: + list[str]: List of item modalities in order of their positions in + the input sequence. + list[PlaceholderRange]: Sorted list of all PlaceholdeRanges from + mm_positions. + Optional[list[str]]: Sorted list of all hashes from mm_hashes if + given, None otherwise. + """ + + modalities = list(mm_positions.keys()) + + assert len(modalities) > 0, "No modalities found in the mm_positions." + + # For single modality, placeholder ranges and hashes are already sorted + # so we can return the list directly. + if len(modalities) == 1: + modality = modalities[0] + placeholder_list = list(mm_positions[modality]) + + return [modality] * len( + placeholder_list + ), placeholder_list, None if not mm_hashes else mm_hashes[modality] + + # Create a list of (modality, placeholder, hash) tuples for all placeholders + all_items = [] + for modality in modalities: + placeholder_list = list(mm_positions[modality]) + hash_list: list[Optional[str]] = list( + mm_hashes[modality]) if mm_hashes and modality in mm_hashes else [ + None + ] * len(placeholder_list) + + for placeholder, hash_value in zip(placeholder_list, hash_list): + all_items.append((modality, placeholder, hash_value)) + + # Sort all items by offset + all_items.sort(key=lambda x: x[1]['offset']) + + # Split into separate lists + sorted_modalities = [item[0] for item in all_items] + merged_placeholders = [item[1] for item in all_items] + merged_hashes = [str(item[2]) + for item in all_items] if mm_hashes is not None else None + + return sorted_modalities, merged_placeholders, merged_hashes + + +def group_mm_inputs_by_modality( + mm_inputs: list["MultiModalKwargs"]) -> list[list["MultiModalKwargs"]]: + """Group consecutive MultiModalKwargs from mm_inputs with the same modality + together into the same list for batching purpose. For MultiModalKwargs with + multiple modalities, put them into their own list. + + Args: + mm_inputs: List of MultiModalKwargs. + + Returns: + list[list[MultiModalKwargs]]: List of list of MultiModalKwargs, each + inner list contains consecutive MultiModalKwargs with same modality. + """ + if not mm_inputs: + return [] + + def modality_group_func(mm_input: "MultiModalKwargs") -> Union[str, int]: + # If the input has multiple modalities, return a id as the unique key + # for the mm_input input. + if len(mm_input.modalities) > 1: + return id(mm_input) + + elif len(mm_input.modalities) == 1: + return list(mm_input.modalities)[0] + + # FIXME(Isotr0py): Modality of mm_input from legacy pipeline is empty, + # this is used to make InternVL with legacy pipeline still work with v1. + else: + return "" + + return [ + list(group) for _, group in groupby(mm_inputs, key=modality_group_func) + ] diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py new file mode 100644 index 0000000..f7c3f10 --- /dev/null +++ b/vllm/multimodal/video.py @@ -0,0 +1,233 @@ +# SPDX-License-Identifier: Apache-2.0 + +import base64 +from functools import partial +from io import BytesIO +from pathlib import Path +from typing import TYPE_CHECKING, Any, Optional + +import numpy as np +import numpy.typing as npt +from PIL import Image + +from vllm.inputs.registry import InputContext +from vllm.logger import init_logger +from vllm.transformers_utils.processor import cached_get_video_processor +from vllm.utils import is_list_of + +from .base import MediaIO, ModalityData +from .image import ImageMediaIO, ImagePlugin +from .inputs import MultiModalKwargs, VideoItem + +if TYPE_CHECKING: + from vllm.config import ModelConfig + +logger = init_logger(__name__) + + +class VideoPlugin(ImagePlugin): + """Plugin for video data.""" + + def get_data_key(self) -> str: + return "video" + + def _get_hf_video_processor( + self, + model_config: "ModelConfig", + mm_processor_kwargs: Optional[dict[str, Any]] = None, + ): + if mm_processor_kwargs is None: + mm_processor_kwargs = {} + return cached_get_video_processor( + model_config.model, + trust_remote_code=model_config.trust_remote_code, + **mm_processor_kwargs) + + def _default_input_mapper( + self, + ctx: InputContext, + data: ModalityData[VideoItem], + **mm_processor_kwargs, + ) -> MultiModalKwargs: + model_config = ctx.model_config + + if isinstance(data, list) and len(data) == 1: + data = data[0] # type: ignore + + if isinstance(data, np.ndarray) or is_list_of(data, np.ndarray): + video_processor = self._get_hf_video_processor( + model_config, + mm_processor_kwargs, + ) + if video_processor is None: + raise RuntimeError("No HuggingFace processor is available " + "to process the video object") + try: + # NOTE: Similar to image; it may be a good idea to filter and + # pass mm_processor_kwargs here too, but for now we don't to + # avoid extra complexity if the initializer and preprocess + # signatures of the processor don't align + batch_data = video_processor(data, return_tensors="pt").data + except Exception: + logger.error("Failed to process video (%s)", data) + raise + + return MultiModalKwargs(batch_data) + + raise TypeError(f"Invalid video type: {type(data)}") + + def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: + return 4096 + + +def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray: + num_frames, _, _, channels = frames.shape + new_height, new_width = size + resized_frames = np.empty((num_frames, new_height, new_width, channels), + dtype=frames.dtype) + # lazy import cv2 to avoid bothering users who only use text models + import cv2 + for i, frame in enumerate(frames): + resized_frame = cv2.resize(frame, (new_width, new_height)) + resized_frames[i] = resized_frame + return resized_frames + + +def rescale_video_size(frames: npt.NDArray, size_factor: float) -> npt.NDArray: + _, height, width, _ = frames.shape + new_height = int(height * size_factor) + new_width = int(width * size_factor) + + return resize_video(frames, (new_height, new_width)) + + +def sample_frames_from_video(frames: npt.NDArray, + num_frames: int) -> npt.NDArray: + total_frames = frames.shape[0] + if num_frames == -1: + return frames + + frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) + sampled_frames = frames[frame_indices, ...] + return sampled_frames + + +class VideoLoader: + + @classmethod + def load_bytes(self, data: bytes, num_frames: int = -1) -> npt.NDArray: + raise NotImplementedError + + +class OpenCVVideoBackend(VideoLoader): + + def get_cv2_video_api(self): + import cv2.videoio_registry as vr + + api_pref = None + for backend in vr.getStreamBufferedBackends(): + if not vr.hasBackend(backend): + continue + if not vr.isBackendBuiltIn(backend): + _, abi, api = vr.getStreamBufferedBackendPluginVersion(backend) + if (abi < 1 or (abi == 1 and api < 2)): + continue + api_pref = backend + break + return api_pref + + @classmethod + def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray: + import cv2 + + backend = cls().get_cv2_video_api() + cap = cv2.VideoCapture(BytesIO(data), backend, []) + if not cap.isOpened(): + raise ValueError("Could not open video stream") + + total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + full_read = num_frames == -1 or total_frames_num < num_frames + if full_read: + frame_idx = list(range(0, total_frames_num)) + else: + uniform_sampled_frames = np.linspace(0, + total_frames_num - 1, + num_frames, + dtype=int) + frame_idx = uniform_sampled_frames.tolist() + + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + frames = np.empty((len(frame_idx), height, width, 3), dtype=np.uint8) + + i = 0 + for idx in range(total_frames_num): + ok = cap.grab() # next img + if not ok: + break + if idx in frame_idx: # only decompress needed + ret, frame = cap.retrieve() + if ret: + frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + i += 1 + # we expect all frames loaded + assert i == num_frames + return frames + + +class VideoMediaIO(MediaIO[npt.NDArray]): + + def __init__( + self, + image_io: ImageMediaIO, + *, + num_frames: int = 32, + ) -> None: + super().__init__() + + self.image_io = image_io + self.num_frames = num_frames + self.video_loader = OpenCVVideoBackend + + def load_bytes(self, data: bytes) -> npt.NDArray: + return self.video_loader.load_bytes(data, self.num_frames) + + def load_base64(self, media_type: str, data: str) -> npt.NDArray: + if media_type.lower() == "video/jpeg": + load_frame = partial( + self.image_io.load_base64, + "image/jpeg", + ) + + return np.stack([ + np.array(load_frame(frame_data)) + for frame_data in data.split(",") + ]) + + return self.load_bytes(base64.b64decode(data)) + + def load_file(self, filepath: Path) -> npt.NDArray: + with filepath.open("rb") as f: + data = f.read() + + return self.load_bytes(data) + + def encode_base64( + self, + media: npt.NDArray, + *, + video_format: str = "JPEG", + ) -> str: + video = media + + if video_format == "JPEG": + encode_frame = partial( + self.image_io.encode_base64, + image_format=video_format, + ) + + return ",".join( + encode_frame(Image.fromarray(frame)) for frame in video) + + msg = "Only JPEG format is supported for now." + raise NotImplementedError(msg) diff --git a/vllm/outputs.py b/vllm/outputs.py new file mode 100644 index 0000000..014e8d5 --- /dev/null +++ b/vllm/outputs.py @@ -0,0 +1,508 @@ +# SPDX-License-Identifier: Apache-2.0 + +import time +from collections.abc import MutableSequence +from collections.abc import Sequence as GenericSequence +from dataclasses import dataclass +from typing import Generic, Optional, Union + +import torch +from typing_extensions import TypeVar, deprecated + +from vllm.lora.request import LoRARequest +from vllm.multimodal.inputs import MultiModalPlaceholderDict +from vllm.sampling_params import RequestOutputKind +from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, + SequenceGroup, SequenceGroupBase, SequenceStatus) + + +@dataclass +class CompletionOutput: + """The output data of one completion output of a request. + + Args: + index: The index of the output in the request. + text: The generated output text. + token_ids: The token IDs of the generated output text. + cumulative_logprob: The cumulative log probability of the generated + output text. + logprobs: The log probabilities of the top probability words at each + position if the logprobs are requested. + finish_reason: The reason why the sequence is finished. + stop_reason: The stop string or token id that caused the completion + to stop, None if the completion finished for some other reason + including encountering the EOS token. + lora_request: The LoRA request that was used to generate the output. + """ + + index: int + text: str + token_ids: GenericSequence[int] + cumulative_logprob: Optional[float] + logprobs: Optional[SampleLogprobs] + finish_reason: Optional[str] = None + stop_reason: Union[int, str, None] = None + lora_request: Optional[LoRARequest] = None + + def finished(self) -> bool: + return self.finish_reason is not None + + def __repr__(self) -> str: + return (f"CompletionOutput(index={self.index}, " + f"text={self.text!r}, " + f"token_ids={self.token_ids}, " + f"cumulative_logprob={self.cumulative_logprob}, " + f"logprobs={self.logprobs}, " + f"finish_reason={self.finish_reason}, " + f"stop_reason={self.stop_reason})") + + +@dataclass +class PoolingOutput: + """The output data of one pooling output of a request. + + Args: + data: The extracted hidden states. + """ + data: torch.Tensor + + def __repr__(self) -> str: + return (f"PoolingOutput(data={self.data})") + + def __eq__(self, other: object) -> bool: + return (isinstance(other, self.__class__) and bool( + (self.data == other.data).all())) + + @property + @deprecated("`LLM.encode()` now stores raw outputs in the `data` " + "attribute. To return embeddings, use `LLM.embed()`. " + "To return class probabilities, use `LLM.classify()` " + "and access the `probs` attribute. ") + def embedding(self) -> list[float]: + return self.data.tolist() + + +class RequestOutput: + """The output data of a completion request to the LLM. + + Args: + request_id: The unique ID of the request. + prompt: The prompt string of the request. + For encoder/decoder models, this is the + decoder input prompt. + prompt_token_ids: The token IDs of the prompt. + For encoder/decoder models, this is the + decoder input prompt token ids. + prompt_logprobs: The log probabilities to return per prompt token. + outputs: The output sequences of the request. + finished: Whether the whole request is finished. + metrics: Metrics associated with the request. + lora_request: The LoRA request that was used to generate the output. + encoder_prompt: The encoder prompt string of the request. + None if decoder-only. + encoder_prompt_token_ids: The token IDs of the encoder prompt. + None if decoder-only. + num_cached_tokens: The number of tokens with prefix cache hit. + """ + + def __init__( + self, + request_id: str, + prompt: Optional[str], + prompt_token_ids: Optional[list[int]], + prompt_logprobs: Optional[PromptLogprobs], + outputs: list[CompletionOutput], + finished: bool, + metrics: Optional[RequestMetrics] = None, + lora_request: Optional[LoRARequest] = None, + encoder_prompt: Optional[str] = None, + encoder_prompt_token_ids: Optional[list[int]] = None, + num_cached_tokens: Optional[int] = None, + *, + multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None, + ) -> None: + self.request_id = request_id + self.prompt = prompt + self.prompt_token_ids = prompt_token_ids + self.multi_modal_placeholders = multi_modal_placeholders or {} + self.prompt_logprobs = prompt_logprobs + self.outputs = outputs + self.finished = finished + self.metrics = metrics + self.lora_request = lora_request + self.encoder_prompt = encoder_prompt + self.encoder_prompt_token_ids = encoder_prompt_token_ids + self.num_cached_tokens = num_cached_tokens + + def add(self, next_output: "RequestOutput") -> None: + """Merge subsequent RequestOutput into this one""" + + self.finished |= next_output.finished + + for next_completion in next_output.outputs: + for completion in self.outputs: + if completion.index == next_completion.index: + # Merge outputs with same index + completion.text += next_completion.text + if not isinstance(completion.token_ids, MutableSequence): + completion.token_ids = list(completion.token_ids) + completion.token_ids.extend(next_completion.token_ids) + if next_completion.logprobs: + assert completion.logprobs is not None + completion.logprobs.extend(next_completion.logprobs) + completion.cumulative_logprob = ( + next_completion.cumulative_logprob) + completion.finish_reason = next_completion.finish_reason + completion.stop_reason = next_completion.stop_reason + break + else: + self.outputs.append(next_completion) + + @classmethod + def from_seq_group( + cls, seq_group: SequenceGroup, use_cache: bool, + seq_id_to_seq_group: dict[str, SequenceGroupBase] + ) -> Optional["RequestOutput"]: + finished = seq_group.is_finished() + + if seq_group.request_id in seq_id_to_seq_group: + group: SequenceGroupBase = seq_id_to_seq_group[ + seq_group.request_id] + assembled_seq_group = group.maybe_assemble_group(seq_group) + if finished: + group.finish_seq(seq_group) + if assembled_seq_group is None: + return None + return cls.from_seq_group(assembled_seq_group, use_cache, + seq_id_to_seq_group) + + sampling_params = seq_group.sampling_params + if sampling_params is None: + raise ValueError( + "Sampling parameters are missing for a CompletionRequest.") + + if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and ( + not finished): + return None + + # Init cache (if needed) + if use_cache and seq_group.cached_request_output is None: + seq_group.cached_request_output = RequestOutput( # type: ignore + request_id="", + prompt=None, + prompt_token_ids=[], + prompt_logprobs=None, + outputs=[], + finished=False) + + top_n_seqs = seq_group.get_seqs() + + # Create the outputs. + # NOTE: We need omit logprobs here explicitly because the sequence + # always has the logprobs of the sampled tokens even if the + # logprobs are not requested. + include_logprobs = sampling_params.logprobs is not None + text_buffer_length = sampling_params.output_text_buffer_length + delta = sampling_params.output_kind == RequestOutputKind.DELTA + + outputs = [] + include_prompt = True + # num_cached_tokens should be the same for all the sequences + num_cached_tokens = None + for i, seq in enumerate(top_n_seqs): + output_text = seq.get_output_text_to_return( + text_buffer_length, delta) + + output_token_ids = seq.get_output_token_ids_to_return(delta) + num_output_tokens = 1 if isinstance(output_token_ids, + int) else len(output_token_ids) + num_cached_tokens = seq.data.get_num_cached_tokens() + + output_logprobs = seq.output_logprobs if include_logprobs else None + + if delta: + # Slice logprobs delta if applicable + if output_logprobs: + # num_output_tokens can be 0 when n > 1 and request finishes + # before the others + if num_output_tokens > 0: + output_logprobs = output_logprobs[-num_output_tokens:] + else: + output_logprobs = None + # Don't include prompt if this is after the first output + # containing decode token ids + if include_prompt and seq.get_output_len() > num_output_tokens: + include_prompt = False + + if use_cache: + # Get cached output object + cached_outputs = seq_group.cached_request_output.outputs # type: ignore + if i >= len(cached_outputs): + cached_outputs.append( + CompletionOutput(index=i, + text="", + token_ids=[], + cumulative_logprob=None, + logprobs=None, + finish_reason=None, + stop_reason=None)) + output = cached_outputs[i] + + # Init cached output object + assert output.index == i + output.text = output_text + + if isinstance(output_token_ids, int): + output.token_ids.clear() + output.token_ids.append(output_token_ids) + else: + output.token_ids = output_token_ids + + output.cumulative_logprob = seq.get_cumulative_logprob() \ + if include_logprobs else None + output.logprobs = output_logprobs + output.finish_reason = SequenceStatus.get_finished_reason( + seq.status) + output.stop_reason = seq.stop_reason + + else: + output = CompletionOutput( + top_n_seqs.index(seq), output_text, [output_token_ids] + if isinstance(output_token_ids, int) else output_token_ids, + seq.get_cumulative_logprob() if include_logprobs else None, + output_logprobs, + SequenceStatus.get_finished_reason(seq.status), + seq.stop_reason) + + outputs.append(output) + + # Every sequence in the sequence group should have the same prompt. + if include_prompt: + prompt = seq_group.prompt + prompt_token_ids = seq_group.prompt_token_ids + encoder_prompt = seq_group.encoder_prompt + encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids + prompt_logprobs = seq_group.prompt_logprobs + else: + prompt = None + prompt_token_ids = None + encoder_prompt = None + encoder_prompt_token_ids = None + prompt_logprobs = None + finished_time = time.time() if finished else None + seq_group.set_finished_time(finished_time) + + init_kwargs = { + "request_id": seq_group.request_id, + "prompt": prompt, + "prompt_token_ids": prompt_token_ids, + "prompt_logprobs": prompt_logprobs, + "outputs": outputs, + "finished": finished, + "metrics": seq_group.metrics, + "lora_request": seq_group.lora_request, + "encoder_prompt": encoder_prompt, + "encoder_prompt_token_ids": encoder_prompt_token_ids, + "num_cached_tokens": num_cached_tokens, + "multi_modal_placeholders": seq_group.multi_modal_placeholders + } + + if use_cache: + request_output = seq_group.cached_request_output + request_output.__init__(**init_kwargs) # type: ignore + else: + request_output = cls(**init_kwargs) # type: ignore + + return request_output + + def __repr__(self) -> str: + return (f"RequestOutput(request_id={self.request_id}, " + f"prompt={self.prompt!r}, " + f"prompt_token_ids={self.prompt_token_ids}, " + f"encoder_prompt={self.encoder_prompt!r}, " + f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, " + f"prompt_logprobs={self.prompt_logprobs}, " + f"outputs={self.outputs}, " + f"finished={self.finished}, " + f"metrics={self.metrics}, " + f"lora_request={self.lora_request}, " + f"num_cached_tokens={self.num_cached_tokens}, " + f"multi_modal_placeholders={self.multi_modal_placeholders})") + + +_O = TypeVar("_O", default=PoolingOutput) + + +class PoolingRequestOutput(Generic[_O]): + """ + The output data of a pooling request to the LLM. + + Args: + request_id (str): A unique identifier for the pooling request. + outputs (PoolingOutput): The pooling results for the given input. + prompt_token_ids (list[int]): A list of token IDs used in the prompt. + finished (bool): A flag indicating whether the pooling is completed. + """ + + def __init__(self, request_id: str, outputs: _O, + prompt_token_ids: list[int], finished: bool): + self.request_id = request_id + self.prompt_token_ids = prompt_token_ids + self.finished = finished + self.outputs = outputs + + @staticmethod + def from_seq_group(seq_group: SequenceGroup) -> "PoolingRequestOutput": + pooled_data = seq_group.pooled_data + assert pooled_data is not None + + data = pooled_data.to(dtype=torch.float32, device="cpu") + output = PoolingOutput(data) + prompt_token_ids = seq_group.prompt_token_ids + finished = seq_group.is_finished() + + return PoolingRequestOutput(seq_group.request_id, output, + prompt_token_ids, finished) + + def __repr__(self): + """ + Returns a string representation of an PoolingRequestOutput instance. + + The representation includes the request_id and the number of outputs, + providing a quick overview of the pooling request's results. + + Returns: + str: A string representation of the PoolingRequestOutput instance. + """ + return (f"{type(self).__name__}(request_id={self.request_id!r}, " + f"outputs={self.outputs!r}, " + f"prompt_token_ids={self.prompt_token_ids}, " + f"finished={self.finished})") + + +class RequestOutputFactory: + + @staticmethod + def create(seq_group: SequenceGroup, + seq_id_to_seq_group: dict[str, SequenceGroupBase], + use_cache: bool = False): + if seq_group.pooled_data is not None: + return PoolingRequestOutput.from_seq_group(seq_group) + else: + return RequestOutput.from_seq_group(seq_group, use_cache, + seq_id_to_seq_group) + + +@dataclass +class EmbeddingOutput: + """The output data of one embedding output of a request. + + Args: + embedding: The embedding vector, which is a list of floats. + Its length depends on the hidden dimension of the model. + """ + embedding: list[float] + + @staticmethod + def from_base(pooling_output: PoolingOutput): + pooled_data = pooling_output.data + if pooled_data.ndim != 1: + raise ValueError("pooled_data should be a 1-D embedding vector") + + return EmbeddingOutput(pooled_data.tolist()) + + @property + def hidden_size(self) -> int: + return len(self.embedding) + + def __repr__(self) -> str: + return f"EmbeddingOutput(hidden_size={self.hidden_size})" + + +class EmbeddingRequestOutput(PoolingRequestOutput[EmbeddingOutput]): + + @staticmethod + def from_base(request_output: PoolingRequestOutput): + return EmbeddingRequestOutput( + request_id=request_output.request_id, + outputs=EmbeddingOutput.from_base(request_output.outputs), + prompt_token_ids=request_output.prompt_token_ids, + finished=request_output.finished, + ) + + +@dataclass +class ClassificationOutput: + """The output data of one classification output of a request. + + Args: + probs: The probability vector, which is a list of floats. + Its length depends on the number of classes. + """ + probs: list[float] + + @staticmethod + def from_base(pooling_output: PoolingOutput): + pooled_data = pooling_output.data + if pooled_data.ndim != 1: + raise ValueError("pooled_data should be a 1-D probability vector") + + return ClassificationOutput(pooled_data.tolist()) + + @property + def num_classes(self) -> int: + return len(self.probs) + + def __repr__(self) -> str: + return f"ClassificationOutput(num_classes={self.num_classes})" + + +class ClassificationRequestOutput(PoolingRequestOutput[ClassificationOutput]): + + @staticmethod + def from_base(request_output: PoolingRequestOutput): + return ClassificationRequestOutput( + request_id=request_output.request_id, + outputs=ClassificationOutput.from_base(request_output.outputs), + prompt_token_ids=request_output.prompt_token_ids, + finished=request_output.finished, + ) + + +@dataclass +class ScoringOutput: + """The output data of one scoring output of a request. + + Args: + score: The similarity score, which is a scalar value. + """ + score: float + + @staticmethod + def from_base(pooling_output: PoolingOutput): + pooled_data = pooling_output.data + if pooled_data.ndim != 0: + raise ValueError("pooled_data should be a scalar score") + + return ScoringOutput(pooled_data.item()) + + def __repr__(self) -> str: + return f"ScoringOutput(score={self.score})" + + @property + @deprecated("`LLM.score()` now returns scalar scores. " + "Please access it via the `score` attribute. ") + def embedding(self) -> list[float]: + return [self.score] + + +class ScoringRequestOutput(PoolingRequestOutput[ScoringOutput]): + + @staticmethod + def from_base(request_output: PoolingRequestOutput): + return ScoringRequestOutput( + request_id=request_output.request_id, + outputs=ScoringOutput.from_base(request_output.outputs), + prompt_token_ids=request_output.prompt_token_ids, + finished=request_output.finished, + ) diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py new file mode 100644 index 0000000..3722344 --- /dev/null +++ b/vllm/platforms/__init__.py @@ -0,0 +1,287 @@ +# SPDX-License-Identifier: Apache-2.0 + +import logging +import traceback +from itertools import chain +from typing import TYPE_CHECKING, Optional + +from vllm.plugins import load_plugins_by_group +from vllm.utils import resolve_obj_by_qualname + +from .interface import _Backend # noqa: F401 +from .interface import CpuArchEnum, Platform, PlatformEnum + +logger = logging.getLogger(__name__) + + +def vllm_version_matches_substr(substr: str) -> bool: + """ + Check to see if the vLLM version matches a substring. + """ + from importlib.metadata import PackageNotFoundError, version + try: + vllm_version = version("vllm") + except PackageNotFoundError as e: + logger.warning( + "The vLLM package was not found, so its version could not be " + "inspected. This may cause platform detection to fail.") + raise e + return substr in vllm_version + + +def tpu_platform_plugin() -> Optional[str]: + is_tpu = False + logger.debug("Checking if TPU platform is available.") + try: + # While it's technically possible to install libtpu on a + # non-TPU machine, this is a very uncommon scenario. Therefore, + # we assume that libtpu is installed if and only if the machine + # has TPUs. + import libtpu # noqa: F401 + is_tpu = True + logger.debug("Confirmed TPU platform is available.") + except Exception as e: + logger.debug("TPU platform is not available because: %s", str(e)) + pass + + return "vllm.platforms.tpu.TpuPlatform" if is_tpu else None + + +def cuda_platform_plugin() -> Optional[str]: + return "vllm.platforms.cuda.CudaPlatform" + is_cuda = False + logger.debug("Checking if CUDA platform is available.") + try: + from vllm.utils import import_pynvml + pynvml = import_pynvml() + pynvml.nvmlInit() + try: + # NOTE: Edge case: vllm cpu build on a GPU machine. + # Third-party pynvml can be imported in cpu build, + # we need to check if vllm is built with cpu too. + # Otherwise, vllm will always activate cuda plugin + # on a GPU machine, even if in a cpu build. + is_cuda = (pynvml.nvmlDeviceGetCount() > 0 + and not vllm_version_matches_substr("cpu")) + if pynvml.nvmlDeviceGetCount() <= 0: + logger.debug( + "CUDA platform is not available because no GPU is found.") + if vllm_version_matches_substr("cpu"): + logger.debug("CUDA platform is not available because" + " vLLM is built with CPU.") + if is_cuda: + logger.debug("Confirmed CUDA platform is available.") + finally: + pynvml.nvmlShutdown() + except Exception as e: + logger.debug("Exception happens when checking CUDA platform: %s", + str(e)) + if "nvml" not in e.__class__.__name__.lower(): + # If the error is not related to NVML, re-raise it. + raise e + + # CUDA is supported on Jetson, but NVML may not be. + import os + + def cuda_is_jetson() -> bool: + return os.path.isfile("/etc/nv_tegra_release") \ + or os.path.exists("/sys/class/tegra-firmware") + + if cuda_is_jetson(): + logger.debug("Confirmed CUDA platform is available on Jetson.") + is_cuda = True + else: + logger.debug("CUDA platform is not available because: %s", str(e)) + + return "vllm.platforms.cuda.CudaPlatform" if is_cuda else None + + +def rocm_platform_plugin() -> Optional[str]: + is_rocm = False + logger.debug("Checking if ROCm platform is available.") + try: + import amdsmi + amdsmi.amdsmi_init() + try: + if len(amdsmi.amdsmi_get_processor_handles()) > 0: + is_rocm = True + logger.debug("Confirmed ROCm platform is available.") + else: + logger.debug("ROCm platform is not available because" + " no GPU is found.") + finally: + amdsmi.amdsmi_shut_down() + except Exception as e: + logger.debug("ROCm platform is not available because: %s", str(e)) + pass + + return "vllm.platforms.rocm.RocmPlatform" if is_rocm else None + + +def hpu_platform_plugin() -> Optional[str]: + is_hpu = False + logger.debug("Checking if HPU platform is available.") + try: + from importlib import util + is_hpu = util.find_spec('habana_frameworks') is not None + if is_hpu: + logger.debug("Confirmed HPU platform is available.") + else: + logger.debug("HPU platform is not available because " + "habana_frameworks is not found.") + except Exception as e: + logger.debug("HPU platform is not available because: %s", str(e)) + pass + + return "vllm.platforms.hpu.HpuPlatform" if is_hpu else None + + +def xpu_platform_plugin() -> Optional[str]: + is_xpu = False + logger.debug("Checking if XPU platform is available.") + try: + # installed IPEX if the machine has XPUs. + import intel_extension_for_pytorch # noqa: F401 + import oneccl_bindings_for_pytorch # noqa: F401 + import torch + if hasattr(torch, 'xpu') and torch.xpu.is_available(): + is_xpu = True + logger.debug("Confirmed XPU platform is available.") + except Exception as e: + logger.debug("XPU platform is not available because: %s", str(e)) + pass + + return "vllm.platforms.xpu.XPUPlatform" if is_xpu else None + + +def cpu_platform_plugin() -> Optional[str]: + is_cpu = False + logger.debug("Checking if CPU platform is available.") + try: + is_cpu = vllm_version_matches_substr("cpu") + if is_cpu: + logger.debug("Confirmed CPU platform is available because" + " vLLM is built with CPU.") + if not is_cpu: + import sys + is_cpu = sys.platform.startswith("darwin") + if is_cpu: + logger.debug("Confirmed CPU platform is available" + " because the machine is MacOS.") + + except Exception as e: + logger.debug("CPU platform is not available because: %s", str(e)) + pass + + return "vllm.platforms.cpu.CpuPlatform" if is_cpu else None + + +def neuron_platform_plugin() -> Optional[str]: + is_neuron = False + logger.debug("Checking if Neuron platform is available.") + try: + import transformers_neuronx # noqa: F401 + is_neuron = True + logger.debug("Confirmed Neuron platform is available because" + " transformers_neuronx is found.") + except ImportError as e: + logger.debug("Neuron platform is not available because: %s", str(e)) + pass + + return "vllm.platforms.neuron.NeuronPlatform" if is_neuron else None + + +builtin_platform_plugins = { + 'tpu': tpu_platform_plugin, + 'cuda': cuda_platform_plugin, + 'rocm': rocm_platform_plugin, + 'hpu': hpu_platform_plugin, + 'xpu': xpu_platform_plugin, + 'cpu': cpu_platform_plugin, + 'neuron': neuron_platform_plugin, +} + + +def resolve_current_platform_cls_qualname() -> str: + platform_plugins = load_plugins_by_group('vllm.platform_plugins') + + activated_plugins = [] + + for name, func in chain(builtin_platform_plugins.items(), + platform_plugins.items()): + try: + assert callable(func) + platform_cls_qualname = func() + if platform_cls_qualname is not None: + activated_plugins.append(name) + except Exception: + pass + + activated_builtin_plugins = list( + set(activated_plugins) & set(builtin_platform_plugins.keys())) + activated_oot_plugins = list( + set(activated_plugins) & set(platform_plugins.keys())) + + if len(activated_oot_plugins) >= 2: + raise RuntimeError( + "Only one platform plugin can be activated, but got: " + f"{activated_oot_plugins}") + elif len(activated_oot_plugins) == 1: + platform_cls_qualname = platform_plugins[activated_oot_plugins[0]]() + logger.info("Platform plugin %s is activated", + activated_oot_plugins[0]) + elif len(activated_builtin_plugins) >= 2: + raise RuntimeError( + "Only one platform plugin can be activated, but got: " + f"{activated_builtin_plugins}") + elif len(activated_builtin_plugins) == 1: + platform_cls_qualname = builtin_platform_plugins[ + activated_builtin_plugins[0]]() + logger.info("Automatically detected platform %s.", + activated_builtin_plugins[0]) + else: + platform_cls_qualname = "vllm.platforms.interface.UnspecifiedPlatform" + logger.info( + "No platform detected, vLLM is running on UnspecifiedPlatform") + return platform_cls_qualname + + +_current_platform = None +_init_trace: str = '' + +if TYPE_CHECKING: + current_platform: Platform + + +def __getattr__(name: str): + if name == 'current_platform': + # lazy init current_platform. + # 1. out-of-tree platform plugins need `from vllm.platforms import + # Platform` so that they can inherit `Platform` class. Therefore, + # we cannot resolve `current_platform` during the import of + # `vllm.platforms`. + # 2. when users use out-of-tree platform plugins, they might run + # `import vllm`, some vllm internal code might access + # `current_platform` during the import, and we need to make sure + # `current_platform` is only resolved after the plugins are loaded + # (we have tests for this, if any developer violate this, they will + # see the test failures). + global _current_platform + if _current_platform is None: + platform_cls_qualname = resolve_current_platform_cls_qualname() + _current_platform = resolve_obj_by_qualname( + platform_cls_qualname)() + global _init_trace + _init_trace = "".join(traceback.format_stack()) + return _current_platform + elif name in globals(): + return globals()[name] + else: + raise AttributeError( + f"No attribute named '{name}' exists in {__name__}.") + + +__all__ = [ + 'Platform', 'PlatformEnum', 'current_platform', 'CpuArchEnum', + "_init_trace" +] diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py new file mode 100644 index 0000000..67466bd --- /dev/null +++ b/vllm/platforms/cpu.py @@ -0,0 +1,182 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +import sys +from importlib.util import find_spec +from typing import TYPE_CHECKING, Optional + +import psutil +import torch + +from vllm.logger import init_logger + +from .interface import Platform, PlatformEnum, _Backend + +logger = init_logger(__name__) + +if TYPE_CHECKING: + from vllm.config import VllmConfig +else: + VllmConfig = None + +logger = init_logger(__name__) + + +class CpuPlatform(Platform): + _enum = PlatformEnum.CPU + device_name: str = "cpu" + device_type: str = "cpu" + dispatch_key: str = "CPU" + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + return "cpu" + + @classmethod + def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, + dtype: torch.dtype, kv_cache_dtype: Optional[str], + block_size: int, use_v1: bool, + use_mla: bool) -> str: + if selected_backend and selected_backend != _Backend.TORCH_SDPA: + logger.info("Cannot use %s backend on CPU.", selected_backend) + if use_mla: + logger.info("Using CPU MLA backend.") + return "vllm.attention.backends.cpu_mla.CPUMLABackend" + logger.info("Using Torch SDPA backend.") + return "vllm.attention.backends.torch_sdpa.TorchSDPABackend" + + @classmethod + def get_device_total_memory(cls, device_id: int = 0) -> int: + return psutil.virtual_memory().total + + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + return False + + @classmethod + def inference_mode(cls): + return torch.no_grad() + + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + import vllm.envs as envs + from vllm.utils import GiB_bytes + model_config = vllm_config.model_config + # Reminder: Please update docs/source/features/compatibility_matrix.md + # If the feature combo become valid + if not model_config.enforce_eager: + model_config.enforce_eager = True + + cache_config = vllm_config.cache_config + + ipex_avaliable = find_spec("intel_extension_for_pytorch") is not None + + if cache_config and cache_config.block_size is None: + cache_config.block_size = 128 if ipex_avaliable else 16 + + if not ipex_avaliable and cache_config.block_size != 16: + raise RuntimeError( + f"--block-size={cache_config.block_size} requires" + " intel_extension_for_pytorch") + + scheduler_config = vllm_config.scheduler_config + if ((scheduler_config.chunked_prefill_enabled + or cache_config.enable_prefix_caching) + and cache_config.cache_dtype != "auto"): + raise RuntimeError("Chunked-prefill and prefix-cache on the CPU " + "backend is not compatible with FP8 KV cache.") + + if cache_config.cache_dtype == "fp8_e4m3": + cache_config.cache_dtype = "fp8_e5m2" + logger.warning( + "CPU backend doesn't support fp8_e4m3 KV cache type, " + "cast to fp8_e5m2.") + + if (cache_config.cache_dtype != "auto" + and model_config.dtype == torch.half): + logger.warning("FP8 KV cache on the CPU backend only does not" + " support fp16 for now, cast to bf16.") + model_config.dtype = torch.bfloat16 + + kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE + + if kv_cache_space >= 0: + if kv_cache_space == 0: + cache_config.cpu_kvcache_space_bytes = 4 * GiB_bytes # type: ignore + logger.warning( + "Environment variable VLLM_CPU_KVCACHE_SPACE (GiB) " + "for CPU backend is not set, using 4 by default.") + else: + cache_config.cpu_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore # noqa + else: + raise RuntimeError( + "Invalid environment variable VLLM_CPU_KVCACHE_SPACE" + f" {kv_cache_space}, expect a positive integer value.") + + parallel_config = vllm_config.parallel_config + if (parallel_config.distributed_executor_backend is not None + and parallel_config.distributed_executor_backend != "mp"): + logger.warning(("%s is not supported on CPU, fallback to mp " + "distributed executor backend."), + parallel_config.distributed_executor_backend) + parallel_config.distributed_executor_backend = "mp" + if parallel_config.worker_cls == "auto": + if vllm_config.speculative_config: + parallel_config.worker_cls = \ + "vllm.spec_decode.spec_decode_worker.create_spec_worker" + parallel_config.sd_worker_cls = \ + "vllm.worker.cpu_worker.CPUWorker" + else: + parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker" + + assert vllm_config.device_config.device_type == "cpu" + + # + # Environment variables for CPU executor + # + + # Set default threads num for OpenMP parallel + os.environ["OMP_NUM_THREADS"] = str(torch.get_num_threads()) + + # Disable torch async compiling which won't work with daemonic processes + os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" + + # Intel OpenMP setting + ld_prealod_str = os.getenv("LD_PRELOAD", "") + if "libiomp5.so" in ld_prealod_str: + # The time(milliseconds) that a thread should wait after + # completing the execution of a parallel region, before sleeping. + os.environ['KMP_BLOCKTIME'] = "1" + # Prevents the CPU to run into low performance state + os.environ['KMP_TPAUSE'] = "0" + # Provides fine granularity parallelism + os.environ['KMP_FORKJOIN_BARRIER_PATTERN'] = "dist,dist" + os.environ['KMP_PLAIN_BARRIER_PATTERN'] = "dist,dist" + os.environ['KMP_REDUCTION_BARRIER_PATTERN'] = "dist,dist" + + # To hint IPEX uses shared memory based AllReduce + os.environ["LOCAL_WORLD_SIZE"] = str( + vllm_config.parallel_config.tensor_parallel_size) + if sys.platform == "darwin" and \ + envs.VLLM_WORKER_MULTIPROC_METHOD == "fork": + if os.environ.get('VLLM_WORKER_MULTIPROC_METHOD', None) is None: + logger.warning( + "Default to spawn method on MacOS. If this is not desired," + " set VLLM_WORKER_MULTIPROC_METHOD to fork explicitly.") + os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' + + @classmethod + def is_pin_memory_available(cls) -> bool: + logger.warning("Pin memory is not supported on CPU.") + return False + + @classmethod + def get_punica_wrapper(cls) -> str: + return "vllm.lora.punica_wrapper.punica_cpu.PunicaWrapperCPU" + + @classmethod + def get_device_communicator_cls(cls) -> str: + """ + Get device specific communicator class for distributed communication. + """ + return "vllm.distributed.device_communicators.cpu_communicator.CpuCommunicator" # noqa diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py new file mode 100644 index 0000000..171331e --- /dev/null +++ b/vllm/platforms/cuda.py @@ -0,0 +1,462 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Code inside this file can safely assume cuda platform, e.g. importing +pynvml. However, it should not initialize cuda context. +""" + +import os +from functools import wraps +from typing import (TYPE_CHECKING, Callable, List, Optional, Tuple, TypeVar, + Union) + +import torch +from typing_extensions import ParamSpec + +# import custom ops, trigger op registration +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.utils import import_pynvml + +from .interface import DeviceCapability, Platform, PlatformEnum, _Backend + +if TYPE_CHECKING: + from vllm.config import ModelConfig, VllmConfig +else: + ModelConfig = None + VllmConfig = None + +logger = init_logger(__name__) + +_P = ParamSpec("_P") +_R = TypeVar("_R") + +pynvml = import_pynvml() + +# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models +# see https://github.com/huggingface/diffusers/issues/9704 for details +# torch.backends.cuda.enable_cudnn_sdp(False) + + +def device_id_to_physical_device_id(device_id: int) -> int: + if "CUDA_VISIBLE_DEVICES" in os.environ: + device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",") + if device_ids == [""]: + msg = ( + "CUDA_VISIBLE_DEVICES is set to empty string, which means" + " GPU support is disabled. If you are using ray, please unset" + " the environment variable `CUDA_VISIBLE_DEVICES` inside the" + " worker/actor. " + "Check https://github.com/vllm-project/vllm/issues/8402 for" + " more information.") + raise RuntimeError(msg) + physical_device_id = device_ids[device_id] + return int(physical_device_id) + else: + return device_id + + +def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: + + @wraps(fn) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: + pynvml.nvmlInit() + try: + return fn(*args, **kwargs) + finally: + pynvml.nvmlShutdown() + + return wrapper + + +class CudaPlatformBase(Platform): + _enum = PlatformEnum.CUDA + device_name: str = "cuda" + device_type: str = "cuda" + dispatch_key: str = "CUDA" + ray_device_key: str = "GPU" + device_control_env_var: str = "CUDA_VISIBLE_DEVICES" + + @classmethod + def get_device_capability(cls, + device_id: int = 0 + ) -> Optional[DeviceCapability]: + raise NotImplementedError + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + raise NotImplementedError + + @classmethod + def get_device_total_memory(cls, device_id: int = 0) -> int: + raise NotImplementedError + + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + if enforce_eager: + logger.warning( + "To see benefits of async output processing, enable CUDA " + "graph. Since, enforce-eager is enabled, async output " + "processor cannot be used") + return False + return True + + @classmethod + def is_fully_connected(cls, device_ids: List[int]) -> bool: + raise NotImplementedError + + @classmethod + def log_warnings(cls): + pass + + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + parallel_config = vllm_config.parallel_config + scheduler_config = vllm_config.scheduler_config + compilation_config = vllm_config.compilation_config + model_config = vllm_config.model_config + + if parallel_config.worker_cls == "auto": + if scheduler_config.is_multi_step: + if envs.VLLM_USE_V1: + raise NotImplementedError( + "Multi-step scheduling is not supported (and not " + "needed) on vLLM V1. Please launch without " + "--num-scheduler-steps.") + else: + parallel_config.worker_cls = \ + "vllm.worker.multi_step_worker.MultiStepWorker" + elif vllm_config.speculative_config: + if envs.VLLM_USE_V1: + parallel_config.worker_cls = \ + "vllm.v1.worker.gpu_worker.Worker" + else: + parallel_config.worker_cls = \ + "vllm.spec_decode.spec_decode_worker.create_spec_worker" + parallel_config.sd_worker_cls = \ + "vllm.worker.worker.Worker" + else: + if envs.VLLM_USE_V1: + parallel_config.worker_cls = \ + "vllm.v1.worker.gpu_worker.Worker" + else: + parallel_config.worker_cls = "vllm.worker.worker.Worker" + + cache_config = vllm_config.cache_config + if cache_config and cache_config.block_size is None: + cache_config.block_size = 16 + + # TODO(lucas): handle this more gracefully + # Note: model_config may be None during testing + if model_config is not None and model_config.use_mla: + # if `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, then + # we default to FlashMLA backend, so we need to force the blocksize + # here + use_flashmla = (envs.VLLM_ATTENTION_BACKEND is None \ + or envs.VLLM_ATTENTION_BACKEND == "FLASHMLA") + from vllm.attention.ops.flashmla import is_flashmla_supported + if use_flashmla and is_flashmla_supported()[0] \ + and cache_config.block_size != 64: + cache_config.block_size = 64 + logger.info( + "Forcing kv cache block size to 64 for FlashMLA backend.") + + if (parallel_config.data_parallel_size > 1 + and compilation_config.use_cudagraph): + logger.info( + "Data Parallel: Forcing enforce eager to be True since DP is " + "currently not supported with CUDA Graphs.") + vllm_config.model_config.enforce_eager = True + compilation_config.use_cudagraph = False + + @classmethod + def get_current_memory_usage(cls, + device: Optional[torch.types.Device] = None + ) -> float: + torch.cuda.reset_peak_memory_stats(device) + return torch.cuda.max_memory_allocated(device) + + @classmethod + def get_attn_backend_cls(cls, selected_backend, head_size, dtype, + kv_cache_dtype, block_size, use_v1, + use_mla) -> str: + if use_mla: + # TODO(lucas): refactor to be more concise + # we should probably consider factoring out V1 here + if selected_backend == _Backend.TRITON_MLA or block_size != 64: + if use_v1: + logger.info_once("Using Triton MLA backend on V1 engine.") + return ("vllm.v1.attention.backends.mla." + "triton_mla.TritonMLABackend") + else: + logger.info("Using Triton MLA backend.") + return "vllm.attention.backends.triton_mla.TritonMLABackend" + else: + from vllm.attention.backends.flashmla import ( + is_flashmla_supported) + if not is_flashmla_supported()[0]: + logger.warning( + "FlashMLA backend is not supported due to %s", + is_flashmla_supported()[1]) + elif block_size != 64: + logger.warning( + "FlashMLA backend is not supported for block size %d" + " (currently only supports block size 64).", + block_size) + else: + if use_v1: + logger.info_once( + "Using FlashMLA backend on V1 engine.") + return ("vllm.v1.attention.backends.mla." + "flashmla.FlashMLABackend") + else: + logger.info("Using FlashMLA backend.") + return ("vllm.attention.backends." + "flashmla.FlashMLABackend") + if use_v1: + if selected_backend == _Backend.TRITON_ATTN_VLLM_V1: + logger.info_once("Using Triton backend on V1 engine.") + return ("vllm.v1.attention.backends." + "triton_attn.TritonAttentionBackend") + if cls.has_device_capability(80): + logger.info_once("Using Flash Attention backend on V1 engine.") + return ("vllm.v1.attention.backends." + "flash_attn.FlashAttentionBackend") + if selected_backend == _Backend.FLASHINFER: + logger.info("Using FlashInfer backend.") + return "vllm.attention.backends.flashinfer.FlashInferBackend" + elif selected_backend == _Backend.XFORMERS: + logger.info("Using XFormers backend.") + return "vllm.attention.backends.xformers.XFormersBackend" + elif selected_backend == _Backend.FLASH_ATTN: + pass + elif selected_backend: + raise ValueError( + f"Invalid attention backend for {cls.device_name}, " + f"with use_v1: {use_v1} use_mla: {use_mla}") + + target_backend = _Backend.FLASH_ATTN + if not cls.has_device_capability(80): + # Volta and Turing NVIDIA GPUs. + logger.info( + "Cannot use FlashAttention-2 backend for Volta and Turing " + "GPUs.") + target_backend = _Backend.XFORMERS + elif dtype not in (torch.float16, torch.bfloat16): + logger.info( + "Cannot use FlashAttention-2 backend for dtype other than " + "torch.float16 or torch.bfloat16.") + target_backend = _Backend.XFORMERS + elif block_size % 16 != 0: + logger.info( + "Cannot use FlashAttention-2 backend for block size not " + "divisible by 16.") + target_backend = _Backend.XFORMERS + + # FlashAttn is valid for the model, checking if the package is + # installed. + if target_backend == _Backend.FLASH_ATTN: + try: + # import vllm.vllm_flash_attn # noqa: F401 + from vllm.attention.backends.flash_attn import ( # noqa: F401 + FlashAttentionBackend, flash_attn_supports_fp8) + + supported_sizes = \ + FlashAttentionBackend.get_supported_head_sizes() + if head_size not in supported_sizes: + logger.info( + "Cannot use FlashAttention-2 backend for head size %d.", + head_size) + target_backend = _Backend.XFORMERS + fp8_kv_cache = (kv_cache_dtype is not None + and kv_cache_dtype.startswith("fp8")) + if (fp8_kv_cache and not flash_attn_supports_fp8()): + logger.info( + "Cannot use FlashAttention backend for FP8 KV cache.") + logger.warning( + "Please use FlashInfer backend with FP8 KV Cache for " + "better performance by setting environment variable " + "VLLM_ATTENTION_BACKEND=FLASHINFER") + target_backend = _Backend.XFORMERS + except ImportError: + logger.info( + "Cannot use FlashAttention-2 backend because the " + "vllm.vllm_flash_attn package is not found. " + "Make sure that vllm_flash_attn was built and installed " + "(on by default).") + target_backend = _Backend.XFORMERS + + if target_backend == _Backend.XFORMERS: + logger.info("Using XFormers backend.") + return "vllm.attention.backends.xformers.XFormersBackend" + + logger.info("Using Flash Attention backend.") + return "vllm.attention.backends.flash_attn.FlashAttentionBackend" + + @classmethod + def get_punica_wrapper(cls) -> str: + return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU" + + @classmethod + def get_device_communicator_cls(cls) -> str: + return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" # noqa + + @classmethod + def supports_fp8(cls) -> bool: + return cls.has_device_capability(89) + + @classmethod + def supports_v1(cls, model_config: ModelConfig) -> bool: + return True + + @classmethod + def use_custom_allreduce(cls) -> bool: + return True + + +# NVML utils +# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, +# all the related functions work on real physical device ids. +# the major benefit of using NVML is that it will not initialize CUDA +class NvmlCudaPlatform(CudaPlatformBase): + + @classmethod + @with_nvml_context + def get_device_capability(cls, + device_id: int = 0 + ) -> Optional[DeviceCapability]: + try: + physical_device_id = device_id_to_physical_device_id(device_id) + handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id) + major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle) + return DeviceCapability(major=major, minor=minor) + except RuntimeError: + return None + + @classmethod + @with_nvml_context + def has_device_capability( + cls, + capability: Union[Tuple[int, int], int], + device_id: int = 0, + ) -> bool: + try: + return super().has_device_capability(capability, device_id) + except RuntimeError: + return False + + @classmethod + @with_nvml_context + def get_device_name(cls, device_id: int = 0) -> str: + physical_device_id = device_id_to_physical_device_id(device_id) + return cls._get_physical_device_name(physical_device_id) + + @classmethod + @with_nvml_context + def get_device_uuid(cls, device_id: int = 0) -> str: + physical_device_id = device_id_to_physical_device_id(device_id) + handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id) + return pynvml.nvmlDeviceGetUUID(handle) + + @classmethod + @with_nvml_context + def get_device_total_memory(cls, device_id: int = 0) -> int: + physical_device_id = device_id_to_physical_device_id(device_id) + handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id) + return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total) + + @classmethod + @with_nvml_context + def is_fully_connected(cls, physical_device_ids: List[int]) -> bool: + """ + query if the set of gpus are fully connected by nvlink (1 hop) + """ + handles = [ + pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids + ] + for i, handle in enumerate(handles): + for j, peer_handle in enumerate(handles): + if i < j: + try: + p2p_status = pynvml.nvmlDeviceGetP2PStatus( + handle, + peer_handle, + pynvml.NVML_P2P_CAPS_INDEX_NVLINK, + ) + if p2p_status != pynvml.NVML_P2P_STATUS_OK: + return False + except pynvml.NVMLError: + logger.exception( + "NVLink detection failed. This is normal if" + " your machine has no NVLink equipped.") + return False + return True + + @classmethod + def _get_physical_device_name(cls, device_id: int = 0) -> str: + handle = pynvml.nvmlDeviceGetHandleByIndex(device_id) + return pynvml.nvmlDeviceGetName(handle) + + @classmethod + @with_nvml_context + def log_warnings(cls): + device_ids: int = pynvml.nvmlDeviceGetCount() + if device_ids > 1: + device_names = [ + cls._get_physical_device_name(i) for i in range(device_ids) + ] + if (len(set(device_names)) > 1 + and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"): + logger.warning( + "Detected different devices in the system: %s. Please" + " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to " + "avoid unexpected behavior.", + ", ".join(device_names), + ) + + +class NonNvmlCudaPlatform(CudaPlatformBase): + + @classmethod + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + major, minor = torch.cuda.get_device_capability(device_id) + return DeviceCapability(major=9, minor=0) + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + @classmethod + def get_device_total_memory(cls, device_id: int = 0) -> int: + device_props = torch.cuda.get_device_properties(device_id) + return device_props.total_memory + + @classmethod + def is_fully_connected(cls, physical_device_ids: List[int]) -> bool: + logger.exception( + "NVLink detection not possible, as context support was" + " not found. Assuming no NVLink available.") + return False + + +# Autodetect either NVML-enabled or non-NVML platform +# based on whether NVML is available. +nvml_available = False +try: + try: + pynvml.nvmlInit() + nvml_available = False + except Exception: + # On Jetson, NVML is not supported. + nvml_available = False +finally: + if nvml_available: + pynvml.nvmlShutdown() + +CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform + +try: + from sphinx.ext.autodoc.mock import _MockModule + + if not isinstance(pynvml, _MockModule): + CudaPlatform.log_warnings() +except ModuleNotFoundError: + CudaPlatform.log_warnings() diff --git a/vllm/platforms/hpu.py b/vllm/platforms/hpu.py new file mode 100644 index 0000000..4c842b5 --- /dev/null +++ b/vllm/platforms/hpu.py @@ -0,0 +1,94 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +from typing import TYPE_CHECKING, Optional + +import torch + +from vllm import envs +from vllm.logger import init_logger + +from .interface import Platform, PlatformEnum, _Backend + +if TYPE_CHECKING: + from vllm.config import VllmConfig +else: + VllmConfig = None + +logger = init_logger(__name__) + + +class HpuPlatform(Platform): + _enum = PlatformEnum.HPU + device_name: str = "hpu" + device_type: str = "hpu" + dispatch_key: str = "HPU" + ray_device_key: str = "HPU" + device_control_env_var: str = "HABANA_VISIBLE_MODULES" + + @classmethod + def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, + dtype: torch.dtype, kv_cache_dtype: Optional[str], + block_size: int, use_v1: bool, + use_mla: bool) -> str: + logger.info("Using HPUAttention backend.") + return "vllm.attention.backends.hpu_attn.HPUAttentionBackend" + + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + return True + + @staticmethod + def inference_mode(): + return torch.no_grad() + + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + + scheduler_config = vllm_config.scheduler_config + if scheduler_config.is_multi_step: + raise NotImplementedError( + "Multi-step execution is not implemented for HPU") + + if vllm_config.speculative_config is not None: + raise NotImplementedError( + "Speculative decoding is not implemented for HPU") + + parallel_config = vllm_config.parallel_config + if parallel_config.worker_cls == "auto": + parallel_config.worker_cls = "vllm.worker.hpu_worker.HPUWorker" + + # NOTE(kzawora): default block size for Gaudi should be 128 + # smaller sizes still work, but very inefficiently + cache_config = vllm_config.cache_config + if cache_config and cache_config.block_size is None: + cache_config.block_size = 128 + if (parallel_config.distributed_executor_backend == 'mp' + and envs.VLLM_WORKER_MULTIPROC_METHOD == 'fork'): + if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", + None) is not None: + logger.warning("On HPU, VLLM_WORKER_MULTIPROC_METHOD=fork " + "might cause application hangs on exit. Using " + "VLLM_WORKER_MULTIPROC_METHOD=fork anyway, " + "as it was explicitly requested.") + else: + logger.warning( + "On HPU, VLLM_WORKER_MULTIPROC_METHOD=fork " + "might cause application hangs on exit. Setting " + "VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. " + "To override that behavior, please set " + "VLLM_WORKER_MULTIPROC_METHOD=fork explicitly.") + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + + @classmethod + def is_pin_memory_available(cls): + logger.warning("Pin memory is not supported on HPU.") + return False + + @classmethod + def get_punica_wrapper(cls) -> str: + return "vllm.lora.punica_wrapper.punica_hpu.PunicaWrapperHPU" + + @classmethod + def get_device_communicator_cls(cls) -> str: + return "vllm.distributed.device_communicators.hpu_communicator.HpuCommunicator" # noqa diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py new file mode 100644 index 0000000..a5b58aa --- /dev/null +++ b/vllm/platforms/interface.py @@ -0,0 +1,393 @@ +# SPDX-License-Identifier: Apache-2.0 + +import enum +import platform +import random +from platform import uname +from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union + +import numpy as np +import torch + +from vllm.logger import init_logger + +if TYPE_CHECKING: + from vllm.config import ModelConfig, VllmConfig + from vllm.utils import FlexibleArgumentParser +else: + ModelConfig = None + VllmConfig = None + FlexibleArgumentParser = None + +logger = init_logger(__name__) + + +def in_wsl() -> bool: + # Reference: https://github.com/microsoft/WSL/issues/4071 + return "microsoft" in " ".join(uname()).lower() + + +class _Backend(enum.Enum): + FLASH_ATTN = enum.auto() + FLASH_ATTN_VLLM_V1 = enum.auto() + TRITON_ATTN_VLLM_V1 = enum.auto() + XFORMERS = enum.auto() + ROCM_FLASH = enum.auto() + TORCH_SDPA = enum.auto() + FLASHINFER = enum.auto() + TRITON_MLA = enum.auto() # Supported by V1 + FLASHMLA = enum.auto() # Supported by V1 + HPU_ATTN = enum.auto() + PALLAS = enum.auto() + PALLAS_VLLM_V1 = enum.auto() + IPEX = enum.auto() + BLOCK_SPARSE_FLASH_ATTN = enum.auto() + NO_ATTENTION = enum.auto() + + +class PlatformEnum(enum.Enum): + CUDA = enum.auto() + ROCM = enum.auto() + TPU = enum.auto() + HPU = enum.auto() + XPU = enum.auto() + CPU = enum.auto() + NEURON = enum.auto() + OOT = enum.auto() + UNSPECIFIED = enum.auto() + + +class CpuArchEnum(enum.Enum): + X86 = enum.auto() + ARM = enum.auto() + POWERPC = enum.auto() + OTHER = enum.auto() + UNKNOWN = enum.auto() + + +class DeviceCapability(NamedTuple): + major: int + minor: int + + def as_version_str(self) -> str: + return f"{self.major}.{self.minor}" + + def to_int(self) -> int: + """ + Express device capability as an integer ````. + + It is assumed that the minor version is always a single digit. + """ + assert 0 <= self.minor < 10 + return self.major * 10 + self.minor + + +class Platform: + _enum: PlatformEnum + device_name: str + device_type: str + + # available dispatch keys: + # check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa + # use "CPU" as a fallback for platforms not registered in PyTorch + dispatch_key: str = "CPU" + + # available ray device keys: + # https://github.com/ray-project/ray/blob/10ba5adadcc49c60af2c358a33bb943fb491a171/python/ray/_private/ray_constants.py#L438 # noqa + # empty string means the device does not support ray + ray_device_key: str = "" + + # platform-agnostic way to specify the device control environment variable, + # .e.g. CUDA_VISIBLE_DEVICES for CUDA. + # hint: search for "get_visible_accelerator_ids_env_var" in + # https://github.com/ray-project/ray/tree/master/python/ray/_private/accelerators # noqa + device_control_env_var: str = "VLLM_DEVICE_CONTROL_ENV_VAR_PLACEHOLDER" + + # The torch.compile backend for compiling simple and + # standalone functions. The default value is "inductor" to keep + # the same behavior as PyTorch. + # NOTE: for the forward part of the model, vLLM has another separate + # compilation strategy. + simple_compile_backend: str = "inductor" + + supported_quantization: list[str] = [] + + additional_env_vars: list[str] = [] + + def is_cuda(self) -> bool: + return self._enum == PlatformEnum.CUDA + + def is_rocm(self) -> bool: + return self._enum == PlatformEnum.ROCM + + def is_tpu(self) -> bool: + return self._enum == PlatformEnum.TPU + + def is_hpu(self) -> bool: + return self._enum == PlatformEnum.HPU + + def is_xpu(self) -> bool: + return self._enum == PlatformEnum.XPU + + def is_cpu(self) -> bool: + return self._enum == PlatformEnum.CPU + + def is_neuron(self) -> bool: + return self._enum == PlatformEnum.NEURON + + def is_out_of_tree(self) -> bool: + return self._enum == PlatformEnum.OOT + + def is_cuda_alike(self) -> bool: + """Stateless version of :func:`torch.cuda.is_available`.""" + return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM) + + @classmethod + def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, + dtype: torch.dtype, kv_cache_dtype: Optional[str], + block_size: int, use_v1: bool, + use_mla: bool) -> str: + """Get the attention backend class of a device.""" + return "" + + @classmethod + def get_device_capability( + cls, + device_id: int = 0, + ) -> Optional[DeviceCapability]: + """Stateless version of :func:`torch.cuda.get_device_capability`.""" + return None + + @classmethod + def has_device_capability( + cls, + capability: Union[Tuple[int, int], int], + device_id: int = 0, + ) -> bool: + """ + Test whether this platform is compatible with a device capability. + + The ``capability`` argument can either be: + + - A tuple ``(major, minor)``. + - An integer ````. (See :meth:`DeviceCapability.to_int`) + """ + return True + current_capability = cls.get_device_capability(device_id=device_id) + if current_capability is None: + return False + + if isinstance(capability, tuple): + return current_capability >= capability + + return current_capability.to_int() >= capability + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + """Get the name of a device.""" + raise NotImplementedError + + @classmethod + def get_device_uuid(cls, device_id: int = 0) -> str: + """Get the uuid of a device, e.g. the PCI bus ID.""" + raise NotImplementedError + + @classmethod + def get_device_total_memory(cls, device_id: int = 0) -> int: + """Get the total memory of a device in bytes.""" + raise NotImplementedError + + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + """ + Check if the current platform supports async output. + """ + raise NotImplementedError + + @classmethod + def inference_mode(cls): + """A device-specific wrapper of `torch.inference_mode`. + + This wrapper is recommended because some hardware backends such as TPU + do not support `torch.inference_mode`. In such a case, they will fall + back to `torch.no_grad` by overriding this method. + """ + return torch.inference_mode(mode=True) + + @classmethod + def seed_everything(cls, seed: Optional[int] = None) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + if seed is not None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @classmethod + def pre_register_and_update(cls, + parser: Optional[FlexibleArgumentParser] = None + ) -> None: + """ + Do some pre-registeration or update action for the current platform. + + This function is called before global VllmConfig is initialized or cli + arguments are parsed. It's used for out-of-tree platforms to register or + update the configuration. + + For example, the out-of-tree quantization config can be imported and + registered here dynamically. + """ + pass + + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + """ + Check and update the configuration for the current platform. + + It can raise an exception if the configuration is not compatible with + the current platform, or it can update the configuration to make it + compatible with the current platform. + + The config is passed by reference, so it can be modified in place. + """ + pass + + @classmethod + def verify_model_arch(cls, model_arch: str) -> None: + """ + Verify whether the current platform supports the specified model + architecture. + + - This will raise an Error or Warning based on the model support on + the current platform. + - By default all models are considered supported. + """ + pass + + @classmethod + def verify_quantization(cls, quant: str) -> None: + """ + Verify whether the quantization is supported by the current platform. + """ + if cls.supported_quantization and \ + quant not in cls.supported_quantization: + raise ValueError( + f"{quant} quantization is currently not supported in " + f"{cls.device_name}.") + + @classmethod + def get_cpu_architecture(cls) -> CpuArchEnum: + """ + Determine the CPU architecture of the current system. + Returns CpuArchEnum indicating the architecture type. + """ + machine = platform.machine().lower() + + if machine in ("x86_64", "amd64", "i386", "i686"): + return CpuArchEnum.X86 + elif machine.startswith("arm") or machine.startswith("aarch"): + return CpuArchEnum.ARM + elif machine.startswith("ppc"): + return CpuArchEnum.POWERPC + + return CpuArchEnum.OTHER if machine else CpuArchEnum.UNKNOWN + + @classmethod + def is_pin_memory_available(cls) -> bool: + """Checks whether pin memory is available on the current platform.""" + if in_wsl(): + # Pinning memory in WSL is not supported. + # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications + logger.warning("Using 'pin_memory=False' as WSL is detected. " + "This may slow down the performance.") + return False + return True + + @classmethod + def get_current_memory_usage(cls, + device: Optional[torch.types.Device] = None + ) -> float: + """ + Return the memory usage in bytes. + """ + raise NotImplementedError + + @classmethod + def get_punica_wrapper(cls) -> str: + """ + Return the punica wrapper for current platform. + """ + raise NotImplementedError + + @classmethod + def get_device_communicator_cls(cls) -> str: + """ + Get device specific communicator class for distributed communication. + """ + return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" # noqa + + @classmethod + def supports_fp8(cls) -> bool: + """ + Returns whether the current platform supports FP8 types. + """ + return False + + @classmethod + def is_fp8_fnuz(cls) -> bool: + """ + Returns whether the preferred FP8 type is FNUZ on the current platform. + + There are two representations of FP8, OCP FP8 and FNUZ FP8. + The OCP specification can be found at https://tinyurl.com/b7jvwpft. + The FNUZ specification can be found at https://tinyurl.com/5n6hwwu5. + + AMD's MI300 and MI325 have native hardware support for FNUZ. All other + hardware has converged on the OCP FP8 standard. + """ + return False + + @classmethod + def fp8_dtype(cls) -> torch.dtype: + """ + Returns the preferred FP8 type on the current platform. + + See the documentation for is_fp8_fnuz for details. + """ + return torch.float8_e4m3fn + + @classmethod + def use_all_gather(cls) -> bool: + """ + Whether to use allgather in LogitsProcessor to gather the logits. + """ + import vllm.envs as envs + from vllm.config import get_current_vllm_config + + parallel_config = get_current_vllm_config().parallel_config + return (envs.VLLM_USE_V1 + or parallel_config.distributed_executor_backend + == "external_launcher") + + @classmethod + def supports_v1(cls, model_config: ModelConfig) -> bool: + """Returns whether the current platform can support v1 for the supplied + model configuration. + """ + return False + + @classmethod + def use_custom_allreduce(cls) -> bool: + """ + Returns if custom allreduce is supported on the current platform + """ + return False + + +class UnspecifiedPlatform(Platform): + _enum = PlatformEnum.UNSPECIFIED + device_type = "" diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py new file mode 100644 index 0000000..c1f426e --- /dev/null +++ b/vllm/platforms/neuron.py @@ -0,0 +1,69 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import TYPE_CHECKING, Optional + +from vllm import envs +from vllm.logger import init_logger + +from .interface import Platform, PlatformEnum + +if TYPE_CHECKING: + from vllm.config import VllmConfig +else: + VllmConfig = None + +logger = init_logger(__name__) + + +class NeuronPlatform(Platform): + _enum = PlatformEnum.NEURON + device_name: str = "neuron" + device_type: str = "neuron" + ray_device_key: str = "neuron_cores" + supported_quantization: list[str] = ["neuron_quant"] + device_control_env_var: str = "NEURON_RT_VISIBLE_CORES" + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + return "neuron" + + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + return False + + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + parallel_config = vllm_config.parallel_config + if parallel_config.worker_cls == "auto": + parallel_config.worker_cls = \ + "vllm.worker.neuron_worker.NeuronWorker" + + if parallel_config.world_size > 1: + parallel_config.distributed_executor_backend = "uni" + + assert (vllm_config.lora_config + is None), "LoRA is not supported for Neuron backend." + assert (not vllm_config.speculative_config + ), "Speculative decoding not yet supported for Neuron backend." + + cache_config = vllm_config.cache_config + if cache_config: + # neuron needs block_size = max_model_len + vllm_config.cache_config.block_size = \ + vllm_config.model_config.max_model_len + + @classmethod + def is_pin_memory_available(cls) -> bool: + logger.warning("Pin memory is not supported on Neuron.") + return False + + @classmethod + def get_device_communicator_cls(cls) -> str: + if envs.VLLM_USE_V1: + return "vllm.distributed.device_communicators.neuron_communicator.NeuronCommunicator" # noqa + else: + return Platform.get_device_communicator_cls() + + @classmethod + def use_all_gather(cls) -> bool: + return True diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py new file mode 100644 index 0000000..d18b7c2 --- /dev/null +++ b/vllm/platforms/rocm.py @@ -0,0 +1,311 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +from functools import cache, lru_cache, wraps +from typing import TYPE_CHECKING, Dict, List, Optional + +import torch + +import vllm.envs as envs +from vllm.logger import init_logger + +from .interface import DeviceCapability, Platform, PlatformEnum, _Backend + +if TYPE_CHECKING: + from vllm.config import ModelConfig, VllmConfig +else: + ModelConfig = None + VllmConfig = None + +logger = init_logger(__name__) + +try: + from amdsmi import (AmdSmiException, amdsmi_get_gpu_asic_info, + amdsmi_get_processor_handles, amdsmi_init, + amdsmi_shut_down, amdsmi_topo_get_link_type) +except ImportError as e: + logger.warning("Failed to import from amdsmi with %r", e) + +try: + import vllm._C # noqa: F401 +except ImportError as e: + logger.warning("Failed to import from vllm._C with %r", e) + +# import custom ops, trigger op registration +try: + import vllm._rocm_C # noqa: F401 +except ImportError as e: + logger.warning("Failed to import from vllm._rocm_C with %r", e) + +# Models not supported by ROCm. +_ROCM_UNSUPPORTED_MODELS: List[str] = [] + +# Models partially supported by ROCm. +# Architecture -> Reason. +_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in " + "Triton flash attention. For half-precision SWA support, " + "please use CK flash attention by setting " + "`VLLM_USE_TRITON_FLASH_ATTN=0`") +_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = { + "Qwen2ForCausalLM": + _ROCM_SWA_REASON, + "MistralForCausalLM": + _ROCM_SWA_REASON, + "MixtralForCausalLM": + _ROCM_SWA_REASON, + "PaliGemmaForConditionalGeneration": + ("ROCm flash attention does not yet " + "fully support 32-bit precision on PaliGemma"), + "Phi3VForCausalLM": + ("ROCm Triton flash attention may run into compilation errors due to " + "excessive use of shared memory. If this happens, disable Triton FA " + "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`") +} + +# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES`` +if "HIP_VISIBLE_DEVICES" in os.environ: + val = os.environ["HIP_VISIBLE_DEVICES"] + if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None): + assert val == cuda_val + else: + os.environ["CUDA_VISIBLE_DEVICES"] = val + +# AMDSMI utils +# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`, +# all the related functions work on real physical device ids. +# the major benefit of using AMDSMI is that it will not initialize CUDA + + +def with_amdsmi_context(fn): + + @wraps(fn) + def wrapper(*args, **kwargs): + amdsmi_init() + try: + return fn(*args, **kwargs) + finally: + amdsmi_shut_down() + + return wrapper + + +def device_id_to_physical_device_id(device_id: int) -> int: + if "CUDA_VISIBLE_DEVICES" in os.environ: + device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",") + physical_device_id = device_ids[device_id] + return int(physical_device_id) + else: + return device_id + + +@cache +def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, + block_size: int, gqa_ratio: int, + max_seq_len: int, + sliding_window: int) -> bool: + + GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName + ON_NAVI = "gfx1" in GPU_ARCH + ON_MI250_MI300 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"]) + + # rocm custom page attention not support on navi (gfx1*) + return (ON_MI250_MI300 and not ON_NAVI + and (sliding_window == 0 or sliding_window == (-1, -1)) + and (qtype == torch.half or qtype == torch.bfloat16) + and (head_size == 64 or head_size == 128) + and (block_size == 16 or block_size == 32) + and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768 + and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) + + +class RocmPlatform(Platform): + _enum = PlatformEnum.ROCM + device_name: str = "rocm" + device_type: str = "cuda" + dispatch_key: str = "CUDA" + ray_device_key: str = "GPU" + # rocm shares the same device control env var as CUDA + device_control_env_var: str = "CUDA_VISIBLE_DEVICES" + + supported_quantization: list[str] = [ + "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors", + "fbgemm_fp8", "gguf", "quark", "ptpc_fp8" + ] + + @classmethod + def get_attn_backend_cls(cls, selected_backend, head_size, dtype, + kv_cache_dtype, block_size, use_v1, + use_mla) -> str: + if use_mla: + logger.info("Using Triton MLA backend.") + return "vllm.attention.backends.triton_mla.TritonMLABackend" + selected_backend = (_Backend.ROCM_FLASH if selected_backend + == _Backend.FLASH_ATTN else selected_backend) + if envs.VLLM_USE_V1: + logger.info("Using Triton Attention backend on V1 engine.") + return ("vllm.v1.attention.backends." + "triton_attn.TritonAttentionBackend") + if selected_backend == _Backend.ROCM_FLASH: + if not cls.has_device_capability(90): + # not Instinct series GPUs. + logger.info("flash_attn is not supported on NAVI GPUs.") + else: + logger.info("%s is not supported in AMD GPUs.", selected_backend) + logger.info("Using ROCmFlashAttention backend.") + return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend" # noqa: E501 + + @classmethod + @lru_cache(maxsize=8) + def get_device_capability(cls, + device_id: int = 0 + ) -> Optional[DeviceCapability]: + major, minor = torch.cuda.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) + + @staticmethod + @with_amdsmi_context + def is_fully_connected(physical_device_ids: List[int]) -> bool: + """ + Query if the set of gpus are fully connected by xgmi (1 hop) + """ + handles = [ + amdsmi_get_processor_handles()[i] for i in physical_device_ids + ] + for i, handle in enumerate(handles): + for j, peer_handle in enumerate(handles): + if i < j: + try: + link_type = amdsmi_topo_get_link_type( + handle, peer_handle) + # type is 2 for XGMI + if link_type["hops"] != 1 or link_type["type"] != 2: + return False + except AmdSmiException as error: + logger.error("AMD 1 hop XGMI detection failed.", + exc_info=error) + return False + return True + + @classmethod + @with_amdsmi_context + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + physical_device_id = device_id_to_physical_device_id(device_id) + handle = amdsmi_get_processor_handles()[physical_device_id] + return amdsmi_get_gpu_asic_info(handle)["market_name"] + + @classmethod + def get_device_total_memory(cls, device_id: int = 0) -> int: + device_props = torch.cuda.get_device_properties(device_id) + return device_props.total_memory + + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + if enforce_eager: + logger.warning( + "To see benefits of async output processing, enable CUDA " + "graph. Since, enforce-eager is enabled, async output " + "processor cannot be used") + return False + return True + + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + cache_config = vllm_config.cache_config + if cache_config and cache_config.block_size is None: + cache_config.block_size = 16 + + parallel_config = vllm_config.parallel_config + scheduler_config = vllm_config.scheduler_config + if parallel_config.worker_cls == "auto": + if scheduler_config.is_multi_step: + if envs.VLLM_USE_V1: + raise NotImplementedError( + "Multi-step scheduling is not supported (and not " + "needed) on vLLM V1. Please launch without " + "--num-scheduler-steps.") + else: + parallel_config.worker_cls = \ + "vllm.worker.multi_step_worker.MultiStepWorker" + elif vllm_config.speculative_config: + if envs.VLLM_USE_V1: + raise NotImplementedError( + "Speculative decoding is not yet supported on vLLM V1." + ) + else: + parallel_config.worker_cls = \ + "vllm.spec_decode.spec_decode_worker.create_spec_worker" + parallel_config.sd_worker_cls = \ + "vllm.worker.worker.Worker" + else: + if envs.VLLM_USE_V1: + parallel_config.worker_cls = \ + "vllm.v1.worker.gpu_worker.Worker" + else: + parallel_config.worker_cls = "vllm.worker.worker.Worker" + + @classmethod + def verify_model_arch(cls, model_arch: str) -> None: + if model_arch in _ROCM_UNSUPPORTED_MODELS: + raise ValueError(f"Model architecture '{model_arch}' is not " + "supported by ROCm for now.") + + if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS: + msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch] + logger.warning( + "Model architecture '%s' is partially " + "supported by ROCm: %s", model_arch, msg) + + @classmethod + def verify_quantization(cls, quant: str) -> None: + super().verify_quantization(quant) + if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ: + logger.warning( + "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ" + " is not set, enabling VLLM_USE_TRITON_AWQ.") + envs.VLLM_USE_TRITON_AWQ = True + + @classmethod + def get_punica_wrapper(cls) -> str: + return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU" + + @classmethod + def get_current_memory_usage(cls, + device: Optional[torch.types.Device] = None + ) -> float: + torch.cuda.reset_peak_memory_stats(device) + return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info( + device)[0] + + @classmethod + def get_device_communicator_cls(cls) -> str: + return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa + + @classmethod + def supports_fp8(cls) -> bool: + gcn_arch = torch.cuda.get_device_properties(0).gcnArchName + return any(gfx in gcn_arch for gfx in ['gfx94', 'gfx95', 'gfx12']) + + @classmethod + def is_fp8_fnuz(cls) -> bool: + # only device 0 is checked, this assumes MI300 platforms are homogeneous + return 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName + + @classmethod + def fp8_dtype(cls) -> torch.dtype: + if cls.is_fp8_fnuz(): + return torch.float8_e4m3fnuz + else: + return torch.float8_e4m3fn + + @classmethod + def supports_v1(cls, model_config: ModelConfig) -> bool: + # V1 support on AMD gpus is experimental + return True + + @classmethod + def use_custom_allreduce(cls) -> bool: + # We only enable custom allreduce for MI300 series + gcn_arch = torch.cuda.get_device_properties(0).gcnArchName + supported_archs = ['gfx94'] + return any(gfx in gcn_arch for gfx in supported_archs) diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py new file mode 100644 index 0000000..43d3044 --- /dev/null +++ b/vllm/platforms/tpu.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import TYPE_CHECKING, Optional + +import torch + +import vllm.envs as envs +from vllm.logger import init_logger + +from .interface import Platform, PlatformEnum, _Backend + +if TYPE_CHECKING: + from vllm.config import ModelConfig, VllmConfig +else: + ModelConfig = None + VllmConfig = None + +logger = init_logger(__name__) + + +class TpuPlatform(Platform): + _enum = PlatformEnum.TPU + device_name: str = "tpu" + device_type: str = "tpu" + dispatch_key: str = "XLA" + ray_device_key: str = "TPU" + device_control_env_var: str = "TPU_VISIBLE_CHIPS" + + supported_quantization: list[str] = [ + "tpu_int8", "compressed-tensors", "compressed_tensors" + ] + + additional_env_vars: list[str] = [ + "TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS" + ] + + @classmethod + def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, + dtype: torch.dtype, kv_cache_dtype: Optional[str], + block_size: int, use_v1: bool, + use_mla: bool) -> str: + if (selected_backend != _Backend.PALLAS + and selected_backend != _Backend.PALLAS_VLLM_V1): + logger.info("Cannot use %s backend on TPU.", selected_backend) + + if use_v1: + logger.info("Using Pallas V1 backend.") + return "vllm.v1.attention.backends.pallas.PallasAttentionBackend" + else: + logger.info("Using Pallas backend.") + return "vllm.attention.backends.pallas.PallasAttentionBackend" + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + return "tpu" + + @classmethod + def get_device_total_memory(cls, device_id: int = 0) -> int: + raise NotImplementedError + + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + return not envs.VLLM_USE_V1 + + @classmethod + def inference_mode(cls): + return torch.no_grad() + + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + from vllm.config import CompilationLevel + + cache_config = vllm_config.cache_config + if cache_config and cache_config.block_size is None: + cache_config.block_size = 16 + + compilation_config = vllm_config.compilation_config + + # TPU only supports DYNAMO_ONCE compilation level + if compilation_config.level != CompilationLevel.DYNAMO_ONCE: + logger.info("[TPU] Forcing DYNAMO_ONCE compilation level") + compilation_config.level = CompilationLevel.DYNAMO_ONCE + + if compilation_config.backend == "": + compilation_config.backend = "openxla" + + assert vllm_config.speculative_config is None, \ + "TPU does not support speculative decoding" + + if vllm_config.model_config.dtype in (torch.float16, torch.float32): + logger.warning( + "The TPU backend currently does not support %s. " + "Using bfloat16 instead.", vllm_config.model_config.dtype) + vllm_config.model_config.dtype = torch.bfloat16 + + parallel_config = vllm_config.parallel_config + scheduler_config = vllm_config.scheduler_config + if parallel_config.worker_cls == "auto": + if scheduler_config.is_multi_step: + if envs.VLLM_USE_V1: + raise NotImplementedError( + "Multi-step scheduling is not supported (and not " + "needed) on vLLM V1. Please launch without " + "--num-scheduler-steps.") + else: + parallel_config.worker_cls = \ + "vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker" + else: + if envs.VLLM_USE_V1: + parallel_config.worker_cls = \ + "vllm.v1.worker.tpu_worker.TPUWorker" + else: + parallel_config.worker_cls = \ + "vllm.worker.tpu_worker.TPUWorker" + + assert not vllm_config.speculative_config, ( + "Speculative decoding is not yet supported for TPU backend") + + @classmethod + def is_pin_memory_available(cls): + logger.warning("Pin memory is not supported on TPU.") + return False + + @classmethod + def get_device_communicator_cls(cls) -> str: + return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator" # noqa + + @classmethod + def use_all_gather(cls) -> bool: + return True + + @classmethod + def supports_v1(cls, model_config: ModelConfig) -> bool: + # V1 support on TPU is experimental + return True diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py new file mode 100644 index 0000000..225e756 --- /dev/null +++ b/vllm/platforms/xpu.py @@ -0,0 +1,142 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import TYPE_CHECKING, Optional + +import torch + +from vllm.logger import init_logger + +from .interface import DeviceCapability, Platform, PlatformEnum, _Backend + +if TYPE_CHECKING: + from vllm.config import VllmConfig +else: + VllmConfig = None + +logger = init_logger(__name__) + + +class XPUPlatform(Platform): + _enum = PlatformEnum.XPU + device_name: str = "xpu" + device_type: str = "xpu" + dispatch_key: str = "XPU" + # Intel XPU's device key is "GPU" for Ray. + # see https://github.com/ray-project/ray/blob/6a5eb5865eeb9ccf058a79b44f107e327e360673/python/ray/_private/accelerators/intel_gpu.py#L20 # noqa: E501 + ray_device_key: str = "GPU" + device_control_env_var: str = "ONEAPI_DEVICE_SELECTOR" + + @classmethod + def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, + dtype: torch.dtype, kv_cache_dtype: Optional[str], + block_size: int, use_v1: bool, + use_mla: bool) -> str: + if selected_backend != _Backend.IPEX: + logger.info("Cannot use %s backend on XPU.", selected_backend) + logger.info("Using IPEX attention backend.") + return "vllm.attention.backends.ipex_attn.IpexAttnBackend" + + @staticmethod + def get_device_capability( + device_id: int = 0) -> Optional[DeviceCapability]: + # capacity format differs from cuda's and will cause unexpected + # failure, so use None directly + return None + + @staticmethod + def get_device_name(device_id: int = 0) -> str: + return torch.xpu.get_device_name(device_id) + + @classmethod + def get_device_total_memory(cls, device_id: int = 0) -> int: + device_props = torch.xpu.get_device_properties(device_id) + return device_props.total_memory + + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + return True + + @staticmethod + def inference_mode(): + return torch.no_grad() + + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + cache_config = vllm_config.cache_config + if cache_config and cache_config.block_size is None: + cache_config.block_size = 16 + + # check and update model config + model_config = vllm_config.model_config + if model_config.dtype == torch.bfloat16: + bf16_supported = cls.device_support_bf16() + if not bf16_supported: + logger.warning( + "bfloat16 is only supported on Intel Data Center GPU, " + "Intel Arc GPU is not supported yet. Your device is %s," + " which is not supported. will fallback to float16", + cls.get_device_name()) + model_config.dtype = torch.float16 + if not model_config.enforce_eager: + logger.warning( + "CUDA graph is not supported on XPU, fallback to the eager " + "mode.") + model_config.enforce_eager = True + + if vllm_config.speculative_config is not None: + raise NotImplementedError( + "XPU does not support speculative decoding") + + if vllm_config.device_config is not None: + assert vllm_config.device_config.device_type == "xpu" + + # check and update parallel config + parallel_config = vllm_config.parallel_config + if parallel_config.worker_cls == "auto": + parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker" + + if parallel_config.distributed_executor_backend is None: + parallel_config.distributed_executor_backend = "ray" + elif parallel_config.distributed_executor_backend == "mp": + # FIXME(kunshang): + # spawn needs calling `if __name__ == '__main__':`` + # fork is not supported for xpu start new process. + logger.error( + "Both start methods (spawn and fork) have issue " + "on XPU if you use mp backend, setting it to ray instead.") + parallel_config.distributed_executor_backend = "ray" + + elif parallel_config.distributed_executor_backend != "ray": + logger.warning( + "%s is not supported on XPU, fallback to ray distributed" + " executor backend.", + parallel_config.distributed_executor_backend) + parallel_config.distributed_executor_backend = "ray" + + @classmethod + def is_pin_memory_available(cls): + logger.warning("Pin memory is not supported on XPU.") + return False + + @classmethod + def get_current_memory_usage(cls, + device: Optional[torch.types.Device] = None + ) -> float: + torch.xpu.reset_peak_memory_stats(device) + return torch.xpu.max_memory_allocated(device) + + @classmethod + def device_support_bf16(cls) -> bool: + device_name = cls.get_device_name().lower() + if device_name.count("arc") > 0: + return False + elif device_name.count("data center gpu") > 0: + return True + else: + logger.warning("Unknown device name %s, always use float16", + device_name) + return False + + @classmethod + def get_device_communicator_cls(cls) -> str: + return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py new file mode 100644 index 0000000..389cb87 --- /dev/null +++ b/vllm/plugins/__init__.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os +from typing import Callable, Dict + +import torch + +import vllm.envs as envs + +logger = logging.getLogger(__name__) + +# make sure one process only loads plugins once +plugins_loaded = False + + +def load_plugins_by_group(group: str) -> Dict[str, Callable]: + import sys + if sys.version_info < (3, 10): + from importlib_metadata import entry_points + else: + from importlib.metadata import entry_points + + allowed_plugins = envs.VLLM_PLUGINS + + discovered_plugins = entry_points(group=group) + if len(discovered_plugins) == 0: + logger.debug("No plugins for group %s found.", group) + return {} + logger.info("Available plugins for group %s:", group) + for plugin in discovered_plugins: + logger.info("name=%s, value=%s", plugin.name, plugin.value) + if allowed_plugins is None: + logger.info("all available plugins for group %s will be loaded.", + group) + logger.info("set environment variable VLLM_PLUGINS to control" + " which plugins to load.") + plugins = {} + for plugin in discovered_plugins: + if allowed_plugins is None or plugin.name in allowed_plugins: + try: + func = plugin.load() + plugins[plugin.name] = func + logger.info("plugin %s loaded.", plugin.name) + except Exception: + logger.exception("Failed to load plugin %s", plugin.name) + return plugins + + +def load_general_plugins(): + """WARNING: plugins can be loaded for multiple times in different + processes. They should be designed in a way that they can be loaded + multiple times without causing issues. + """ + global plugins_loaded + if plugins_loaded: + return + plugins_loaded = True + + # some platform-specific configurations + from vllm.platforms import current_platform + + if current_platform.is_xpu(): + # see https://github.com/pytorch/pytorch/blob/43c5f59/torch/_dynamo/config.py#L158 + torch._dynamo.config.disable = True + elif current_platform.is_hpu(): + # NOTE(kzawora): PT HPU lazy backend (PT_HPU_LAZY_MODE = 1) + # does not support torch.compile + # Eager backend (PT_HPU_LAZY_MODE = 0) must be selected for + # torch.compile support + is_lazy = os.environ.get('PT_HPU_LAZY_MODE', '1') == '1' + if is_lazy: + torch._dynamo.config.disable = True + # NOTE(kzawora) multi-HPU inference with HPUGraphs (lazy-only) + # requires enabling lazy collectives + # see https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html # noqa: E501 + os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES'] = 'true' + + plugins = load_plugins_by_group(group='vllm.general_plugins') + # general plugins, we only need to execute the loaded functions + for func in plugins.values(): + func() diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py new file mode 100644 index 0000000..061232e --- /dev/null +++ b/vllm/pooling_params.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional + +import msgspec + + +class PoolingParams( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] + """API parameters for pooling models. This is currently a placeholder. + + Attributes: + additional_data: Any additional data needed for pooling. + """ + additional_data: Optional[Any] = None + + def clone(self) -> "PoolingParams": + """Returns a deep copy of the PoolingParams instance.""" + return PoolingParams(additional_data=self.additional_data) + + def __repr__(self) -> str: + return (f"PoolingParams(" + f"additional_metadata={self.additional_data})") diff --git a/vllm/profiler/__init__.py b/vllm/profiler/__init__.py new file mode 100644 index 0000000..00af72b --- /dev/null +++ b/vllm/profiler/__init__.py @@ -0,0 +1,7 @@ +# SPDX-License-Identifier: Apache-2.0 + +from .layerwise_profile import layerwise_profile + +__all__ = [ + "layerwise_profile", +] diff --git a/vllm/profiler/layerwise_profile.py b/vllm/profiler/layerwise_profile.py new file mode 100644 index 0000000..6351ef6 --- /dev/null +++ b/vllm/profiler/layerwise_profile.py @@ -0,0 +1,374 @@ +# SPDX-License-Identifier: Apache-2.0 + +import copy +from collections import defaultdict +from dataclasses import asdict, dataclass, field +from typing import Any, Callable, Dict, List, Optional, Tuple, TypeAlias, Union + +import pandas as pd +from torch._C._autograd import DeviceType, _KinetoEvent, _ProfilerResult +from torch._C._profiler import _EventType, _ExperimentalConfig, _ProfilerEvent +from torch.autograd.profiler import FunctionEvent +from torch.profiler import ProfilerActivity, profile + +from vllm.profiler.utils import (TablePrinter, event_has_module, + event_is_torch_op, event_module_repr, + event_torch_op_stack_trace, indent_string) + + +@dataclass +class _ModuleTreeNode: + event: _ProfilerEvent + parent: Optional['_ModuleTreeNode'] = None + children: List['_ModuleTreeNode'] = field(default_factory=list) + trace: str = "" + + @property + def is_leaf(self): + return (self.event.children is None or len(self.event.children) == 0) + + @property + def is_torch_op(self): + return event_is_torch_op(self.event) + + @property + def is_cuda(self): + return (self.event.tag == _EventType.Kineto + and self.event.typed[1].device_type == DeviceType.CUDA) + + +@dataclass +class SummaryStatsEntry: + name: str + cuda_time_us: float + pct_cuda_time: float + invocations: int + + +@dataclass +class ModelStatsEntry: + name: str + cpu_time_us: float + cuda_time_us: float + pct_cuda_time: float + trace: str + + +StatsEntry: TypeAlias = Union[ModelStatsEntry, SummaryStatsEntry] + + +@dataclass +class _StatsTreeNode: + entry: StatsEntry + children: List[StatsEntry] + parent: Optional[StatsEntry] + + +@dataclass +class LayerwiseProfileResults(profile): + _kineto_results: _ProfilerResult + _kineto_event_correlation_map: Dict[int, + List[_KinetoEvent]] = field(init=False) + _event_correlation_map: Dict[int, List[FunctionEvent]] = field(init=False) + _module_tree: List[_ModuleTreeNode] = field(init=False) + _model_stats_tree: List[_StatsTreeNode] = field(init=False) + _summary_stats_tree: List[_StatsTreeNode] = field(init=False) + + # profile metadata + num_running_seqs: Optional[int] = None + + def __post_init__(self): + self._build_correlation_map() + self._build_module_tree() + self._build_stats_trees() + + def print_model_table(self, column_widths: Dict[str, int] = None): + _column_widths = dict(name=60, + cpu_time_us=12, + cuda_time_us=12, + pct_cuda_time=12, + trace=60) + if column_widths: + _column_widths.update(**column_widths) + filtered_model_table = [ + (depth, row) + for depth, row in self._flatten_stats_tree(self._model_stats_tree) + if row.cuda_time_us > 0 or row.cpu_time_us > 0 + ] + TablePrinter(ModelStatsEntry, _column_widths).print_table( + self._indent_row_names_based_on_depth( + filtered_model_table, + indent_style=lambda indent: "|" + "-" * indent + " ")) + + def print_summary_table(self, column_widths: Dict[str, int] = None): + _column_widths = dict(name=80, + cuda_time_us=12, + pct_cuda_time=12, + invocations=15) + if column_widths: + _column_widths.update(**column_widths) + filtered_summary_table = [(depth, row) + for depth, row in self._flatten_stats_tree( + self._summary_stats_tree) + if row.cuda_time_us > 0] + TablePrinter(SummaryStatsEntry, _column_widths).print_table( + self._indent_row_names_based_on_depth( + filtered_summary_table, + indent_style=lambda indent: "|" + "-" * indent + " ")) + + def export_model_stats_table_csv(self, filename: str): + df = pd.DataFrame([ + asdict(row) + for _, row in self._flatten_stats_tree(self._model_stats_tree) + ]) + df.to_csv(filename) + + def export_summary_stats_table_csv(self, filename: str): + df = pd.DataFrame([ + asdict(row) + for _, row in self._flatten_stats_tree(self._summary_stats_tree) + ]) + df.to_csv(filename) + + def convert_stats_to_dict(self) -> dict[str, Any]: + return { + "metadata": { + "num_running_seqs": self.num_running_seqs + }, + "summary_stats": + self._convert_stats_tree_to_dict(self._summary_stats_tree), + "model_stats": + self._convert_stats_tree_to_dict(self._model_stats_tree) + } + + @staticmethod + def _indent_row_names_based_on_depth(depths_rows: List[Tuple[int, + StatsEntry]], + indent_style: Union[Callable[[int], + str], + str] = " "): + indented_rows = [] + for depth, row in depths_rows: + if row.cuda_time_us == 0: + continue + indented_row = copy.deepcopy(row) + indented_row.name = indent_string(indented_row.name, depth, + indent_style) + indented_rows.append(indented_row) + return indented_rows + + def _build_correlation_map(self): + self._kineto_event_correlation_map = defaultdict(list) + for event in self._kineto_results.events(): + self._kineto_event_correlation_map[event.correlation_id()].append( + event) + + def _build_module_tree(self): + self._module_tree = [] + event_tree = self._kineto_results.experimental_event_tree() + + def _df_traversal(event: _ProfilerEvent, + curr_node: Optional[_ModuleTreeNode] = None): + + # For the tensor parallel case for now only look at task 1 + if event.start_tid != 1: + return + + if event_has_module(event): + node = _ModuleTreeNode(event=event, parent=curr_node) + if curr_node: + curr_node.children.append(node) + else: + self._module_tree.append(node) + curr_node = node + + is_leaf = (event.children is None or len(event.children) == 0) + if is_leaf and curr_node: + node = _ModuleTreeNode( + event=event, + parent=curr_node, + trace=event_torch_op_stack_trace( + event, until=lambda x: event_has_module(x))) + curr_node.children.append(node) + curr_node = node + + for child in event.children: + _df_traversal(child, curr_node) + + for root in event_tree: + _df_traversal(root) + + def _get_kineto_gpu_event(self, node: _ModuleTreeNode): + if node.event.tag != _EventType.Kineto: + return None + correlated_kineto_events = self._kineto_event_correlation_map.get( + node.event.correlation_id, []) + iterator = (x for x in correlated_kineto_events + if x.device_type() == DeviceType.CUDA + and x.name() == node.event.name) + return next(iterator, None) + + def _cumulative_cuda_time(self, node: _ModuleTreeNode): + 'Return cuda time in microseconds' + + def _cumulative_cuda_time_recursive(node: _ModuleTreeNode): + if node.is_leaf and (gpu_kineto_event := + self._get_kineto_gpu_event(node)): + return gpu_kineto_event.duration_ns() / 1000.0 + else: + cumulative_cuda_time = 0 + for child in node.children: + cumulative_cuda_time += _cumulative_cuda_time_recursive( + child) + return cumulative_cuda_time + + return _cumulative_cuda_time_recursive(node) + + def _total_cuda_time(self): + return sum( + [self._cumulative_cuda_time(root) for root in self._module_tree]) + + def _build_stats_trees(self): + summary_dict: Dict[str, _StatsTreeNode] = {} + total_cuda_time = self._total_cuda_time() + + def pct_cuda_time(cuda_time_us): + return (cuda_time_us / total_cuda_time) * 100 + + def build_summary_stats_tree_df( + node: _ModuleTreeNode, + parent: Optional[_StatsTreeNode] = None, + summary_trace: Tuple[str] = ()): + + if event_has_module(node.event): + name = event_module_repr(node.event) + cuda_time_us = self._cumulative_cuda_time(node) + elif (gpu_kineto_event := self._get_kineto_gpu_event(node)): + name = gpu_kineto_event.name() + cuda_time_us = gpu_kineto_event.duration_ns() / 1000.0 + else: + return None + + summary_trace = summary_trace + (name, ) + if summary_trace in summary_dict: + entry = summary_dict[summary_trace].entry + entry.cuda_time_us += cuda_time_us + entry.invocations += 1 + entry.pct_cuda_time = pct_cuda_time(entry.cuda_time_us) + else: + new_node = _StatsTreeNode(entry=SummaryStatsEntry( + name=name, + cuda_time_us=cuda_time_us, + pct_cuda_time=pct_cuda_time(cuda_time_us), + invocations=1), + children=[], + parent=parent) + if parent: + parent.children.append(new_node) + summary_dict[summary_trace] = new_node + + for child in node.children: + build_summary_stats_tree_df(child, summary_dict[summary_trace], + summary_trace) + + return summary_dict[summary_trace] + + self._summary_stats_tree = [] + for root in self._module_tree: + self._summary_stats_tree.append(build_summary_stats_tree_df(root)) + + def build_model_stats_tree_df(node: _ModuleTreeNode, + parent: Optional[_StatsTreeNode] = None): + if event_has_module(node.event, ): + name = event_module_repr(node.event) + cuda_time_us = self._cumulative_cuda_time(node) + cpu_time_us = node.event.duration_time_ns / 1000 + trace = "" + elif (gpu_kineto_event := self._get_kineto_gpu_event(node)): + name = gpu_kineto_event.name() + cuda_time_us = gpu_kineto_event.duration_ns() / 1000.0 + cpu_time_us = 0 + trace = node.trace + else: + return None + + new_node = _StatsTreeNode(entry=ModelStatsEntry( + name=name, + cpu_time_us=cpu_time_us, + cuda_time_us=cuda_time_us, + pct_cuda_time=pct_cuda_time(cuda_time_us), + trace=trace), + parent=parent, + children=[]) + if parent: + parent.children.append(new_node) + + for child in node.children: + build_model_stats_tree_df(child, new_node) + + return new_node + + self._model_stats_tree = [] + for root in self._module_tree: + self._model_stats_tree.append(build_model_stats_tree_df(root)) + + def _flatten_stats_tree( + self, tree: List[_StatsTreeNode]) -> List[Tuple[int, StatsEntry]]: + entries: List[Tuple[int, StatsEntry]] = [] + + def df_traversal(node: _StatsTreeNode, depth=0): + entries.append((depth, node.entry)) + for child in node.children: + df_traversal(child, depth=depth + 1) + + for root in tree: + df_traversal(root) + + return entries + + def _convert_stats_tree_to_dict(self, + tree: List[_StatsTreeNode]) -> List[Dict]: + root_dicts: List[Dict] = [] + + def df_traversal(node: _StatsTreeNode, curr_json_list: List[Dict]): + curr_json_list.append({ + "entry": asdict(node.entry), + "children": [] + }) + for child in node.children: + df_traversal(child, curr_json_list[-1]["children"]) + + for root in tree: + df_traversal(root, root_dicts) + + return root_dicts + + +class layerwise_profile(profile): + + def __init__(self, num_running_seqs: Optional[int] = None): + """ + layerwise profile constructor. + + Args: + num_running_seqs (Optional[int], optional): When given, + num_running_seqs will be passed to LayerProfileResults for metadata + update. Defaults to None. + """ + super().__init__( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=True, + with_stack=True, + with_modules=True, + experimental_config=_ExperimentalConfig(verbose=True)) + + self.num_running_seqs = num_running_seqs + + def __enter__(self): + return super().__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + super().__exit__(exc_type, exc_val, exc_tb) + self.results = LayerwiseProfileResults( + self.profiler.kineto_results, + num_running_seqs=self.num_running_seqs) diff --git a/vllm/profiler/utils.py b/vllm/profiler/utils.py new file mode 100644 index 0000000..62b39f5 --- /dev/null +++ b/vllm/profiler/utils.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +from typing import Callable, Dict, List, Type, Union + +from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata + +# +# String / Print Manipulation +# + + +def trim_string_front(string, width): + if len(string) > width: + offset = len(string) - width + 3 + string = string[offset:] + if len(string) > 3: + string = "..." + string[3:] + return string + + +def trim_string_back(string, width): + if len(string) > width: + offset = len(string) - width + 3 + string = string[:-offset] + if len(string) > 3: + string = string + "..." + return string + + +class TablePrinter: + + def __init__(self, row_cls: Type[dataclasses.dataclass], + column_widths: Dict[str, int]): + self.row_cls = row_cls + self.fieldnames = [x.name for x in dataclasses.fields(row_cls)] + self.column_widths = column_widths + assert set(self.column_widths.keys()) == set(self.fieldnames) + + def print_table(self, rows: List[dataclasses.dataclass]): + self._print_header() + self._print_line() + for row in rows: + self._print_row(row) + + def _print_header(self): + for i, f in enumerate(self.fieldnames): + last = (i == len(self.fieldnames) - 1) + col_width = self.column_widths[f] + print(trim_string_back(f, col_width).ljust(col_width), + end=" | " if not last else "\n") + + def _print_row(self, row): + assert isinstance(row, self.row_cls) + + for i, f in enumerate(self.fieldnames): + last = (i == len(self.fieldnames) - 1) + col_width = self.column_widths[f] + val = getattr(row, f) + + val_str = "" + if isinstance(val, str): + val_str = trim_string_back(val, col_width).ljust(col_width) + elif type(val) in [float, int]: + val_str = f"{float(val):>.2f}".rjust(col_width) + else: + val_str = f"{val}".rjust(col_width) + print(val_str, end=" | " if not last else "\n") + + def _print_line(self): + total_col_width = 0 + for column_width in self.column_widths.values(): + total_col_width += column_width + print("=" * (total_col_width + 3 * (len(self.column_widths) - 1))) + + +def indent_string(string: str, + indent: int, + indent_style: Union[Callable[[int], str], str] = " ") -> str: + if indent: + if isinstance(indent_style, str): + return indent_style * indent + string + else: + return indent_style(indent) + string + else: + return string + + +# +# _ProfilerEvent utils +# + + +def event_has_module(event: _ProfilerEvent) -> bool: + event_type, typed_event = event.typed + if event_type == _EventType.PyCall: + return typed_event.module is not None + return False + + +def event_is_torch_op(event: _ProfilerEvent) -> bool: + return event.tag == _EventType.TorchOp + + +def event_arg_repr(arg) -> str: + if arg is None or type(arg) in [float, int, bool, str]: + return f"{arg}" + elif isinstance(arg, list): + return f"[{', '.join([event_arg_repr(x) for x in arg])}]" + elif isinstance(arg, tuple): + return f"({', '.join([event_arg_repr(x) for x in arg])})" + else: + assert isinstance(arg, + _TensorMetadata), f"Unsupported type: {type(arg)}" + sizes_str = ', '.join([str(x) for x in arg.sizes]) + return f"{str(arg.dtype).replace('torch.', '')}[{sizes_str}]" + + +def event_torch_op_repr(event: _ProfilerEvent) -> str: + assert event.tag == _EventType.TorchOp + args_str = ', '.join([event_arg_repr(x) for x in event.typed[1].inputs]) + return f"{event.name}({args_str})".replace("aten::", "") + + +def event_module_repr(event: _ProfilerEvent) -> str: + assert event_has_module(event) + module = event.typed[1].module + if module.parameters and len(module.parameters) > 0: + args_str = ', '.join( + [f'{x[0]}={event_arg_repr(x[1])}' for x in module.parameters]) + return f"{module.cls_name}({args_str})" + else: + return module.cls_name + + +def event_torch_op_stack_trace(curr_event: _ProfilerEvent, + until: Callable[[_ProfilerEvent], bool]) -> str: + trace = "" + curr_event = curr_event.parent + while curr_event and not until(curr_event): + if event_is_torch_op(curr_event): + if len(trace) > 0: + trace += " <- " + trace += event_torch_op_repr(curr_event) + curr_event = curr_event.parent + + return trace diff --git a/vllm/prompt_adapter/__init__.py b/vllm/prompt_adapter/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/prompt_adapter/layers.py b/vllm/prompt_adapter/layers.py new file mode 100644 index 0000000..c2f9f16 --- /dev/null +++ b/vllm/prompt_adapter/layers.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import nn + +from vllm.adapter_commons.layers import AdapterMapping +from vllm.config import PromptAdapterConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) + + +@dataclass +class PromptAdapterMapping(AdapterMapping): + pass + + +class VocabParallelEmbeddingWithPromptAdapter(nn.Module): + + def __init__(self, base_layer: VocabParallelEmbedding) -> None: + super().__init__() + self.base_layer = base_layer + self.emb_layer = self.base_layer + if 'LoRA' in base_layer.__class__.__name__: + self.emb_layer = self.base_layer.base_layer + + def create_prompt_adapter_weights( + self, prompt_adapter_config: PromptAdapterConfig): + self.embeddings_tensors = torch.zeros( + ( + prompt_adapter_config.max_prompt_adapters, + prompt_adapter_config.max_prompt_adapter_token, + self.emb_layer.embedding_dim, + ), + dtype=self.emb_layer.weight.dtype, + device=self.emb_layer.weight.device, + ) + self.adapter_lengths = torch.zeros( + prompt_adapter_config.max_prompt_adapters, + dtype=torch.long, + device=self.emb_layer.weight.device) + + self.indices_gpu: torch.Tensor + self.embedding_indices_gpu: torch.Tensor + + def reset_prompt_adapter(self, index: int): + self.embeddings_tensors[index] = 0 + + def set_prompt_adapter( + self, + index: int, + adapter_model: Optional[torch.Tensor], + ): + self.reset_prompt_adapter(index) + if adapter_model is not None: + length = adapter_model.shape[0] + self.embeddings_tensors[index, :length] = adapter_model + self.adapter_lengths[index] = length + + def set_mapping( + self, + prompt_indices: torch.Tensor, + prompt_embedding_indices: torch.Tensor, + ): + self.indices_gpu = prompt_indices.to( + device=self.emb_layer.weight.device) + self.embedding_indices_gpu = prompt_embedding_indices.to( + device=self.emb_layer.weight.device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + hidden_states = self.base_layer(x) + if self.embedding_indices_gpu.ndim > 1: + valid_mask = self.indices_gpu != -1 + gathered_embeddings = self.embeddings_tensors[ + self.embedding_indices_gpu[:, 0], + self.embedding_indices_gpu[:, 1]] + + # Update hidden states + hidden_states[valid_mask] = gathered_embeddings + return hidden_states \ No newline at end of file diff --git a/vllm/prompt_adapter/models.py b/vllm/prompt_adapter/models.py new file mode 100644 index 0000000..7955916 --- /dev/null +++ b/vllm/prompt_adapter/models.py @@ -0,0 +1,357 @@ +# SPDX-License-Identifier: Apache-2.0 + +import logging +import math +from typing import Any, Callable, Dict, List, Optional, Type + +import torch +from torch import nn + +from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel, + AdapterModelManager) +from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter, + get_adapter, list_adapters, + remove_adapter, set_adapter_mapping) +from vllm.config import PromptAdapterConfig +from vllm.prompt_adapter.layers import ( + VocabParallelEmbeddingWithPromptAdapter) # yapf: disable +from vllm.prompt_adapter.layers import PromptAdapterMapping +from vllm.prompt_adapter.utils import load_peft_weights + +logger = logging.getLogger(__name__) + +_GLOBAL_PROMPT_ADAPTER_ID = 0 + + +def get_prompt_adapter_id(): + global _GLOBAL_PROMPT_ADAPTER_ID + _GLOBAL_PROMPT_ADAPTER_ID += 1 + return _GLOBAL_PROMPT_ADAPTER_ID + + +def convert_to_embedding_indices(indices): + embedding_indices = [] + count = 0 + + for value in indices: + if value == -1: + count = 0 + else: + embedding_indices.append([value, count]) + count += 1 + + return torch.tensor(embedding_indices) + + +def convert_mapping( + mapping: PromptAdapterMapping, + prompt_adapter_index_to_id: List[Optional[int]], +) -> torch.Tensor: + """Converts PromptAdapterMapping to index tensors. + + Args: + mapping: PromptAdapterMapping mapping rows in a + batch to PromptAdapter ids. + prompt_adapter_index_to_id: List mapping PromptAdapter + ids to PromptAdapter indices. + + Returns: + pa_indices: Tensor of shape [batch_size] mapping batch rows to + PromptAdapter indices. + """ + id_to_index = { + id_: idx + for idx, id_ in enumerate(prompt_adapter_index_to_id) + if id_ is not None + } + pa_indices = ([ + id_to_index.get(id_, -1) if id_ > 0 else -1 + for id_ in mapping.index_mapping + ]) + + pa_embedding_mapping = convert_to_embedding_indices(pa_indices) + pa_indices = torch.tensor(pa_indices) + return pa_indices, pa_embedding_mapping + + +class PromptAdapterModel(AdapterModel): + + def __init__(self, + prompt_adapter_id=None, + num_virtual_tokens=None, + prompt_embedding=None) -> None: + self.id = prompt_adapter_id + self.prompt_embedding = prompt_embedding + self.num_virtual_tokens = num_virtual_tokens + + @classmethod + def from_local_checkpoint( + cls, + adapter_model_path: str, + prompt_adapter_id: int, + num_virtual_tokens: int, + config: PromptAdapterConfig, + device: str = "cuda", + ) -> "PromptAdapterModel": + + if num_virtual_tokens > config.max_prompt_adapter_token: + raise ValueError( + f'num_virtual_tokens ({num_virtual_tokens}) should be <= ' + f'max_prompt_adapter_token({config.max_prompt_adapter_token})') + + adapters_weights = load_peft_weights(adapter_model_path, device) + prompt_embedding = adapters_weights["prompt_embeddings"].to( + config.prompt_adapter_dtype) + + return cls(prompt_adapter_id, num_virtual_tokens, prompt_embedding) + + +class PromptAdapterModelManager(AdapterModelManager): + """A manager that manages multiple Prompt Adapter models.""" + + def __init__( + self, + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + prompt_adapter_config: PromptAdapterConfig, + ): + """Create a PromptAdapterModel and adapter for a given model. + + Args: + model: the model to be adapted. + max_num_seqs: the maximum number of sequences model can run in a + single batch. + max_num_batched_tokens: the maximum number of tokens model can run + in a single batch. + prompt_adapter_config: the PromptAdapter config, + """ + self.model: nn.Module = model + # Dict instead of a Set for compatibility with LRUCache. + self.prompt_adapter_index_to_id: List[ + Optional[int]] = [None] * self.prompt_adapter_slots + self.max_num_seqs = max_num_seqs + self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 + self.prompt_adapter_config = prompt_adapter_config + self.model.prompt_adapter_manager = self + self.adapter_type = 'PromptAdapter' + + self.base_indices = torch.tensor([-1]) + self.base_embedding_indices = torch.tensor([]) + + self.modules: Dict[str, nn.Module] = {} + self._create_prompt_adapter_modules() + self._last_mapping: Optional[PromptAdapterMapping] = None + + @property + def prompt_adapter_slots(self) -> int: + return self.prompt_adapter_config.max_prompt_adapters + + @property + def adapter_slots(self) -> int: + return self.prompt_adapter_slots + + @property + def capacity(self) -> int: + return self.prompt_adapter_config.max_cpu_prompt_adapters + + def activate_adapter( + self, + prompt_adapter_id: int, + ) -> bool: + """Move PromptAdapter into a GPU buffer + to be used in the forward pass.""" + if prompt_adapter_id in self._active_adapters: + return False + first_free_slot = next( + ((i, prompt_adapter_id) for i, prompt_adapter_id in enumerate( + self.prompt_adapter_index_to_id) if prompt_adapter_id is None), + None) + if first_free_slot is None: + raise ValueError("No free prompt_adapter slots") + index, _ = first_free_slot + self._active_adapters[prompt_adapter_id] = None + prompt_adapter_model = (self._registered_adapters[prompt_adapter_id]) + logger.debug("Activating prompt_adapter. int id: %d, slot index: %d", + prompt_adapter_model.id, index) + self.prompt_adapter_index_to_id[index] = prompt_adapter_model.id + for _, v in self.modules.items(): + v.set_prompt_adapter(index, prompt_adapter_model.prompt_embedding) + return True + + def _deactivate_adapter(self, prompt_adapter_id: int): + try: + index = self.prompt_adapter_index_to_id.index(prompt_adapter_id) + self.prompt_adapter_index_to_id[index] = None + for _, v in self.modules.items(): + v.reset_prompt_adapter(index) + except ValueError: + pass + + def _add_adapter(self, prompt_adapter: PromptAdapterModel): + self._registered_adapters[prompt_adapter.id] = prompt_adapter + + def _set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None: + base_indices, base_embedding_indices = convert_mapping( + mapping, self.prompt_adapter_index_to_id) + for k, v in self.modules.items(): + v.set_mapping(base_indices, base_embedding_indices) + + def _create_prompt_adapter_modules(self): + for module_name, module in self.model.named_modules( + remove_duplicate=False): + if "VocabParallel" in module.__class__.__name__: + new_module = VocabParallelEmbeddingWithPromptAdapter(module) + new_module.create_prompt_adapter_weights( + self.prompt_adapter_config) + replaced_module = self.replace_submodule( + self.model, module_name, new_module) + self.register_module(module.__class__.__name__, + replaced_module) + replaced_module.set_mapping(self.base_indices, + self.base_embedding_indices) + break + + def replace_submodule(self, model: nn.Module, module_name: str, + new_module: nn.Module) -> nn.Module: + """Replace a submodule in a model with a new module.""" + parent = model.get_submodule(".".join(module_name.split(".")[:-1])) + target_name = module_name.split(".")[-1] + setattr(parent, target_name, new_module) + return new_module + + def register_module(self, module_name: str, module: nn.Module): + self.modules[module_name] = module + + def pin_adapter(self, prompt_adapter_id: int) -> bool: + """Pin a PromptAdapterModel in the manager cache.""" + raise NotImplementedError( + "Pinning is not supported in PromptAdapterModelManager. " + "Use LRUCachePromptAdapterModelManager for pinning" + ) # type: ignore + + def remove_all_adapters(self): + """Remove all PromptAdapterModel from the manager.""" + self._registered_adapters.clear() + self.prompt_adapter_index_to_id = [None] * self.prompt_adapter_slots + self._active_adapters.clear() + + def deactivate_adapter(self, adapter_id: int) -> bool: + return deactivate_adapter(adapter_id, self._active_adapters, + self._deactivate_adapter) + + def add_adapter(self, adapter: PromptAdapterModel) -> bool: + return add_adapter(adapter, self._registered_adapters, self.capacity, + self._add_adapter) + + def set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None: + self._last_mapping = set_adapter_mapping(mapping, self._last_mapping, + self._set_adapter_mapping) + + def remove_adapter(self, adapter_id: int) -> bool: + return remove_adapter(adapter_id, self._registered_adapters, + self.deactivate_adapter) + + def list_adapters(self) -> Dict[int, Any]: + return list_adapters(self._registered_adapters) + + def get_adapter(self, adapter_id: int) -> Optional[Any]: + return get_adapter(adapter_id, self._registered_adapters) + + +class PromptAdapterLRUCache(AdapterLRUCache[PromptAdapterModel]): + + def __init__(self, capacity: int, + deactivate_prompt_adapter_fn: Callable[[int], bool]): + super().__init__(capacity, deactivate_prompt_adapter_fn) + + +class LRUCachePromptAdapterModelManager(PromptAdapterModelManager): + """A model manager that manages multiple prompt_adapters with LRU cache.""" + + def __init__( + self, + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + prompt_adapter_config: PromptAdapterConfig, + ): + self.prompt_adapter_config = prompt_adapter_config + super().__init__(model, max_num_seqs, max_num_batched_tokens, + prompt_adapter_config) + self._registered_adapters = PromptAdapterLRUCache( + self.capacity, self.deactivate_adapter) + self._active_adapters = PromptAdapterLRUCache( + self.prompt_adapter_slots, self._deactivate_adapter) + + def list_adapters(self) -> Dict[int, PromptAdapterModel]: + """List all registered PromptAdapterModel.""" + return dict(self._registered_adapters.cache) + + def add_adapter(self, prompt_adapter: PromptAdapterModel) -> bool: + """Add a PromptAdapterModel to the manager.""" + if prompt_adapter.id not in self._registered_adapters: + self._add_adapter(prompt_adapter) + was_added = True + else: + # We always touch to update the LRU cache order + self._registered_adapters.touch(prompt_adapter.id) + was_added = False + return was_added + + def activate_adapter( + self, + prompt_adapter_id: int, + ) -> bool: + if prompt_adapter_id not in self._active_adapters and len( + self._active_adapters) >= self.prompt_adapter_slots: + self._active_adapters.remove_oldest() + result = super().activate_adapter(prompt_adapter_id) + # We always touch to update the LRU cache order + self._active_adapters.touch(prompt_adapter_id) + return result + + def remove_oldest_adapter(self) -> bool: + if len(self._registered_adapters) > 0: + self._registered_adapters.remove_oldest() + return True + return False + + def pin_adapter(self, prompt_adapter_id: int) -> bool: + """Pin a PromptAdapterModel in the manager cache.""" + self._pin_prompt_adapter_in_cpu_cache(prompt_adapter_id) + self._pin_prompt_adapter_in_gpu_cache(prompt_adapter_id) + return True + + def _pin_prompt_adapter_in_cpu_cache(self, prompt_adapter_id: int): + try: + self._registered_adapters.pin(prompt_adapter_id) + except ValueError as err: + raise ValueError( + "Pinning failed. " + f"Prompt Adapter {prompt_adapter_id} is not registered." + ) from err + + def _pin_prompt_adapter_in_gpu_cache(self, prompt_adapter_id: int): + if prompt_adapter_id not in self._active_adapters: + # move adapter to gpu if not already active + self.activate_adapter(prompt_adapter_id) + self._active_adapters.pin(prompt_adapter_id) + + +def create_prompt_adapter_manager( + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + prompt_adapter_config: PromptAdapterConfig, + prompt_adapter_manager_cls: Type[ + PromptAdapterModelManager] = PromptAdapterModelManager, + **kwargs) -> PromptAdapterModelManager: + """Create a PromptAdapterModel for a given model.""" + prompt_adapter_manager = prompt_adapter_manager_cls( + model=model, + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + prompt_adapter_config=prompt_adapter_config, + **kwargs) + return prompt_adapter_manager diff --git a/vllm/prompt_adapter/request.py b/vllm/prompt_adapter/request.py new file mode 100644 index 0000000..dfb8e61 --- /dev/null +++ b/vllm/prompt_adapter/request.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 + +import msgspec + +from vllm.adapter_commons.request import AdapterRequest + + +class PromptAdapterRequest( + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + frozen=True): # type: ignore[call-arg] + """ + Request for a Prompt adapter. + """ + __metaclass__ = AdapterRequest + + prompt_adapter_name: str + prompt_adapter_id: int + prompt_adapter_local_path: str + prompt_adapter_num_virtual_tokens: int + + def __hash__(self): + return super().__hash__() + + @property + def adapter_id(self): + return self.prompt_adapter_id + + @property + def name(self): + return self.prompt_adapter_name + + @property + def local_path(self): + return self.prompt_adapter_local_path diff --git a/vllm/prompt_adapter/utils.py b/vllm/prompt_adapter/utils.py new file mode 100644 index 0000000..dd179ab --- /dev/null +++ b/vllm/prompt_adapter/utils.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 + +# code borrowed from: https://github.com/huggingface/peft/blob/v0.12.0/src/peft/utils/save_and_load.py#L420 + +import os +from typing import Optional + +import torch +from huggingface_hub import file_exists, hf_hub_download +from huggingface_hub.utils import EntryNotFoundError +from safetensors.torch import load_file as safe_load_file + +from vllm.platforms import current_platform + +WEIGHTS_NAME = "adapter_model.bin" +SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors" + + +# Get current device name based on available devices +def infer_device() -> str: + if current_platform.is_cuda_alike(): + return "cuda" + return "cpu" + + +def load_peft_weights(model_id: str, + device: Optional[str] = None, + **hf_hub_download_kwargs) -> dict: + r""" + A helper method to load the PEFT weights from the HuggingFace Hub or locally + + Args: + model_id (`str`): + The local path to the adapter weights or the name of the adapter to + load from the HuggingFace Hub. + device (`str`): + The device to load the weights onto. + hf_hub_download_kwargs (`dict`): + Additional arguments to pass to the `hf_hub_download` method when + loading from the HuggingFace Hub. + """ + path = (os.path.join(model_id, hf_hub_download_kwargs["subfolder"]) if + hf_hub_download_kwargs.get("subfolder") is not None else model_id) + + if device is None: + device = infer_device() + + if os.path.exists(os.path.join(path, SAFETENSORS_WEIGHTS_NAME)): + filename = os.path.join(path, SAFETENSORS_WEIGHTS_NAME) + use_safetensors = True + elif os.path.exists(os.path.join(path, WEIGHTS_NAME)): + filename = os.path.join(path, WEIGHTS_NAME) + use_safetensors = False + else: + token = hf_hub_download_kwargs.get("token") + if token is None: + token = hf_hub_download_kwargs.get("use_auth_token") + + hub_filename = (os.path.join(hf_hub_download_kwargs["subfolder"], + SAFETENSORS_WEIGHTS_NAME) + if hf_hub_download_kwargs.get("subfolder") is not None + else SAFETENSORS_WEIGHTS_NAME) + has_remote_safetensors_file = file_exists( + repo_id=model_id, + filename=hub_filename, + revision=hf_hub_download_kwargs.get("revision"), + repo_type=hf_hub_download_kwargs.get("repo_type"), + token=token, + ) + use_safetensors = has_remote_safetensors_file + + if has_remote_safetensors_file: + # Priority 1: load safetensors weights + filename = hf_hub_download( + model_id, + SAFETENSORS_WEIGHTS_NAME, + **hf_hub_download_kwargs, + ) + else: + try: + filename = hf_hub_download(model_id, WEIGHTS_NAME, + **hf_hub_download_kwargs) + except EntryNotFoundError: + raise ValueError( # noqa: B904 + f"Can't find weights for {model_id} in {model_id} or \ + in the Hugging Face Hub. " + f"Please check that the file {WEIGHTS_NAME} or \ + {SAFETENSORS_WEIGHTS_NAME} is present at {model_id}.") + + if use_safetensors: + adapters_weights = safe_load_file(filename, device=device) + else: + adapters_weights = torch.load(filename, + map_location=torch.device(device), + weights_only=True) + + return adapters_weights diff --git a/vllm/prompt_adapter/worker_manager.py b/vllm/prompt_adapter/worker_manager.py new file mode 100644 index 0000000..28dcc16 --- /dev/null +++ b/vllm/prompt_adapter/worker_manager.py @@ -0,0 +1,178 @@ +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import Any, Optional, Set, Type + +import torch + +from vllm.adapter_commons.utils import (add_adapter_worker, + apply_adapters_worker, + list_adapters_worker, + set_active_adapters_worker) +from vllm.adapter_commons.worker_manager import AbstractWorkerManager +from vllm.config import PromptAdapterConfig +from vllm.prompt_adapter.models import (LRUCachePromptAdapterModelManager, + PromptAdapterModel, + PromptAdapterModelManager, + create_prompt_adapter_manager) +from vllm.prompt_adapter.request import PromptAdapterRequest + +logger = logging.getLogger(__name__) + + +class WorkerPromptAdapterManager(AbstractWorkerManager): + """WorkerPromptAdapterManager that manages + prompt_adapter models on the worker side. + + Every request, the requested prompt_adapters will be + loaded (unless they are already loaded), + and every other prompt_adapter will be unloaded.""" + + _manager_cls: Type[PromptAdapterModelManager] = PromptAdapterModelManager + + def __init__( + self, + max_num_seqs: int, + max_num_batched_tokens: int, + device: torch.device, + prompt_adapter_config: PromptAdapterConfig, + prompt_adapter_model_cls: Type[PromptAdapterModel] = PromptAdapterModel + ): + self._adapter_manager: PromptAdapterModelManager + self.max_num_seqs = max_num_seqs + self.max_num_batched_tokens = max_num_batched_tokens + self._prompt_adapter_model_cls = prompt_adapter_model_cls + self.prompt_adapter_config = prompt_adapter_config + super().__init__(device) + + @property + def is_enabled(self) -> bool: + return True + + def create_prompt_adapter_manager( + self, + model: torch.nn.Module, + ) -> Any: + prompt_adapter_manager = create_prompt_adapter_manager( + model, + max_num_seqs=self.max_num_seqs, + max_num_batched_tokens=self.max_num_batched_tokens, + prompt_adapter_config=self.prompt_adapter_config, + prompt_adapter_manager_cls=self._manager_cls, + ) + self._adapter_manager = prompt_adapter_manager + return prompt_adapter_manager.model + + def _load_adapter( + self, prompt_adapter_request: PromptAdapterRequest + ) -> PromptAdapterModel: + try: + prompt_adapter = ( + self._prompt_adapter_model_cls.from_local_checkpoint( + prompt_adapter_request.prompt_adapter_local_path, + prompt_adapter_id=prompt_adapter_request.prompt_adapter_id, + num_virtual_tokens=prompt_adapter_request. + prompt_adapter_num_virtual_tokens, + config=self.prompt_adapter_config, + device=str(self.device), + )) + except Exception as e: + raise RuntimeError( + f"Loading prompt_adapter " + f"{prompt_adapter_request.prompt_adapter_local_path}" + f" failed") from e + return prompt_adapter + + def add_dummy_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + return True + + def pin_adapter(self, adapter_id: int) -> bool: + return self._adapter_manager.pin_adapter(adapter_id) + + def set_active_adapters(self, requests: Set[Any], + mapping: Optional[Any]) -> None: + set_active_adapters_worker(requests, mapping, self._apply_adapters, + self._adapter_manager.set_adapter_mapping) + + def add_adapter(self, adapter_request: Any) -> bool: + return add_adapter_worker(adapter_request, self.list_adapters, + self._load_adapter, + self._adapter_manager.add_adapter, + self._adapter_manager.activate_adapter) + + def _apply_adapters(self, adapter_requests: Set[Any]) -> None: + apply_adapters_worker(adapter_requests, self.list_adapters, + self._adapter_manager.adapter_slots, + self.remove_adapter, self.add_adapter) + + def remove_adapter(self, adapter_id: int) -> bool: + return self._adapter_manager.remove_adapter(adapter_id) + + def remove_all_adapters(self): + self._adapter_manager.remove_all_adapters() + + def list_adapters(self) -> Set[int]: + return list_adapters_worker(self._adapter_manager.list_adapters) + + +class LRUCacheWorkerPromptAdapterManager(WorkerPromptAdapterManager): + """WorkerPromptAdapterManager that manages + prompt_adapter models on the worker side. + + Uses an LRU Cache. Every request, the requested + prompt_adapters will be loaded (unless they are already loaded) + and least recently used prompt_adapters will + be unloaded if the cache is above capacity.""" + + _prompt_adapter_manager_cls: Type[ + LRUCachePromptAdapterModelManager] = LRUCachePromptAdapterModelManager + + def create_prompt_adapter_manager( + self, + model: torch.nn.Module, + ) -> Any: + prompt_adapter_manager = create_prompt_adapter_manager( + model, + max_num_seqs=self.max_num_seqs, + max_num_batched_tokens=self.max_num_batched_tokens, + prompt_adapter_config=self.prompt_adapter_config, + prompt_adapter_manager_cls=self._prompt_adapter_manager_cls) + self._adapter_manager: LRUCachePromptAdapterModelManager = ( + prompt_adapter_manager) + return prompt_adapter_manager.model + + def _apply_adapters( + self, prompt_adapter_requests: Set[PromptAdapterRequest]) -> None: + prompt_adapters_map = { + prompt_adapter_request.prompt_adapter_id: prompt_adapter_request + for prompt_adapter_request in prompt_adapter_requests + if prompt_adapter_request + } + if len(prompt_adapters_map + ) > self._adapter_manager.prompt_adapter_slots: + raise RuntimeError( + f"Number of requested prompt_adapters " + f"({len(prompt_adapters_map)}) is greater " + "than the number of GPU prompt_adapter slots " + f"({self._adapter_manager.prompt_adapter_slots}).") + for prompt_adapter in prompt_adapters_map.values(): + self.add_adapter(prompt_adapter) + + def add_adapter(self, + prompt_adapter_request: PromptAdapterRequest) -> bool: + if prompt_adapter_request.prompt_adapter_id not in self.list_adapters( + ): + # Remove before we load the new prompt_adapter to save memory + if len(self._adapter_manager) + 1 > self._adapter_manager.capacity: + self._adapter_manager.remove_oldest_adapter() + prompt_adapter = self._load_adapter(prompt_adapter_request) + loaded = self._adapter_manager.add_adapter(prompt_adapter) + else: + # If the prompt_adapter is already loaded, just touch it to + # update its position in the caches + loaded = self._adapter_manager.get_adapter( + prompt_adapter_request.prompt_adapter_id) is not None + self._adapter_manager.activate_adapter( + prompt_adapter_request.prompt_adapter_id) + return loaded diff --git a/vllm/py.typed b/vllm/py.typed new file mode 100644 index 0000000..33b3ad7 --- /dev/null +++ b/vllm/py.typed @@ -0,0 +1,2 @@ +# Marker file for PEP 561. +# The vllm package uses inline types. diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py new file mode 100644 index 0000000..65606ce --- /dev/null +++ b/vllm/reasoning/__init__.py @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 + +from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager +from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser +from .granite_reasoning_parser import GraniteReasoningParser +from .qwen3_reasoning_parser import Qwen3ReasoningParser + +__all__ = [ + "ReasoningParser", + "ReasoningParserManager", + "DeepSeekR1ReasoningParser", + "GraniteReasoningParser", + "Qwen3ReasoningParser", +] diff --git a/vllm/reasoning/abs_reasoning_parsers.py b/vllm/reasoning/abs_reasoning_parsers.py new file mode 100644 index 0000000..454167a --- /dev/null +++ b/vllm/reasoning/abs_reasoning_parsers.py @@ -0,0 +1,189 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +from abc import abstractmethod +from collections.abc import Sequence +from functools import cached_property +from typing import Callable, Optional, Union + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import import_from_path, is_list_of + +logger = init_logger(__name__) + + +class ReasoningParser: + """ + Abstract reasoning parser class that should not be used directly. + Provided and methods should be used in derived classes. + + It is used to extract reasoning content from the model output. + """ + + def __init__(self, tokenizer: AnyTokenizer): + self.model_tokenizer = tokenizer + + @cached_property + def vocab(self) -> dict[str, int]: + # NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab + # whereas all tokenizers have .get_vocab() + return self.model_tokenizer.get_vocab() + + @abstractmethod + def is_reasoning_end(self, input_ids: list[int]) -> bool: + """ + Check if the reasoning content ends in the input_ids. + + It is used in structured engines like `xgrammar` to check if the + reasoning content ends in the model output. + + Parameters: + input_ids: list[int] + The input_ids of the model output. + + Returns: + bool + True if the reasoning content ends in the input_ids. + """ + + @abstractmethod + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + """ + Extract content token ids from the input_ids. + Parameters: + input_ids: list[int] + The input_ids of the model output. + Returns: + list[int] + The extracted content from the input_ids. + """ + + @abstractmethod + def extract_reasoning_content( + self, model_output: str, request: ChatCompletionRequest + ) -> tuple[Optional[str], Optional[str]]: + """ + Extract reasoning content from a complete model-generated string. + + Used for non-streaming responses where we have the entire model response + available before sending to the client. + + Parameters: + model_output: str + The model-generated string to extract reasoning content from. + + request: ChatCompletionRequest + The request object that was used to generate the model_output. + + Returns: + tuple[Optional[str], Optional[str]] + A tuple containing the reasoning content and the content. + """ + + @abstractmethod + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + """ + Instance method that should be implemented for extracting reasoning + from an incomplete response; for use when handling reasoning calls and + streaming. Has to be an instance method because it requires state - + the current tokens/diffs, but also the information about what has + previously been parsed and extracted (see constructor) + """ + + +class ReasoningParserManager: + reasoning_parsers: dict[str, type] = {} + + @classmethod + def get_reasoning_parser(cls, name) -> type: + """ + Get reasoning parser by name which is registered by `register_module`. + + Raise a KeyError exception if the name is not registered. + """ + if name in cls.reasoning_parsers: + return cls.reasoning_parsers[name] + + raise KeyError( + f"reasoning helper: '{name}' not found in reasoning_parsers") + + @classmethod + def _register_module( + cls, + module: type, + module_name: Optional[Union[str, list[str]]] = None, + force: bool = True, + ) -> None: + if not issubclass(module, ReasoningParser): + raise TypeError("module must be subclass of ReasoningParser, " + f"but got {type(module)}") + if module_name is None: + module_name = module.__name__ + if isinstance(module_name, str): + module_name = [module_name] + for name in module_name: + if not force and name in cls.reasoning_parsers: + existed_module = cls.reasoning_parsers[name] + raise KeyError(f"{name} is already registered " + f"at {existed_module.__module__}") + cls.reasoning_parsers[name] = module + + @classmethod + def register_module( + cls, + name: Optional[Union[str, list[str]]] = None, + force: bool = True, + module: Union[type, None] = None, + ) -> Union[type, Callable]: + """ + Register module with the given name or name list. it can be used as a + decoder(with module as None) or normal function(with module as not + None). + """ + if not isinstance(force, bool): + raise TypeError(f"force must be a boolean, but got {type(force)}") + + # raise the error ahead of time + if not (name is None or isinstance(name, str) + or is_list_of(name, str)): + raise TypeError( + "name must be None, an instance of str, or a sequence of str, " + f"but got {type(name)}") + + # use it as a normal method: x.register_module(module=SomeClass) + if module is not None: + cls._register_module(module=module, module_name=name, force=force) + return module + + # use it as a decorator: @x.register_module() + def _register(module): + cls._register_module(module=module, module_name=name, force=force) + return module + + return _register + + @classmethod + def import_reasoning_parser(cls, plugin_path: str) -> None: + """ + Import a user-defined reasoning parser by the path + of the reasoning parser define file. + """ + module_name = os.path.splitext(os.path.basename(plugin_path))[0] + + try: + import_from_path(module_name, plugin_path) + except Exception: + logger.exception("Failed to load module '%s' from %s.", + module_name, plugin_path) + return diff --git a/vllm/reasoning/deepseek_r1_reasoning_parser.py b/vllm/reasoning/deepseek_r1_reasoning_parser.py new file mode 100644 index 0000000..1c283c0 --- /dev/null +++ b/vllm/reasoning/deepseek_r1_reasoning_parser.py @@ -0,0 +1,172 @@ +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Sequence +from typing import Optional, Union + +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage) +from vllm.logger import init_logger +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +logger = init_logger(__name__) + + +@ReasoningParserManager.register_module("deepseek_r1") +class DeepSeekR1ReasoningParser(ReasoningParser): + """ + Reasoning parser for DeepSeek R1 model. + + The DeepSeek R1 model uses ... tokens to denote reasoning + text. This parser extracts the reasoning content from the model output. + """ + + start_token_id: int + end_token_id: int + + start_token: str = "" + end_token: str = "" + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ReasoningParser " + "constructor during construction.") + + self.start_token_id = self.vocab.get(self.start_token) + self.end_token_id = self.vocab.get(self.end_token) + if self.start_token_id is None or self.end_token_id is None: + raise RuntimeError( + "DeepSeek R1 reasoning parser could not locate think start/end " + "tokens in the tokenizer!") + + def is_reasoning_end(self, input_ids: list[int]) -> bool: + return self.end_token_id in input_ids + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + """ + Extract the content after the end tokens + """ + if self.end_token_id not in input_ids[:-1]: + return [] + else: + return input_ids[input_ids.index(self.end_token_id) + 1:] + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + """ + Extract reasoning content from a delta message. + Handles streaming output where previous + delta = current. + Uses token IDs for faster processing. + For text abcxyz: + - 'abc' goes to reasoning_content + - 'xyz' goes to content + """ + # Skip single special tokens + if len(delta_token_ids) == 1 and (delta_token_ids[0] in [ + self.start_token_id, self.end_token_id + ]): + return None + + # Check if is present in previous or delta. + # Keep compatibility with models that don't generate tokens. + if self.start_token_id in previous_token_ids: + if self.end_token_id in delta_token_ids: + # in previous, in delta, + # extract reasoning content + end_index = delta_text.find(self.end_token) + reasoning_content = delta_text[:end_index] + content = delta_text[end_index + len(self.end_token):] + return DeltaMessage( + reasoning_content=reasoning_content, + content=content if content else None, + ) + elif self.end_token_id in previous_token_ids: + # in previous, in previous, + # reasoning content continues + return DeltaMessage(content=delta_text) + else: + # in previous, no in previous or delta, + # reasoning content continues + return DeltaMessage(reasoning_content=delta_text) + elif self.start_token_id in delta_token_ids: + if self.end_token_id in delta_token_ids: + # in delta, in delta, extract reasoning content + start_index = delta_text.find(self.start_token) + end_index = delta_text.find(self.end_token) + reasoning_content = delta_text[start_index + + len(self.start_token):end_index] + content = delta_text[end_index + len(self.end_token):] + return DeltaMessage( + reasoning_content=reasoning_content, + content=content if content else None, + ) + else: + # in delta, no in delta, + # reasoning content continues + return DeltaMessage(reasoning_content=delta_text) + else: + # No in previous or delta, also need to check for . + # Because the model may have generated without + # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f + if self.end_token_id in delta_token_ids: + # in delta with more tokens, + # extract reasoning content and content + end_index = delta_text.find(self.end_token) + reasoning_content = delta_text[:end_index] + content = delta_text[end_index + len(self.end_token):] + return DeltaMessage( + reasoning_content=reasoning_content, + content=content if content else None, + ) + elif self.end_token_id in previous_token_ids: + # in previous, thinking content ends + return DeltaMessage(content=delta_text) + else: + # no in previous or delta, reasoning content continues + return DeltaMessage(reasoning_content=delta_text) + + def extract_reasoning_content( + self, model_output: str, request: ChatCompletionRequest + ) -> tuple[Optional[str], Optional[str]]: + """ + Extract reasoning content from the model output. + + For text abcxyz: + - 'abc' goes to reasoning_content + - 'xyz' goes to content + + Returns: + tuple[Optional[str], Optional[str]]: reasoning content and content + """ + + # Check if the start token is present in the model output, remove it + # if it is present. + model_output_parts = model_output.partition(self.start_token) + model_output = model_output_parts[2] if model_output_parts[ + 1] else model_output_parts[0] + + # DeepSeek R1 doesn't generate now. + # Thus we assume the reasoning content is always at the start. + # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f + if self.end_token not in model_output: + return model_output, None + else: + reasoning_content, _, content = model_output.partition( + self.end_token) + # If the end token is not found, return the model output as is. + # It should not happen since we already checked for the presence + # of the end token. + # If generation stops right after end-of-think, return null content + final_content = content or None + return reasoning_content, final_content diff --git a/vllm/reasoning/granite_reasoning_parser.py b/vllm/reasoning/granite_reasoning_parser.py new file mode 100644 index 0000000..249ace1 --- /dev/null +++ b/vllm/reasoning/granite_reasoning_parser.py @@ -0,0 +1,362 @@ +# SPDX-License-Identifier: Apache-2.0 + +import re +from collections.abc import Sequence +from typing import Optional, Union + +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage) +from vllm.logger import init_logger +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +logger = init_logger(__name__) + + +@ReasoningParserManager.register_module("granite") +class GraniteReasoningParser(ReasoningParser): + """ + Reasoning parser for IBM Granite. + + IBM granite models currently use "Here is my thought process:" + and "Here is my response:" to separate its thinking / response outputs. + """ + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + + # NOTE: There have been some observed occurrences of quantized + # instances of the current models using "Here's" instead of "Here is", + # so to be safe, we match on both. + self.think_start_expr = r"(?:Here's|Here is) my thought process:" + self.response_start_expr = r"(?:Here's|Here is) my response:" + + self.reasoning_regex = re.compile( + rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)", + re.DOTALL) + + self.valid_think_starts = [ + "Here's my thought process:", "Here is my thought process:" + ] + self.valid_response_starts = [ + "Here's my response:", "Here is my response:" + ] + + # Substrings to match for sequence boundaries on raw text + self.seq_boundary_end = ":" + self.seq_boundary_start = "Here" + + # The longest any thinking / start of response message can be + self.longest_think_start = max( + len(think_start) for think_start in self.valid_think_starts) + + def extract_reasoning_content( + self, model_output: str, request: ChatCompletionRequest + ) -> tuple[Optional[str], Optional[str]]: + """Extract the reasoning content & content sections, respectively. + If the sequence doesn't match what we expect, i.e., the model generates + something else, all content is considered non-reasoning content. + + Args: + model_output (str): Output of the model to be parsed. + request (ChatCompletionReqest): Request being processed. + + Returns: + tuple[Optional[str], Optional[str]]: Tuple pair containing the + reasoning content and non-reasoning content. + """ + re_match = self.reasoning_regex.findall(model_output) + if not re_match: + return None, model_output + reasoning_content, response_content = re_match[0] + if not response_content: + return reasoning_content, None + return reasoning_content, response_content + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + """Extract the reasoning content / content emitted by granite models; + If the sequence doesn't match what we expect, i.e., the model generates + something else, all content is considered non-reasoning content. + + NOTE: Granite models do not use a special token to start their reasoning + and response sections; instead they have token sequences, e.g., + + Here is my thought process: Foo Here is my response: Bar + + This increases the complexity of correctly handling streams, since we + need to watch for specific sequences and correctly parse them without + dropping content that is potentially overlapping & spanning multiple + delta messages. + + Args: + previous_text (str): Previous text outside of this delta message. + current_text (str): Previous text + delta text. + delta_text (str): Text to consider and parse content from. + previous_token_ids (Sequence[int]): Token IDs of previous_text. + current_token_ids (Sequence[int]): Token IDs of current_text. + delta_token_ids (Sequence[int]): Token IDs of delta_text. + + Returns: + Union[DeltaMessage, None] + DeltaMessage with either reasoning content or content, or None. + """ + reasoning_content, resp_seq_len, content = self._get_content_sections( + current_text) + # Either we haven't finished the start of the reasoning sequence, + # or the model is generating something unexpected. + if not reasoning_content: + delta_message = self._get_delta_message_with_no_reasoning_bounds( + current_text, delta_text) + # We have a start of reasoning message, but have not yet finished + # the start of response sequence. + elif not content: + delta_message = self._get_delta_message_with_no_response_bounds( + current_text, reasoning_content, delta_text) + # We've finished both the start of reasoning and start of response seq. + else: + # This should never happen since we matched on the response + assert resp_seq_len is not None + delta_message = self._get_delta_message_with_both_bounds( + delta_text, reasoning_content, content, current_text, + resp_seq_len) + if not delta_message.content and not delta_message.reasoning_content: + return None + return delta_message + + #### Implementation details of stream parsing for granite models + def _is_reasoning_start_substr(self, text: str) -> bool: + """Check if a text matches one of the possible start reasoning seqs. + + Args: + text (str): Text to check for leading substr. + + Returns: + bool: True if any of the possible reasoning start seqs match. + """ + return any( + think_start.startswith(text) + for think_start in self.valid_think_starts) + + def _is_response_start_substr(self, text: str) -> bool: + """Check if a text matches one of the possible start response seqs. + + Args: + text (str): Text to check for leading substr. + + Returns: + bool: True if any of the possible response start seqs match. + """ + return any( + response_start.startswith(text) + for response_start in self.valid_response_starts) + + def _get_delta_message_with_no_reasoning_bounds( + self, + current_text: str, + delta_text: str, + ) -> DeltaMessage: + """Parse the delta message when the current text has not yet completed + its start of reasoning sequence. + + Args: + current_text (str): The full previous + delta text. + delta_text (str): Text to consider and parse content from. + + Returns: + DeltaMessage: Message containing the parsed content. + """ + prev_longest_length = len(current_text) - len(delta_text) + is_substr = self._is_reasoning_start_substr(current_text) + was_substr = self._is_reasoning_start_substr( + current_text[:prev_longest_length]) + + # Check if we just generated something NOT in the special token seq; + # if so, add everything that we previously skipped with this delta + # message and append everything to content in the future. + if was_substr and not is_substr: + return DeltaMessage( + reasoning_content=None, + content=current_text, + ) + if is_substr: + # Might still be in the special token sequence; return nothing + return DeltaMessage(reasoning_content=None, content=None) + # Otherwise the sequence has already been broken and we already + # corrected; just return the delta text as normal content. + return DeltaMessage(reasoning_content=None, content=delta_text) + + def _get_delta_message_with_no_response_bounds( + self, + current_text: str, + reasoning_content: str, + delta_text: str, + ) -> DeltaMessage: + """Parse the delta message when the current text has both reasoning + content with no (response) content. NOTE that we may have overlapping + tokens with the start of reasoning / start of response sequences on + either side of the delta text. + + Args: + current_text (str): The full previous + delta text. + reasoning_content (str): reasoning content from current_text. + delta_text (str): Text to consider and parse content from. + + Returns: + DeltaMessage: Message containing the parsed content. + """ + # If we have no reasoning content or explicitly end with the start of + # response sequence, we are in transition to the response; need to be + # careful here, since the final token (:) will match the reasoning + # content and fully parse it out; we should not pass the : back. + ends_with_start_response_seq = any( + current_text.endswith(response_start) + for response_start in self.valid_response_starts) + if reasoning_content is None or ends_with_start_response_seq: + return DeltaMessage(reasoning_content=None, content=None) + + # Consider previous / current text only within context of the reasoning + previous_text = reasoning_content[:-len(delta_text)] + current_text = reasoning_content + + # We need to be careful about adding unfinished response sequences; + # Find the place at which we MIGHT be starting a response sequence + prev_idx = previous_text.rfind(self.seq_boundary_start) + delta_idx = delta_text.rfind(self.seq_boundary_start) + + # Check the state of potential start of response substring matches. + prev_was_substr = self._is_response_start_substr( + previous_text[prev_idx:]) if prev_idx >= 0 else False + delta_continues_substr = self._is_response_start_substr( + current_text[prev_idx:]) if prev_idx >= 0 else False + delta_new_substr = self._is_response_start_substr( + delta_text[delta_idx:]) if delta_idx >= 0 else False + + # Delta only contains potential continued response sequence text. + if delta_continues_substr: + return DeltaMessage(reasoning_content=None, content=None) + + if not prev_was_substr: + # Delta may be starting a new response seq but has other text too. + if delta_new_substr: + return DeltaMessage(reasoning_content=delta_text[:delta_idx], + content=None) + # Normal case for most reasoning text (no potential special seqs). + return DeltaMessage(reasoning_content=delta_text, content=None) + # The substring that previously seemed to be a potential response + # seq wasn't one; we need to add the content to the delta message, + # and also slice off the potential response sequence + elif delta_new_substr: + reasoning_content = previous_text[ + prev_idx:] + delta_text[:delta_idx] + return DeltaMessage(reasoning_content=reasoning_content, + content=None) + # No new substring yet, and we broke our old one; take the whole delta + return DeltaMessage( + reasoning_content=previous_text[prev_idx:] + delta_text, + content=None, + ) + + def _get_delta_message_with_both_bounds( + self, + delta_text: str, + reasoning_content: str, + response_content: str, + current_text: str, + response_seq_len: int, + ) -> DeltaMessage: + """Parse the delta message when the current text has both reasoning + content and normal (response) content. + + Args: + delta_text (str): Text to consider and parse content from. + reasoning_content (str): reasoning content from current_text. + response_content (str): response content from current_text. + current_text (str): The full previous + delta text. + response_seq_len(str): Len of the complete response sequence used. + + Returns: + DeltaMessage: Message containing the parsed content. + """ + # Always have content; take length to the end + delta_content = delta_text[-len(response_content):] + reasoning_end_idx = len(delta_text) - (len(response_content) + + response_seq_len) + + if reasoning_end_idx < 0: + delta_reasoning_content = None + else: + # Get the starting offset + start_reasoning_content_idx = len( + reasoning_content) + response_seq_len + len( + response_content) - 1 + delta_offset = len(current_text) - len(delta_text) + start_offset = start_reasoning_content_idx - delta_offset + if start_offset < 0: + start_offset = 0 + delta_reasoning_content = delta_text[ + start_offset:reasoning_end_idx] + + return DeltaMessage( + reasoning_content=delta_reasoning_content, + content=delta_content, + ) + + def _get_content_sections( + self, current_text: str + ) -> tuple[Optional[str], Optional[int], Optional[str]]: + """Parse the text to extract the reasoning content / content + if we have them. + + Args: + current_text (str): The full previous + delta text. + + Returns: + tuple[Optional[str], Optional[int], Optional[str]]: Tuple of len 3 + containing the reasoning content, the length of the response seq + (if there is one) and the non-reasoning content. + """ + current_chunk_start = 0 + start_reasoning_content = None + parsed_content = False + delimiter_idxs = [ + idx for idx, char in enumerate(current_text) + if char == self.seq_boundary_end + ] + + for current_chunk_end in delimiter_idxs: + current_chunk = current_text[current_chunk_start:current_chunk_end] + # Check to see if the start of reasoning seq if complete + if start_reasoning_content is None: + for think_start in self.valid_think_starts: + if current_chunk == think_start[:-1]: + start_reasoning_content = current_chunk_end + 1 + current_chunk_start = current_chunk_end + 1 + break + + # Check to see if the start of response seq if complete + elif not parsed_content: + for response_start in self.valid_response_starts: + if current_chunk[-len(response_start) + + 1:] == response_start[:-1]: + # Mark end of reasoning and start response content + # after the start of response sequence. + end_reasoning_content = current_chunk_end - len( + response_start) + reasoning_content = current_text[ + start_reasoning_content:end_reasoning_content] + response_content = current_text[current_chunk_end + 1:] + return reasoning_content, len( + response_start), response_content + + if start_reasoning_content and not parsed_content: + return current_text[start_reasoning_content:], None, None + return None, None, None diff --git a/vllm/reasoning/qwen3_reasoning_parser.py b/vllm/reasoning/qwen3_reasoning_parser.py new file mode 100644 index 0000000..7095034 --- /dev/null +++ b/vllm/reasoning/qwen3_reasoning_parser.py @@ -0,0 +1,150 @@ +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Sequence +from typing import Optional, Union + +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage) +from vllm.logger import init_logger +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +logger = init_logger(__name__) + + +@ReasoningParserManager.register_module("qwen3") +class Qwen3ReasoningParser(ReasoningParser): + """ + Reasoning parser for the Qwen3 model. + + The Qwen3 model uses ... tokens to denote reasoning text + within its output. The model provides a strict switch to disable reasoning + output via the 'enable_thinking=False' parameter. This parser extracts the + reasoning content enclosed by and tokens from the model's + output. + """ + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + self.think_start_token = "" + self.think_end_token = "" + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ReasoningParser " + "constructor during construction.") + + self.think_start_token_id = self.vocab.get(self.think_start_token) + self.think_end_token_id = self.vocab.get(self.think_end_token) + if (self.think_start_token_id is None + or self.think_end_token_id is None): + raise RuntimeError( + "Qwen3 reasoning parser could not locate think start/end " + "tokens in the tokenizer!") + + def is_reasoning_end(self, input_ids: list[int]) -> bool: + return self.think_end_token_id in input_ids + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + """ + Extract the content after the end tokens + """ + if self.think_end_token_id not in input_ids[:-1]: + return [] + else: + return input_ids[input_ids.index(self.think_end_token_id) + 1:] + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + """ + Extract reasoning content from a delta message. + Handles streaming output where previous + delta = current. + Uses token IDs for faster processing. + For text abcxyz: + - 'abc' goes to reasoning_content + - 'xyz' goes to content + """ + # Skip single special tokens + if len(delta_token_ids) == 1 and (delta_token_ids[0] in [ + self.think_start_token_id, self.think_end_token_id + ]): + return None + + if self.think_start_token_id in previous_token_ids: + if self.think_end_token_id in delta_token_ids: + # in previous, in delta, + # extract reasoning content + end_index = delta_text.find(self.think_end_token) + reasoning_content = delta_text[:end_index] + content = delta_text[end_index + len(self.think_end_token):] + return DeltaMessage(reasoning_content=reasoning_content, + content=content if content else None) + elif self.think_end_token_id in previous_token_ids: + # in previous, in previous, + # reasoning content continues + return DeltaMessage(content=delta_text) + else: + # in previous, no in previous or delta, + # reasoning content continues + return DeltaMessage(reasoning_content=delta_text) + elif self.think_start_token_id in delta_token_ids: + if self.think_end_token_id in delta_token_ids: + # in delta, in delta, extract reasoning content + start_index = delta_text.find(self.think_start_token) + end_index = delta_text.find(self.think_end_token) + reasoning_content = delta_text[start_index + + len(self.think_start_token + ):end_index] + content = delta_text[end_index + len(self.think_end_token):] + return DeltaMessage(reasoning_content=reasoning_content, + content=content if content else None) + else: + # in delta, no in delta, + # reasoning content continues + return DeltaMessage(reasoning_content=delta_text) + else: + # thinking is disabled, just content + return DeltaMessage(content=delta_text) + + def extract_reasoning_content( + self, model_output: str, request: ChatCompletionRequest + ) -> tuple[Optional[str], Optional[str]]: + """ + Extract reasoning content from the model output. + + For text abcxyz: + - 'abc' goes to reasoning_content + - 'xyz' goes to content + + Returns: + tuple[Optional[str], Optional[str]]: reasoning content and content + """ + + # Check if the model output contains the and tokens. + if (self.think_start_token not in model_output + or self.think_end_token not in model_output): + return None, model_output + # Check if the is present in the model output, remove it + # if it is present. + model_output_parts = model_output.partition(self.think_start_token) + model_output = model_output_parts[2] if model_output_parts[ + 1] else model_output_parts[0] + # Check if the model output contains the tokens. + # If the end token is not found, return the model output as is. + if self.think_end_token not in model_output: + return None, model_output + + # Extract reasoning content from the model output. + reasoning_content, _, content = model_output.partition( + self.think_end_token) + + final_content = content or None + return reasoning_content, final_content diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py new file mode 100644 index 0000000..584320e --- /dev/null +++ b/vllm/sampling_params.py @@ -0,0 +1,579 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Sampling parameters for text generation.""" +import copy +from dataclasses import dataclass +from enum import Enum, IntEnum +from functools import cached_property +from typing import Annotated, Any, Optional, Union + +import msgspec +from pydantic import BaseModel + +from vllm.logger import init_logger +from vllm.logits_process import LogitsProcessor +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer + +logger = init_logger(__name__) + +_SAMPLING_EPS = 1e-5 +_MAX_TEMP = 1e-2 + + +class SamplingType(IntEnum): + GREEDY = 0 + RANDOM = 1 + RANDOM_SEED = 2 + + +# maybe make msgspec? +@dataclass +class GuidedDecodingParams: + """One of these fields will be used to build a logit processor.""" + json: Optional[Union[str, dict]] = None + regex: Optional[str] = None + choice: Optional[list[str]] = None + grammar: Optional[str] = None + json_object: Optional[bool] = None + """These are other options that can be set""" + backend: Optional[str] = None + whitespace_pattern: Optional[str] = None + + @staticmethod + def from_optional( + json: Optional[Union[dict, BaseModel, str]] = None, + regex: Optional[str] = None, + choice: Optional[list[str]] = None, + grammar: Optional[str] = None, + json_object: Optional[bool] = None, + backend: Optional[str] = None, + whitespace_pattern: Optional[str] = None, + ) -> Optional["GuidedDecodingParams"]: + if all(arg is None + for arg in (json, regex, choice, grammar, json_object)): + return None + # Extract json schemas from pydantic models + if isinstance(json, (BaseModel, type(BaseModel))): + json = json.model_json_schema() + return GuidedDecodingParams( + json=json, + regex=regex, + choice=choice, + grammar=grammar, + json_object=json_object, + backend=backend, + whitespace_pattern=whitespace_pattern, + ) + + @property + def backend_name(self) -> str: + """Return the backend name without any options. + + For example if the backend is "xgrammar:no-fallback", returns "xgrammar" + """ + return (self.backend or "").split(":")[0] + + def backend_options(self) -> list[str]: + """Return the backend options as a list of strings.""" + if not self.backend or ":" not in self.backend: + return [] + return self.backend.split(":")[1].split(",") + + def no_fallback(self) -> bool: + """Returns True if the "no-fallback" option is supplied for the guided + decoding backend""" + return "no-fallback" in self.backend_options() + + def __post_init__(self): + """Validate that some fields are mutually exclusive.""" + guide_count = sum([ + self.json is not None, self.regex is not None, self.choice + is not None, self.grammar is not None, self.json_object is not None + ]) + if guide_count > 1: + raise ValueError( + "You can only use one kind of guided decoding but multiple are " + f"specified: {self.__dict__}") + + +class RequestOutputKind(Enum): + # Return entire output so far in every RequestOutput + CUMULATIVE = 0 + # Return only deltas in each RequestOutput + DELTA = 1 + # Do not return intermediate RequestOuputs + FINAL_ONLY = 2 + + +class SamplingParams( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True): # type: ignore[call-arg] + """Sampling parameters for text generation. + + Overall, we follow the sampling parameters from the OpenAI text completion + API (https://platform.openai.com/docs/api-reference/completions/create). + In addition, we support beam search, which is not supported by OpenAI. + + Args: + n: Number of output sequences to return for the given prompt. + best_of: Number of output sequences that are generated from the prompt. + From these `best_of` sequences, the top `n` sequences are returned. + `best_of` must be greater than or equal to `n`. By default, + `best_of` is set to `n`. Warning, this is only supported in V0. + presence_penalty: Float that penalizes new tokens based on whether they + appear in the generated text so far. Values > 0 encourage the model + to use new tokens, while values < 0 encourage the model to repeat + tokens. + frequency_penalty: Float that penalizes new tokens based on their + frequency in the generated text so far. Values > 0 encourage the + model to use new tokens, while values < 0 encourage the model to + repeat tokens. + repetition_penalty: Float that penalizes new tokens based on whether + they appear in the prompt and the generated text so far. Values > 1 + encourage the model to use new tokens, while values < 1 encourage + the model to repeat tokens. + temperature: Float that controls the randomness of the sampling. Lower + values make the model more deterministic, while higher values make + the model more random. Zero means greedy sampling. + top_p: Float that controls the cumulative probability of the top tokens + to consider. Must be in (0, 1]. Set to 1 to consider all tokens. + top_k: Integer that controls the number of top tokens to consider. Set + to -1 to consider all tokens. + min_p: Float that represents the minimum probability for a token to be + considered, relative to the probability of the most likely token. + Must be in [0, 1]. Set to 0 to disable this. + seed: Random seed to use for the generation. + stop: list of strings that stop the generation when they are generated. + The returned output will not contain the stop strings. + stop_token_ids: list of tokens that stop the generation when they are + generated. The returned output will contain the stop tokens unless + the stop tokens are special tokens. + bad_words: list of words that are not allowed to be generated. + More precisely, only the last token of a corresponding + token sequence is not allowed when the next generated token + can complete the sequence. + include_stop_str_in_output: Whether to include the stop strings in + output text. Defaults to False. + ignore_eos: Whether to ignore the EOS token and continue generating + tokens after the EOS token is generated. + max_tokens: Maximum number of tokens to generate per output sequence. + min_tokens: Minimum number of tokens to generate per output sequence + before EOS or stop_token_ids can be generated + logprobs: Number of log probabilities to return per output token. + When set to None, no probability is returned. If set to a non-None + value, the result includes the log probabilities of the specified + number of most likely tokens, as well as the chosen tokens. + Note that the implementation follows the OpenAI API: The API will + always return the log probability of the sampled token, so there + may be up to `logprobs+1` elements in the response. + prompt_logprobs: Number of log probabilities to return per prompt token. + detokenize: Whether to detokenize the output. Defaults to True. + skip_special_tokens: Whether to skip special tokens in the output. + spaces_between_special_tokens: Whether to add spaces between special + tokens in the output. Defaults to True. + logits_processors: list of functions that modify logits based on + previously generated tokens, and optionally prompt tokens as + a first argument. + truncate_prompt_tokens: If set to an integer k, will use only the last k + tokens from the prompt (i.e., left truncation). Defaults to None + (i.e., no truncation). + guided_decoding: If provided, the engine will construct a guided + decoding logits processor from these parameters. Defaults to None. + logit_bias: If provided, the engine will construct a logits processor + that applies these logit biases. Defaults to None. + allowed_token_ids: If provided, the engine will construct a logits + processor which only retains scores for the given token ids. + Defaults to None. + extra_args: Arbitrary additional args, that can be used by custom + sampling implementations. Not used by any in-tree sampling + implementations. + """ + + n: int = 1 + best_of: Optional[int] = None + _real_n: Optional[int] = None + presence_penalty: float = 0.0 + frequency_penalty: float = 0.0 + repetition_penalty: float = 1.0 + temperature: float = 1.0 + top_p: float = 1.0 + top_k: int = -1 + min_p: float = 0.0 + seed: Optional[int] = None + stop: Optional[Union[str, list[str]]] = None + stop_token_ids: Optional[list[int]] = None + ignore_eos: bool = False + max_tokens: Optional[int] = 16 + min_tokens: int = 0 + logprobs: Optional[int] = None + prompt_logprobs: Optional[int] = None + # NOTE: This parameter is only exposed at the engine level for now. + # It is not exposed in the OpenAI API server, as the OpenAI API does + # not support returning only a list of token IDs. + detokenize: bool = True + skip_special_tokens: bool = True + spaces_between_special_tokens: bool = True + # Optional[list[LogitsProcessor]] type. We use Any here because + # Optional[list[LogitsProcessor]] type is not supported by msgspec. + logits_processors: Optional[Any] = None + include_stop_str_in_output: bool = False + truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None + output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE + + # The below fields are not supposed to be used as an input. + # They are set in post_init. + output_text_buffer_length: int = 0 + _all_stop_token_ids: set[int] = msgspec.field(default_factory=set) + + # Fields used to construct logits processors + guided_decoding: Optional[GuidedDecodingParams] = None + logit_bias: Optional[dict[int, float]] = None + allowed_token_ids: Optional[list[int]] = None + extra_args: Optional[dict[str, Any]] = None + + # Fields used for bad words + bad_words: Optional[list[str]] = None + _bad_words_token_ids: Optional[list[list[int]]] = None + + @staticmethod + def from_optional( + n: Optional[int] = 1, + best_of: Optional[int] = None, + presence_penalty: Optional[float] = 0.0, + frequency_penalty: Optional[float] = 0.0, + repetition_penalty: Optional[float] = 1.0, + temperature: Optional[float] = 1.0, + top_p: Optional[float] = 1.0, + top_k: int = -1, + min_p: float = 0.0, + seed: Optional[int] = None, + stop: Optional[Union[str, list[str]]] = None, + stop_token_ids: Optional[list[int]] = None, + bad_words: Optional[list[str]] = None, + include_stop_str_in_output: bool = False, + ignore_eos: bool = False, + max_tokens: Optional[int] = 16, + min_tokens: int = 0, + logprobs: Optional[int] = None, + prompt_logprobs: Optional[int] = None, + detokenize: bool = True, + skip_special_tokens: bool = True, + spaces_between_special_tokens: bool = True, + logits_processors: Optional[list[LogitsProcessor]] = None, + truncate_prompt_tokens: Optional[Annotated[int, + msgspec.Meta(ge=1)]] = None, + output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE, + guided_decoding: Optional[GuidedDecodingParams] = None, + logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None, + allowed_token_ids: Optional[list[int]] = None, + extra_args: Optional[dict[str, Any]] = None, + ) -> "SamplingParams": + if logit_bias is not None: + # Convert token_id to integer + # Clamp the bias between -100 and 100 per OpenAI API spec + logit_bias = { + int(token): min(100.0, max(-100.0, bias)) + for token, bias in logit_bias.items() + } + + return SamplingParams( + n=1 if n is None else n, + best_of=best_of, + presence_penalty=0.0 + if presence_penalty is None else presence_penalty, + frequency_penalty=0.0 + if frequency_penalty is None else frequency_penalty, + repetition_penalty=1.0 + if repetition_penalty is None else repetition_penalty, + temperature=1.0 if temperature is None else temperature, + top_p=1.0 if top_p is None else top_p, + top_k=top_k, + min_p=min_p, + seed=seed, + stop=stop, + stop_token_ids=stop_token_ids, + bad_words=bad_words, + include_stop_str_in_output=include_stop_str_in_output, + ignore_eos=ignore_eos, + max_tokens=max_tokens, + min_tokens=min_tokens, + logprobs=logprobs, + prompt_logprobs=prompt_logprobs, + detokenize=detokenize, + skip_special_tokens=skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + logits_processors=logits_processors, + truncate_prompt_tokens=truncate_prompt_tokens, + output_kind=output_kind, + guided_decoding=guided_decoding, + logit_bias=logit_bias, + allowed_token_ids=allowed_token_ids, + extra_args=extra_args, + ) + + def __post_init__(self) -> None: + # how we deal with `best_of``: + # if `best_of`` is not set, we default to `n`; + # if `best_of`` is set, we set `n`` to `best_of`, + # and set `_real_n`` to the original `n`. + # when we return the result, we will check + # if we need to return `n` or `_real_n` results + if self.best_of: + if self.best_of < self.n: + raise ValueError( + f"best_of must be greater than or equal to n, " + f"got n={self.n} and best_of={self.best_of}.") + if not self._real_n: + self._real_n = self.n + self.n = self.best_of + + if 0 < self.temperature < _MAX_TEMP: + logger.warning( + "temperature %s is less than %s, which may cause numerical " + "errors nan or inf in tensors. We have maxed it out to %s.", + self.temperature, _MAX_TEMP, _MAX_TEMP) + self.temperature = max(self.temperature, _MAX_TEMP) + + if self.seed == -1: + self.seed = None + + if self.stop is None: + self.stop = [] + elif isinstance(self.stop, str): + self.stop = [self.stop] + + if self.stop_token_ids is None: + self.stop_token_ids = [] + + if self.bad_words is None: + self.bad_words = [] + + if self.logprobs is True: + self.logprobs = 1 + + if self.prompt_logprobs is True: + self.prompt_logprobs = 1 + + # Number of characters to hold back for stop string evaluation + # until sequence is finished. + if self.stop and not self.include_stop_str_in_output: + self.output_text_buffer_length = max(len(s) for s in self.stop) - 1 + + self._verify_args() + + if self.temperature < _SAMPLING_EPS: + # Zero temperature means greedy sampling. + self.top_p = 1.0 + self.top_k = -1 + self.min_p = 0.0 + self._verify_greedy_sampling() + + # eos_token_id is added to this by the engine + self._all_stop_token_ids.update(self.stop_token_ids) + + def _verify_args(self) -> None: + if not isinstance(self.n, int): + raise ValueError(f"n must be an int, but is of " + f"type {type(self.n)}") + if self.n < 1: + raise ValueError(f"n must be at least 1, got {self.n}.") + if not -2.0 <= self.presence_penalty <= 2.0: + raise ValueError("presence_penalty must be in [-2, 2], got " + f"{self.presence_penalty}.") + if not -2.0 <= self.frequency_penalty <= 2.0: + raise ValueError("frequency_penalty must be in [-2, 2], got " + f"{self.frequency_penalty}.") + if not 0.0 < self.repetition_penalty <= 2.0: + raise ValueError("repetition_penalty must be in (0, 2], got " + f"{self.repetition_penalty}.") + if self.temperature < 0.0: + raise ValueError( + f"temperature must be non-negative, got {self.temperature}.") + if not 0.0 < self.top_p <= 1.0: + raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.") + if self.top_k < -1 or self.top_k == 0: + raise ValueError(f"top_k must be -1 (disable), or at least 1, " + f"got {self.top_k}.") + if not isinstance(self.top_k, int): + raise TypeError( + f"top_k must be an integer, got {type(self.top_k).__name__}") + if not 0.0 <= self.min_p <= 1.0: + raise ValueError("min_p must be in [0, 1], got " + f"{self.min_p}.") + if self.max_tokens is not None and self.max_tokens < 1: + raise ValueError( + f"max_tokens must be at least 1, got {self.max_tokens}.") + if self.min_tokens < 0: + raise ValueError(f"min_tokens must be greater than or equal to 0, " + f"got {self.min_tokens}.") + if self.max_tokens is not None and self.min_tokens > self.max_tokens: + raise ValueError( + f"min_tokens must be less than or equal to " + f"max_tokens={self.max_tokens}, got {self.min_tokens}.") + if self.logprobs is not None and self.logprobs < 0: + raise ValueError( + f"logprobs must be non-negative, got {self.logprobs}.") + if self.prompt_logprobs is not None and self.prompt_logprobs < 0: + raise ValueError(f"prompt_logprobs must be non-negative, got " + f"{self.prompt_logprobs}.") + if (self.truncate_prompt_tokens is not None + and self.truncate_prompt_tokens < 1): + raise ValueError(f"truncate_prompt_tokens must be >= 1, " + f"got {self.truncate_prompt_tokens}") + assert isinstance(self.stop, list) + if any(not stop_str for stop_str in self.stop): + raise ValueError("stop cannot contain an empty string.") + if self.stop and not self.detokenize: + raise ValueError( + "stop strings are only supported when detokenize is True. " + "Set detokenize=True to use stop.") + if self.best_of != self._real_n and self.output_kind == ( + RequestOutputKind.DELTA): + raise ValueError("best_of must equal n to use output_kind=DELTA") + + def _verify_greedy_sampling(self) -> None: + if self.n > 1: + raise ValueError("n must be 1 when using greedy sampling, " + f"got {self.n}.") + + def update_from_generation_config( + self, + generation_config: dict[str, Any], + model_eos_token_id: Optional[int] = None) -> None: + """Update if there are non-default values from generation_config""" + + if model_eos_token_id is not None: + # Add the eos token id into the sampling_params to support + # min_tokens processing. + self._all_stop_token_ids.add(model_eos_token_id) + + # Update eos_token_id for generation + if (eos_ids := generation_config.get("eos_token_id")) is not None: + # it can be either int or list of int + eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids) + if model_eos_token_id is not None: + # We don't need to include the primary eos_token_id in + # stop_token_ids since it's handled separately for stopping + # purposes. + eos_ids.discard(model_eos_token_id) + if eos_ids: + self._all_stop_token_ids.update(eos_ids) + if not self.ignore_eos: + eos_ids.update(self.stop_token_ids) + self.stop_token_ids = list(eos_ids) + + def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None: + if not self.bad_words: + return + self._bad_words_token_ids = [] + for bad_word in self.bad_words: + # To prohibit words both at the beginning + # and in the middle of text + # (related to add_prefix_space tokenizer parameter) + for add_prefix_space in [False, True]: + prefix = " " if add_prefix_space else "" + prompt = prefix + bad_word.lstrip() + + if isinstance(tokenizer, MistralTokenizer): + # Mistral tokenizers should not add special tokens + prompt_token_ids = tokenizer.encode(text=prompt) + else: + prompt_token_ids = tokenizer.encode( + text=prompt, add_special_tokens=False) + + # If no space at the beginning + # or if prefix space produces a new word token + if (not add_prefix_space) or ( + add_prefix_space and prompt_token_ids[0] + != self._bad_words_token_ids[-1][0] + and len(prompt_token_ids) == len( + self._bad_words_token_ids[-1])): + self._bad_words_token_ids.append(prompt_token_ids) + + invalid_token_ids = [ + token_id for bad_words_token_ids in self._bad_words_token_ids + for token_id in bad_words_token_ids + if token_id < 0 or token_id > tokenizer.max_token_id + ] + if len(invalid_token_ids) > 0: + raise ValueError( + f"The model vocabulary size is {tokenizer.max_token_id+1}," + f" but the following tokens" + f" were specified as bad: {invalid_token_ids}." + f" All token id values should be integers satisfying:" + f" 0 <= token_id <= {tokenizer.max_token_id}.") + + @cached_property + def sampling_type(self) -> SamplingType: + if self.temperature < _SAMPLING_EPS: + return SamplingType.GREEDY + if self.seed is not None: + return SamplingType.RANDOM_SEED + return SamplingType.RANDOM + + @property + def all_stop_token_ids(self) -> set[int]: + return self._all_stop_token_ids + + @property + def bad_words_token_ids(self) -> Optional[list[list[int]]]: + # For internal use only. Backward compatibility not guaranteed + return self._bad_words_token_ids + + def clone(self) -> "SamplingParams": + """Deep copy, but maybe not the LogitsProcessor objects. + + LogitsProcessor objects may contain an arbitrary, nontrivial amount of + data that is expensive to copy. However, if not copied, the processor + needs to support parallel decoding for multiple sequences + See https://github.com/vllm-project/vllm/issues/3087 + """ + + logit_processor_refs = None if self.logits_processors is None else { + id(lp): lp.clone() if hasattr(lp, 'clone') else lp + for lp in self.logits_processors + } + return copy.deepcopy(self, memo=logit_processor_refs) + + def __repr__(self) -> str: + return ( + f"SamplingParams(n={self.n}, " + f"presence_penalty={self.presence_penalty}, " + f"frequency_penalty={self.frequency_penalty}, " + f"repetition_penalty={self.repetition_penalty}, " + f"temperature={self.temperature}, " + f"top_p={self.top_p}, " + f"top_k={self.top_k}, " + f"min_p={self.min_p}, " + f"seed={self.seed}, " + f"stop={self.stop}, " + f"stop_token_ids={self.stop_token_ids}, " + f"bad_words={self.bad_words}, " + f"include_stop_str_in_output={self.include_stop_str_in_output}, " + f"ignore_eos={self.ignore_eos}, " + f"max_tokens={self.max_tokens}, " + f"min_tokens={self.min_tokens}, " + f"logprobs={self.logprobs}, " + f"prompt_logprobs={self.prompt_logprobs}, " + f"skip_special_tokens={self.skip_special_tokens}, " + "spaces_between_special_tokens=" + f"{self.spaces_between_special_tokens}, " + f"truncate_prompt_tokens={self.truncate_prompt_tokens}, " + f"guided_decoding={self.guided_decoding}, " + f"extra_args={self.extra_args})") + + +class BeamSearchParams( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True): # type: ignore[call-arg] + """Beam search parameters for text generation.""" + beam_width: int + max_tokens: int + ignore_eos: bool = False + temperature: float = 0.0 + length_penalty: float = 1.0 + include_stop_str_in_output: bool = False diff --git a/vllm/scalar_type.py b/vllm/scalar_type.py new file mode 100644 index 0000000..1d7675d --- /dev/null +++ b/vllm/scalar_type.py @@ -0,0 +1,335 @@ +# SPDX-License-Identifier: Apache-2.0 + +import functools +import struct +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Union + + +# Mirrors enum in `core/scalar_type.hpp` +class NanRepr(Enum): + NONE = 0 # nans are not supported + IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s + EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s + + +# This ScalarType class is a parallel implementation of the C++ ScalarType +# class found in csrc/core/scalar_type.hpp. These two classes should be kept +# in sync until the inductor fully supports custom C++ classes. +@dataclass(frozen=True) +class ScalarType: + """ + ScalarType can represent a wide range of floating point and integer + types, in particular it can be used to represent sub-byte data types + (something that torch.dtype currently does not support). It is also + capable of representing types with a bias, i.e.: + `stored_value = value + bias`, + this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias + of 8). The implementation for this class can be found in + csrc/core/scalar_type.hpp, these type signatures should be kept in sync + with that file. + """ + + exponent: int + """ + Number of bits in the exponent if this is a floating point type + (zero if this an integer type) + """ + + mantissa: int + """ + Number of bits in the mantissa if this is a floating point type, + or the number bits representing an integer excluding the sign bit if + this an integer type. + """ + + signed: bool + "If the type is signed (i.e. has a sign bit)" + + bias: int + """ + bias used to encode the values in this scalar type + (value = stored_value - bias, default 0) for example if we store the + type as an unsigned integer with a bias of 128 then the value 0 will be + stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. + """ + + _finite_values_only: bool = False + """ + Private: if infs are supported, used `has_infs()` instead. + """ + + nan_repr: NanRepr = NanRepr.IEEE_754 + """ + How NaNs are represent in this scalar type, returns NanRepr value. + (not applicable for integer types) + """ + + def _floating_point_max_int(self) -> int: + assert ( + self.mantissa <= 52 and self.exponent <= 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + + max_mantissa = (1 << self.mantissa) - 1 + if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN: + max_mantissa = max_mantissa - 1 + + max_exponent = (1 << self.exponent) - 2 + if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN + or self.nan_repr == NanRepr.NONE): + assert ( + self.exponent < 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + max_exponent = max_exponent + 1 + + # adjust the exponent to match that of a double + # for now we assume the exponent bias is the standard 2^(e-1) -1, (where + # e is the exponent bits), there is some precedent for non-standard + # biases, example `float8_e4m3b11fnuz` here: + # https://github.com/jax-ml/ml_dtypes but to avoid premature over + # complication we are just assuming the standard exponent bias until + # there is a need to support non-standard biases + exponent_bias = (1 << (self.exponent - 1)) - 1 + exponent_bias_double = (1 << 10) - 1 # double e = 11 + + max_exponent_double = (max_exponent - exponent_bias + + exponent_bias_double) + + # shift the mantissa and exponent into the proper positions for an + # IEEE double and bitwise-or them together. + return (max_mantissa << + (52 - self.mantissa)) | (max_exponent_double << 52) + + def _floating_point_max(self) -> float: + double_raw = self._floating_point_max_int() + return struct.unpack('!d', struct.pack('!Q', double_raw))[0] + + def _raw_max(self) -> Union[int, float]: + if self.is_floating_point(): + return self._floating_point_max() + else: + assert (self.size_bits < 64 or self.size_bits == 64 + and self.is_signed()), "Cannot represent max as an int" + return (1 << self.mantissa) - 1 + + def _raw_min(self) -> Union[int, float]: + if self.is_floating_point(): + assert self.is_signed( + ), "We currently assume all floating point types are signed" + sign_bit_double = 1 << 63 + + max_raw = self._floating_point_max_int() + min_raw = max_raw | sign_bit_double + return struct.unpack('!d', struct.pack('!Q', min_raw))[0] + else: + assert (not self.is_signed() or self.size_bits + <= 64), "Cannot represent min as a int64_t" + + if self.is_signed(): + return -(1 << (self.size_bits - 1)) + else: + return 0 + + @functools.cached_property + def id(self) -> int: + """ + Convert the ScalarType to an int which can be passed to pytorch custom + ops. This layout of the int must be kept in sync with the C++ + ScalarType's from_id method. + """ + val = 0 + offset = 0 + + def or_and_advance(member, bit_width): + nonlocal val + nonlocal offset + bit_mask = (1 << bit_width) - 1 + val = val | (int(member) & bit_mask) << offset + offset = offset + bit_width + + or_and_advance(self.exponent, 8) + or_and_advance(self.mantissa, 8) + or_and_advance(self.signed, 1) + or_and_advance(self.bias, 32) + or_and_advance(self._finite_values_only, 1) + or_and_advance(self.nan_repr.value, 8) + + assert offset <= 64, \ + f"ScalarType fields too big {offset} to fit into an int64" + + return val + + @property + def size_bits(self) -> int: + return self.exponent + self.mantissa + int(self.signed) + + def min(self) -> Union[int, float]: + """ + Min representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_min() - self.bias + + def max(self) -> Union[int, float]: + """ + Max representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_max() - self.bias + + def is_signed(self) -> bool: + """ + If the type is signed (i.e. has a sign bit), same as `signed` + added for consistency with: + https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html + """ + return self.signed + + def is_floating_point(self) -> bool: + "If the type is a floating point type" + return self.exponent != 0 + + def is_integer(self) -> bool: + "If the type is an integer type" + return self.exponent == 0 + + def has_bias(self) -> bool: + "If the type has a non-zero bias" + return self.bias != 0 + + def has_infs(self) -> bool: + "If the type is floating point and supports infinity" + return not self._finite_values_only + + def has_nans(self) -> bool: + return self.nan_repr != NanRepr.NONE.value + + def is_ieee_754(self) -> bool: + """ + If the type is a floating point type that follows IEEE 754 + conventions + """ + return self.nan_repr == NanRepr.IEEE_754.value and \ + not self._finite_values_only + + def __str__(self) -> str: + """ + naming generally follows: https://github.com/jax-ml/ml_dtypes + for floating point types (leading f) the scheme is: + `float_em[flags]` + flags: + - no-flags: means it follows IEEE 754 conventions + - f: means finite values only (no infinities) + - n: means nans are supported (non-standard encoding) + for integer types the scheme is: + `[u]int[b]` + - if bias is not present it means its zero + """ + if self.is_floating_point(): + ret = "float" + str(self.size_bits) + "_e" + str( + self.exponent) + "m" + str(self.mantissa) + + if not self.is_ieee_754(): + if self._finite_values_only: + ret = ret + "f" + if self.nan_repr != NanRepr.NONE: + ret = ret + "n" + + return ret + else: + ret = ("int" if self.is_signed() else "uint") + str(self.size_bits) + if self.has_bias(): + ret = ret + "b" + str(self.bias) + return ret + + def __repr__(self) -> str: + return "ScalarType." + self.__str__() + + # __len__ needs to be defined (and has to throw TypeError) for pytorch's + # opcheck to work. + def __len__(self) -> int: + raise TypeError + + # + # Convenience Constructors + # + + @classmethod + def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + "Create a signed integer scalar type (size_bits includes sign-bit)." + ret = cls(0, size_bits - 1, True, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + """Create a unsigned integer scalar type.""" + ret = cls(0, size_bits, False, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': + """ + Create a standard floating point type + (i.e. follows IEEE 754 conventions). + """ + assert (mantissa > 0 and exponent > 0) + ret = cls(exponent, mantissa, True, 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, + nan_repr: NanRepr) -> 'ScalarType': + """ + Create a non-standard floating point type + (i.e. does not follow IEEE 754 conventions). + """ + assert (mantissa > 0 and exponent > 0) + assert (nan_repr != NanRepr.IEEE_754), ( + "use `float_IEEE754` constructor for floating point types that " + "follow IEEE 754 conventions") + ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr) + ret.id # noqa B018: make sure the id is cached + return ret + + +# naming generally follows: https://github.com/jax-ml/ml_dtypes +# for floating point types (leading f) the scheme is: +# `float_em[flags]` +# flags: +# - no-flags: means it follows IEEE 754 conventions +# - f: means finite values only (no infinities) +# - n: means nans are supported (non-standard encoding) +# for integer types the scheme is: +# `[u]int[b]` +# - if bias is not present it means its zero + + +class scalar_types: + int4 = ScalarType.int_(4, None) + uint4 = ScalarType.uint(4, None) + int8 = ScalarType.int_(8, None) + uint8 = ScalarType.uint(8, None) + float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN) + float8_e5m2 = ScalarType.float_IEEE754(5, 2) + float16_e8m7 = ScalarType.float_IEEE754(8, 7) + float16_e5m10 = ScalarType.float_IEEE754(5, 10) + + # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main + float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) + + # fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + float4_e2m1fn = ScalarType.float_(2, 1, True, NanRepr.NONE) + + # "gptq" types + uint2b2 = ScalarType.uint(2, 2) + uint3b4 = ScalarType.uint(3, 4) + uint4b8 = ScalarType.uint(4, 8) + uint8b128 = ScalarType.uint(8, 128) + + # colloquial names + bfloat16 = float16_e8m7 + float16 = float16_e5m10 diff --git a/vllm/scripts.py b/vllm/scripts.py new file mode 100644 index 0000000..7e569d2 --- /dev/null +++ b/vllm/scripts.py @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm.entrypoints.cli.main import main as vllm_main +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +# Backwards compatibility for the move from vllm.scripts to +# vllm.entrypoints.cli.main +def main(): + logger.warning("vllm.scripts.main() is deprecated. Please re-install " + "vllm or use vllm.entrypoints.cli.main.main() instead.") + vllm_main() diff --git a/vllm/sequence.py b/vllm/sequence.py new file mode 100644 index 0000000..66a0740 --- /dev/null +++ b/vllm/sequence.py @@ -0,0 +1,1516 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Sequence and its related classes.""" +import copy +import enum +from abc import ABC, abstractmethod +from array import array +from collections import defaultdict +from collections.abc import Mapping +from collections.abc import Sequence as GenericSequence +from dataclasses import dataclass, field +from functools import reduce +from typing import Any, Callable, Optional, Union + +import msgspec +import torch + +from vllm.inputs import SingletonInputs, SingletonInputsAdapter +from vllm.lora.request import LoRARequest +from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import RequestOutputKind, SamplingParams + +VLLM_TOKEN_ID_ARRAY_TYPE = "l" + +VLLM_INVALID_TOKEN_ID = -1 + + +def array_full(token_id: int, count: int): + """:class:`array` equivalent of :func:`numpy.full`.""" + return array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count + + +# We use dataclass for now because it is used for +# openai server output, and msgspec is not serializable. +# TODO(sang): Fix it. +@dataclass +class Logprob: + """Infos for supporting OpenAI compatible logprobs and token ranks. + + Attributes: + logprob: The logprob of chosen token + rank: The vocab rank of chosen token (>=1) + decoded_token: The decoded chosen token index + """ + logprob: float + rank: Optional[int] = None + decoded_token: Optional[str] = None + + +# {token_id -> logprob} per each sequence group. None if the corresponding +# sequence group doesn't require prompt logprob. +PromptLogprobs = list[Optional[dict[int, Logprob]]] +# {token_id -> logprob} for each sequence group. +SampleLogprobs = list[dict[int, Logprob]] + + +class SequenceStatus(enum.IntEnum): + """Status of a sequence.""" + WAITING = 0 + RUNNING = 1 + SWAPPED = 2 + # Note: anything after SWAPPED (2) will be considered + # as a finished status. + FINISHED_STOPPED = 3 + FINISHED_LENGTH_CAPPED = 4 + FINISHED_ABORTED = 5 + FINISHED_IGNORED = 6 + + @staticmethod + def is_finished(status: "SequenceStatus") -> bool: + return status > SequenceStatus.SWAPPED + + @staticmethod + def get_finished_reason(status: "SequenceStatus") -> Union[str, None]: + if status == SequenceStatus.FINISHED_STOPPED: + finish_reason = "stop" + elif status == SequenceStatus.FINISHED_LENGTH_CAPPED: + finish_reason = "length" + elif status == SequenceStatus.FINISHED_ABORTED: + finish_reason = "abort" + elif status == SequenceStatus.FINISHED_IGNORED: + # The ignored sequences are the sequences whose prompt lengths + # are longer than the model's length cap. Therefore, the stop + # reason should also be "length" as in OpenAI API. + finish_reason = "length" + else: + finish_reason = None + return finish_reason + + +class SequenceStage(enum.Enum): + PREFILL = enum.auto() + DECODE = enum.auto() + + +@dataclass +class RequestMetrics: + """Metrics associated with a request. + + Attributes: + arrival_time: The time when the request arrived. + first_scheduled_time: The time when the request was first scheduled. + first_token_time: The time when the first token was generated. + time_in_queue: The time the request spent in the queue. + finished_time: The time when the request was finished. + scheduler_time: The time spent in the scheduler when this request was + being considered by the scheduler. + model_forward_time: The time spent in the model forward pass when this + request was in the batch. + model_execute_time: The time spent in the model execute function. This + will include model forward, block/sync across + workers, cpu-gpu sync time and sampling time. + spec_token_acceptance_counts: number of accepted speculative tokens at + each position; the first token is from + the target model and is always accepted; + e.g., when it's [10, 8, 4, 2] for a req, + it means there were 10 forward passes in + total, and there were 8, 4, 2 accepted + tokens at 1st, 2nd, 3rd speculation step. + """ + arrival_time: float + last_token_time: float + first_scheduled_time: Optional[float] + first_token_time: Optional[float] + time_in_queue: Optional[float] + finished_time: Optional[float] = None + scheduler_time: Optional[float] = None + model_forward_time: Optional[float] = None + model_execute_time: Optional[float] = None + spec_token_acceptance_counts: Optional[list[int]] = None + + +class SequenceDataDelta( + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True): # type: ignore[call-arg] + """Delta SequenceData to send to workers per step.""" + # A new token to be appended to existing SequenceData. + new_output_token_ids: list[int] + # Overwriting existing `cumulative_logprob` + new_cumulative_logprob: float + # Overwriting existing `num_computed_tokens`. + new_num_computed_tokens: int + # Overwriting existing `stage`. + new_stage: SequenceStage + + +class SequenceData(msgspec.Struct, + omit_defaults=True): # type: ignore[call-arg] + """Data associated with a sequence. + + Args: + prompt_token_ids: The token IDs of the prompt. + output_token_ids: The token IDs of the output. Set to an empty list if + None. + + Attributes: + prompt_token_ids: The token IDs of the prompt. + output_token_ids: The token IDs of the output. + cumulative_logprob: The cumulative log probability of the output. + """ + # NOTE: we cannot use Union[list, array] because msgspec cannot support + # union of 2 list types. + _prompt_token_ids: array + _output_token_ids: array = msgspec.field( + default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, [])) + + ### The below fields should not be passed as an argument ### + _cumulative_logprob: float = 0.0 + _prompt_token_ids_tuple: tuple[int, + ...] = msgspec.field(default_factory=tuple) + # The number of tokens that are computed (that run against the model). + _num_computed_tokens: int = 0 + # The number of tokens with prefix cache hit. + _num_cached_tokens: int = 0 + _stage: SequenceStage = SequenceStage.PREFILL + _cached_all_token_ids: list[int] = msgspec.field(default_factory=list) + + # It is used to get delta input. It is reset when `get_delta_and_reset` + # is called. + _new_appended_tokens: list[int] = msgspec.field(default_factory=list) + + # It is used to compute mrope_position_ids. + _mrope_position_delta: Optional[int] = None + + @staticmethod + def from_prompt_token_counts( + *token_counts: tuple[int, int]) -> "SequenceData": + """ + Construct a :class:`SequenceData` instance by concatenating + prompt token sequences. + + Each tuple represents one token sequence, expressed in the form + :code:`(token_id, count)`. + """ + if len(token_counts) == 0: + return SequenceData.from_seqs([]) + + prompt_token_ids_arr = reduce( + array.__iadd__, + (array_full(token_id, count) for token_id, count in token_counts), + ) + + return SequenceData(prompt_token_ids_arr) + + @staticmethod + def from_seqs( + prompt_token_ids: GenericSequence[int], + output_token_ids: Optional[GenericSequence[int]] = None, + ) -> "SequenceData": + """ + Construct a :class:`SequenceData` instance from prompt and output + token sequences. + """ + prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, + prompt_token_ids) + + if output_token_ids is None: + return SequenceData(prompt_token_ids_arr) + + output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, + output_token_ids) + + return SequenceData(prompt_token_ids_arr, + _output_token_ids=output_token_ids_arr) + + def __post_init__(self) -> None: + assert self._prompt_token_ids.typecode == "l" + assert self._output_token_ids.typecode == "l" + self._prompt_token_ids_tuple: tuple[int, ...] = tuple( + self._prompt_token_ids) + self._update_cached_all_tokens() + + def _update_cached_all_tokens(self): + assert isinstance(self._prompt_token_ids, array) + assert isinstance(self._output_token_ids, array) + self._cached_all_token_ids: list[int] = list(self._prompt_token_ids + + self._output_token_ids) + + @property + def cumulative_logprob(self) -> float: + return self._cumulative_logprob + + @property + def prompt_token_ids(self) -> tuple[int, ...]: + return self._prompt_token_ids_tuple + + @prompt_token_ids.setter + def prompt_token_ids(self, new_prompt_token_ids) -> None: + raise NotImplementedError + + @property + def prompt_token_ids_array(self) -> array: + """Return the prompt token ids in array type. + + Note that the array is in "I" type, and it is not compatible + with torch.long (2 bytes vs 4 bytes). So beware of the usage. + """ + return self._prompt_token_ids + + @property + def output_token_ids(self) -> tuple[int, ...]: + return tuple(self._output_token_ids) + + @output_token_ids.setter + def output_token_ids(self, + new_output_token_ids: GenericSequence[int]) -> None: + self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + new_output_token_ids) + self._update_cached_all_tokens() + + @property + def output_token_ids_array(self) -> array: + """Return the prompt token ids in array type. + + Note that the array is in "I" type, and it is not compatible + with torch.long (2 bytes vs 4 bytes). So beware of the usage. + """ + assert isinstance(self._output_token_ids, array) + return self._output_token_ids + + @property + def mrope_position_delta(self) -> Optional[int]: + return self._mrope_position_delta + + @mrope_position_delta.setter + def mrope_position_delta(self, new_mrope_position_delta): + self._mrope_position_delta = new_mrope_position_delta + + def append_token_id(self, token_id: int, logprob: float) -> None: + self._output_token_ids.append(token_id) + self._new_appended_tokens.append(token_id) + self._cached_all_token_ids.append(token_id) + self._cumulative_logprob += logprob + + def get_len(self) -> int: + return len(self._output_token_ids) + len(self._prompt_token_ids) + + def get_prompt_len(self) -> int: + return len(self._prompt_token_ids) + + def get_output_len(self) -> int: + return len(self._output_token_ids) + + def get_token_ids(self) -> list[int]: + return self._cached_all_token_ids + + def get_prefix_token_ids( + self, num_tokens: int + ) -> tuple[tuple[int, ...], Optional[tuple[int, ...]]]: + """Get prefix tokens, and make the return value hashable""" + prompt_length = self.get_prompt_len() + if num_tokens > prompt_length: + return (self._prompt_token_ids_tuple, + tuple(self._output_token_ids[:num_tokens - prompt_length])) + else: + return (self._prompt_token_ids_tuple[:num_tokens], None) + + def get_num_computed_tokens(self) -> int: + """Return the number of prefill tokens that are already computed.""" + return self._num_computed_tokens + + def update_num_computed_tokens(self, num_new_computed_tokens: int): + """Update number of tokens computed so far.""" + self._num_computed_tokens += num_new_computed_tokens + assert self._num_computed_tokens <= self.get_len(), ( + self._num_computed_tokens, self.get_len()) + # If all tokens are computed, it means it is in decoding phase. + if self.get_num_uncomputed_tokens() == 0: + self._stage = SequenceStage.DECODE + + def get_num_cached_tokens(self) -> int: + """Return the number of tokens with prefix cache hit.""" + return self._num_cached_tokens + + def update_num_cached_tokens(self, num_cached_tokens: int): + """Update the number of tokens with prefix cache hit.""" + self._num_cached_tokens = num_cached_tokens + + def reset_state_for_recompute(self) -> None: + """Reset the number of computed tokens from this sequence. It is + supposed to be called when a sequence needs to be started from + the beginning again (e.g., sequence is preempted). + """ + self._num_computed_tokens = 0 + self._stage = SequenceStage.PREFILL + self._new_appended_tokens = [] + + def get_num_uncomputed_tokens(self) -> int: + """Return the number of prefill tokens that are not computed.""" + # we use `get_len()` which includes prompt_len + output_len instead + # of prompt_len here. This is because during recompute we need to + # prefill for both prompt and output. + return self.get_len() - self.get_num_computed_tokens() + + def get_last_token_id(self) -> int: + if not self._output_token_ids: + return self._prompt_token_ids[-1] + return self._output_token_ids[-1] + + def get_prompt_token_ids(self) -> tuple[int, ...]: + return self.prompt_token_ids + + def get_output_token_ids(self) -> tuple[int, ...]: + return self.output_token_ids + + def get_delta_and_reset(self) -> SequenceDataDelta: + delta = SequenceDataDelta(self._new_appended_tokens, + self._cumulative_logprob, + self.get_num_computed_tokens(), self.stage) + # Reset delta state. + self._new_appended_tokens = [] + return delta + + def apply_delta(self, delta: SequenceDataDelta): + self._num_computed_tokens = delta.new_num_computed_tokens + self._cumulative_logprob = delta.new_cumulative_logprob + self._stage = delta.new_stage + self._output_token_ids.extend(delta.new_output_token_ids) + self._cached_all_token_ids.extend(delta.new_output_token_ids) + + @property + def stage(self) -> SequenceStage: + return self._stage + + def __repr__(self) -> str: + return (f"SequenceData(" + f"prompt_token_ids={self._prompt_token_ids}, " + f"output_token_ids={self.output_token_ids}, " + f"cumulative_logprob={self.cumulative_logprob}, " + f"get_num_computed_tokens={self.get_num_computed_tokens()})") + + +class Sequence: + """Stores the data, status, and block information of a sequence. + + The sequence is constructed from the :data:`DecoderOnlyInputs` + (for decoder-only) or :data:`EncoderDecoderInputs` (for encoder-decoder) + instance passed in through the :code:`inputs` constructor argument. + + Args: + seq_id: The ID of the sequence. + inputs: The inputs of the sequence. + block_size: The block size of the sequence. Should be the same as the + block size used by the block manager and cache engine. + eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM. + lora_request: LoRA request. + prompt_adapter_request: Prompt Adapter request. + """ + + def __init__( + self, + seq_id: int, + inputs: SingletonInputs, + block_size: int, + eos_token_id: Optional[int] = None, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> None: + self.seq_id = seq_id + self.inputs = SingletonInputsAdapter(inputs) + self.block_size = block_size + self.eos_token_id = eos_token_id + self.lora_request = lora_request + self.prompt_adapter_request = prompt_adapter_request + + self.data = SequenceData.from_seqs(self.prompt_token_ids) + self.output_logprobs: SampleLogprobs = [] + self.output_text = "" + + self.status = SequenceStatus.WAITING + self.stop_reason: Union[int, str, None] = None + + # These are used to keep track of delta outputs + self._last_output_token_ids_offset: int = 0 + self._last_output_text_offset: int = 0 + + # Used for incremental detokenization + self.prefix_offset = 0 + self.read_offset = 0 + # Input + output tokens + self.tokens: Optional[list[str]] = None + + @property + def n_blocks(self) -> int: + return (self.get_len() + self.block_size - 1) // self.block_size + + @property + def prompt(self) -> Optional[str]: + return self.inputs.prompt + + @property + def prompt_token_ids(self) -> list[int]: + return self.inputs.prompt_token_ids + + @property + def prompt_embeds(self) -> Optional[torch.Tensor]: + return self.inputs.prompt_embeds + + @property + def token_type_ids(self) -> list[int]: + return self.inputs.token_type_ids + + @property + def multi_modal_data(self) -> "MultiModalDataDict": + return self.inputs.multi_modal_data + + @property + def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: + return self.inputs.multi_modal_placeholders + + @property + def mm_processor_kwargs(self) -> dict[str, Any]: + return self.inputs.mm_processor_kwargs + + @property + def lora_int_id(self) -> int: + return self.lora_request.lora_int_id if self.lora_request else 0 + + @property + def prompt_adapter_id(self) -> int: + return self.prompt_adapter_request.prompt_adapter_id \ + if self.prompt_adapter_request else 0 + + def get_output_text_to_return(self, buffer_length: int, + delta: bool) -> str: + """If delta is True, only new text since the last call to + this method is returned""" + + # We return the full output text if the sequence is finished. + truncate = buffer_length and not self.is_finished() + if not delta: + return self.output_text[:-buffer_length] if truncate else ( + self.output_text) + length = len(self.output_text) + if truncate: + length -= buffer_length + last_offset = self._last_output_text_offset + if last_offset < length: + self._last_output_text_offset = length + return self.output_text[last_offset:length] + return "" + + def get_output_token_ids_to_return( + self, delta: bool) -> Union[GenericSequence[int], int]: + """If delta is True, only new tokens since the last call to + this method are returned""" + if not delta: + return self.get_output_token_ids() + + output_len = self.get_output_len() + + # Get the number of new tokens + num_new_tokens = output_len - self._last_output_token_ids_offset + self._last_output_token_ids_offset = output_len + + # Return new tokens + if num_new_tokens == 1: + # Optimization for single decode token case + # (which is what we have most of the time) + return self.data._cached_all_token_ids[-1] + + if num_new_tokens == 0: + return [] + + return self.data._cached_all_token_ids[-num_new_tokens:] + + def hash_of_block(self, logical_idx: int) -> int: + # TODO This can produce incorrect hash when block size > prompt size + + # Compute the number of tokens in the sequence + # TODO: The current hashing function is O(L^2). We should optimize + # this in the future. + num_tokens = self.num_hashed_tokens_of_block(logical_idx) + hashed_tokens = self.data.get_prefix_token_ids(num_tokens) + return hash((hashed_tokens, self.lora_int_id)) + + def extra_hash(self) -> Optional[int]: + """ + This function computes an extra hash for a sequence, specifically + designed for prefix caching mode. The final sequence hash is determined + by applying token_ids from the sequence's blocks. + """ + if self.prompt_adapter_id == 0 and self.lora_int_id == 0: + return None + + # NOTE: If there are additional factors influencing the block aside from + # token_ids, include them as input parameters to the hash. + return hash((self.prompt_adapter_id, self.lora_int_id)) + + def num_hashed_tokens_of_block(self, logical_idx: int): + return logical_idx * self.block_size + self.block_size + + def reset_state_for_recompute(self): + """Reset the sequence states for recomputation.""" + self.data.reset_state_for_recompute() + + def append_token_id(self, token_id: int, logprobs: dict[int, + Logprob]) -> None: + assert token_id in logprobs + self.output_logprobs.append(logprobs) + self.data.append_token_id(token_id, logprobs[token_id].logprob) + + def get_len(self) -> int: + return self.data.get_len() + + def get_prompt_len(self) -> int: + return self.data.get_prompt_len() + + def get_output_len(self) -> int: + return self.data.get_output_len() + + def get_token_ids(self) -> list[int]: + return self.data.get_token_ids() + + def get_prompt_token_ids(self) -> tuple[int, ...]: + return self.data.get_prompt_token_ids() + + def get_last_token_id(self) -> int: + return self.data.get_last_token_id() + + def get_output_token_ids(self) -> tuple[int, ...]: + return self.data.get_output_token_ids() + + def get_cumulative_logprob(self) -> float: + return self.data.cumulative_logprob + + def is_finished(self) -> bool: + return SequenceStatus.is_finished(self.status) + + def fork(self, new_seq_id: int) -> "Sequence": + new_seq = copy.deepcopy(self) + new_seq.seq_id = new_seq_id + return new_seq + + def get_num_new_tokens(self) -> int: + """Get the number of new tokens to be computed. + + Returns: + The new number of tokens to be computed. I.e., 1 for decode, or + the remaining prompt size for prefill. + """ + if self.data.stage == SequenceStage.DECODE: + return 1 + return self.data.get_num_uncomputed_tokens() + + def get_num_computed_tokens(self) -> int: + return self.data.get_num_computed_tokens() + + def is_prefill(self) -> bool: + return self.data.stage == SequenceStage.PREFILL + + def __repr__(self) -> str: + return (f"Sequence(seq_id={self.seq_id}, " + f"status={self.status.name}, " + f"num_blocks={self.n_blocks})") + + +class SequenceGroupState(msgspec.Struct, + omit_defaults=True): # type: ignore[call-arg] + """Mutable state tied to a specific sequence group""" + + # for multi-step decoding + num_steps: int = 1 + current_step: int = 0 + + @property + def remaining_steps(self) -> int: + return self.num_steps - self.current_step + + +class SequenceGroup: + """A group of sequences that are generated from the same prompt. + + Args: + request_id: The ID of the request. + seqs: The list of sequences. + sampling_params: The sampling parameters used to generate the outputs. + arrival_time: The arrival time of the request. + lora_request: LoRA request. + pooling_params: The parameters used to generate the pooler + for a pooling model. + pooled_data: The extracted hidden states from a pooling model. + encoder_seq: Optional, the single encoder sequence. Should be None + unless you are working with an encoder/decoder model. + trace_headers: OpenTelemetry trace headers. + prompt_adapter_request: Prompt Adapter request. + priority: User-defined priority of the request. + draft_size: The number of speculative tokens plus one from the target + model; equal to max number of tokens a step can generate + for single-draft speculative decoding but larger than + that for multi-draft SD (currently not supported). + """ + + def __init__(self, + request_id: str, + seqs: list[Sequence], + arrival_time: float, + sampling_params: Optional[SamplingParams] = None, + lora_request: Optional[LoRARequest] = None, + pooling_params: Optional[PoolingParams] = None, + pooled_data: Optional[torch.Tensor] = None, + encoder_seq: Optional[Sequence] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + draft_size: int = 1) -> None: + self.request_id = request_id + self.seqs = seqs + self.first_seq = seqs[0] + self.arrival_time = arrival_time + self.is_single_seq = len(seqs) == 1 + self.seqs_dict = {seq.seq_id: seq for seq in seqs} + + self.sampling_params = sampling_params + self.metrics = RequestMetrics(arrival_time=arrival_time, + last_token_time=arrival_time, + first_scheduled_time=None, + first_token_time=None, + time_in_queue=None, + spec_token_acceptance_counts=[0] * + draft_size) + self.last_token_latency = 0.0 + self.lora_request = lora_request + self.prompt_logprobs: Optional[PromptLogprobs] = None + self.state = SequenceGroupState() + self.pooling_params = pooling_params + self.pooled_data = pooled_data + self.prompt_adapter_request = prompt_adapter_request + self.encoder_seq = encoder_seq + self.trace_headers = trace_headers + self.priority = priority + + self.cached_request_output = None + + @property + def prompt(self) -> Optional[str]: + return self.first_seq.prompt + + @property + def prompt_token_ids(self) -> list[int]: + return self.first_seq.prompt_token_ids + + @property + def encoder_prompt(self) -> Optional[str]: + # There are either 0 or 1 encoder sequences + # If one is present, its prompt is distinct + # from the decoder's. + return (self.encoder_seq.prompt + if self.encoder_seq is not None else None) + + @property + def encoder_prompt_token_ids(self) -> Optional[list[int]]: + # There are either 0 or 1 encoder sequences + # If one is present, its prompt token ids are + # distinct from the decoder's. + return (self.encoder_seq.prompt_token_ids + if self.encoder_seq is not None else None) + + @property + def token_type_ids(self) -> Optional[list[int]]: + return self.first_seq.token_type_ids + + @property + def multi_modal_data(self) -> MultiModalDataDict: + if self.first_seq.multi_modal_data: + return self.first_seq.multi_modal_data + elif self.encoder_seq is not None: + return self.encoder_seq.multi_modal_data + return {} + + @property + def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: + if self.first_seq.multi_modal_data: + return self.first_seq.multi_modal_placeholders + elif self.encoder_seq is not None: + return self.encoder_seq.multi_modal_placeholders + return {} + + @property + def mm_processor_kwargs(self) -> dict[str, Any]: + if self.first_seq.multi_modal_data: + return self.first_seq.mm_processor_kwargs + elif self.encoder_seq is not None: + return self.encoder_seq.mm_processor_kwargs + return {} + + @property + def lora_int_id(self) -> int: + return self.lora_request.lora_int_id if self.lora_request else 0 + + @property + def prompt_adapter_id(self) -> int: + return self.prompt_adapter_request.prompt_adapter_id \ + if self.prompt_adapter_request else 0 + + @property + def prompt_adapter_num_virtual_tokens(self) -> int: + return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\ + if self.prompt_adapter_request else 0 + + def init_multi_step(self, num_steps: int) -> None: + self.state.num_steps = num_steps + self.state.current_step = 0 + + def init_multi_step_from_lookahead_slots(self, num_lookahead_slots: int, + num_scheduler_steps: int, + is_multi_step: bool, + enable_chunking: bool) -> None: + + if not is_multi_step: + self.init_multi_step(num_steps=num_scheduler_steps) + return + + # Multi-Step case + is_prefill = self.is_prefill() + + # The asserts below reflect the expectations of the current system. + if is_prefill and enable_chunking: + assert num_lookahead_slots == num_scheduler_steps + self.init_multi_step(num_steps=num_lookahead_slots) + else: + is_decode: bool = not is_prefill + # If it is a prefill, num_lookahead_slots must be 0 + assert num_lookahead_slots == 0 or is_decode + # If it is a decode, num_lookahead_slots + 1 must match + # the scheduler steps. + assert num_lookahead_slots + 1 == num_scheduler_steps or is_prefill + self.init_multi_step(num_steps=num_lookahead_slots + 1) + + def set_last_token_time(self, now: float) -> None: + """Sets the last token time for Request level timings.""" + # If still in prefill phase, assertion fails. + assert not self.is_prefill(), ( + "seq_group.set_last_token_time() should not be called " + "if the seq_group is in prefill phase.") + self.last_token_latency = now - self.metrics.last_token_time + self.metrics.last_token_time = now + + def get_last_token_latency(self) -> float: + """Returns the latency of the last token.""" + assert not self.is_prefill(), ( + "seq_group.get_last_token_latency() should not be called " + "if the seq_group is in prefill phase.") + return self.last_token_latency + + def maybe_set_first_token_time(self, time: float) -> None: + """Sets the first token time for Request level timings.""" + # Note: in a case where a sequence_group is swapped and + # recomputed, the time between iterations is counted + # in TPOT, rather than recalculating TTFT (since from the ) + # POV of the user, there is simply a long generation delay. + if (self.metrics.first_token_time is None + and self.first_seq.get_output_len() == 1): + self.metrics.first_token_time = time + + def maybe_set_first_scheduled_time(self, time: float) -> None: + """Sets the first scheduled time and time in queue for Request + level timings.""" + if self.metrics.first_scheduled_time is None: + self.metrics.first_scheduled_time = time + self.metrics.time_in_queue = time - self.metrics.arrival_time + + def set_finished_time(self, time: Optional[float]) -> None: + """Sets the finished time for Request level timings.""" + self.metrics.finished_time = time + + def get_max_num_running_seqs(self) -> int: + """The maximum number of sequences running in parallel in the remaining + lifetime of the request.""" + if self.is_single_seq: + return 0 if self.first_seq.is_finished() else 1 + return self.num_seqs() - self.num_finished_seqs() + + def get_seqs( + self, + status: Optional[SequenceStatus] = None, + ) -> list[Sequence]: + if status is None: + return self.seqs + + if self.is_single_seq: + return self.seqs if self.first_seq.status == status else [] + + return [seq for seq in self.seqs if seq.status == status] + + def is_encoder_decoder(self) -> bool: + return self.encoder_seq is not None + + def get_encoder_seq(self) -> Optional[Sequence]: + return self.encoder_seq + + def get_finished_seqs(self) -> list[Sequence]: + if self.is_single_seq: + return self.seqs if self.first_seq.is_finished() else [] + + return [seq for seq in self.seqs if seq.is_finished()] + + def update_num_computed_tokens(self, num_new_computed_tokens: int): + """Update number of tokens computed so far.""" + for seq in self.seqs: + if not seq.is_finished(): + seq.data.update_num_computed_tokens(num_new_computed_tokens) + + def get_num_uncomputed_tokens(self) -> int: + num_uncomputed_tokens = 0 + for seq in self.seqs: + if not seq.is_finished(): + num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens() + return num_uncomputed_tokens + + def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: + # Optimization. We don't need to call get_seqs if we don't need to + # filter by states. + if status is None: + return len(self.seqs) + + if self.is_single_seq: + return 1 if self.seqs[0].status == status else 0 + + return len(self.get_seqs(status)) + + def num_finished_seqs(self) -> int: + if self.is_single_seq: + return 1 if self.seqs[0].is_finished() else 0 + return len(self.get_finished_seqs()) + + def is_finished(self) -> bool: + if self.is_single_seq: + return self.first_seq.is_finished() + return all(seq.is_finished() for seq in self.seqs) + + def is_prefill(self) -> bool: + return self.first_seq.is_prefill() + + def __repr__(self) -> str: + return (f"SequenceGroup(request_id={self.request_id}, " + f"sampling_params={self.sampling_params}, " + f"num_seqs={len(self.seqs)})") + + +class SequenceGroupMetadataDelta( + msgspec.Struct, + tag=True, # type: ignore[call-arg] + array_like=True, # type: ignore[call-arg] + omit_defaults=True): # type: ignore[call-arg] + """Delta of SequenceGroupMetadata. + + After sending the first SequenceGroupMetadata, vLLM scheduler + only sends delta to reduce the data payload size. + """ + seq_data_delta: dict[int, SequenceDataDelta] + request_id: str + block_tables: dict[int, list[int]] + is_prompt: bool + do_sample: bool = True + token_chunk_size: Optional[int] = None + computed_block_nums: Optional[list[int]] = None + state: Optional[SequenceGroupState] = msgspec.field( + default_factory=lambda: SequenceGroupState()) + + +class SequenceGroupMetadata( + msgspec.Struct, + tag=True, # type: ignore[call-arg] + array_like=True, # type: ignore[call-arg] + omit_defaults=True): # type: ignore[call-arg] + """Metadata for a sequence group. Used to create `AttentionMetadata`. + + Args: + request_id: The ID of the request. + is_prompt: Whether the request is at prompt stage. + seq_data: The sequence data. (Seq id -> sequence data) + sampling_params: The sampling parameters used to generate the outputs. + block_tables: The block tables. (Seq id -> list of physical block + numbers) + do_sample: True if sampling is required. Sampling is not required when + e.g., prefill is chunked, and the current iteration only computes + query tokens for prefill, we don't need sampling. + token_chunk_size: The number of tokens to be processed (per sequence). + None if chunking is not required. + lora_request: LoRA request. + computed_block_nums: The block numbers that are already computed, + used in prefix caching. + state: Internal state tied to this sequence group. + multi_modal_data: Multi modal data. + mm_processor_kwargs: Multimodal input processor / mapper overrides. + encoder_seq_data: Optional sequence data for encoder prompt + (SequenceGroup.encoder_seq). Should be None + unless you are working with an encoder/decoder + model. + cross_block_table: Optional cross-attention block table associated + with the encoder prompt + (SequenceGroup.encoder_seq). Should be None + unless you are working with an encoder/decoder + model. + prompt_adapter_request: Prompt Adapter request. + """ + + request_id: str + is_prompt: bool + seq_data: dict[int, SequenceData] + sampling_params: Optional[SamplingParams] + block_tables: dict[int, list[int]] + do_sample: bool = True + pooling_params: Optional[PoolingParams] = None + lora_request: Optional[LoRARequest] = None + computed_block_nums: Optional[list[int]] = None + state: Optional[SequenceGroupState] = msgspec.field( + default_factory=lambda: SequenceGroupState()) + # "MultiModalDataDict" types. We have to use Any due to msgspec + # doesn't allow to have union of 2 different dicts. + token_type_ids: Optional[list[int]] = None + multi_modal_data: Optional[Any] = None + multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None + mm_processor_kwargs: Optional[dict[str, Any]] = None + encoder_seq_data: Optional[SequenceData] = None + cross_block_table: Optional[list[int]] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None + token_chunk_size: Optional[int] = None + + ### Stateful fields that are lazily defined. ### + # The number of speculative tokens adopted in this request. + # None means specuative decoding is not used. + # Zero means speculative decoding is disabled for some reasons. + # TODO: We should maintain this states out of the sequence group. + num_speculative_tokens: Optional[int] = None + + def __post_init__(self): + if self.seq_data is not None and self.token_chunk_size is None: + if self.is_prompt: + self.token_chunk_size = next(iter( + self.seq_data.values())).get_len() + else: + self.token_chunk_size = 1 + + @property + def lora_int_id(self) -> int: + return self.lora_request.lora_int_id if self.lora_request else 0 + + @property + def prompt_adapter_id(self) -> int: + return self.prompt_adapter_request.prompt_adapter_id \ + if self.prompt_adapter_request else 0 + + @property + def prompt_adapter_num_virtual_tokens(self) -> int: + return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \ + if self.prompt_adapter_request else 0 + + # Multi-Step Chunked-Prefill property + @property + def is_single_step_prompt(self) -> bool: + # do_sample is true, only when the token_chunk_size matches the + # num_uncomputed_tokens of the sequence. This indicates that + # the prompt will finish processing in a single `execute_model` + # step. + return self.is_prompt and self.do_sample + + def get_first_seq_id(self) -> int: + # This is an efficient way of fetching the seq_id when + # we know this SequenceGroup has only one sequence. + return next(iter(self.seq_data)) + + def apply_delta(self, + sequence_group_metadata_delta: SequenceGroupMetadataDelta): + for id, delta in sequence_group_metadata_delta.seq_data_delta.items(): + self.seq_data[id].apply_delta(delta) + assert self.request_id == sequence_group_metadata_delta.request_id + self.block_tables = sequence_group_metadata_delta.block_tables + self.token_chunk_size = sequence_group_metadata_delta.token_chunk_size + self.do_sample = sequence_group_metadata_delta.do_sample + self.is_prompt = sequence_group_metadata_delta.is_prompt + + def finish_step(self) -> None: + assert self.state is not None + assert self.state.current_step < self.state.num_steps, \ + f"current step {self.state.current_step}, num_steps {self.state.num_steps}" # noqa + self.state.current_step += 1 + + +class SequenceOutput( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] + """The model output associated with a sequence. + + Args: + parent_seq_id: The ID of the parent sequence (for forking in beam + search). + output_token: The output token ID. + logprobs: The logprobs of the output token. + (Token id -> logP(x_i+1 | x_0, ..., x_i)) + """ + parent_seq_id: int + output_token: int + logprobs: dict[int, Logprob] + + def __repr__(self) -> str: + return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " + f"output_token={self.output_token}, " + f"logprobs={self.logprobs})") + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SequenceOutput): + raise NotImplementedError() + equal = (self.parent_seq_id == other.parent_seq_id + and self.output_token == other.output_token) + log_probs_equal = other.logprobs == self.logprobs + return equal and log_probs_equal + + +class SequenceGroupOutput(ABC): + """The base class for model outputs associated with a sequence group.""" + + @abstractmethod + def __repr__(self) -> str: + pass + + @abstractmethod + def __eq__(self, other: object) -> bool: + pass + + +class CompletionSequenceGroupOutput( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] + """The model output associated with a completion sequence group.""" + __metaclass__ = SequenceGroupOutput + samples: list[SequenceOutput] + # Prompt logprob for each prompt query token. + prompt_logprobs: Optional[PromptLogprobs] + step_index: Optional[int] = 0 + + def __repr__(self) -> str: + return (f"CompletionSequenceGroupOutput(samples={self.samples}, " + f"prompt_logprobs={self.prompt_logprobs})") + + def __eq__(self, other: object) -> bool: + if not isinstance(other, CompletionSequenceGroupOutput): + raise NotImplementedError() + return (self.samples == other.samples + and self.prompt_logprobs == other.prompt_logprobs) + + +class PoolingSequenceGroupOutput( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True, # type: ignore[call-arg] +): + """The model output associated with a pooling sequence group.""" + __metaclass__ = SequenceGroupOutput + # Annotated as Any to be compatible with msgspec + # The actual type is in SequenceGroup.pooled_data + data: Any + + def __repr__(self) -> str: + return f"PoolingSequenceGroupOutput(data={self.data}" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, PoolingSequenceGroupOutput): + raise NotImplementedError() + return self.data == other.data + + +# cannot use msgspec.Struct here because Dynamo does not support it +@dataclass +class IntermediateTensors: + """For all pipeline stages except the last, we need to return the hidden + states and residuals to be sent to the next stage. This data structure + contains the hidden states and residuals for a request. + """ + + tensors: dict[str, torch.Tensor] + + def __init__(self, tensors): + # manually define this function, so that + # Dynamo knows `IntermediateTensors()` comes from this file. + # Otherwise, dataclass will generate this function by evaluating + # a string, and we will lose the information about the source file. + self.tensors = tensors + + def __getitem__(self, key: Union[str, slice]): + if isinstance(key, str): + return self.tensors[key] + elif isinstance(key, slice): + return self.__class__({k: v[key] for k, v in self.tensors.items()}) + + def __setitem__(self, key: str, value: torch.Tensor): + self.tensors[key] = value + + def items(self): + return self.tensors.items() + + def __len__(self): + return len(self.tensors) + + def __eq__(self, other: object): + return isinstance(other, self.__class__) and self + + def __repr__(self) -> str: + return f"IntermediateTensors(tensors={self.tensors})" + + +class PoolerOutput( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] + """The output from a pooling operation in the pooling model.""" + outputs: list[PoolingSequenceGroupOutput] + + def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput: + return self.outputs[idx] + + def __setitem__(self, idx: int, value: PoolingSequenceGroupOutput): + self.outputs[idx] = value + + def __len__(self): + return len(self.outputs) + + def __eq__(self, other: object): + return isinstance(other, + self.__class__) and self.outputs == other.outputs + + +def get_all_seq_ids( + seq_group_metadata_list: list[SequenceGroupMetadata]) -> list[int]: + """Given a list of SequenceGroupMetadata, create a list of all + sequence ids. + """ + return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data] + + +def get_all_seq_ids_and_request_ids( + seq_group_metadata_list: list[SequenceGroupMetadata] +) -> tuple[list[int], dict[str, set[int]]]: + """Given a list of SequenceGroupMetadata, create a list of all + sequence ids. + """ + seq_ids: list[int] = [] + request_id_seq_ids_mapping: defaultdict[str, set[int]] = defaultdict(set) + for sg in seq_group_metadata_list: + for seq_id in sg.seq_data: + seq_ids.append(seq_id) + request_id_seq_ids_mapping[sg.request_id].add(seq_id) + return seq_ids, request_id_seq_ids_mapping + + +class HiddenStates(msgspec.Struct, array_like=True, + omit_defaults=True): # type: ignore[call-arg] + """Hidden states corresponding to in-progress sequences. + Used in speculative decoding to pass hidden states from + the target model to the proposer model. + + seq_ids are the sequence ids of each entry of the batch + dimension of the hidden_states tensor""" + # Scorer hidden states. For prefill step, it is used for hidden states of + # all tokens, whereas for decode step, it use used for last accepted tokens. + hidden_states: torch.Tensor + # The sequence group metadata list. Only needed for decode step. + seq_group_metadata_list: Optional[list[SequenceGroupMetadata]] = None + # Scorer hidden states of the 2nd last token proposed by the proposer ( + # irrespective of whether it was accepted or not). Only used for cases when + # last proposed token is accepted (i.e., in case of bonus tokens). For the + # case of no bonus tokens, these are ignored. + second_last_token_hidden_states: Optional[torch.Tensor] = None + + _seq_ids: list[int] = msgspec.field(default_factory=list) + + def __post_init__(self): + if self.seq_group_metadata_list is not None: + # assert len(self.seq_group_metadata_list) == len(self.hidden_states) + self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list) + + @property + def seq_ids(self) -> list[int]: + return self._seq_ids + + def update(self, + hidden_states: torch.Tensor, + seq_group_metadata_list: list[SequenceGroupMetadata], + second_last_token_hidden_states: Optional[torch.Tensor] = None): + """Update hidden states from target model invocation. Only used for + decode steps""" + assert len(seq_group_metadata_list) == len(hidden_states) + self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list)) + if self.seq_group_metadata_list is not None: + self.seq_group_metadata_list.extend(seq_group_metadata_list) + self.hidden_states = torch.cat([self.hidden_states, hidden_states]) + + if self.second_last_token_hidden_states is not None: + # Adding dummy hidden_states to this to maintain same shape + self.second_last_token_hidden_states = torch.cat([ + self.second_last_token_hidden_states, + torch.zeros_like(hidden_states) + if second_last_token_hidden_states is None else + second_last_token_hidden_states + ]) + + def prune(self, + seq_group_metadata_list: list[SequenceGroupMetadata]) -> None: + """Prune to provided list of sequence ids. Only used for decode steps. + """ + # Currently this prunes all seq_ids not present in + # seq_group_metadata_list which might cause problems where a sequence + # may be "paused" then "resumed" later. This should only prune sequences + # which are confirmed to be aborted. + seq_ids = get_all_seq_ids(seq_group_metadata_list) + if seq_ids != self._seq_ids: + # Batch contents changed - prune removed sequences. + index = [self._seq_ids.index(seq_id) for seq_id in seq_ids] + self.hidden_states = self.hidden_states[index] + if self.seq_group_metadata_list is not None: + self.seq_group_metadata_list = [ + self.seq_group_metadata_list[i] for i in index + ] + if self.second_last_token_hidden_states is not None: + self.second_last_token_hidden_states = self\ + .second_last_token_hidden_states[index] + self._seq_ids = seq_ids + + def expand_with_bonus_tokens( + self, seq_with_bonus_token_in_last_step: set) -> None: + """Expand hidden states for sequences with bonus tokens. This is in + alignment with `MultiStepWorker._expand_execute_model_request`.""" + if self.second_last_token_hidden_states is None \ + or not seq_with_bonus_token_in_last_step: + return + + index = [] + expanded_seq_ids = [] + expanded_seq_group_metadata_list = [] + for i, seq_id in enumerate(self._seq_ids): + if seq_id in seq_with_bonus_token_in_last_step: + index.append(i + len(self._seq_ids)) + expanded_seq_ids.append(seq_id) + if self.seq_group_metadata_list is not None: + expanded_seq_group_metadata_list.append( + self.seq_group_metadata_list[i]) + index.append(i) + expanded_seq_ids.append(seq_id) + if self.seq_group_metadata_list is not None: + expanded_seq_group_metadata_list.append( + self.seq_group_metadata_list[i]) + + self._seq_ids = expanded_seq_ids + self.seq_group_metadata_list = expanded_seq_group_metadata_list + self.hidden_states = torch.cat( + [self.hidden_states, self.second_last_token_hidden_states])[index] + + +class ExecuteModelRequest( + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True): # type: ignore[call-arg] + """The model execution request, containing CPU metadata only. The LLM + engine should create an instance of this class for each request batch.""" + # The sequence group metadata list. + seq_group_metadata_list: list[Union[SequenceGroupMetadata, + SequenceGroupMetadataDelta]] + # Blocks to swap in. List of CPU -> GPU block number. + blocks_to_swap_in: list[tuple[int, + int]] = msgspec.field(default_factory=list) + # Blocks to swap out. List of GPU -> CPU block number. + blocks_to_swap_out: list[tuple[int, + int]] = msgspec.field(default_factory=list) + # Blocks to copy. Source to dest block. + blocks_to_copy: list[tuple[int, int]] = msgspec.field(default_factory=list) + # Virtual engine ID for pipeline parallel. + virtual_engine: int = 0 + # The number of slots for lookahead decoding. + num_lookahead_slots: int = 0 + # The number of requests in the running queue. + running_queue_size: int = 0 + # Optional hidden states from prior step. + previous_hidden_states: Optional[HiddenStates] = None + # The number of forward steps to run. + num_steps: int = 1 + # The step index for spec model input. + spec_step_idx: Optional[int] = None + # Finished request ids since last step. + finished_requests_ids: list[str] = msgspec.field(default_factory=list) + # The last sampled token ids for multi step decoding. + last_sampled_token_ids: Optional[torch.Tensor] = None + # Async callback + async_callback: Optional[Callable] = None + + @property + def is_first_multi_step(self) -> bool: + # TODO(will) make this be able to handle batches with variable number of + # steps + assert len(self.seq_group_metadata_list) > 0 + first_seq_group = self.seq_group_metadata_list[0] + assert first_seq_group.state is not None + return first_seq_group.state.current_step == 0 + + @property + def is_last_step(self) -> bool: + # TODO(will) make this be able to handle batches with variable number of + # steps + assert len(self.seq_group_metadata_list) > 0 + first_seq_group = self.seq_group_metadata_list[0] + assert first_seq_group.state is not None + return first_seq_group.state.remaining_steps == 1 + + @property + def current_step(self) -> int: + # TODO(will) make this be able to handle batches with variable number of + # steps + assert len(self.seq_group_metadata_list) > 0 + state = self.seq_group_metadata_list[0].state + assert state is not None + return state.current_step + + def clone( + self, seq_group_metadata_list: list[Union[SequenceGroupMetadata, + SequenceGroupMetadataDelta]] + ) -> "ExecuteModelRequest": + """Clone the request with a new sequence group metadata list.""" + return ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=self.blocks_to_swap_in.copy(), + blocks_to_swap_out=self.blocks_to_swap_out.copy(), + blocks_to_copy=self.blocks_to_copy.copy(), + virtual_engine=self.virtual_engine, + num_lookahead_slots=self.num_lookahead_slots, + running_queue_size=self.running_queue_size, + previous_hidden_states=copy.copy(self.previous_hidden_states), + num_steps=self.num_steps, + finished_requests_ids=self.finished_requests_ids, + last_sampled_token_ids=self.last_sampled_token_ids.clone() + if self.last_sampled_token_ids is not None else None, + async_callback=self.async_callback) + + +@dataclass +class SequenceGroupBase: + group_id: str # the original request id before splitting + + assembled_seq_group: Optional[SequenceGroup] = None + + # seq id to a unique index inside this group + seq_id_to_index: dict[str, int] = field(default_factory=dict) + + # seq ids to be finished + to_be_finished: dict[str, SequenceGroup] = field(default_factory=dict) + + # seq id to finished sequences + finished_reqs: dict[str, SequenceGroup] = field(default_factory=dict) + + streaming: bool = False + + output_produced: bool = False + + @staticmethod + def add_request(request_id: str, engine, params, *args, **kwargs): + """When we are ready to add a request with request_id and params + into the engine, we can split the request into multiple requests. + """ + raise NotImplementedError + + def finish_seq(self, seq: SequenceGroup): + """The sequence `seq` finishes, we should record the information. + """ + del self.to_be_finished[seq.request_id] + self.finished_reqs[seq.request_id] = seq + + def maybe_assemble_group( + self, seq_group: SequenceGroup) -> Optional[SequenceGroup]: + """Assemble the sequence group, for producing the final + output, or adding request in the engine again. + """ + raise NotImplementedError + + +class ParallelSampleSequenceGroup(SequenceGroupBase): + + @staticmethod + def add_request(request_id: str, engine, params, **kwargs): + original_params = params + group = ParallelSampleSequenceGroup(request_id) + seqs = [] + for i in range(original_params.n): + request_id_i = f"{request_id}_parallel_sample_{i}" + group.seq_id_to_index[request_id_i] = i + params = copy.deepcopy(original_params) + params.n = 1 + if params.seed is not None: + params.seed += i + seq_group = engine._add_processed_request( + request_id_i, + params=params, + **kwargs, + ) # type: ignore + assert seq_group is not None + engine.seq_id_to_seq_group[request_id_i] = group + group.to_be_finished[request_id_i] = seq_group + seqs.append(seq_group.seqs[0]) + + # for parallel sampling, the `assembled_seq_group` is always + # available, since we have all the sequences ready, and they + # will not change. + group.assembled_seq_group = SequenceGroup( + request_id=request_id, + seqs=seqs, + arrival_time=seq_group.arrival_time, + sampling_params=original_params, + lora_request=seq_group.lora_request, + pooling_params=seq_group.pooling_params, + pooled_data=seq_group.pooled_data, + encoder_seq=seq_group.encoder_seq, + trace_headers=seq_group.trace_headers, + prompt_adapter_request=seq_group.prompt_adapter_request, + priority=seq_group.priority, + ) + + group.streaming = params.output_kind == RequestOutputKind.DELTA + group.output_produced = False + + def maybe_assemble_group( + self, seq_group: SequenceGroup) -> Optional[SequenceGroup]: + + # in the streaming mode, we will return the assembled sequence + # for the first remaining sequence, and then return None for the + # rest of sequences + if self.streaming: + first_remaining_id = next(iter(self.to_be_finished)) + if seq_group.request_id == first_remaining_id: + return self.assembled_seq_group + return None + + # in the non-streaming mode, we will return the assembled sequence + # when the last sequences finishes, and then return None for the + # rest of the time + if (len(self.to_be_finished) == 1 + and seq_group.request_id in self.to_be_finished + and seq_group.is_finished()): + assert self.assembled_seq_group is not None + params = self.assembled_seq_group.sampling_params + assert isinstance(params, SamplingParams) + if not self.output_produced: + self.output_produced = True + if params._real_n is not None: + # Get the top-n sequences. + n = params._real_n or params.n + seqs = self.assembled_seq_group.seqs + sorting_key = lambda seq: seq.get_cumulative_logprob() + sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) + top_n_seqs = sorted_seqs[:n] + self.assembled_seq_group.seqs = top_n_seqs + return self.assembled_seq_group + if self.output_produced: + return None + return None diff --git a/vllm/spec_decode/__init__.py b/vllm/spec_decode/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py new file mode 100644 index 0000000..4ef1adc --- /dev/null +++ b/vllm/spec_decode/batch_expansion.py @@ -0,0 +1,538 @@ +# SPDX-License-Identifier: Apache-2.0 + +from array import array +from itertools import chain, count +from typing import Iterator, List, Optional, Tuple + +import torch + +from vllm import SamplingParams +from vllm.distributed import get_pp_group +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE, + ExecuteModelRequest, SequenceData, + SequenceGroupMetadata, get_all_seq_ids) +from vllm.spec_decode.interfaces import (SpeculativeProposals, + SpeculativeScorer, SpeculativeScores) +from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len + +SeqId = int +TargetSeqId = int +TokenId = int + +DEFAULT_SIMPLE_SAMPLING_PARAMS = SamplingParams() + + +class BatchExpansionTop1Scorer(SpeculativeScorer): + """Implements a speculative scorer that uses batch expansion to get + probabilities of speculative tokens according to the scoring model. + + Batch expansion converts a list of sequences and multiple query positions + to a new batch of sequences, each with a single query position. This allows + for MQA-like scoring in speculative decoding without requiring an MQA + kernel. + + It is strictly less efficient than MQA scoring. + + It only supports scoring the top1 proposal tokens of the proposer, instead + of topk/tree. + """ + + @nvtx_range("BatchExpansionTop1Scorer.score_proposals") + def score_proposals( + self, + execute_model_req: ExecuteModelRequest, + proposals: SpeculativeProposals, + ) -> SpeculativeScores: + """Score the proposed tokens via the scorer model. + + This converts each input sequence to a set of k+1 target sequences. The + target sequences have the unique continuations to be scored and a + unique sequence ID that is different from all input sequence ids. + + If a speculative sequence length would exceed the max model length, then + no speculation is produced for that sequence. + + Args: + execute_model_req: The execution request. + proposals: The speculative proposals to score. + Returns: + SpeculativeScores: The scores of each speculative token, along with + which sequences were ignored during scoring. + """ + + # TODO(cade) perform this on GPU to remove blocking call. + proposal_lens_list = proposals.proposal_lens.tolist() + proposal_token_ids_list = proposals.proposal_token_ids.tolist() + + # Filter the list to ignore invalid proposals. + proposal_token_ids_list_without_skips = [ + proposals for proposals in proposal_token_ids_list + if VLLM_INVALID_TOKEN_ID not in proposals + ] + + (spec_indices, non_spec_indices, target_seq_group_metadata_list, + num_scoring_tokens) = self._expand_batch( + seq_group_metadata_list=execute_model_req.seq_group_metadata_list, + proposal_token_ids_list=proposal_token_ids_list_without_skips, + proposal_lens_list=proposal_lens_list, + ) + + target_sampler_output = self._scorer_worker.execute_model( + execute_model_req=execute_model_req.clone( + seq_group_metadata_list=target_seq_group_metadata_list)) + if get_pp_group().is_last_rank: + assert len( + target_sampler_output) == 1, "expected single-step output" + target_sampler_output = target_sampler_output[0] + # Store hidden states from target model execution, BxD. + sampled_token_probs = target_sampler_output.sampled_token_probs + logprobs = target_sampler_output.logprobs + sampled_token_ids = target_sampler_output.sampled_token_ids + hidden_states = target_sampler_output.hidden_states + prefill_hidden_states = target_sampler_output.prefill_hidden_states + tensors = { + "sampled_token_probs": sampled_token_probs, + "logprobs": logprobs, + "sampled_token_ids": sampled_token_ids, + "hidden_states": hidden_states, + "prefill_hidden_states": prefill_hidden_states + } + get_pp_group().broadcast_tensor_dict( + tensors, src=get_pp_group().world_size - 1) + else: + tensors = get_pp_group().broadcast_tensor_dict( + src=get_pp_group().world_size - 1) + sampled_token_probs = tensors["sampled_token_probs"] + logprobs = tensors["logprobs"] + sampled_token_ids = tensors["sampled_token_ids"] + hidden_states = tensors["hidden_states"] + prefill_hidden_states = tensors["prefill_hidden_states"] + target_sampler_output = SamplerOutput( + outputs=None, + sampled_token_probs=sampled_token_probs, + logprobs=logprobs, + sampled_token_ids=sampled_token_ids, + hidden_states=hidden_states, + prefill_hidden_states=prefill_hidden_states) + + if not non_spec_indices: + # All sequence groups in batch have spec decoding enabled + return self._contract_batch_all_spec( + target_sampler_output=target_sampler_output, + proposals=proposals, + ) + else: + # Batch has a mix of spec decode enabled and disabled seq groups + return self._contract_batch( + execute_model_req.seq_group_metadata_list, + target_sampler_output=target_sampler_output, + proposals=proposals, + num_scoring_tokens=num_scoring_tokens, + non_spec_indices=non_spec_indices, + spec_indices=spec_indices, + k=execute_model_req.num_lookahead_slots, + ) + + def _expand_batch( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + proposal_token_ids_list: List[List[TokenId]], + proposal_lens_list: List[int], + ) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]: + """Given the input sequences and potentially multiple corresponding + proposal tokens, create a new batch where each sequence has a single + query token. + """ + + # vLLM currently only supports proposal lens equal to zero or the batch + # proposal len. This adds some complexity (splitting the batch into spec + # and non spec sequences) and should be removed in the future. It can be + # done by supporting per-sequence proposal lens. + (spec_seqs, spec_indices), (non_spec_seqs, non_spec_indices) = \ + split_batch_by_proposal_len( + seq_group_metadata_list, proposal_lens_list) + + spec_expanded_seqs = self._create_scoring_model_input( + seq_group_metadata_list=spec_seqs, + proposal_token_ids=proposal_token_ids_list, + # NOTE: We determine the seq ids in the expanded batch using the + # full seq_group_metadata_list, instead of only spec_seqs. + target_seq_ids_iter=self._create_target_seq_id_iterator( + seq_ids=get_all_seq_ids(seq_group_metadata_list)), + ) + + num_scoring_tokens = len(spec_expanded_seqs) + # Batch speculative and non-speculative (e.g. chunked prefill) requests + # but make sure order is prefill|decode due to backend requirement. + target_seq_group_metadata_list = non_spec_seqs + spec_expanded_seqs + + return (spec_indices, non_spec_indices, target_seq_group_metadata_list, + num_scoring_tokens) + + def _contract_non_speculative( + self, scores: SpeculativeScores, + seq_group_metadata_list: List[SequenceGroupMetadata], + non_spec_indices: List[int], non_spec_outputs: SpeculativeScores, + has_prompt_log: bool) -> SpeculativeScores: + """ + Augment input `scores` with non-speculative requests outputs. + This includes decode requests with speculation turned off, as well + as prefill requests when `enable_chunked_prefill` is set. + For the latter, prefills are further separated into terminal and + non-terminal chunks (from which no token is sampled). + """ + if not non_spec_indices: + return scores + + if has_prompt_log: + # When prompt_logprobs is enabled, prefills yield output token + # (and respective prob) in the last entry (prompt|out): + # [.|.|.|prefill0_out|.|prefill1_out|decode0_out|..]. + # With chunked prefill, non-terminal chunks have -1 on each + # position: they're still picked, but they're discarded later. + seq_meta = seq_group_metadata_list + nospec_sizes = torch.tensor([ + seq_meta[i].token_chunk_size if seq_meta[i].is_prompt else 1 + for i in non_spec_indices + ]) + nospec_sampled_token_idxs = torch.cumsum(nospec_sizes, 0).add_(-1) + else: + # In this case only sampled tokens are returned, select all. + nospec_sampled_token_idxs = list( + range(len(non_spec_outputs.token_ids))) + + scores.token_ids[non_spec_indices, :1] = \ + non_spec_outputs.token_ids[nospec_sampled_token_idxs].unsqueeze(1) + scores.probs[non_spec_indices, :1, :] = \ + non_spec_outputs.probs[nospec_sampled_token_idxs].unsqueeze(1) + scores.logprobs[non_spec_indices, :1, :] = \ + non_spec_outputs.logprobs[nospec_sampled_token_idxs].unsqueeze(1) + if scores.hidden_states is not None: + assert non_spec_outputs.hidden_states is not None + scores.hidden_states[non_spec_indices, :1, :] = \ + non_spec_outputs.hidden_states[nospec_sampled_token_idxs].unsqueeze(1) + return scores + + def _contract_batch( + self, + contracted_seq_group_metadata_list: List[SequenceGroupMetadata], + target_sampler_output: SamplerOutput, + proposals: SpeculativeProposals, num_scoring_tokens: int, + non_spec_indices: List[int], spec_indices: List[int], + k: int) -> SpeculativeScores: + """Contract the expanded batch back into its original size. + This maps the scores of speculative tokens back to their original + sequences. + + contracted_bs is the original batch size, and the batch size that the + target_sampler_output will be contracted to. + """ + contracted_bs = len(contracted_seq_group_metadata_list) + (target_token_ids, target_probs, target_logprobs, target_hidden_states, + non_spec_target_token_ids, non_spec_target_probs, + non_spec_target_logprobs, + non_spec_target_hidden_states) = self._split_scoring_output( + target_sampler_output, num_scoring_tokens) + + # Map distinct sequences used to score each token + # of shape [batch_size * k + 1] back to [batch_size, k + 1]. + expanded_batch_size, k = proposals.proposal_token_ids.shape + + # The number of tokens in the expanded batch used for speculation is + # equal to the total expanded batch size minus the number of samples for + # non-speculative sequences, prefill chunks with no out tokens included + non_spec_expanded_bs = len(non_spec_indices) + spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs + + target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1) + target_probs = target_probs.reshape(*target_token_ids.shape, + self._vocab_size) + target_logprobs = target_logprobs.reshape(target_probs.shape) + + if target_hidden_states is not None: + target_hidden_states = target_hidden_states.reshape( + *target_token_ids.shape, target_hidden_states.shape[-1]) + + all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1), + fill_value=-1) + all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size) + all_logprobs = target_logprobs.new_full(size=all_probs.shape, + fill_value=-float("inf")) + + if target_sampler_output.hidden_states is not None: + all_hidden_states = target_hidden_states.new_zeros( + size=(contracted_bs, k + 1, target_hidden_states.shape[-1])) + else: + all_hidden_states = None + + has_prompt_log = any((sg.sampling_params.prompt_logprobs + and sg.sampling_params.prompt_logprobs > 0) + for sg in contracted_seq_group_metadata_list) + # When prompt logprobs is enabled, lens of returned tensors go from + # n_sampled (requests with do_sample=True) to n_prompt+n_prefills. + # We adjust stride accordingly to get the generated tokens and + # their probs, but pass on prompt_logprobs as is. + prompt_logprobs = None + if (not self._scorer_worker.model_runner.disable_logprobs\ + and has_prompt_log): + prompt_logprobs = [ + o.prompt_logprobs for o in target_sampler_output.outputs + ] + elif not has_prompt_log: + # When prompt logprobs are not to be returned, + # we can ignore non-terminal chunks (no out token). + non_spec_indices = [ + idx for idx in non_spec_indices + if contracted_seq_group_metadata_list[idx].do_sample + ] + + # "Contract" speculative. + if spec_indices: + all_tokens[spec_indices] = target_token_ids + all_probs[spec_indices] = target_probs + all_logprobs[spec_indices] = target_logprobs + if all_hidden_states is not None: + all_hidden_states[spec_indices] = target_hidden_states + + spec_scores = SpeculativeScores(probs=all_probs, + token_ids=all_tokens, + logprobs=all_logprobs, + hidden_states=all_hidden_states, + prompt_logprobs=prompt_logprobs) + + non_spec_outputs = SpeculativeScores( + probs=non_spec_target_probs, + token_ids=non_spec_target_token_ids, + logprobs=non_spec_target_logprobs, + hidden_states=non_spec_target_hidden_states) + # Contract remaining nonspec entries based on non_spec_indices, if any. + return self._contract_non_speculative( + spec_scores, contracted_seq_group_metadata_list, non_spec_indices, + non_spec_outputs, has_prompt_log) + + def _contract_batch_all_spec( + self, + target_sampler_output: SamplerOutput, + proposals: SpeculativeProposals, + ) -> SpeculativeScores: + """Contract the expanded batch back into its original size. + This maps the scores of speculative tokens back to their original + sequences. + + It assumes all sequences in the batch were previously expanded. + """ + + # Map distinct sequences used to score each token + # of shape [batch_size * k + 1] back to [batch_size, k + 1]. + contracted_bs, k = proposals.proposal_token_ids.shape + + # Reshape tensors to original batch size + target_token_ids = target_sampler_output.sampled_token_ids.reshape( + contracted_bs, k + 1) + target_probs = target_sampler_output.sampled_token_probs.reshape( + *target_token_ids.shape, self._vocab_size) + target_logprobs = target_sampler_output.logprobs.reshape( + target_probs.shape) + target_hidden_states = target_sampler_output.hidden_states + if target_hidden_states is not None: + target_hidden_states = target_hidden_states.reshape( + *target_token_ids.shape, target_hidden_states.shape[-1]) + + return SpeculativeScores(probs=target_probs, + token_ids=target_token_ids, + logprobs=target_logprobs, + hidden_states=target_hidden_states, + prompt_logprobs=None) + + def _create_scoring_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k] + target_seq_ids_iter: Iterator[TargetSeqId], + ) -> List[SequenceGroupMetadata]: + """Given the original input sequences and proposed tokens from the draft + model, create a list of target sequences that can be used for scoring. + + target_seq_ids_iter provides sequence ids for the expanded batch, + fulfilling the requirement that no seq id in the expanded batch is equal + to the seq id in the original batch. + """ + + if not seq_group_metadata_list: + return [] + + target_seq_group_metadata = list( + chain.from_iterable( + self._create_target_seq_group_metadata( + seq_group_metadata, + proposal_token_ids, + i, + target_seq_ids_iter, + ) for i, seq_group_metadata in enumerate( + seq_group_metadata_list))) + + return target_seq_group_metadata + + def _create_target_seq_group_metadata( + self, + input_seq_group_metadata: SequenceGroupMetadata, + proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k] + batch_index: int, + target_seq_ids_iter: Iterator[TargetSeqId], + ) -> List[SequenceGroupMetadata]: + """Given an input sequence group metadata and a list of draft tokens, + create a list of target SequenceGroupMetadata, one for each + token id that needs to be scored. + + Naive speculative decoding requires K target model scores, one for each + draft model token. However one can add a bonus token such that if each + token is accepted, then a final token may be sampled from the model. + This function creates K+1 target SequenceGroupMetadata to take + advantage of the bonus token. + """ + assert len(input_seq_group_metadata.seq_data) == 1, ( + "Beam search " + "not supported in speculative decoding") + input_seq_id = next(iter(input_seq_group_metadata.seq_data.keys())) + + token_ids_to_score = self._get_token_ids_to_score( + proposal_token_ids[batch_index]) + + sampling_params = input_seq_group_metadata.sampling_params + target_seq_group_metadata_list: List[SequenceGroupMetadata] = [] + for i, token_ids in enumerate(token_ids_to_score): + target_seq_group_metadata_list.append( + self._create_single_target_seq_group_metadata( + input_seq_group_metadata, + input_seq_id, + next(target_seq_ids_iter), + token_ids, + sampling_params=sampling_params, + )) + + return target_seq_group_metadata_list + + @staticmethod + def _create_single_target_seq_group_metadata( + seq_group_metadata: SequenceGroupMetadata, + seq_id: SeqId, + target_seq_id: TargetSeqId, + token_ids: List[TokenId], + sampling_params: SamplingParams, + ) -> SequenceGroupMetadata: + """Create a single target SequenceGroupMetadata. + + Args: + seq_group_metadata: The metadata for the input sequence. + seq_id: The input sequence ID. + target_seq_id: The corresponding target sequence ID. + token_ids: The list of token ids that are to be appended to the + input sequence. + """ + seq_data = seq_group_metadata.seq_data[seq_id] + prompt_token_ids = seq_data.prompt_token_ids_array + new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids] + mrope_position_delta = seq_data.mrope_position_delta + + new_seq_data_dict = { + target_seq_id: + SequenceData( + prompt_token_ids, + _output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE, + new_output_token_ids), + ), + } + # This is a hack. Technically, spec decoding should compute + # num_lookahead slots at one shot, but instead, it expands the batch + # and evaluate one by one right now. context_len is seq_len - 1 because + # the kv cache is filled by a previous batch in the batch expansion. + for data in new_seq_data_dict.values(): + data.update_num_computed_tokens(data.get_len() - 1) + data.mrope_position_delta = mrope_position_delta + + return SequenceGroupMetadata( + request_id=seq_group_metadata.request_id, + is_prompt=seq_group_metadata.is_prompt, + seq_data=new_seq_data_dict, + sampling_params=sampling_params, + block_tables={ + target_seq_id: seq_group_metadata.block_tables[seq_id], + }, + lora_request=None, + token_chunk_size=1, + ) + + @staticmethod + def _split_scoring_output( + sampler_output: SamplerOutput, num_scoring_tokens: int + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor], torch.Tensor, torch.Tensor, + torch.Tensor, Optional[torch.Tensor]]: + """Split the target model output into speculative and non-speculative + output. + """ + + # vLLM currently only supports proposal lens equal to zero or the batch + # proposal len. This adds some complexity (splitting the batch into spec + # and non spec sequences) and should be removed in the future. It can be + # done by supporting per-sequence proposal lens. + # + # First samples are non-speculative, latter samples are from speculative + # scoring (prefill|decode order). + split_sizes = (sampler_output.sampled_token_ids.numel() - + num_scoring_tokens, num_scoring_tokens) + (non_spec_probs, + spec_probs) = sampler_output.sampled_token_probs.split(split_sizes) + (non_spec_sampled_tokens, spec_sampled_tokens + ) = sampler_output.sampled_token_ids.flatten().split(split_sizes) + (non_spec_logprobs, + spec_logprobs) = sampler_output.logprobs.split(split_sizes) + + if sampler_output.hidden_states is not None: + (non_spec_hidden_states, spec_hidden_states + ) = sampler_output.hidden_states.split(split_sizes) + else: + non_spec_hidden_states, spec_hidden_states = None, None + + return (spec_sampled_tokens, spec_probs, spec_logprobs, + spec_hidden_states, non_spec_sampled_tokens, non_spec_probs, + non_spec_logprobs, non_spec_hidden_states) + + @staticmethod + def _create_target_seq_id_iterator( + seq_ids: List[SeqId]) -> Iterator[TargetSeqId]: + """Create an iterator for creating target sequence ids. + Target sequence ids are distinct from sequence ids because we create a + distinct target sequence id for each proposal token to be scored. + + This implementation increments a counter starting at 1 + max of all + provided input sequence ids. + """ + return count(start=max(seq_ids) + 1) + + @staticmethod + def _get_token_ids_to_score( + full_spec_token_ids: List[TokenId] # shape: [k] + ) -> List[List[TokenId]]: + """Given an int tensor of proposal token ids, return a list of + token ids that should be scored. + + Returns k+1 output lists. The additional one is used for generating the + bonus token. + + Example: + Input: [0, 1, 2, 3] (k=4) + Output: (k+1 lists) + [] + [0] + [0, 1] + [0, 1, 2] + [0, 1, 2, 3] + """ + empty_token_ids: List[TokenId] = [] + + token_ids_to_score = [empty_token_ids] + token_ids_to_score.extend(full_spec_token_ids[:i + 1] + for i in range(len(full_spec_token_ids))) + return token_ids_to_score diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py new file mode 100644 index 0000000..fa69278 --- /dev/null +++ b/vllm/spec_decode/draft_model_runner.py @@ -0,0 +1,335 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional + +import torch + +from vllm.forward_context import set_forward_context +from vllm.model_executor.layers.sampler import SamplerOutput + +try: + try: + from vllm.attention.backends.flash_attn import FlashAttentionMetadata + except (ModuleNotFoundError, ImportError): + # vllm_flash_attn is not installed, try the ROCm FA metadata + from vllm.attention.backends.rocm_flash_attn import ( + ROCmFlashAttentionMetadata as FlashAttentionMetadata) +except (ModuleNotFoundError, ImportError) as err: + raise RuntimeError( + "Draft model speculative decoding currently only supports " + "CUDA and ROCm flash attention backend.") from err + +from vllm.logger import init_logger +from vllm.multimodal import MultiModalKwargs +from vllm.sequence import ExecuteModelRequest, IntermediateTensors +from vllm.worker.model_runner_base import (ModelRunnerBase, + ModelRunnerInputBase, + ModelRunnerWrapperBase) + +logger = init_logger(__name__) + +# A flag to enable debug prints for the updated input tensors +# before each step. +debug_advance_input = False +# A flag to allow GPU advance step for draft model runner. +# Set to False for debugging. +allow_gpu_advance_step = True + + +class TP1PP1DraftModelRunner(ModelRunnerWrapperBase): + """Specialized model runner for speculative decoding draft model. + Since the draft model always execute k forward passes consecutively to + generate k speculative tokens in a single speculative decoding step, + we could get rid of most CPU-GPU synchronization and data transfer + overheads by keeping model input and output tensors on GPU all the time. + + TODOs: + 1. Currently supports only flash-attn, add support for other attn_backends. + 2. Support TP > 1 (this requires some designs because we do not expect + any broadcasting inside execute_model). + """ + + def __init__(self, model_runner: ModelRunnerBase): + super().__init__(model_runner) + + self.indices_of_seq_with_bonus_tokens = None + + def _update_sampling_metadata(self, sampling_metadata, num_seqs, + num_queries): + + assert sampling_metadata.num_prompts == 0 + assert len(sampling_metadata.seq_groups) == num_queries + assert sampling_metadata.selected_token_indices.shape == ( + num_queries, ) + # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501 + + # Verify that all sequences are decodes + for i in range(num_queries): + seq_group = sampling_metadata.seq_groups[i] + + assert seq_group.is_prompt is False # No prompt + assert seq_group.prompt_logprob_indices == [] # No prompt + assert seq_group.sample_indices == [i] # Simple + + def _gpu_advance_step(self, model_input: ModelRunnerInputBase, + last_output: SamplerOutput) -> ModelRunnerInputBase: + # Currently, we expect "decode mode" only + assert not model_input.is_prompt + + # Get num_seqs + num_seqs = len(model_input.seq_lens) + num_queries = len(model_input.query_lens) + + # Get output tokens GPU tensor + sampled_token_ids = last_output.sampled_token_ids + assert sampled_token_ids is not None + + # Update attn_metadata + attn_metadata = model_input.attn_metadata + assert isinstance(attn_metadata, FlashAttentionMetadata) + + attn_metadata.advance_step(model_input, sampled_token_ids, + self.block_size, num_seqs, num_queries) + + # Update sampling_metadata + sampling_metadata = model_input.sampling_metadata + self._update_sampling_metadata(sampling_metadata, num_seqs, + num_queries) + + # Create new input + new_model_input = self._model_input_cls( + input_tokens=model_input.input_tokens, + input_positions=model_input.input_positions, + attn_metadata=attn_metadata, + seq_lens=attn_metadata.seq_lens, + query_lens=model_input.query_lens, + lora_mapping=model_input.lora_mapping, + lora_requests=model_input.lora_requests, + multi_modal_kwargs=model_input.multi_modal_kwargs, + sampling_metadata=model_input.sampling_metadata, + is_prompt=False, + ) + + # Ensure we skip CPU samples + assert new_model_input.sampling_metadata.skip_sampler_cpu_output is True + # We can reuse sampling tensors since every decode iteration is the same + new_model_input.sampling_metadata.reuse_sampling_tensors = True + + if debug_advance_input: + logger.debug("NEW INPUT: ") + logger.debug(" input_tokens = %s", new_model_input.input_tokens) + logger.debug(" input_positions = %s", + new_model_input.input_positions) + logger.debug(" seq_lens = %d", new_model_input.seq_lens) + logger.debug(" query_lens = %d", new_model_input.query_lens) + logger.debug(" attn_metadata:") + logger.debug(" seq_lens_tensor: %s", + attn_metadata.seq_lens_tensor) + logger.debug(" slot_mapping: %s", attn_metadata.slot_mapping) + logger.debug(" block_tables: %s", attn_metadata.block_tables) + + return new_model_input + + def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest): + """Determines if draft_model_runner GPU multi-step can be used. + Currently required conditions are: + 1. Only decodes + 2. Only flash-attn + 3. No LORA + 4. No prompt_adapter_config + """ + if not allow_gpu_advance_step: + return False + + # We allow multi-step GPU only in decode mode + for seq_group in execute_model_req.seq_group_metadata_list: + if seq_group.is_prompt: + return False + + # TODO: Add support for other attn backends + if self.attn_backend.get_name() not in ("FLASH_ATTN", ): + return False + + # TODO: Add support for LORA + if self.lora_config: + return False + + # TODO: Add soft-tuning prompt adapter support + return not self.prompt_adapter_config + + def set_indices_of_seq_with_bonus_tokens(self, + indices_of_seq_with_bonus_tokens): + self.indices_of_seq_with_bonus_tokens = indices_of_seq_with_bonus_tokens + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelRunnerInputBase, + kv_caches: List[torch.Tensor], + previous_hidden_states: Optional[torch.Tensor] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + **kwargs, + ) -> Optional[List[SamplerOutput]]: + """Executes num_steps forward passes with advacement of input tensors + on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions. + + Optimizations used: + 1. Input tensors are updated on the GPU directly + 2. Skips GPU=>CPU serialization of sampler outputs (we don't need + them since we do batch expansion later that uses GPU outputs) + 3. Reuses sampling tensors (since we run only decodes and they have + a repeating sampling logic) + """ + + # When num_steps == 1, we execute the fallback here for the GPU + # advance_step, which runs prepare_inputs on CPU and for each spec + # iteration invokes this function only once + # (Look at multi-step-worker code) + is_fallback = num_steps == 1 + if not is_fallback: + # Since we do not broadcast data inside execute_model anymore, + # we need to figure out the best way to support TP > 1 in this + # case, because we will at least need to broadcast the sampled + # tokens to all workers. + if not self.is_driver_worker: + raise ValueError("TP1DraftModelRunner only supports TP=1.") + + # Sanity + if self.lora_config is not None: + raise ValueError("TP1DraftModelRunner has no support for LORA") + if self.prompt_adapter_config is not None: + raise ValueError("TP1DraftModelRunner has no support for " + "prompt_adapter_config") + if model_input.multi_modal_kwargs: + raise ValueError( + "TP1DraftModelRunner has no support for multi_modal_kwargs" + ) + else: + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) + + if self.prompt_adapter_config: + assert model_input.prompt_adapter_requests is not None + assert model_input.prompt_adapter_mapping is not None + self.set_active_prompt_adapters( + model_input.prompt_adapter_requests, + model_input.prompt_adapter_mapping) + + self.attn_state.begin_forward(model_input) + + # Detect exec mode + assert model_input.attn_metadata is not None + use_cuda_graph = False + if model_input.attn_metadata.num_prefills > 0: + # In this case, execute_model(..) was called directly + if num_steps > 1: + raise ValueError( + "execute_model(..) of draft_model_runner can be called " + "directly only with a single-step prefill") + else: + # We can skip CPU samples for spec token generation. + # (We do allow CPU samples for num_steps == 1 to support the + # fallback case, where supports_gpu_multi_step(..) does not pass) + model_input.sampling_metadata.skip_sampler_cpu_output = ( + not is_fallback) + + # Attn attr defines if we use cuda graphs + use_cuda_graph = model_input.attn_metadata.use_cuda_graph + + # Get model + if use_cuda_graph: + graph_batch_size = model_input.input_tokens.shape[0] + model_executable = (self.graph_runners[model_input.virtual_engine] + [graph_batch_size]) + + if previous_hidden_states is not None: + hidden_states = torch.cat([ + previous_hidden_states, + torch.empty([ + graph_batch_size - previous_hidden_states.shape[0], + *previous_hidden_states.shape[1:] + ], + dtype=previous_hidden_states.dtype, + device=previous_hidden_states.device) + ]) + else: + hidden_states = None + else: + model_executable = self.model + hidden_states = previous_hidden_states + + outputs: List[SamplerOutput] = [] + for step in range(num_steps): + multi_modal_kwargs = model_input.multi_modal_kwargs or {} + + model_execute_kwargs = {"previous_hidden_states": hidden_states} \ + if previous_hidden_states is not None else {} + + compute_logits_kwargs = {} + # Run model + if hasattr(self.model.config, "num_nextn_predict_layers"): + # for DeepSeek MTP only to use the corresponding layer for + # each step + spec_step_idx = kwargs.get("spec_step_idx", step) + model_execute_kwargs["spec_step_idx"] = spec_step_idx + compute_logits_kwargs["spec_step_idx"] = spec_step_idx + with set_forward_context(model_input.attn_metadata, + self.vllm_config): + hidden_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + intermediate_tensors=intermediate_tensors, + **MultiModalKwargs.as_kwargs(multi_modal_kwargs, + device=self.device), + **model_execute_kwargs, + ) + + # Compute the logits. + logits = self.model.compute_logits(hidden_states, + model_input.sampling_metadata, + **compute_logits_kwargs) + if not self.is_driver_worker: + return [] + # Sample the next token. + output = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + outputs.append(output) + + if self.return_hidden_states and is_fallback: + if use_cuda_graph: + indices = model_input.sampling_metadata\ + .selected_token_indices + output.hidden_states = hidden_states[:len(indices)] + else: + output.hidden_states = hidden_states + + if model_input.attn_metadata.num_prefills == 0 \ + and self.indices_of_seq_with_bonus_tokens is not None: + assert output.sampled_token_ids is not None + # output.sampled_token_ids should be of shape (num_seqs, 1) + nums_seqs, num_tokens_per_seq = output.sampled_token_ids.shape + assert num_tokens_per_seq == 1 + count = 0 + for i in range(nums_seqs): + bonus_seq_idx = self.indices_of_seq_with_bonus_tokens[ + count] + if i != bonus_seq_idx: + # The following might cause a cpu->gpu sync + # However, the performance impact is negligible as we + # benchmarked on H100. + output.sampled_token_ids[ + i, :] = model_input.input_tokens[bonus_seq_idx] + else: + count += 1 + + # Prepare inputs for the next step + if step != num_steps - 1: + model_input = self._gpu_advance_step(model_input, outputs[-1]) + + return outputs diff --git a/vllm/spec_decode/interfaces.py b/vllm/spec_decode/interfaces.py new file mode 100644 index 0000000..dd085ad --- /dev/null +++ b/vllm/spec_decode/interfaces.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import List, Optional, Set, Union + +import torch + +from vllm.sequence import ExecuteModelRequest, PromptLogprobs +from vllm.worker.worker_base import WorkerBase + + +@dataclass +class SpeculativeProposals: + """Datastructure used to represent proposal tokens from some proposer. It + also tracks how many speculative tokens each sequence has. + """ + + # Speculative proposal tokens. + proposal_token_ids: torch.Tensor + + # Probabilities of the proposal tokens according to the proposer. + proposal_probs: torch.Tensor + + # The valid length of each proposal; can be zero. + proposal_lens: torch.Tensor + + # A flag to mark that there's no available proposals + no_proposals: bool = False + + def __repr__(self): + return (f"SpeculativeProposals(" + f"proposal_token_ids={self.proposal_token_ids}, " + f"proposal_probs={self.proposal_probs.shape}, " + f"proposal_lens={self.proposal_lens})") + + +@dataclass +class SpeculativeScores: + """Datastructure used to represent the scores of speculative tokens + according to the scoring model. + """ + + # Probabilities of the speculative tokens according to the scoring model. + probs: torch.Tensor + + # Log-probabilities of the speculative tokens according to the scoring + # model. These values can be used to generate Logprob objects that are + # returned to the user. + logprobs: torch.Tensor + + # Token ids sampled from the scoring model. Used for speculative bonus + # tokens and also non-speculative normal decoding. + token_ids: torch.Tensor + + # Optional last hidden states from the scoring model. + hidden_states: Optional[torch.Tensor] = None + + # Scoring model may also return logprobs for prompt tokens + # for each request, when chunked prefill is enabled. + prompt_logprobs: Optional[List[PromptLogprobs]] = None + + def __repr__(self): + return (f"SpeculativeScores(" + f"probs={self.probs.shape}, " + f"token_ids={self.token_ids.shape})") + + +class SpeculativeProposer(ABC): + + @abstractmethod + def get_spec_proposals( + self, + execute_model_req: ExecuteModelRequest, + # If set, this contains all sequence IDs that were assigned + # bonus tokens in their last forward pass. + seq_ids_with_bonus_token_in_last_step: Set[int], + ) -> SpeculativeProposals: + raise NotImplementedError + + +class SpeculativeScorer(ABC): + + def __init__(self, scorer_worker: WorkerBase, + device: Union[torch.device, str], vocab_size: int): + self._scorer_worker = scorer_worker + if isinstance(device, torch.device): + device = device.type + self._device = device + self._vocab_size = vocab_size + + @abstractmethod + def score_proposals( + self, + execute_model_req: ExecuteModelRequest, + proposals: SpeculativeProposals, + ) -> SpeculativeScores: + raise NotImplementedError diff --git a/vllm/spec_decode/medusa_worker.py b/vllm/spec_decode/medusa_worker.py new file mode 100644 index 0000000..0b62a98 --- /dev/null +++ b/vllm/spec_decode/medusa_worker.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 + +import weakref +from typing import List, Optional, Set, Tuple + +import torch + +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata +from vllm.spec_decode.interfaces import SpeculativeProposals +from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase +from vllm.spec_decode.top1_proposer import Top1Proposer +from vllm.worker.worker_base import DelegateWorkerBase + + +class MedusaWorker(NonLLMProposerWorkerBase, DelegateWorkerBase): + """Worker for Medusa. + """ + + def __init__(self, *args, **kwargs): + DelegateWorkerBase.__init__(self, *args, **kwargs) + # Lazy initialization list. + self._proposer: Top1Proposer + + def init_device(self): + self.worker.init_device() + + self._proposer = Top1Proposer( + weakref.proxy(self), # type: ignore[arg-type] + self.device, + self.vocab_size, + max_proposal_len=self.max_model_len, + ) + + def set_include_gpu_probs_tensor(self): + pass + + def set_should_modify_greedy_probs_inplace(self): + pass + + @torch.inference_mode() + def sampler_output( + self, + execute_model_req: ExecuteModelRequest, + sample_len: int, + # Unused parameter. + seq_ids_with_bonus_token_in_last_step: Set[int], + ) -> Tuple[List[SamplerOutput], bool]: + """Run the model forward pass to generate sample_len future tokens. + Returns the list of sampler output, one per layer, along with indicator + of whether torch tensor in sampler output need to be transposed in + latter sampler_output_to_torch logic. + + For medusa worker, this indicator shall be False. + """ + self._raise_if_unsupported(execute_model_req) + + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + + seq_lens, query_lens = self._prepare_input_tensors( + seq_group_metadata_list) + + generators = self.model_runner.get_generators( + execute_model_req.finished_requests_ids) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, seq_lens, query_lens, self.device, + self.model_runner.pin_memory, generators) + + model_outputs = self.model_runner.model.generate_proposals( + previous_hidden_states=execute_model_req.previous_hidden_states. + hidden_states, + sampling_metadata=sampling_metadata) + + return model_outputs, False + + def _prepare_input_tensors( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + ) -> Tuple[List[int], List[int]]: + if not seq_group_metadata_list: + return [], [] + + seq_lens: List[int] = [] + query_lens: List[int] = [] + + for seq_group_metadata in seq_group_metadata_list: + is_prompt = seq_group_metadata.is_prompt + + for seq_data in seq_group_metadata.seq_data.values(): + seq_data_len = seq_data.get_len() + if is_prompt: + context_len = seq_data.get_num_computed_tokens() + seq_len = min( + seq_data_len, + context_len + seq_group_metadata.token_chunk_size) + seq_lens.append(seq_len) + query_lens.append(seq_len - context_len) + else: + seq_lens.append(seq_data_len) + query_lens.append(1) + + return seq_lens, query_lens + + def get_spec_proposals( + self, + execute_model_req: ExecuteModelRequest, + seq_ids_with_bonus_token_in_last_step: Set[int], + ) -> SpeculativeProposals: + """Produce speculations given an input batch of sequences. The number of + speculative tokens per sequence is determined by max_proposal_len. + """ + + return self._proposer.get_spec_proposals( + execute_model_req, seq_ids_with_bonus_token_in_last_step) + + def _raise_if_unsupported( + self, + execute_model_req: ExecuteModelRequest, + ) -> None: + """MedusaWorker does not yet implement support for cache swap + operations or beam search. + """ + if any([ + execute_model_req.blocks_to_swap_in, + execute_model_req.blocks_to_swap_out, + execute_model_req.blocks_to_copy + ]): + raise NotImplementedError( + "MedusaWorker does not support cache operations") + + if any( + len(seq_group_metadata.seq_data.keys()) != 1 + for seq_group_metadata in + execute_model_req.seq_group_metadata_list): + raise NotImplementedError( + "MedusaWorker does not support beam search.") diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py new file mode 100644 index 0000000..9479282 --- /dev/null +++ b/vllm/spec_decode/metrics.py @@ -0,0 +1,211 @@ +# SPDX-License-Identifier: Apache-2.0 + +import time +from typing import Callable, Optional, Union + +import msgspec +import torch + +from vllm.model_executor.layers.spec_decode_base_sampler import ( + SpecDecodeBaseSampler) +from vllm.utils import is_pin_memory_available + + +class SpecDecodeWorkerMetrics( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] + """Dataclass holding metrics emitted from the spec decode worker. + """ + + # The empirical acceptance rate of the proposal method on a per-token basis. + # This is useful for evaluating how well the proposal method aligns with the + # scoring method. + draft_acceptance_rate: float + + # The empirical efficiency, measured as the number of tokens emitted by the + # system divided by the number of tokens that could be emitted by the system + # if the proposal method were perfect. + system_efficiency: float + + # The number of speculative tokens produced by the proposal method. + draft_tokens: int + + # The number of tokens emitted by the entire system. + emitted_tokens: int + + # The number of tokens accepted by the scoring model and verification + # routine, e.g. Llama2-70B and lossless rejection sampling. + # + # NOTE: Any token accepted by the verification routine is considered + # accepted (regardless of if the speculative prefix is also accepted). The + # user will usually see less accepted tokens. This metric is helpful when + # evaluating alignment of the proposal method with the scoring model. + accepted_tokens: int + + # The number of speculative tokens per sequence. + num_spec_tokens: int + + +Timer = Callable[[], float] + + +class AsyncMetricsCollector: + """Class which copies rejection/typical-acceptance sampler metrics + from the device to CPU on a non-default Torch stream. + """ + + def __init__(self, + spec_decode_sampler: SpecDecodeBaseSampler, + timer: Optional[Timer] = None, + collect_interval_s: float = 5.0): + self.spec_decode_sampler = spec_decode_sampler + self._timer = time.time if timer is None else timer + + self._rank: Optional[int] = None + + # We don't have a device set yet. + self._copy_stream: Optional[torch.cuda.Stream] = None + + self._in_flight_copy: Optional[torch.cuda.Event] = None + + pin_memory = False + self._aggregate_num_accepted_tokens = torch.tensor( + 0, dtype=torch.long, device="cpu", pin_memory=pin_memory) + self._aggregate_num_emitted_tokens = torch.tensor( + 0, dtype=torch.long, device="cpu", pin_memory=pin_memory) + self._aggregate_num_draft_tokens = 0 + + self._rejsample_metrics_collect_interval_s = collect_interval_s + self._last_metrics_collect_time = self._timer() + + def init_gpu_tensors(self, rank: int) -> None: + self._rank = rank + self._copy_stream = torch.cuda.Stream() + + def init_tensors(self, + rank: int, + device_type: Union[torch.device, str] = 'cuda') -> None: + self._rank = rank + if isinstance(device_type, torch.device): + device_type = device_type.type + if device_type == 'cuda': + self._copy_stream = torch.cuda.Stream() + + def maybe_collect_rejsample_metrics( + self, k: int) -> Optional[SpecDecodeWorkerMetrics]: + # currently using cuda.Event, skip for any non_cuda_alike platform + from vllm.platforms import current_platform + if not current_platform.is_cuda_alike(): + return None + + # If a copy was initiated in the previous call, collect and return. + if self._in_flight_copy is not None: + ready_event = self._in_flight_copy + self._in_flight_copy = None + return self._collect_rejsample_metrics(k, ready_event) + + # Otherwise, check if we should start a new copy. + if self._should_collect_rejsample_metrics(self._timer()): + assert self._in_flight_copy is None + self._in_flight_copy = self._copy_rejsample_metrics_async() + + return None + + def _should_collect_rejsample_metrics(self, now: float) -> bool: + """Return whether or not this iteration should print sampling + metrics. + """ + if self._rank != 0: + return False + + return now - self._last_metrics_collect_time >= self._rejsample_metrics_collect_interval_s # noqa: E501 + + def _copy_rejsample_metrics_async(self) -> torch.cuda.Event: + """Copy rejection/typical-acceptance sampling metrics + (number of accepted tokens, etc) to CPU asynchronously. + + Returns a CUDA event recording when the copy is complete. + """ + assert self._copy_stream is not None + self._copy_stream.wait_stream(torch.cuda.current_stream()) + + with torch.cuda.stream(self._copy_stream): + self._aggregate_num_accepted_tokens.copy_( + self.spec_decode_sampler.num_accepted_tokens, + non_blocking=True) + self._aggregate_num_emitted_tokens.copy_( + self.spec_decode_sampler.num_emitted_tokens, non_blocking=True) + # Number of draft tokens is calculated on CPU, so no copy is + # required. + self._aggregate_num_draft_tokens = ( + self.spec_decode_sampler.num_draft_tokens) + + aggregate_metrics_ready = torch.cuda.Event() + aggregate_metrics_ready.record(self._copy_stream) + + return aggregate_metrics_ready + + def _collect_rejsample_metrics( + self, k: int, + ready_event: torch.cuda.Event) -> SpecDecodeWorkerMetrics: + """Create metrics object from statistics copied asynchronously. + + Args: + k: int. The number of speculative tokens; used to determine system + efficiency. + ready_event: torch.cuda.Event. The CUDA event recording when the + async GPU->CPU copy is complete. + """ + + ready_event.synchronize() + + # update time of last collection + self._last_metrics_collect_time = self._timer() + + accepted_tokens = self._aggregate_num_accepted_tokens.item() + emitted_tokens = self._aggregate_num_emitted_tokens.item() + draft_tokens = self._aggregate_num_draft_tokens + + max_num_emitted_tokens = self.get_max_num_emitted_tokens( + draft_tokens, k) + + if draft_tokens > 0: + draft_acceptance_rate = accepted_tokens / draft_tokens + else: + draft_acceptance_rate = float("nan") + + if max_num_emitted_tokens > 0: + system_efficiency = emitted_tokens / max_num_emitted_tokens + else: + system_efficiency = float("nan") + + return SpecDecodeWorkerMetrics( + num_spec_tokens=k, + draft_acceptance_rate=draft_acceptance_rate, + system_efficiency=system_efficiency, + accepted_tokens=accepted_tokens, + draft_tokens=draft_tokens, + emitted_tokens=emitted_tokens, + ) + + @staticmethod + def get_max_num_emitted_tokens(draft_tokens: int, k: int) -> int: + """Calculate the number of emitted tokens, assuming all tokens are + accepted. + + This is equal to the number of sequences that have been speculated on, + times (speculation len + 1). The +1 comes from the bonus token. + """ + # Determine the number of sequences that have been speculated on. Since + # the batch size can be variable, we divide by k. + assert draft_tokens % k == 0 + total_num_spec_seqs = draft_tokens // k + + # A single sequence may emit k accepted tokens and one bonus token in + # the best case. + num_emitted_per_seq_if_all_accepted = k + 1 + + # The max num of emitted tokens is the number of speculated sequences + # times the max emitted per seq. + return total_num_spec_seqs * num_emitted_per_seq_if_all_accepted diff --git a/vllm/spec_decode/mlp_speculator_worker.py b/vllm/spec_decode/mlp_speculator_worker.py new file mode 100644 index 0000000..bdaf318 --- /dev/null +++ b/vllm/spec_decode/mlp_speculator_worker.py @@ -0,0 +1,93 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional, Set, Tuple + +import torch + +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata +from vllm.spec_decode.multi_step_worker import MultiStepWorker +from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase + + +class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker): + """Worker for MLPSpeculator models. + + Not currently compatible with LoRA or chunked prefill. + """ + + @torch.inference_mode() + def sampler_output( + self, + execute_model_req: ExecuteModelRequest, + sample_len: int, + # Unused parameter. MLPSpeculatorWorker does not use the KV Cache and + # therefore does not need this parameter. + seq_ids_with_bonus_token_in_last_step: Set[int], + ) -> Tuple[List[SamplerOutput], bool]: + """Run the model forward pass to generate sample_len future tokens. + Returns the list of sampler output, one per layer, along with indicator + of whether torch tensor in sampler output need to be transposed in + latter sampler_output_to_torch logic. + + For mlp spec worker, this indicator shall be True. + """ + self._raise_if_unsupported(execute_model_req) + + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + + (input_tokens, seq_lens, + query_lens) = self._prepare_input_tensors(seq_group_metadata_list) + + generators = self.model_runner.get_generators( + execute_model_req.finished_requests_ids) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, seq_lens, query_lens, self.device, + self.model_runner.pin_memory, generators) + + model_outputs = self.model_runner.model.generate_proposals( + input_ids=input_tokens, + previous_hidden_states=execute_model_req.previous_hidden_states. + hidden_states, + num_predict_tokens=sample_len, + sampling_metadata=sampling_metadata) + + assert len(model_outputs) == sample_len + + return model_outputs, True + + def _prepare_input_tensors( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + ) -> Tuple[torch.Tensor, List[int], List[int]]: + if not seq_group_metadata_list: + return torch.empty(0, device=self.device), [], [] + + input_tokens: List[int] = [] + seq_lens: List[int] = [] + query_lens: List[int] = [] + + for seq_group_metadata in seq_group_metadata_list: + is_prompt = seq_group_metadata.is_prompt + + for seq_data in seq_group_metadata.seq_data.values(): + seq_data_len = seq_data.get_len() + if is_prompt: + context_len = seq_data.get_num_computed_tokens() + seq_len = min( + seq_data_len, + context_len + seq_group_metadata.token_chunk_size) + tokens = seq_data.get_token_ids()[context_len:seq_len] + seq_lens.append(seq_len) + input_tokens.extend(tokens) + query_lens.append(seq_len - context_len) + else: + seq_lens.append(seq_data_len) + input_tokens.append(seq_data.get_last_token_id()) + query_lens.append(1) + + input_tokens_tensor = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + return input_tokens_tensor, seq_lens, query_lens diff --git a/vllm/spec_decode/mqa_scorer.py b/vllm/spec_decode/mqa_scorer.py new file mode 100644 index 0000000..40018f7 --- /dev/null +++ b/vllm/spec_decode/mqa_scorer.py @@ -0,0 +1,193 @@ +# SPDX-License-Identifier: Apache-2.0 +from vllm.distributed import get_pp_group +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import (ExecuteModelRequest, SequenceData, + SequenceGroupMetadata, get_all_seq_ids) +from vllm.spec_decode.interfaces import (SpeculativeProposals, + SpeculativeScorer, SpeculativeScores) + +SeqId = int +TargetSeqId = int + + +class MQAScorer(SpeculativeScorer): + + def score_proposals( + self, + execute_model_req: ExecuteModelRequest, + proposals: SpeculativeProposals, + ) -> SpeculativeScores: + target_seq_group_metadata_list = [] + target_seq_id_start = max( + get_all_seq_ids(execute_model_req.seq_group_metadata_list)) + 1 + all_proposal_tokens = proposals.proposal_token_ids.tolist() + all_proposal_lengths = proposals.proposal_lens.tolist() + for i, seq_group_metadata in enumerate( + execute_model_req.seq_group_metadata_list): + if all_proposal_lengths[i] == 0: + # Keep prompt seqs untouched (keep computed_tokens for chunks). + target_seq_group_metadata_list.append(seq_group_metadata) + continue + + seq_data_dict = seq_group_metadata.seq_data + assert len(seq_data_dict) == 1 + seq_id = next(iter(seq_data_dict.keys())) + + seq_data: SequenceData = seq_data_dict[seq_id] + prompt_token_ids = seq_data.get_prompt_token_ids() + output_token_ids = seq_data.get_output_token_ids() + proposal_token_ids = all_proposal_tokens[ + i][:all_proposal_lengths[i]] + new_output_token_ids = [*output_token_ids, *proposal_token_ids] + + target_seq_id = target_seq_id_start + i + new_seq_data = SequenceData.from_seqs( + prompt_token_ids=prompt_token_ids, + output_token_ids=new_output_token_ids, + ) + new_seq_data.update_num_computed_tokens( + len(prompt_token_ids) + len(output_token_ids) - 1) + + # Ensure that the new decode sequence has at least one token. + assert len(output_token_ids) >= 1 + new_seq_data_dict = {target_seq_id: new_seq_data} + + new_seq_group_metadata = SequenceGroupMetadata( + request_id=seq_group_metadata.request_id, + is_prompt=seq_group_metadata.is_prompt, + seq_data=new_seq_data_dict, + sampling_params=seq_group_metadata.sampling_params, + block_tables={ + target_seq_id: seq_group_metadata.block_tables[seq_id], + }, + lora_request=None, + ) + target_seq_group_metadata_list.append(new_seq_group_metadata) + + target_sampler_output = self._scorer_worker.execute_model( + execute_model_req=execute_model_req.clone( + seq_group_metadata_list=target_seq_group_metadata_list)) + + if get_pp_group().is_last_rank: + assert len( + target_sampler_output) == 1, "expected single-step output" + target_sampler_output = target_sampler_output[0] + # Store hidden states from target model execution, BxD. + sampled_token_probs = target_sampler_output.sampled_token_probs + logprobs = target_sampler_output.logprobs + sampled_token_ids = target_sampler_output.sampled_token_ids + hidden_states = target_sampler_output.hidden_states + prefill_hidden_states = target_sampler_output.prefill_hidden_states + tensors = { + "sampled_token_probs": sampled_token_probs, + "logprobs": logprobs, + "sampled_token_ids": sampled_token_ids, + "hidden_states": hidden_states, + "prefill_hidden_states": prefill_hidden_states + } + get_pp_group().broadcast_tensor_dict( + tensors, src=get_pp_group().world_size - 1) + else: + tensors = get_pp_group().broadcast_tensor_dict( + src=get_pp_group().world_size - 1) + sampled_token_probs = tensors["sampled_token_probs"] + logprobs = tensors["logprobs"] + sampled_token_ids = tensors["sampled_token_ids"] + hidden_states = tensors["hidden_states"] + prefill_hidden_states = tensors["prefill_hidden_states"] + target_sampler_output = SamplerOutput( + outputs=None, + sampled_token_probs=sampled_token_probs, + logprobs=logprobs, + sampled_token_ids=sampled_token_ids, + hidden_states=hidden_states, + prefill_hidden_states=prefill_hidden_states) + + k = execute_model_req.num_lookahead_slots + bs = len(execute_model_req.seq_group_metadata_list) + target_token_ids = target_sampler_output.sampled_token_ids + target_probs = target_sampler_output.sampled_token_probs + target_logprobs = target_sampler_output.logprobs + prompt_logprobs = None + + # If all requests have the same number of query tokens, we can avoid + # the for loop to build output for better performance. + if min(all_proposal_lengths) == k: + # Regular decodes only. + assert all(not sg.is_prompt + for sg in target_seq_group_metadata_list + if sg.is_prompt) + bs, _ = proposals.proposal_token_ids.shape + all_tokens = target_token_ids.reshape(bs, k + 1) + all_probs = target_probs.reshape(bs, k + 1, self._vocab_size) + all_logprobs = target_logprobs.reshape(bs, k + 1, self._vocab_size) + else: + # We either have decodes with different lens or prefill+decodes. + all_tokens = target_token_ids.new_full(size=(bs, k + 1), + fill_value=-1) + all_probs = target_probs.new_zeros(*all_tokens.shape, + self._vocab_size) + all_logprobs = target_logprobs.new_full(size=all_probs.shape, + fill_value=-float("inf")) + target_token_ids = target_token_ids.flatten() + + # When prompt logprobs is enabled, lens of returned tensors go from + # n_sampled (requests with do_sample=True) to n_prompt+n_prefills. + # We adjust stride accordingly to get the generated tokens and + # their probs, but pass on prompt_logprobs as is, since it may be + # that n_prompts >> K. + has_prompt_log = any((sg.sampling_params.prompt_logprobs + and sg.sampling_params.prompt_logprobs > 0) + for sg in target_seq_group_metadata_list) + # TODO (NickLucche) we should surface `disable_logprobs` as to not + # break abstraction to get its value. + if (not self._scorer_worker.model_runner.disable_logprobs\ + and has_prompt_log): + prompt_logprobs = [ + o.prompt_logprobs for o in target_sampler_output.outputs + ] + + # Split loop into prefill|decode for readability. + start_loc, i = 0, 0 + while i < len(target_seq_group_metadata_list + ) and target_seq_group_metadata_list[i].is_prompt: + seq_meta = target_seq_group_metadata_list[i] + end_loc = start_loc + if has_prompt_log: + end_loc += seq_meta.token_chunk_size + elif seq_meta.do_sample: + end_loc += 1 + + # Skip chunks with no output tokens. + if seq_meta.do_sample: + # Get sampled token (last position in chunk) and its prob. + all_tokens[i, 0] = target_token_ids[end_loc - 1] + all_probs[i, 0] = target_probs[end_loc - 1] + all_logprobs[i, 0] = target_logprobs[end_loc - 1] + + i += 1 + start_loc = end_loc + # Decodes. + while i < len(target_seq_group_metadata_list): + proposed_len, seq_meta = all_proposal_lengths[ + i], target_seq_group_metadata_list[i] + output_len = proposed_len + 1 + end_loc = start_loc + output_len + all_tokens[ + i, :output_len] = target_token_ids[start_loc:end_loc] + all_probs[i, :output_len] = target_probs[start_loc:end_loc] + all_logprobs[ + i, :output_len] = target_logprobs[start_loc:end_loc] + start_loc = end_loc + i += 1 + + hidden_states = None + if target_sampler_output.hidden_states is not None: + hidden_states = target_sampler_output.hidden_states.reshape( + bs, (k + 1), -1) + + return SpeculativeScores(probs=all_probs, + token_ids=all_tokens, + logprobs=all_logprobs, + hidden_states=hidden_states, + prompt_logprobs=prompt_logprobs) diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py new file mode 100644 index 0000000..70fd352 --- /dev/null +++ b/vllm/spec_decode/multi_step_worker.py @@ -0,0 +1,417 @@ +# SPDX-License-Identifier: Apache-2.0 + +import copy +import weakref +from typing import Dict, List, Set, Tuple + +import torch + +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.platforms import current_platform +from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData, + SequenceGroupMetadata) + +if current_platform.is_cuda_alike(): + from vllm.spec_decode.draft_model_runner import TP1PP1DraftModelRunner + +from vllm.spec_decode.interfaces import (SpeculativeProposals, + SpeculativeProposer) +from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase +from vllm.spec_decode.top1_proposer import Top1Proposer +from vllm.worker.worker_base import DelegateWorkerBase + + +class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase): + """The MultiStepWorker is equivalent to a Worker except that it allows + multiple forward passes in a single call, assuming the scheduler has + allocated enough space to store the additional KV. This reduces overhead + by invoking the scheduler less. + + The MultiStepWorker does not support cache swap operations, or beam search. + Cache swap operations do not require large modifications. On the other hand, + beam search requires memory allocations during sequence forks and thus + requires more thought for MultiStepWorker support. + """ + + def __init__(self, *args, **kwargs): + DelegateWorkerBase.__init__(self, *args, **kwargs) + # Lazy initialization list. + self._proposer: SpeculativeProposer + + def init_device(self) -> None: + self.worker.init_device() + self._proposer = Top1Proposer( + weakref.proxy(self), # type: ignore[arg-type] + self.device, + self.vocab_size, + max_proposal_len=self.max_model_len, + ) + + def set_include_gpu_probs_tensor(self) -> None: + # Need include_gpu_probs_tensor for MultiStepWorker + self.model_runner.model.sampler.include_gpu_probs_tensor = True + + def set_should_modify_greedy_probs_inplace(self) -> None: + self.model_runner.model.sampler.should_modify_greedy_probs_inplace = ( + True) + + @torch.inference_mode() + def sampler_output( + self, + execute_model_req: ExecuteModelRequest, + sample_len: int, + seq_ids_with_bonus_token_in_last_step: Set[int], + ) -> Tuple[List[SamplerOutput], bool]: + """Run the model forward pass sample_len times. Returns the list of + sampler output, one per model forward pass, along with indicator of + whether torch tensor in sampler output need to be transposed in latter + sampler_output_to_torch logic. + + For multi step worker, this indicator shall be True. + """ + self._raise_if_unsupported(execute_model_req) + # Expand the batch for sequences with a bonus token. + # Perform a forward pass on the expanded batch and filter the + # response to retain only the original sequences' responses. + expanded_request, indices_of_seq_with_bonus_tokens =\ + self._expand_execute_model_request( + execute_model_req, seq_ids_with_bonus_token_in_last_step) + + # Run model sample_len times. + model_outputs: List[SamplerOutput] = [] + if current_platform.is_cuda_alike() and isinstance( + self.model_runner, TP1PP1DraftModelRunner + ) and self.model_runner.supports_gpu_multi_step(expanded_request): + # Here we run the draft_model_runner with multi-step prepare + # on the GPU directly + expanded_request.num_steps = sample_len + self.model_runner.set_indices_of_seq_with_bonus_tokens( + indices_of_seq_with_bonus_tokens) + model_outputs = self.execute_model( + execute_model_req=expanded_request) + else: + # Here we run multi-step directly, with every step prepared + # on the CPU. + # TODO: Remove this branch once DraftModelRunner supports TP>1 + # and other restrictions that are part of DraftModelRunner's + # supports_gpu_multi_step(..) + if expanded_request.previous_hidden_states is not None: + self.worker.model_runner.return_hidden_states = True + for _ in range(sample_len): + model_output: List[SamplerOutput] = self.worker.execute_model( + execute_model_req=expanded_request) + assert (len(model_output) == 1 + ), "composing multistep workers not supported" + model_output = model_output[0] + self._maybe_update_previous_hidden_states( + model_output, expanded_request) + + self._append_new_tokens( + model_output, expanded_request.seq_group_metadata_list, + indices_of_seq_with_bonus_tokens) + model_outputs.append(model_output) + + # move indices to device to avoid stream sync + indices_of_seq_with_bonus_tokens = torch.tensor( + indices_of_seq_with_bonus_tokens, device=self.device) + filtered_model_outputs = self._filter_model_output( + model_outputs, indices_of_seq_with_bonus_tokens) + return filtered_model_outputs, True + + @staticmethod + def _maybe_update_previous_hidden_states( + model_output: SamplerOutput, + expanded_request: ExecuteModelRequest) -> None: + """ + Updates the previous hidden states in an expanded request + in-place with the hidden states from the model output. + """ + if expanded_request.previous_hidden_states is not None: + expanded_request.previous_hidden_states = HiddenStates( + model_output.hidden_states, + expanded_request.seq_group_metadata_list) + + @staticmethod + def _expand_execute_model_request( + execute_model_req: ExecuteModelRequest, + seq_with_bonus_token_in_last_step: set, + ) -> Tuple[ExecuteModelRequest, List[int]]: + """ + Expands the execute model request based on sequences with bonus + tokens. + + For each sequence with a bonus token, this method creates a new + sequence without the bonus token and adds it to the execute model + request. The original sequence groups are also retained. The indices + of the original sequence groups are returned for further processing. + + Args: + execute_model_req (ExecuteModelRequest): The original execute + model request. + seq_with_bonus_token_in_last_step (set): Set of sequence IDs that + contain bonus tokens. + + Returns: + Tuple[ExecuteModelRequest, List[int]]: The updated execute model + request with expanded sequences and a list of indices corresponding + to the original sequence groups. + """ + updated_seq_group_metadata_list: List[SequenceGroupMetadata] = [] + updated_execute_model_req = execute_model_req.clone( + updated_seq_group_metadata_list) + indices_of_original_sequence_groups = [] + for seq_group in execute_model_req.seq_group_metadata_list: + seq_group_has_bonus_tokens = False + for seq_id, _ in seq_group.seq_data.items(): + # Identify sequences with bonus tokens in the sequence group. + if seq_id in seq_with_bonus_token_in_last_step: + seq_group_has_bonus_tokens = True + break + if seq_group_has_bonus_tokens: + #Create new sequences without the last bonus token. These new + # sequence have the same sequence id as the original sequence. + # We create a new sequence group and add them there. + updated_seq_group_without_bonus_token = \ + MultiStepWorker._copy_seq_metadata_excluding_last_token( + seq_group, seq_with_bonus_token_in_last_step) + updated_seq_group_metadata_list.append( + updated_seq_group_without_bonus_token) + # Add the original sequence group. + updated_seq_group_metadata_list.append( + MultiStepWorker._shallow_copy_seq_group_metadata(seq_group)) + # Record the index of the original sequence group. + indices_of_original_sequence_groups.append( + len(updated_seq_group_metadata_list) - 1) + + updated_execute_model_req.seq_group_metadata_list =\ + updated_seq_group_metadata_list + + if isinstance(updated_execute_model_req.previous_hidden_states, + HiddenStates): + updated_execute_model_req.previous_hidden_states\ + .expand_with_bonus_tokens(seq_with_bonus_token_in_last_step) + + return updated_execute_model_req, indices_of_original_sequence_groups + + @staticmethod + def _filter_model_output( + expanded_batch_outputs: List[SamplerOutput], + output_indices_to_retain: torch.Tensor) -> List[SamplerOutput]: + """ + Filters the model output to include only the specified sequence + outputs. This method contracts the expanded batch output from the + model to retain the outputs of only those sequences indicated by the + provided indices. + + Args: + expanded_batch_output (List[SamplerOutput]): The expanded output + batch from the model. + output_indices_to_retain (torch.Tensor): Indices of the model + outputs to retain. + + Returns: + List[SamplerOutput]: A list containing the filtered model + outputs for the specified indices. + """ + return [ + SamplerOutput( + outputs=[ + expanded_batch_output.outputs[i] + for i in output_indices_to_retain + ] if len(expanded_batch_output.outputs) > 0 else [], + sampled_token_probs=( + expanded_batch_output. + sampled_token_probs[output_indices_to_retain] + if expanded_batch_output.sampled_token_probs is not None + else None), + logprobs=( + expanded_batch_output.logprobs[output_indices_to_retain] + if expanded_batch_output.logprobs is not None else None), + sampled_token_ids=(expanded_batch_output. + sampled_token_ids[output_indices_to_retain] + if expanded_batch_output.sampled_token_ids + is not None else None)) + for expanded_batch_output in expanded_batch_outputs + ] + + def get_spec_proposals( + self, + execute_model_req: ExecuteModelRequest, + seq_ids_with_bonus_token_in_last_step: set, + ) -> SpeculativeProposals: + """Produce speculations given an input batch of sequences. The number of + speculative tokens per sequence is determined by max_proposal_len. + """ + return self._proposer.get_spec_proposals( + execute_model_req, seq_ids_with_bonus_token_in_last_step) + + @staticmethod + def _append_new_tokens( + model_output: List[SamplerOutput], + seq_group_metadata_list: List[SequenceGroupMetadata], + indices_of_seq_with_bonus_tokens: List[int]) -> None: + """Given model output from a single run, append the tokens to the + sequences. This is normally done outside of the worker, but it is + required if the worker is to perform multiple forward passes. + """ + count = 0 + for index, (seq_group_metadata, sequence_group_outputs) in enumerate( + zip(seq_group_metadata_list, model_output)): + seq_group_metadata.is_prompt = False + + for seq_output in sequence_group_outputs.samples: + # NOTE: Beam search is not supported, so we can assume that + # parent_seq_id == seq_id. + seq = seq_group_metadata.seq_data[seq_output.parent_seq_id] + + token_id = seq_output.output_token + token_logprob = seq_output.logprobs[token_id] + # Determine the actual token ID to be generated, + # considering bonus tokens + if index != indices_of_seq_with_bonus_tokens[count]: + bonus_seq_metadata = seq_group_metadata_list[ + indices_of_seq_with_bonus_tokens[count]] + _, bonus_token_seq_data = next( + iter(bonus_seq_metadata.seq_data.items())) + token_id = bonus_token_seq_data.output_token_ids[-1] + else: + count += 1 + + seq.append_token_id(token_id, token_logprob.logprob) + seq.update_num_computed_tokens(1) + + @staticmethod + def _shallow_copy_seq_group_metadata( + seq_group_metadata: SequenceGroupMetadata, ) -> SequenceGroupMetadata: + """Copy input data structures to remove side-effects when input data + structures are shared with other modules. + + Helpful when the vLLM scheduler runs in the same process as the worker. + The alternative is deep-copying (or other form of deep copy); this has + performance downsides. + """ + # Shallow-copy the SequenceGroupMetadata. This allows us to + # append tokens and change is_prompt without external side-effects. + # We must shallow-copy seq_group_metadata as is_prompt could change. + new_seq_group_metadata = copy.copy(seq_group_metadata) + + # We must shallow-copy seq_data as we will append token ids + new_seq_data: Dict[int, SequenceData] = {} + for seq_id, old_seq_data in seq_group_metadata.seq_data.items(): + new_seq_data[seq_id] = copy.copy(old_seq_data) + new_seq_data[seq_id].output_token_ids =\ + old_seq_data.output_token_ids[:] + + new_seq_group_metadata.seq_data = new_seq_data + return new_seq_group_metadata + + @staticmethod + def _copy_seq_metadata_excluding_last_token( + seq_group_metadata: SequenceGroupMetadata, + seq_ids_to_copy: Set[int], + ) -> SequenceGroupMetadata: + """ + Creates a shallow copy of the given SequenceGroupMetadata, retaining + only the sequence IDs specified in seq_ids_to_copy. For each of these + sequence IDs, all output_token_ids except the last one are copied. + Sequence IDs not in seq_ids_to_copy are excluded from the copy. + + Parameters: + seq_group_metadata (SequenceGroupMetadata): The original sequence + group metadata. + seq_ids_to_copy (Set[int]): The set of sequence IDs to include in the + copy. + + Returns: + SequenceGroupMetadata: A shallow copy of the sequence group metadata + with the specified modifications. + """ + # Shallow-copy the SequenceGroupMetadata. + new_seq_group_metadata = copy.copy(seq_group_metadata) + # Shallow-copy seq_data and modify the output_token_ids. + new_seq_data: Dict[int, SequenceData] = {} + for seq_id, old_seq_data in seq_group_metadata.seq_data.items(): + if (seq_id in seq_ids_to_copy): + new_seq_data[seq_id] = copy.copy(old_seq_data) + # Copy all the output token ids except the last. + # Also reduce num_computed_tokens by 1 since we are not + # including the last output token. + # NOTE: num_computed_tokens is not directly used by the + # speculative decoding workers, as it is only relevant for + # chunked prefill, which is disabled for speculative decoding. + # However, to maintain consistency in num_computed_tokens, + # we update it here. + new_seq_data[seq_id].output_token_ids =\ + old_seq_data.output_token_ids[:-1] + new_seq_data[seq_id].update_num_computed_tokens(-1) + new_seq_group_metadata.seq_data = new_seq_data + return new_seq_group_metadata + + def _assert_enough_kv_space( + self, seq_group_metadata_list: List[SequenceGroupMetadata], + num_steps: int) -> None: + """Assert there are enough physical blocks per sequence to store the + current KV plus additional KV from num_steps tokens. + """ + assert self.model_runner.block_size is not None + for seq_group_metadata in seq_group_metadata_list: + # Only one seq_id is guaranteed because there is no beam search. + seq_id = list(seq_group_metadata.seq_data.keys())[0] + seq = seq_group_metadata.seq_data[seq_id] + + # After num_steps, the seq len will be the current seq len + # plus one token per step. + final_seq_len = seq.get_len() + num_steps + + # We will have final_seq_len - 1 KV because vLLM saves KV for a + # token in the iteration after the token was generated. + required_num_kv_slots = final_seq_len - 1 + + # The allocated number of kv slots is the number of allocated blocks + # times the number of slots of block. + number_physical_blocks = len( + seq_group_metadata.block_tables[seq_id]) + allocated_kv_slots = (number_physical_blocks * + self.model_runner.block_size) + + if required_num_kv_slots > allocated_kv_slots: + request_id = seq_group_metadata.request_id + raise ValueError( + "The worker attempted to run " + f"{num_steps} times but found insufficient KV space for " + f"{request_id=} {seq_id=}. ({allocated_kv_slots=} " + f"{required_num_kv_slots=}).") + + def _raise_if_unsupported( + self, + execute_model_req: ExecuteModelRequest, + ) -> None: + """MultiStepWorker does not yet implement support for cache swap + operations or beam search. + """ + if any([ + execute_model_req.blocks_to_swap_in, + execute_model_req.blocks_to_swap_out, + execute_model_req.blocks_to_copy + ]): + raise NotImplementedError( + "MultiStepWorker does not support cache operations") + + if any( + len(seq_group_metadata.seq_data.keys()) != 1 + for seq_group_metadata in + execute_model_req.seq_group_metadata_list): + raise NotImplementedError( + "MultiStepWorker does not support beam search.") + + def maybe_load_lm_head_weight( + self, + lm_head_weight: torch.Tensor, + ) -> None: + weight_loader = getattr( + self.worker.model_runner.model_runner.model.lm_head.weight, + "weight_loader", default_weight_loader) + weight_loader( + self.worker.model_runner.model_runner.model.lm_head.weight, + lm_head_weight) diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py new file mode 100644 index 0000000..57ae173 --- /dev/null +++ b/vllm/spec_decode/ngram_worker.py @@ -0,0 +1,195 @@ +# SPDX-License-Identifier: Apache-2.0 + +import weakref +from typing import List, Optional, Set, Tuple + +import torch +import torch.nn as nn + +from vllm.config import VllmConfig +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest +from vllm.spec_decode.interfaces import SpeculativeProposals +from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase +from vllm.spec_decode.top1_proposer import Top1Proposer + + +class _DummyModel(nn.Module): + pass + + +class NGramWorker(NonLLMProposerWorkerBase): + """NGramWorker provides a light drafter without need for model. + + Current NGramWorker only implements prompt lookup decoding, + and in future we may also do RAG type drafter and other scenarios + which don't rely on LLM model to give proposals. + """ + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + device_type: str = "cuda", + **kwargs, + ): + super().__init__(vllm_config) + + # Get local_rank/vocab_size from kwargs attribute + self.local_rank = local_rank + self.device_type = device_type + + # Lazy initialization list. + self._proposer: Top1Proposer + + def set_ngram_window_size(self, ngram_prompt_lookup_min: int, + ngram_prompt_lookup_max: int): + # Search valid candidate window between + # ngram_prompt_lookup_min/ngram_prompt_lookup_max + self.ngram_prompt_lookup_max = ngram_prompt_lookup_max + self.ngram_prompt_lookup_min = ngram_prompt_lookup_min + + def init_device(self): + self.device = torch.device(f"{self.device_type}:{self.local_rank}") + + # Current NGramWorker only supports Top1Proposer + self._proposer = Top1Proposer( + weakref.proxy(self), # type: ignore[arg-type] + device=self.device, + vocab_size=self.vocab_size, + ) + + def load_model(self) -> None: + pass # Dummy + + def get_model(self) -> nn.Module: + return _DummyModel() + + def sampler_output( + self, + execute_model_req: ExecuteModelRequest, + sample_len: int, + # Unused parameter. NGramWorker does not use the KV Cache and + # therefore does not need this parameter. + seq_ids_with_bonus_token_in_last_step: Set[int], + ) -> Tuple[Optional[List[Optional[SamplerOutput]]], bool]: + """NGram match algo to pick proposal candidate. Returns the list of + sampler output, one per SequenceGroupMetadata. + + For ngram worker, we already done needed transposed internal, so the + indicator pass to sampler_output_to_torch shall be False. + """ + self._raise_if_unsupported(execute_model_req) + + has_spec_out = False + token_id_list: List[Optional[torch.Tensor]] = [] + token_prob_list: List[Optional[torch.Tensor]] = [] + for idx, seq_group_metadata in enumerate( + execute_model_req.seq_group_metadata_list): + seq_data = next(iter(seq_group_metadata.seq_data.values())) + + seq_len = seq_data.get_len() + # When seq_len is less than 3072 (3K), we use CPU to perform + # the ngram match. Otherwise, we use the device specified in + # the model config (normally GPU). 3072 is a rough threshold + # based on profiling on H100, and it can be adjusted based + # on the actual performance on different hardware. + cur_device = "cpu" if seq_len < 3072 else self.device + input_ids = torch.as_tensor(seq_data.get_token_ids(), + dtype=torch.long, + device=cur_device) + input_length = seq_data.get_len() + + for ngram_size in range( + min(self.ngram_prompt_lookup_max, input_length - 1), + self.ngram_prompt_lookup_min - 1, + -1, + ): + ngram_tensor = input_ids[-ngram_size:] + if ngram_size == 1: + # Do not match itself and do not use unfold and all + matches = (input_ids[:-1] == ngram_tensor) + else: + windows = input_ids.unfold(dimension=0, + size=ngram_size, + step=1) + # Do not match itself + matches = (windows[:-1] == ngram_tensor).all(dim=-1) + + # first_match includes "values" (bool), indicating whether + # the match is found, and "indices", indicating the index + # of the first match. + first_match = matches.max(dim=-1) + if first_match.values.item(): + proposal_start_idx = first_match.indices.add_(ngram_size) + spec_indices = ( + proposal_start_idx).repeat(sample_len) + torch.arange( + sample_len, device=cur_device) + spec_indices.clamp_(max=input_ids.shape[-1] - 1) + res = input_ids.gather(dim=-1, + index=spec_indices).to(self.device) + token_id_list.append(res) + token_prob_list.append( + torch.nn.functional.one_hot( + res, + num_classes=self.vocab_size).to(torch.float32)) + has_spec_out = True + break + else: + token_id_list.append(None) + token_prob_list.append(None) + + if not has_spec_out: + return None, False + + outputs: List[Optional[SamplerOutput]] = [] + for idx in range(len(execute_model_req.seq_group_metadata_list)): + if token_id_list[idx] is None: + outputs.append(None) + else: + outputs.append( + SamplerOutput( + outputs=None, + sampled_token_probs=token_prob_list[idx], + logprobs=torch.zeros((sample_len, self.vocab_size), + dtype=torch.float32, + device=self.device), + sampled_token_ids=token_id_list[idx], + )) + + return outputs, False + + def get_spec_proposals( + self, + execute_model_req: ExecuteModelRequest, + # Unused parameter. NGramWorker does not use the KV Cache and + # therefore does not need this parameter. + seq_ids_with_bonus_token_in_last_step: Set[int], + ) -> SpeculativeProposals: + """Produce speculations given an input batch of sequences. The number of + speculative tokens per sequence is determined by max_proposal_len. + """ + return self._proposer.get_spec_proposals( + execute_model_req, seq_ids_with_bonus_token_in_last_step) + + def _raise_if_unsupported( + self, + execute_model_req: ExecuteModelRequest, + ) -> None: + """NGramWorker does not yet implement support for cache swap + operations or beam search. + """ + if any([ + execute_model_req.blocks_to_swap_in, + execute_model_req.blocks_to_swap_out, + execute_model_req.blocks_to_copy + ]): + raise NotImplementedError( + "NGramWorker does not support cache operations") + + if any( + len(seq_group_metadata.seq_data.keys()) != 1 + for seq_group_metadata in + execute_model_req.seq_group_metadata_list): + raise NotImplementedError( + "NGramWorker does not support beam search.") diff --git a/vllm/spec_decode/proposer_worker_base.py b/vllm/spec_decode/proposer_worker_base.py new file mode 100644 index 0000000..2829d63 --- /dev/null +++ b/vllm/spec_decode/proposer_worker_base.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from typing import List, Optional, Set, Tuple + +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest +from vllm.spec_decode.interfaces import SpeculativeProposer +from vllm.worker.worker_base import LoRANotSupportedWorkerBase + + +class ProposerWorkerBase(LoRANotSupportedWorkerBase, SpeculativeProposer): + """Interface for proposer workers""" + + @abstractmethod + def sampler_output( + self, + execute_model_req: ExecuteModelRequest, + sample_len: int, + # A set containing all sequence IDs that were assigned bonus tokens + # in their last forward pass. This set is used to backfill the KV cache + # with the key-value pairs of the penultimate token in the sequences. + # This parameter is only used by the MultiStepWorker, which relies on + # the KV cache for token generation. It is not used by workers that + # do not utilize the KV cache. + seq_ids_with_bonus_token_in_last_step: Set[int] + ) -> Tuple[Optional[List[SamplerOutput]], bool]: + raise NotImplementedError + + def set_include_gpu_probs_tensor(self) -> None: + """Implementation optional""" + pass + + def set_should_modify_greedy_probs_inplace(self) -> None: + """Implementation optional""" + pass + + +class NonLLMProposerWorkerBase(ProposerWorkerBase, ABC): + """Proposer worker which does not use a model with kvcache""" + + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + """get_spec_proposals is used to get the proposals""" + return [] + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """This is never called on the proposer, only the target model""" + raise NotImplementedError + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + pass + + def get_cache_block_size_bytes(self) -> int: + return 0 diff --git a/vllm/spec_decode/smaller_tp_pp_proposer_worker.py b/vllm/spec_decode/smaller_tp_pp_proposer_worker.py new file mode 100644 index 0000000..65ad239 --- /dev/null +++ b/vllm/spec_decode/smaller_tp_pp_proposer_worker.py @@ -0,0 +1,221 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional, Set, Tuple + +import torch +import torch.nn as nn + +from vllm.distributed.parallel_state import (get_pp_group, get_tp_group, + init_model_parallel_group, + patch_model_parallel_group) +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.sequence import ExecuteModelRequest +from vllm.spec_decode.interfaces import SpeculativeProposals +from vllm.spec_decode.multi_step_worker import MultiStepWorker +from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase + +logger = init_logger(__name__) + + +class _DummyModel(nn.Module): + pass + + +class SmallerTpPpProposerWorker(ProposerWorkerBase): + """Class which allows a speculative draft model to run with smaller tensor + parallel degree and pipeline parallel degree than target model. + This reduces the communication overhead of small draft models. + + To implement this feature, this class differs behavior based on is_dummy + flag, where dummy means worker that does not participate draft generation. + Participating workers use a smaller tp group by patching vLLM's tensor + parallel group temporarily during forward passes of draft models. + """ + + @classmethod + def maybe_wrap_worker(cls, worker, draft_tensor_parallel_size: int, + target_tensor_parallel_size: int, + draft_pipeline_parallel_size: int, + target_pipeline_parallel_size: int): + """Wrap the worker in a SmallerTpProposerWorker if necessary. + """ + if draft_tensor_parallel_size == target_tensor_parallel_size and \ + draft_pipeline_parallel_size == target_pipeline_parallel_size: + return worker + + # gpu ranks that will generate draft tokens together + draft_tp_ranks = list(range(draft_tensor_parallel_size)) + draft_pp_ranks = list(range(draft_pipeline_parallel_size)) + + logger.info("Wrapping {%s} in {%s}", type(worker), cls) + return cls(worker, draft_tp_ranks, draft_pp_ranks) + + def __init__(self, worker: MultiStepWorker, draft_tp_ranks: List[int], + draft_pp_ranks: List[int]): + """Create a SmallerTpPpProposerWorker. + + Args: + worker (MultiStepWorker): an actual worker wrapped with this class + draft_tp_ranks (List[int]): if this value is given, only the GPU tp + ranks written in this value participate in draft generation + draft_pp_ranks (List[int]): if this value is given, only the GPU pp + ranks written in this value participate in draft generation + """ + self._worker = worker + self._draft_tp_ranks = draft_tp_ranks + self._draft_pp_ranks = draft_pp_ranks + + # init during init_device + self._is_dummy = False + self._tp_group = None + self._pp_group = None + + def _patch_model_parallel_group(self): + """Temporarily patch the global pp group state with its own pp group + state. + """ + return patch_model_parallel_group(self._tp_group, self._pp_group) + + def init_device(self) -> None: + # The non-driver workers need to be dummy + + self._is_dummy = get_tp_group().rank not in self._draft_tp_ranks \ + and get_pp_group().rank not in self._draft_pp_ranks + + + + # dummy workers do nothing + if self._is_dummy: + return + + # creates tp and pp process group containing only a subset of gpu ranks + local_tp_rank = get_tp_group().local_rank + local_pp_rank = get_pp_group().local_rank + tp_backend = torch.distributed.get_backend(get_tp_group().device_group) + pp_backend = torch.distributed.get_backend(get_pp_group().device_group) + + all_ranks = torch.arange(len(self._draft_tp_ranks) * len(self._draft_pp_ranks)).reshape( len(self._draft_pp_ranks), len(self._draft_tp_ranks)) + group_ranks = all_ranks.view(-1, len(self._draft_tp_ranks)).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + + self._tp_group = init_model_parallel_group(group_ranks, + local_tp_rank, tp_backend) + + group_ranks = all_ranks.transpose(0, 1).reshape( + -1, len(self._draft_pp_ranks)).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + self._pp_group = init_model_parallel_group(group_ranks, + local_pp_rank, pp_backend) + + with self._patch_model_parallel_group(): + self._worker.init_device() + + def set_include_gpu_probs_tensor(self) -> None: + if self._is_dummy: + return + + # Need include_gpu_probs_tensor for multi_step_worker + self._worker.set_include_gpu_probs_tensor() + + def set_should_modify_greedy_probs_inplace(self) -> None: + if self._is_dummy: + return + + self._worker.set_should_modify_greedy_probs_inplace() + + def load_model(self) -> None: + if self._is_dummy: + return + + with self._patch_model_parallel_group(): + self._worker.load_model() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + if self._is_dummy: + # this case is not used now + return -1, -1 + + with self._patch_model_parallel_group(): + return self._worker.determine_num_available_blocks() + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + if self._is_dummy: + return + + with self._patch_model_parallel_group(): + self._worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) + + def sampler_output( + self, + execute_model_req: ExecuteModelRequest, + sample_len: int, + seq_ids_with_bonus_token_in_last_step: Set[int], + ) -> Tuple[List[SamplerOutput], bool]: + # Do not check _is_dummy, as it's always called by get_spec_proposals + return self._worker.sampler_output( + execute_model_req, sample_len, + seq_ids_with_bonus_token_in_last_step) + + def get_spec_proposals( + self, + execute_model_req: ExecuteModelRequest, + seq_ids_with_bonus_token_in_last_step: Set[int], + ) -> SpeculativeProposals: + """Produce speculations given an input batch of sequences. The number of + speculative tokens per sequence is determined by max_proposal_len. + """ + if self._is_dummy: + return SpeculativeProposals(None, None, None) + + with self._patch_model_parallel_group(): + return self._worker.get_spec_proposals( + execute_model_req, seq_ids_with_bonus_token_in_last_step) + + def get_model(self) -> nn.Module: + if self._is_dummy: + return _DummyModel() + + with self._patch_model_parallel_group(): + return self._worker.get_model() + + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + if self._is_dummy: + return [] + + with self._patch_model_parallel_group(): + return self._worker.execute_model(execute_model_req) + + def get_cache_block_size_bytes(self) -> int: + if self._is_dummy: + # by returning zero, target worker can use the entire kv cache space + return 0 + + return self._worker.get_cache_block_size_bytes() + + @property + def vocab_size(self) -> int: + return self._worker.vocab_size + + def maybe_load_lm_head_weight( + self, + lm_head_weight: torch.Tensor, + ) -> None: + if self._is_dummy: + return + + with self._patch_model_parallel_group(): + weight_loader = getattr( + self._worker.worker.model_runner.model_runner.model.\ + lm_head.weight, + "weight_loader", + default_weight_loader) + weight_loader( + self._worker.worker.model_runner.model_runner.model.\ + lm_head.weight, + lm_head_weight) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py new file mode 100644 index 0000000..a794ad7 --- /dev/null +++ b/vllm/spec_decode/spec_decode_worker.py @@ -0,0 +1,1407 @@ +# SPDX-License-Identifier: Apache-2.0 + +import copy +from collections import defaultdict +from functools import cached_property +from typing import Any, Dict, List, Optional, Set, Tuple, Type + +import torch +import torch.nn as nn + +from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig +from vllm.distributed import get_pp_group +from vllm.distributed.communication_op import (broadcast_tensor_dict, + get_tp_group, + tensor_model_parallel_gather) +from vllm.distributed.parallel_state import model_parallel_is_initialized +from vllm.logger import init_logger +from vllm.model_executor.layers.rejection_sampler import RejectionSampler +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.spec_decode_base_sampler import ( + SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler) +from vllm.model_executor.layers.typical_acceptance_sampler import ( + TypicalAcceptanceSampler) +from vllm.platforms import current_platform +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, + CompletionSequenceGroupOutput, ExecuteModelRequest, + HiddenStates, IntermediateTensors, + SequenceGroupMetadata, + get_all_seq_ids_and_request_ids) +from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer + +if current_platform.is_cuda_alike(): + from vllm.spec_decode.draft_model_runner import TP1PP1DraftModelRunner + +from vllm.spec_decode.interfaces import (SpeculativeProposals, + SpeculativeScorer, SpeculativeScores) +from vllm.spec_decode.medusa_worker import MedusaWorker +from vllm.spec_decode.metrics import AsyncMetricsCollector +from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker +from vllm.spec_decode.mqa_scorer import MQAScorer +from vllm.spec_decode.multi_step_worker import MultiStepWorker +from vllm.spec_decode.ngram_worker import NGramWorker +from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase +from vllm.spec_decode.smaller_tp_pp_proposer_worker import ( + SmallerTpPpProposerWorker) +from vllm.spec_decode.target_model_runner import TargetModelRunner +from vllm.spec_decode.util import (Timer, create_logprobs_output, + create_sequence_group_output, + get_all_num_logprobs, + get_sampled_token_logprobs, nvtx_range, + split_batch_by_proposal_len) +from vllm.utils import resolve_obj_by_qualname +from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase + +logger = init_logger(__name__) + + +def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": + """Helper method that is the entrypoint for Executors which use + WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config. + """ + vllm_config: VllmConfig = kwargs.get("vllm_config") + speculative_config: SpeculativeConfig = vllm_config.speculative_config + assert speculative_config is not None + + draft_worker_kwargs = kwargs.copy() + + kwargs["model_runner_cls"] = TargetModelRunner + target_worker_config = copy.deepcopy(vllm_config) + target_worker_config.parallel_config.worker_cls =\ + target_worker_config.parallel_config.sd_worker_cls + cls = resolve_obj_by_qualname( + target_worker_config.parallel_config.worker_cls) + target_worker = cls(*args, **kwargs) + # Set the disable_logprobs variable in the TargetModelRunner instance + # as per its value specified in the SpeculativeConfig. + target_worker.model_runner.disable_logprobs =\ + speculative_config.disable_logprobs + + draft_worker_config = copy.deepcopy(vllm_config) + draft_worker_config.model_config = speculative_config.draft_model_config + draft_worker_config.quant_config = VllmConfig._get_quantization_config( + draft_worker_config.model_config, + vllm_config.load_config, + ) + speculative_config.draft_parallel_config.worker_cls =\ + draft_worker_config.parallel_config.sd_worker_cls + draft_worker_config.parallel_config = speculative_config.draft_parallel_config # noqa + # TODO allow draft-model specific load config. + + # Override draft-model specific worker args. + draft_worker_kwargs.update( + vllm_config=draft_worker_config, + ngram_prompt_lookup_max=speculative_config.prompt_lookup_max, + ngram_prompt_lookup_min=speculative_config.prompt_lookup_min, + ) + + spec_decode_worker = SpecDecodeWorker.create_worker( + scorer_worker=target_worker, + draft_worker_kwargs=draft_worker_kwargs, + disable_mqa_scorer=speculative_config.disable_mqa_scorer, + disable_by_batch_size=speculative_config.disable_by_batch_size, + draft_token_acceptance_method=speculative_config.acceptance_method, + typical_acceptance_sampler_posterior_threshold=speculative_config. + posterior_threshold, + typical_acceptance_sampler_posterior_alpha=speculative_config. + posterior_alpha, + disable_logprobs=speculative_config.disable_logprobs, + disable_log_stats=speculative_config.disable_log_stats, + num_speculative_tokens=speculative_config.num_speculative_tokens, + ) + + return spec_decode_worker + + +# Reminder: Please update docs/source/features/compatibility_matrix.md +# If the feature combo become valid +class SpecDecodeWorker(LoRANotSupportedWorkerBase): + """Worker which implements speculative decoding. + + Speculative decoding reduces decoding per-token latency by using a proposal + method, such as a small draft model, to speculate ahead of a larger LLM. The + probabilities of the speculative tokens are then determined by the larger + LLM, after which some verification routine determines which (if any) of the + speculative tokens are accepted by the larger LLM. + + See https://github.com/vllm-project/vllm/pull/2188 and + https://github.com/vllm-project/vllm/pull/3103 for more info. + + The current implementation has the following limitations: + * Only draft-model proposal is implemented (contributions for more forms are + welcome!). + * Only top-1 proposal and scoring are implemented. Tree-attention is left as + future work. + * All sequences in a batch must have the same proposal length, or zero. This + can be improved by having per-sequence speculation in the future. + * The scoring forward pass is done without an MQA kernel, which is + suboptimal especially as the batch size, proposal length, and sequence + lengths grow. Contributions to add a MQA scoring are welcome once + correctness tests pass. + More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit. + """ + + @classmethod + def create_worker( + cls, + scorer_worker: WorkerBase, + draft_worker_kwargs: Dict[str, Any], + disable_mqa_scorer: bool, + disable_by_batch_size: Optional[int], + draft_token_acceptance_method: str, + typical_acceptance_sampler_posterior_threshold: float, + typical_acceptance_sampler_posterior_alpha: float, + disable_logprobs: bool, + disable_log_stats: bool, + num_speculative_tokens: int, + ) -> "SpecDecodeWorker": + + allow_zero_draft_token_step = True + enable_lm_head_weight_load = False + num_spec_prefill_steps = 1 + ngram_prompt_lookup_max = ( + draft_worker_kwargs.pop("ngram_prompt_lookup_max")) + ngram_prompt_lookup_min = ( + draft_worker_kwargs.pop("ngram_prompt_lookup_min")) + draft_model_config = draft_worker_kwargs["vllm_config"].model_config + draft_parallel_config: ParallelConfig = draft_worker_kwargs[ + 'vllm_config'].parallel_config + if ngram_prompt_lookup_max > 0: + draft_worker_kwargs[ + "device_type"] = scorer_worker.device_config.device.type + proposer_worker = NGramWorker(**draft_worker_kwargs) + proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min, + ngram_prompt_lookup_max) + else: + draft_tp = draft_parallel_config.tensor_parallel_size + target_tp = scorer_worker.parallel_config.tensor_parallel_size + draft_pp = draft_parallel_config.pipeline_parallel_size + target_pp = scorer_worker.parallel_config.pipeline_parallel_size + + if draft_model_config.hf_config.model_type == "mlp_speculator": + proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs) + elif draft_model_config.hf_config.model_type == "medusa": + proposer_worker = MedusaWorker(**draft_worker_kwargs) + else: + if draft_tp == 1 and draft_pp == 1: + if current_platform.is_cuda_alike(): + draft_worker_kwargs[ + "model_runner_cls"] = TP1PP1DraftModelRunner + else: + if draft_model_config.hf_config.model_type == "eagle": + raise NotImplementedError( + f"{draft_model_config.hf_config.model_type} " + "does not support TP > 1 or PP > 1 yet") + + allow_zero_draft_token_step = False + + # Load lm_head weight for eagle in init_device + if draft_model_config.hf_config.model_type == "eagle": + enable_lm_head_weight_load = True + + proposer_worker = MultiStepWorker(**draft_worker_kwargs) + if draft_model_config.hf_config.model_type == "deepseek_mtp": + num_spec_prefill_steps = \ + draft_model_config.hf_config.n_predict + + proposer_worker = SmallerTpPpProposerWorker.maybe_wrap_worker( + proposer_worker, draft_tp, target_tp, draft_pp, target_pp) + + logger.info("Configuring SpecDecodeWorker with proposer=%s", + type(proposer_worker)) + + spec_decode_sampler: SpecDecodeBaseSampler = None + if draft_token_acceptance_method == "rejection_sampler": + spec_decode_sampler = RejectionSampler() + elif draft_token_acceptance_method == "typical_acceptance_sampler": + spec_decode_sampler = TypicalAcceptanceSampler( + posterior_threshold=\ + typical_acceptance_sampler_posterior_threshold, + posterior_alpha=typical_acceptance_sampler_posterior_alpha, + ) + logger.info( + "[Speculative Decoding] Configuring" + " SpecDecodeWorker with sampler=%s", type(spec_decode_sampler)) + + if not disable_mqa_scorer: + if scorer_worker.model_runner.attn_backend.get_name( + ) != "FLASH_ATTN": + disable_mqa_scorer = True + logger.info( + "[Speculative Decoding] Disabling MQA scorer as the " + "MQA is only available with flash attn backend.") + + if draft_model_config and \ + draft_model_config.max_model_len < \ + scorer_worker.model_config.max_model_len: + disable_mqa_scorer = True + logger.info( + "[Speculative Decoding] Disabling MQA scorer as the " + "draft model max_model_len is smaller than the target " + "model max_model_len.") + + if not scorer_worker.model_runner.model_config.enforce_eager: + disable_mqa_scorer = True + logger.info( + "[Speculative Decoding] Disabling MQA scorer as the " + "target model is not running in eager mode.") + + return SpecDecodeWorker( + proposer_worker, + scorer_worker, + disable_mqa_scorer=disable_mqa_scorer, + disable_logprobs=disable_logprobs, + disable_log_stats=disable_log_stats, + disable_by_batch_size=disable_by_batch_size, + spec_decode_sampler=spec_decode_sampler, + allow_zero_draft_token_step=allow_zero_draft_token_step, + enable_lm_head_weight_load=enable_lm_head_weight_load, + num_spec_prefill_steps=num_spec_prefill_steps) + + def __init__( + self, + proposer_worker: ProposerWorkerBase, + scorer_worker: WorkerBase, + spec_decode_sampler: SpecDecodeBaseSampler, + disable_mqa_scorer: bool = False, + disable_logprobs: bool = False, + disable_log_stats: bool = False, + metrics_collector: Optional[AsyncMetricsCollector] = None, + disable_by_batch_size: Optional[int] = None, + allow_zero_draft_token_step: Optional[bool] = True, + enable_lm_head_weight_load: Optional[bool] = False, + num_spec_prefill_steps: int = 1, + ): + """ + Create a SpecDecodeWorker. + + Args: + proposer_worker: A worker that can produce speculative tokens for + sequences. + scorer_worker: A worker that produces probabilities of speculative + tokens according to some base model. Typically a vanilla vLLM + Worker. + spec_decode_sampler: A Torch module used to perform acceptance + sampling of the draft tokens in the verification step of + speculative decoding. Currently we support two different + types of sampler namely RejectionSampler and + TypicalAcceptanceSampler. 'spec_decode_sampler' is either an + instance of RejectionSampler or TypicalAcceptanceSampler. + disable_mqa_scorer: If set to True, disable the MQA scorer and use + the BatchExpansionTop1Scorer instead. + disable_logprobs: If set to True, token log probabilities will + not be output in both the draft worker and the target worker. + If set to False, log probabilities will be output by both. + disable_log_stats: If set to True, disable periodic printing of + speculative stage times. + disable_by_batch_size: If the batch size is larger than this, + disable speculative decoding for new incoming requests. + metrics_collector: Helper class for collecting metrics; can be set + for testing purposes. + allow_zero_draft_token_step: whether to allow a step where the draft + model generates no draft token; should disallow when the tp of + draft model is larger than 1 (TODO: #5814) + enable_lm_head_weight_load: whether to load lm_head weight for + draft models like eagle. + num_spec_prefill_steps: number of speculative prefill steps to run + before the speculative decoding starts. This is only used when + the draft model is a deepseek_mtp model that requires prefill + kv cache separately for each MTP layer. + """ + self.proposer_worker = proposer_worker + self.scorer_worker = scorer_worker + scorer_runner = getattr(self.scorer_worker, "model_runner", None) + self.generators = scorer_runner.get_generators( + ) if scorer_runner else None + self.disable_by_batch_size = disable_by_batch_size or float("inf") + self.spec_decode_sampler = spec_decode_sampler + self._allow_zero_draft_token_step = allow_zero_draft_token_step + self._enable_lm_head_weight_load = enable_lm_head_weight_load + self._metrics = AsyncMetricsCollector( + self.spec_decode_sampler + ) if metrics_collector is None else metrics_collector + # Tracks the sequence IDs that received a bonus token ID in + # their last forward pass. Needed only if KV cache is being + # used for token generation such as in the case of MultiStepWorker. + self._seq_with_bonus_token_in_last_step: Set[int] = set() + # Tracks the currently active request ids and the sequence IDs + # corresponding to them + self._request_id_seq_id_mapping: Dict[str, Set[int]] = defaultdict(set) + # Tracks if the proposer worker uses the KV cache or not. + + self.probs_dtype = self.spec_decode_sampler.probs_dtype + self.token_id_dtype = self.spec_decode_sampler.token_id_dtype + # Lazy initialization. + self.scorer: SpeculativeScorer + self.disable_mqa_scorer = disable_mqa_scorer + + # Hidden states from target model to pass to proposer + # in the subsequent step. + self.previous_hidden_states: Dict[int, Optional[HiddenStates]] = {} + self._disable_logprobs = disable_logprobs + self._disable_log_stats = disable_log_stats + self._num_spec_prefill_steps = num_spec_prefill_steps + + def init_device(self) -> None: + """Initialize both scorer and proposer models. + """ + # The scorer worker model is initialized first in case the proposer + # model has a smaller TP degree than the target worker. + self.scorer_worker.init_device() + self.proposer_worker.init_device() + + # NOTE(cade): load_model is not part of the WorkerBase interface. + self.scorer_worker.load_model() + self.proposer_worker.load_model() + + if self._enable_lm_head_weight_load: + if get_pp_group().is_last_rank: + # NOTE(Shangming): gather lm_head weight when tp enabled + target_lm_head_weight: torch.Tensor = \ + tensor_model_parallel_gather( + self.scorer_worker.model_runner.model_runner.model.lm_head.weight.data, + dim=0, + ) + # Send target_lm_head_weight to the first pp rank + tensors = {"target_lm_head_weight": target_lm_head_weight} + get_pp_group().send_tensor_dict(tensors, dst=self._driver_rank) + else: + # Receive target_lm_head_weight from the last pp rank + tensors = get_pp_group().recv_tensor_dict( + src=get_pp_group().world_size - 1) + target_lm_head_weight = tensors["target_lm_head_weight"] + + self.proposer_worker.maybe_load_lm_head_weight( + target_lm_head_weight) + + # self._metrics.init_tensors(self.rank, device_type=self.device) + if model_parallel_is_initialized(): + self._metrics.init_tensors(get_tp_group().rank_in_group, + device_type=self.device) + self.spec_decode_sampler.init_tensors(get_tp_group().local_rank, + device_type=self.device) + else: + self._metrics.init_tensors(self.rank, device_type=self.device) + self.spec_decode_sampler.init_tensors(self.rank, + device_type=self.device) + + scorer_cls: Type[SpeculativeScorer] + if self.disable_mqa_scorer: + scorer_cls = BatchExpansionTop1Scorer + logger.info("[Speculative Decoding] Use batch " + "expansion for scoring proposals.") + else: + scorer_cls = MQAScorer + logger.info( + "[Speculative Decoding] Use MQA scorer for scoring proposals.") + + self.scorer = scorer_cls(scorer_worker=self.scorer_worker, + device=self.device, + vocab_size=self._vocab_size) + + self._configure_model_sampler_for_spec_decode() + + def load_model(self, *args, **kwargs): + pass + + def _configure_model_sampler_for_spec_decode(self): + """Configure model sampler to emit GPU tensors. This allows spec decode + to keep data on device without transferring to CPU and serializing, + which significantly reduces overhead of sampling during verification. + + NOTE(cade): This breaks abstraction boundaries pretty badly. The better + design is to have the "move to CPU and serialize" sampling decision be + done outside of the model/sampler; this way the "last-mile" worker + object which interfaces with the scheduler can serialize and incur the + performance hit as necessary. This allows us to run the worker several + iterations in a row without incurring the "move to CPU and serialize" + performance penalty. + + Since this requires a large change to vLLM, we defer it to later and + temporarily accept this broken abstraction boundary. + + NOTE(cade): This will require a special check if the proposer worker + does not have a sampler (e.g. ngram speculation). + """ + (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor + ) = True + (self.scorer_worker.model_runner.model.sampler. + should_modify_greedy_probs_inplace) = True + self.proposer_worker.set_include_gpu_probs_tensor() + self.proposer_worker.set_should_modify_greedy_probs_inplace() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of cache blocks to use. + + This is done by profiling the scorer model (which is typically the + larger of the two). Then the total memory which would be used by the + scorer cache is divided evenly between the proposer and scorer model KV, + such that the number of blocks is equal in both KV caches. + """ + if hasattr(self.scorer_worker, 'model_runner') and hasattr(self.proposer_worker, 'model_runner'): + self.scorer_worker.model_runner.model_memory_usage = self.scorer_worker.model_runner.model_memory_usage + self.proposer_worker.model_runner.model_memory_usage + num_gpu_blocks, num_cpu_blocks = ( + self.scorer_worker.determine_num_available_blocks()) + + scorer_cache_block_size_bytes = ( + self.scorer_worker.get_cache_block_size_bytes()) + proposer_cache_block_size_bytes = ( + self.proposer_worker.get_cache_block_size_bytes()) + + new_num_gpu_blocks = split_num_cache_blocks_evenly( + scorer_cache_block_size_bytes, proposer_cache_block_size_bytes, + num_gpu_blocks) + return new_num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the cache engine of the scorer and proposer workers. + """ + self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) + self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) + + def get_model(self) -> nn.Module: + return self.scorer_worker.get_model() + + @torch.inference_mode() + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + """Perform speculative decoding on the input batch. + """ + rank = get_tp_group().rank_in_group if model_parallel_is_initialized( + ) else self.rank + if rank != self._driver_rank: + self._run_non_driver_rank() + return [] + + if execute_model_req is None: + # This signals that there's no more requests to process for now. + # All workers are running infinite loop with broadcast_tensor_dict, + # and it stops the loop when the driver broadcasts an empty input. + # Send an empty input to notify all other workers to stop their + # execution loop. + broadcast_tensor_dict({}, src=0) + return [] + + self._track_finished_requests(execute_model_req) + disable_all_speculation = self._should_disable_all_speculation( + execute_model_req) + num_lookahead_slots = execute_model_req.num_lookahead_slots + all_prompt = True + atleast_one_prompt = False + all_zero_spec_tokens = True + for sgm in execute_model_req.seq_group_metadata_list: + all_prompt = all_prompt and sgm.is_prompt + atleast_one_prompt = atleast_one_prompt or sgm.is_prompt + all_zero_spec_tokens = all_zero_spec_tokens and ( + sgm.num_speculative_tokens == 0) + + if all_prompt and execute_model_req.seq_group_metadata_list: + assert num_lookahead_slots == 0, ( + "Prompt only runs should have num_lookahead_slots equal to 0. " + "This should never happen, please file a bug at " + "https://github.com/vllm-project/vllm/issues") + # Speculative decoding is disabled in the following cases: + # 1. Prefill phase: Speculative decoding is not + # used during the prefill phase. + # 2. Auto-disable enabled: The running queue size exceeds + # the specified threshold. + # 3. No request: There are no requests in the batch, or + # none of the requests in the batch have spec decoding enabled. + # In any of these cases, the proposer and scorer workers + # are called normally. + # We expect `num_speculative_tokens` to be None for prefills. + no_spec = (num_lookahead_slots == 0 or disable_all_speculation + or all_zero_spec_tokens) + + # Broadcast how many lookahead slots are scheduled for this step, and + # whether all speculation is disabled, to all non-driver workers. + + # This is required as if the number of draft model runs changes + # dynamically, the non-driver workers won't know unless we perform a + # communication to inform them. + + # no_spec is used to signal non-driver worker about prefill vs decode + # stage. This is needed to ensure that order of execution of proposer + # and scorer is same in both driver and non-driver workers (i.e., + # scorer -> proposer for prefill and proposer -> scorer in decode). This + # order is needed to support models like EAGLE that take scorer states + # as inputs. + broadcast_dict = dict( + num_lookahead_slots=num_lookahead_slots, + no_spec=no_spec, + disable_all_speculation=disable_all_speculation, + # When both chunked prefill and speculative decoding are enabled + # it is possible that the same batch contains both prefill + # and decodes. If that happens in the scorer we run the batch + # as one single forward pass. However, in the proposer we + # run them as 2 different batches - one for prefill and + # the other for decodes. The variable indicates to the non-driver + # worker that there are prefills as part of the speculative batch + # and hence it needs to run an extra prefill forward pass. + run_spec_proposer_for_prefill=atleast_one_prompt, + ) + broadcast_tensor_dict(broadcast_dict, src=self._driver_rank) + + assert execute_model_req.seq_group_metadata_list is not None, ( + "speculative decoding requires non-None seq_group_metadata_list") + + self._maybe_disable_speculative_tokens( + disable_all_speculation, execute_model_req.seq_group_metadata_list) + + if no_spec: + return self._run_no_spec(execute_model_req, + skip_proposer=disable_all_speculation) + return self._run_speculative_decoding_step(execute_model_req, + num_lookahead_slots) + + @torch.inference_mode() + def start_worker_execution_loop(self) -> None: + """Execute model loop to perform speculative decoding + in parallel worker.""" + while self._run_non_driver_rank(): + pass + + def _should_disable_all_speculation( + self, execute_model_req: ExecuteModelRequest) -> bool: + # When the batch size is too large, disable speculative decoding + # to stop trading off throughput for latency. + return (execute_model_req.running_queue_size + >= self.disable_by_batch_size) + + def _maybe_disable_speculative_tokens( + self, disable_all_speculation: bool, + seq_group_metadata_list: List[SequenceGroupMetadata]) -> None: + if not disable_all_speculation: + return + + for seq_group_metadata in seq_group_metadata_list: + # Once num_speculative_tokens is set to 0, the spec decode + # of this request will be disabled forever. + # TODO(comaniac): We currently store spec decoding specific + # state in the global data structure, but we should maintain + # this state within spec decode worker. + seq_group_metadata.num_speculative_tokens = 0 + + def _serialize_sampler_output_no_logprobs( + self, execute_model_req: ExecuteModelRequest, + sampler_output: SamplerOutput) -> List[SamplerOutput]: + """ + Creates and returns a `SamplerOutput` with only the token IDs being + serialized to CPU and populated in `CompletionSequenceGroupOutput`. + All other parameters in `CompletionSequenceGroupOutput` related to log + probabilities are skipped. + + Args: + execute_model_req (ExecuteModelRequest): The model request that + was executed. + sampler_output (SamplerOutput): The output from the sampler with + only GPU tensors populated. + + Returns: + SamplerOutput: A new `SamplerOutput` instance containing a list of + `CompletionSequenceGroupOutput` objects with only token IDs + populated. + """ + seq_output_prompt_logprobs = [ + seq.is_prompt and seq.sampling_params.prompt_logprobs is not None + and seq.sampling_params.prompt_logprobs > 0 + for seq in execute_model_req.seq_group_metadata_list + ] + # ignore slots for prompt tokens that are filled with INVALID_TOKEN_ID + sampled_token_ids_list = (sampler_output.sampled_token_ids[torch.where( + # subtracting is faster than testing for equality + sampler_output.sampled_token_ids - VLLM_INVALID_TOKEN_ID)[0]] \ + if any(seq_output_prompt_logprobs) else \ + sampler_output.sampled_token_ids).tolist() + + seq_data_entries = [ + (seq_id, seq_data) for sg in \ + execute_model_req.seq_group_metadata_list \ + for seq_id, seq_data in sg.seq_data.items() + ] + completion_seq_group_output_list: List[ + CompletionSequenceGroupOutput] = [] + output_index = 0 + # Make sure the non-terminal prefill chunks are still aligned with + # their own empty output. + for idx, seq_group_meta in enumerate( + execute_model_req.seq_group_metadata_list): + needs_prompt_logprobs = seq_output_prompt_logprobs[idx] + seq_id, seq_data = seq_data_entries[idx] + if needs_prompt_logprobs: + prompt_token_ids = seq_data.get_prompt_token_ids() + + # Some of these sequences may belong to non-terminal chunks, + # which may still have to report logprobs for prompts. + start = 1 if seq_data._num_computed_tokens == 0 \ + else seq_data._num_computed_tokens + end = (seq_data._num_computed_tokens + \ + seq_group_meta.token_chunk_size) + prompt_token_ids = prompt_token_ids[start:end] + prompt_logprobs = [ + create_logprobs_output( + token_id=p_token_id, + token_id_logprob_rank=-1, + token_id_logprob=0.0, + topk_token_ids=[], + topk_logprobs=[], + ) for p_token_id in prompt_token_ids + ] + else: + prompt_logprobs = None + + # Since we can get chunks here, we dont always have a sampled token + # (only on last chunk) but we still have to provide an output. + if not seq_group_meta.do_sample: + completion_seq_group_output_list.append( + CompletionSequenceGroupOutput( + samples=[], prompt_logprobs=prompt_logprobs)) + continue + + # Sequence with output. + completion_seq_group_output_list.append( + create_sequence_group_output( + token_id=sampled_token_ids_list[output_index][0], + token_id_logprob_rank=-1, + token_id_logprob=0.0, + seq_id=seq_id, + topk_token_ids=[], + topk_logprobs=[], + prompt_logprobs=prompt_logprobs)) + output_index += 1 + + return [SamplerOutput(outputs=completion_seq_group_output_list)] + + @nvtx_range("spec_decode_worker._run_no_spec") + def _run_no_spec(self, execute_model_req: ExecuteModelRequest, + skip_proposer: bool) -> List[SamplerOutput]: + """Run a single generation step without any speculation. The input is + sent to the proposer and scorer model so that the KV cache is consistent + between the two. When skip_proposer is True, the proposer model is + not called, meaning that the kv-cache in proposer for requests is not + updated, so they cannot enable spec decode in the rest decoding. + """ + + sampler_output = self.scorer_worker.execute_model(execute_model_req) + # assert len(sampler_output) == 1 + # sampler_output = sampler_output[0] + + # # Store hidden states from target model execution, BxD. + # hidden_states = sampler_output.hidden_states + if get_pp_group().is_last_rank: + assert len(sampler_output) == 1 + sampler_output = sampler_output[0] + + # Store hidden states from target model execution, BxD. + sampled_token_ids = sampler_output.sampled_token_ids + hidden_states = sampler_output.hidden_states + prefill_hidden_states = sampler_output.prefill_hidden_states + tensors = { + "sampled_token_ids": sampled_token_ids, + "hidden_states": hidden_states, + "prefill_hidden_states": prefill_hidden_states + } + get_pp_group().broadcast_tensor_dict( + tensors, src=get_pp_group().world_size - 1) + else: + tensors = get_pp_group().broadcast_tensor_dict( + src=get_pp_group().world_size - 1) + sampled_token_ids = tensors["sampled_token_ids"] + hidden_states = tensors["hidden_states"] + prefill_hidden_states = tensors["prefill_hidden_states"] + sampler_output = SamplerOutput( + outputs=None, + sampled_token_ids=sampled_token_ids, + hidden_states=hidden_states, + prefill_hidden_states=prefill_hidden_states) + + # Only decodes and prefill terminal chunks need a hidden state. + seq_group_meta_with_hidden = [ + sg for sg in execute_model_req.seq_group_metadata_list + if sg.do_sample + ] + if hidden_states is not None: + # # Only decodes and prefill terminal chunks need a hidden state. + # seq_group_meta_with_hidden = [ + # sg for sg in execute_model_req.seq_group_metadata_list + # if sg.do_sample + # ] + if execute_model_req.virtual_engine not in \ + self.previous_hidden_states and \ + len(seq_group_meta_with_hidden): + self.previous_hidden_states[ + execute_model_req.virtual_engine] = HiddenStates( + hidden_states, seq_group_meta_with_hidden) + elif execute_model_req.virtual_engine in \ + self.previous_hidden_states and \ + len(seq_group_meta_with_hidden): + previous_hidden_states: HiddenStates = \ + self.previous_hidden_states[execute_model_req.virtual_engine] + previous_hidden_states.update(hidden_states, + seq_group_meta_with_hidden) + + if not skip_proposer: + # We prepare the prefill hidden states here so that there no + # additional complexity in worker for spec_decode vs non_spec_decode + # flow and execute_model doesn't need additional modifications. + execute_model_req.previous_hidden_states = \ + prepare_prefill_hidden_states( + sampler_output.prefill_hidden_states) + for i in range(self._num_spec_prefill_steps): + execute_model_req.spec_step_idx = i + self.proposer_worker.execute_model(execute_model_req) + + sampler_output_to_return = (self._serialize_sampler_output_no_logprobs( + execute_model_req=execute_model_req, sampler_output=sampler_output) + if self._disable_logprobs else + [sampler_output]) + + # Clear device tensors from sampler output. This reduces communication + # overhead when the engine runs in a different process than the workers. + sampler_output.sampled_token_probs = None + sampler_output.sampled_token_ids = None + sampler_output.logprobs = None + return sampler_output_to_return + + def _run_non_driver_rank(self) -> bool: + """Run proposer and verifier model in non-driver workers. This is used + for both speculation cases (num_lookahead_slots>0) and non-speculation + cases (e.g. prefill). + + Returns True if there are remaining sequences to process. + """ + assert self.rank != self._driver_rank + + data = broadcast_tensor_dict(src=self._driver_rank) + if not data: + return False + num_lookahead_slots = data["num_lookahead_slots"] + + # In case of prefill, scorer_worker has to be run before proposer so + # that the hidden states can be propagated to proposer when needed. + if data["no_spec"]: + self.scorer_worker.execute_model() + + if not data["disable_all_speculation"]: + # Even if num_lookahead_slots is zero, we want to run the + # proposer model as it may have KV. + # + # We run the proposer once per lookahead slot. In the future we + # should delegate how many times it runs to the proposer. + for _ in range(max(num_lookahead_slots, 1)): + self.proposer_worker.execute_model() + + if not data["no_spec"]: + self.scorer_worker.execute_model() + if data["run_spec_proposer_for_prefill"]: + self.proposer_worker.execute_model() + + return True + + @nvtx_range("spec_decode_worker._run_speculative_decoding_step") + def _run_speculative_decoding_step( + self, execute_model_req: ExecuteModelRequest, + num_lookahead_slots: int) -> List[SamplerOutput]: + """Execute a single step of speculative decoding. + + This invokes the proposer worker to get k speculative tokens for each + sequence, then scores each speculative token using the scoring worker. + + When `enable_chunked_prefill` is set, scorer will batch decodes and + prefills, while proposer will sync its KV-cache by running an extra + forward on prefills. + + Returns a list of SamplerOutput, each containing a single token per + sequence. + """ + if get_pp_group().is_first_rank: + # With prefill chunking, expect requests to have prompts first + # so that backend gets prefill|decode. + assert num_lookahead_slots == execute_model_req.num_lookahead_slots + + # Pass last hidden states from target model to proposer + execute_model_req.previous_hidden_states = \ + self.previous_hidden_states[execute_model_req.virtual_engine] + self.previous_hidden_states.pop(execute_model_req.virtual_engine) + + with Timer() as proposal_timer: + # Generate proposals using draft worker. + proposals = self.proposer_worker.get_spec_proposals( + execute_model_req, self._seq_with_bonus_token_in_last_step) + + if not self._allow_zero_draft_token_step and proposals.no_proposals: + #TODO: Fix it #5814 + raise RuntimeError( + "Cannot handle cases where distributed draft " + "workers generate no tokens") + + proposal_execute_time = \ + proposal_timer.elapsed_time_ms / num_lookahead_slots + intermediate_tensors = IntermediateTensors({ + "proposal_token_ids": + proposals.proposal_token_ids, + "proposal_probs": + proposals.proposal_probs, + "proposal_lens": + proposals.proposal_lens, + "proposal_execute_time": + proposal_execute_time, + }) + get_pp_group().broadcast_tensor_dict(intermediate_tensors.tensors, + src=self._driver_rank) + else: + intermediate_tensors = IntermediateTensors( + get_pp_group().broadcast_tensor_dict(src=self._driver_rank)) + proposal_token_ids = intermediate_tensors["proposal_token_ids"] + proposal_probs = intermediate_tensors["proposal_probs"] + proposal_lens = intermediate_tensors["proposal_lens"] + proposal_execute_time = intermediate_tensors[ + "proposal_execute_time"] + proposals = SpeculativeProposals( + proposal_token_ids=proposal_token_ids, + proposal_probs=proposal_probs, + proposal_lens=proposal_lens) + + execute_model_req.previous_hidden_states = None + + with Timer() as scoring_timer: + proposal_scores = self.scorer.score_proposals( + execute_model_req, + proposals, + ) + + _, (non_spec_seqs, non_spec_indices) = split_batch_by_proposal_len( + execute_model_req.seq_group_metadata_list, proposals.proposal_lens) + # With prefill chunking enabled, `non_spec_seqs` contains prefills too: + # discard decodes that have already been processed by proposer. + non_spec_indices = [ + idx for idx in non_spec_indices + if execute_model_req.seq_group_metadata_list[idx].is_prompt + ] + if len(non_spec_indices): + all_hidden_states = proposal_scores.hidden_states + if all_hidden_states is not None: + prefill_hidden_states = all_hidden_states[non_spec_indices] + execute_model_req.previous_hidden_states = \ + prepare_prefill_hidden_states(prefill_hidden_states) + # Sync proposer KV cache for prefills. + prefill_req = execute_model_req.clone(non_spec_seqs) + # TODO avoid sampling here? + self.proposer_worker.execute_model(prefill_req) + + with Timer() as verification_timer: + accepted_token_ids, target_logprobs = self._verify_tokens( + execute_model_req, proposal_scores, proposals, + execute_model_req.num_lookahead_slots) + + stage_times = (proposal_execute_time, scoring_timer.elapsed_time_ms, + verification_timer.elapsed_time_ms) + + return self._create_output_sampler_list( + execute_model_req.seq_group_metadata_list, + accepted_token_ids, + target_logprobs=target_logprobs, + prompt_logprobs=proposal_scores.prompt_logprobs + if not self._disable_logprobs else None, + k=execute_model_req.num_lookahead_slots, + stage_times=stage_times) + + @nvtx_range("spec_decode_worker._verify_tokens") + def _verify_tokens( + self, + execute_model_req: ExecuteModelRequest, + proposal_scores: SpeculativeScores, + proposals: SpeculativeProposals, + max_proposal_len: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Determine which speculative tokens are accepted using the + probabilities of each token according to the proposer and scorer models. + + Returns a tuple of Tensors, one for the accepted token ids and one for + the logprobs according to the scoring model. + """ + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + proposal_lens_list = proposals.proposal_lens.tolist() + + # vLLM currently only supports proposal lens equal to zero or the batch + # proposal len. This adds some complexity (splitting the batch into spec + # and non spec sequences) and should be removed in the future. It can be + # done by supporting per-sequence proposal lens. + (_, spec_indices), (_, non_spec_indices) = split_batch_by_proposal_len( + seq_group_metadata_list, proposal_lens_list) + original_indices = spec_indices + non_spec_indices + + # Get probabilities of target model, including bonus tokens. + proposal_verifier_probs = proposal_scores.probs[spec_indices] + + # Get non-speculative sampled tokens from target model. + non_spec_token_ids = proposal_scores.token_ids[non_spec_indices] + + # Get bonus tokens from target model. + bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:] + + # Get probabilities according to proposal method. + proposal_probs = proposals.proposal_probs[spec_indices] + + # Get proposed tokens. + proposal_token_ids = proposals.proposal_token_ids[spec_indices] + + # Sampler arguments + sampler_extra_kwargs: Dict[str, Any] = {} + if self.generators and isinstance(self.spec_decode_sampler, + SpecDecodeStochasticBaseSampler): + sampler_extra_kwargs["seeded_seqs"] = { + idx: self.generators[sgm.request_id] + for idx, sgm in enumerate(seq_group_metadata_list) + if sgm.sampling_params.seed is not None + } + + accepted_token_ids = self.spec_decode_sampler( + target_with_bonus_probs=proposal_verifier_probs, + bonus_token_ids=bonus_token_ids, + draft_probs=proposal_probs, + draft_token_ids=proposal_token_ids, + **sampler_extra_kwargs, + ) + # Append output tokens from non-speculative sequences to + # the accepted token ids tensor. + non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len + + 1).clone() + non_spec_token_ids[:, 1:] = -1 + accepted_token_ids = torch.cat( + [accepted_token_ids, non_spec_token_ids]) + logprobs = proposal_scores.logprobs + # Rearrange so that results are in the order of the original seq group + # metadata. + accepted_token_ids[original_indices] = accepted_token_ids.clone() + + # B x K+1 x D + hidden_states = proposal_scores.hidden_states + if hidden_states is not None: + # Only get terminal hidden states for next step + terminal_metadata = [ + sg for sg in seq_group_metadata_list if sg.do_sample + ] + + # Contract hidden states based on accepted tokens + hs_size = hidden_states.shape[-1] + accepted_index = accepted_token_ids + 1 # Convert -1 to 0 + accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) # b + # Drop non-terminal prefill chunks hidden states. + hidden_states = hidden_states[accepted_index != + VLLM_INVALID_TOKEN_ID] + accepted_index = accepted_index[accepted_index != + VLLM_INVALID_TOKEN_ID] + assert len(accepted_index) == hidden_states.shape[0] == len( + terminal_metadata) + index = accepted_index[:, None, None].expand(-1, 1, + hs_size) # b x 1 x d + second_last_token_hidden_states = hidden_states[:, -2] # b x d + hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d + # Store hidden states from target model for subsequent decode step + self.previous_hidden_states[ + execute_model_req.virtual_engine] = HiddenStates( + hidden_states, terminal_metadata, + second_last_token_hidden_states) + return accepted_token_ids, logprobs + + def _create_output_sampler_list( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1] + target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size] + prompt_logprobs: Optional[ + torch.Tensor], # shape: [nprompt_tokens, vocab_size] + k: int, + stage_times: Tuple[float, float, float], + ) -> List[SamplerOutput]: + """Given the accepted token ids, create a list of SamplerOutput. + + The output is padded with -1 tokens such that each sequence has + the same number of outputs. + """ + batch_size, num_steps = accepted_token_ids.shape + accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1) + if self._disable_logprobs: + # We are skipping the logprobs. Hence don't serialize the + # logprobs related tensors from the GPU. Instead create + # empty/dummy lists. + (accepted_token_id_ranks_by_step, + accepted_token_id_logprobs_by_step, + topk_logprobs_by_step, topk_indices_by_step) =\ + self._create_dummy_logprob_lists( + batch_size, num_steps, + self.scorer_worker.model_config.max_logprobs) + else: + # Organize input tensors by step instead of by sequence. + target_logprobs_by_step = target_logprobs.transpose(0, 1) + # Serialize all tensors into Python lists. + (accepted_token_id_ranks_by_step, + accepted_token_id_logprobs_by_step, + topk_logprobs_by_step, topk_indices_by_step) =\ + self._create_logprob_lists_from_tensors( + target_logprobs_by_step, accepted_token_ids_by_step, + self.scorer_worker.model_config.max_logprobs) + + # Get the sequence ids and num_logprobs (sampling parameter) in the + # batch. + seq_ids, request_ids_seq_ids_mapping = get_all_seq_ids_and_request_ids( + seq_group_metadata_list) + + num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list) + + # Serialize tensor to CPU Python list. + accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() + + # Construct the output on a per-step, per-sequence basis. + # Non-terminal prefill chunks will end up here as rows with just -1s + # i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]] while + # terminal chunks will only have one generated token at time 0. + sampler_output_list: List[SamplerOutput] = [] + + # Prefills are not multi-step (return at most 1 token), in order to + # avoid padding or repetition to fit decodes, we separate them. + for i, sg in enumerate(seq_group_metadata_list): + if not sg.is_prompt: + # Requests are ordered as prefills|decodes=>no more prefills. + break + num_logprobs = num_logprobs_per_seq[i] + seq_kwargs = dict(token_id=-1, + token_id_logprob_rank=0, + token_id_logprob=-float('inf'), + topk_token_ids=[-1] * num_logprobs, + topk_logprobs=[-float('inf')] * num_logprobs, + seq_id=seq_ids[i]) + # Terminal chunk, has token. + if sg.do_sample: + seq_kwargs.update( + dict( + token_id=accepted_token_ids[i][0].item(), + token_id_logprob_rank=accepted_token_id_ranks_by_step[ + 0][i], + token_id_logprob=accepted_token_id_logprobs_by_step[0] + [i], + topk_token_ids=topk_indices_by_step[0][i] + [:num_logprobs], + # output only so step is 0 + topk_logprobs=topk_logprobs_by_step[0][i] + [:num_logprobs], + )) + needs_plogs = (sg.sampling_params.prompt_logprobs + and sg.sampling_params.prompt_logprobs > 0) + plogs = None + if prompt_logprobs is not None: + # Even non-terminal prompt chunks can have logprobs here. + plogs = prompt_logprobs[i] + elif needs_plogs: + # Prompt logprobs are requested but `_disable_logprobs` is set. + seq_data = next(iter(sg.seq_data.values())) + # Get only the tokens in this chunk! + prompt_token_ids = seq_data.get_prompt_token_ids() + prompt_token_ids = prompt_token_ids[ + seq_data. + _num_computed_tokens:seq_data._num_computed_tokens + + sg.token_chunk_size] + + is_first_chunk = seq_data._num_computed_tokens == 0 + # There's no prob generated for the first token in a sequence. + if is_first_chunk: + prompt_token_ids = prompt_token_ids[1:] + plogs = [ + create_logprobs_output( + token_id=p_token_id, + token_id_logprob_rank=-1, + token_id_logprob=0.0, + topk_token_ids=[], + topk_logprobs=[], + ) for p_token_id in prompt_token_ids + ] + seq_kwargs.update(dict(prompt_logprobs=plogs)) + + sampler_output_list.append( + SamplerOutput( + outputs=[create_sequence_group_output( + **seq_kwargs)])) # type: ignore + + # Decodes, create one SamplerOutput per-step (at most K+1). + for step_index in range(num_steps): + if all(token_id == -1 for sg, token_id in zip( + seq_group_metadata_list, + accepted_token_ids_by_step[step_index]) + if not sg.is_prompt): + break + + step_output_token_ids: List[CompletionSequenceGroupOutput] = [] + for sequence_index in range(batch_size): + seq_meta = seq_group_metadata_list[sequence_index] + # Prompts already processed above. + if seq_meta.is_prompt: + continue + + # Each sequence may have a different num_logprobs; retrieve it. + num_logprobs = num_logprobs_per_seq[sequence_index] + step_output_token_ids.append( + create_sequence_group_output( + token_id=accepted_token_ids_by_step[step_index] + [sequence_index], + token_id_logprob_rank=accepted_token_id_ranks_by_step[ + step_index][sequence_index], + token_id_logprob=accepted_token_id_logprobs_by_step[ + step_index][sequence_index], + seq_id=seq_ids[sequence_index], + topk_token_ids=topk_indices_by_step[step_index] + [sequence_index][:num_logprobs], + topk_logprobs=topk_logprobs_by_step[step_index] + [sequence_index][:num_logprobs], + step_index=step_index)) + sampler_output_list.append( + SamplerOutput(outputs=step_output_token_ids)) + + # Populate the data structures needed to keep track of sequences with + # bonus tokens. + self._track_sequences_with_bonus_tokens(seq_ids, + request_ids_seq_ids_mapping, + accepted_token_ids_by_step) + maybe_rejsample_metrics = ( + self._metrics.maybe_collect_rejsample_metrics(k)) + if maybe_rejsample_metrics is not None: + sampler_output_list[ + 0].spec_decode_worker_metrics = maybe_rejsample_metrics + + # Log time spent in each stage periodically. + # This is periodic because the rejection sampler emits metrics + # periodically. + self._maybe_log_stage_times(*stage_times) + # First `n_prefills` entries will contain prefills SamplerOutput when + # chunked prefill is enabled, the rest is decodes in multi-step format. + return sampler_output_list + + def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float, + scoring_time_ms: float, + verification_time_ms: float) -> None: + """Log the speculative stage times. If stat logging is disabled, do + nothing. + """ + if self._disable_log_stats: + return + + logger.info( + "SpecDecodeWorker stage times: " + "average_time_per_proposal_tok_ms=%.02f " + "scoring_time_ms=%.02f verification_time_ms=%.02f", + average_time_per_proposal_tok_ms, scoring_time_ms, + verification_time_ms) + + def _create_dummy_logprob_lists( + self, + batch_size: int, + num_steps: int, + num_top_k: int, + ) -> Tuple[List[List[int]], List[List[float]], + List[List[List[Optional[float]]]], + List[List[List[Optional[int]]]]]: + """ + Creates and returns four dummy lists representing token probabilities + and their ranks. + + This method initializes and returns: + - The ranks of the accepted tokens, shaped (num_steps, batch_size) + - The log probabilities of the accepted tokens, + shaped (num_steps, batch_size) + - The log probabilities of the top k tokens, + shaped (num_steps, batch_size, num_top_k) + - The token IDs of the top k tokens, + shaped (num_steps, batch_size, num_top_k) + + Args: + batch_size (int): The size of the batch. + num_steps (int): The number of steps in the sequence. + num_top_k (int): The number of top-k token log probabilities to + return. + + Returns: + A tuple containing four dummy lists as described above. + """ + accepted_token_id_ranks_by_step = [[-1] * batch_size + for _ in range(num_steps)] + accepted_token_id_logprobs_by_step = [[0.0] * batch_size + for _ in range(num_steps)] + topk_logprobs_by_step: List[List[List[Optional[float]]]] = [[ + [None] * num_top_k for _ in range(batch_size) + ] for _ in range(num_steps)] + topk_indices_by_step: List[List[List[Optional[int]]]] = [[ + [None] * num_top_k for _ in range(batch_size) + ] for _ in range(num_steps)] + return (accepted_token_id_ranks_by_step, + accepted_token_id_logprobs_by_step, topk_logprobs_by_step, + topk_indices_by_step) + + def _create_logprob_lists_from_tensors( + self, + target_logprobs_by_step: torch.Tensor, + accepted_token_ids_by_step: torch.Tensor, + num_top_k: int, + ) -> Tuple[List[List[int]], List[List[float]], + List[List[List[Optional[float]]]], + List[List[List[Optional[int]]]]]: + """ + Creates and returns four lists representing token probabilities and + their ranks. + + This method initializes and returns four lists containing: + - The ranks of the accepted tokens, shaped (num_steps, batch_size) + - The log probabilities of the accepted tokens, + shaped (num_steps, batch_size) + - The log probabilities of the top k tokens, + shaped (num_steps, batch_size, num_top_k) + - The token IDs of the top k tokens, + shaped (num_steps, batch_size, num_top_k) + + Args: + target_logprobs_by_step (torch.Tensor): Tensor representing the + log probabilities of the target model, + shaped (num_steps, batch_size, vocab_size) + accepted_token_ids_by_step (torch.Tensor): Tensor representing + the accepted token_ids, shaped (num_steps, batch_size) + num_top_k (int): The number of top-k token log probabilities to + return. + + Returns: + A tuple containing the lists as described above. + """ + # Serialize all tensors to CPU Python lists. + # Get the logprobs/rank of the accepted tokens. + (accepted_token_id_ranks_by_step_tensor, + accepted_token_id_logprobs_by_step_tensor + ) = get_sampled_token_logprobs( + logprob_tensor=target_logprobs_by_step, + sampled_token_ids=accepted_token_ids_by_step, + ) + # Get the top-k logprobs (which may or may not include the + # logprob of the accepted token). + (topk_logprobs_by_step_tensor, + topk_indices_by_step_tensor) = target_logprobs_by_step.topk( + k=num_top_k, + dim=-1, + ) + accepted_token_id_ranks_by_step = ( + accepted_token_id_ranks_by_step_tensor.tolist()) + accepted_token_id_logprobs_by_step = ( + accepted_token_id_logprobs_by_step_tensor.tolist()) + topk_logprobs_by_step = topk_logprobs_by_step_tensor.tolist() + topk_indices_by_step = topk_indices_by_step_tensor.tolist() + return (accepted_token_id_ranks_by_step, + accepted_token_id_logprobs_by_step, topk_logprobs_by_step, + topk_indices_by_step) + + def _track_finished_requests(self, execute_model_req: ExecuteModelRequest): + """ + Removes the finished requests and their associated sequence ids from + internal book keeping data structures. + """ + for finished_request in execute_model_req.finished_requests_ids: + for seq_id in self._request_id_seq_id_mapping[finished_request]: + self._seq_with_bonus_token_in_last_step.discard(seq_id) + del self._request_id_seq_id_mapping[finished_request] + + def _track_sequences_with_bonus_tokens( + self, seq_ids: List[int], + request_ids_seq_ids_mapping: Dict[str, Set[int]], + accepted_token_ids_by_step: List[List[int]]): + """ + Updates the internal data structures which keep track of sequences + which have been assigned bonus tokens in their last forward pass. + """ + for seq_index, seq_id in enumerate(seq_ids): + last_token_id = accepted_token_ids_by_step[-1][seq_index] + if last_token_id == -1: + self._seq_with_bonus_token_in_last_step.discard(seq_id) + else: + self._seq_with_bonus_token_in_last_step.add(seq_id) + for request_id, sequences in request_ids_seq_ids_mapping.items(): + self._request_id_seq_id_mapping[request_id].update(sequences) + + @cached_property + def _vocab_size(self) -> int: + """Get the vocab size of the model and make sure it's consistent between + draft and target workers. + """ + vocab_sizes = [ + worker.vocab_size + for worker in [self.proposer_worker, self.scorer_worker] + ] + assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes) + return vocab_sizes[0] + + @property + def rank(self): + return self.scorer_worker.rank + @property + def tensor_parallel_size(self): + return self.scorer_worker.parallel_config.tensor_parallel_size + + @property + def device(self): + return self.scorer_worker.device + + @property + def _driver_rank(self) -> int: + return 0 + + def get_cache_block_size_bytes(self): + """Return the size of a cache block in bytes. + + This function is only used to compose workers within a SpecDecodeWorker. + We leave composing a SpecDecodeWorker within a SpecDecodeWorker + undefined for now, although it could be implemented in the future. + See https://arxiv.org/abs/2308.04623. + """ + raise NotImplementedError + + def start_profile(self): + if isinstance(self.scorer_worker, WorkerBase): + self.scorer_worker.start_profile() + + def stop_profile(self): + if isinstance(self.scorer_worker, WorkerBase): + self.scorer_worker.stop_profile() + + +def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int, + proposer_cache_block_size_bytes: int, + total_num_gpu_blocks: int) -> int: + """Given total_num_gpu_blocks, the number of GPU blocks that could be + allocate to the target model, this function calculates how many blocks + should be given to the draft and target model. + + Note that usually the block size, in bytes, of each model is different, + as it's a function of number of KV/layer, number of heads, and hidden + dimension size. + + Since the target and draft models allocate the same number of blocks, we + simply calculate the number of blocks where if allocated by both models, + the total memory usage from KV cache is no larger than the number of + blocks allocatable by the target model alone. + """ + new_num_gpu_blocks = int( + total_num_gpu_blocks * scorer_cache_block_size_bytes / + (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes)) + + return new_num_gpu_blocks + + +def prepare_prefill_hidden_states( + prefill_hidden_states: torch.Tensor) -> HiddenStates: + # For prefill step in proposer, we run the model for N-1 tokens + # because Nth token will be processed in the first decode step. For + # N-1 tokens, the input should be 0:N-1 hidden states which should + # be concatanated with 1:N token (since output of scorer has to be + # the input for proposer). Therefore, we shift the hidden states to + # align n-1th hidden state with nth token. + return HiddenStates(prefill_hidden_states.roll( + shifts=1, dims=0)) if prefill_hidden_states is not None else None diff --git a/vllm/spec_decode/target_model_runner.py b/vllm/spec_decode/target_model_runner.py new file mode 100644 index 0000000..250a3d4 --- /dev/null +++ b/vllm/spec_decode/target_model_runner.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional + +from vllm.sequence import SequenceGroupMetadata +from vllm.worker.model_runner_base import (ModelRunnerBase, + ModelRunnerInputBase, + ModelRunnerWrapperBase) + + +class TargetModelRunner(ModelRunnerWrapperBase): + """Specialized model runner for speculative decoding target model. + In speculative decoding, the log probabilities selected finally may not + be the same ones as selected by the target model sampling. This means + that the time spent in the log probability calculation of the target model + is time wasted, since we calculate log probabilities after deciding which + tokens are accepted. For this reason disabling log probabilities in the + target model will make decode faster. The model runner sets the + SamplingMetadata parameters according to whether log probabilities are + requested or not. + """ + + def __init__(self, model_runner: ModelRunnerBase): + # An internal boolean member variable to indicate if token log + # probabilities are needed or not. + super().__init__(model_runner) + self.disable_logprobs = True + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None, + ) -> ModelRunnerInputBase: + model_input: ModelRunnerInputBase =\ + self.model_runner.prepare_model_input( + seq_group_metadata_list, virtual_engine, finished_requests_ids) + # If token log probabilities is disabled then skip generating sampler + # CPU output. We directly serialize the GPU sampled_token_id tensors + # as needed. If log probabilities is enabled then synchronize all the + # sampling related tensors which includes the logprobs tensors. + if model_input.sampling_metadata is not None: + model_input.sampling_metadata.skip_sampler_cpu_output = ( + self.disable_logprobs) + return model_input diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py new file mode 100644 index 0000000..b538923 --- /dev/null +++ b/vllm/spec_decode/top1_proposer.py @@ -0,0 +1,274 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional, Set, Tuple + +import torch + +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata +from vllm.spec_decode.interfaces import (SpeculativeProposals, + SpeculativeProposer) +from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase +from vllm.spec_decode.util import sampler_output_to_torch + + +class Top1Proposer(SpeculativeProposer): + """Helper class which separates out sequences which would exceed the max + model length when speculated upon. + + This allows combinations of models such as JackFram/llama-68m draft with + meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of + 2048 while Llama2-13b has max_position_embeddings of 4096. + + We treat the sequences which exceed the proposal draft model length as + "non-spec sequences". Essentially they skip the draft model and go through + normal decoding in the target model. + + Currently, only proposal_lens of 0 and k are supported, where k is a global + batch proposal length. In the future vLLM should support per-sequence + proposal lengths. + """ + + def __init__( + self, + worker: ProposerWorkerBase, + device: str, + vocab_size: int, + max_proposal_len: Optional[int] = None, + ): + self._worker = worker + self._device = device + self.max_proposal_len = max_proposal_len + self._vocab_size = vocab_size + + def get_spec_proposals( + self, + execute_model_req: ExecuteModelRequest, + seq_ids_with_bonus_token_in_last_step: Set[int], + ) -> SpeculativeProposals: + """Get speculative proposals given the input batch. + + Sequences which would exceed the max model length are skipped during + speculation. + """ + proposal_len = execute_model_req.num_lookahead_slots + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + + # Split speculative- and non-speculative- sequences. + ( + proposal_lens, + nonzero_proposal_len_seqs, + nonzero_proposal_len_indices, + ) = self._split_by_proposal_len(seq_group_metadata_list, proposal_len) + + if nonzero_proposal_len_seqs: + # Speculate tokens using the draft worker for the speculative + # sequences. + # If sampler_transposed is true, then maybe_sampler_output's + # token_ids is like [batch] format in proposal_len size list, + # while if it is false, the format would be [proposal_len] + # in batch size list + hidden_states = execute_model_req.previous_hidden_states + if hidden_states is not None: + hidden_states.prune(nonzero_proposal_len_seqs) + nonzero_execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=nonzero_proposal_len_seqs, + num_lookahead_slots=proposal_len, + previous_hidden_states=hidden_states, + ) + maybe_sampler_output, transposed = self._worker.sampler_output( + execute_model_req=nonzero_execute_model_req, + sample_len=proposal_len, + seq_ids_with_bonus_token_in_last_step=\ + seq_ids_with_bonus_token_in_last_step, + ) + ( + proposal_lens, + maybe_sampler_output, + nonzero_proposal_len_indices, + ) = self._remove_no_proposal_seqs(proposal_lens, + maybe_sampler_output, + nonzero_proposal_len_indices, + transposed) + else: + # If no sequences can be speculated, set sampler output to None. + maybe_sampler_output = None + transposed = False + + # Combine speculative- and non-speculative sequences into the same + # representation. + proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs( + batch_size=len(seq_group_metadata_list), + proposal_len=proposal_len, + maybe_sampler_output=maybe_sampler_output, + proposal_lens=proposal_lens, + nonzero_proposal_len_indices=nonzero_proposal_len_indices, + sampler_transposed=transposed, + ) + + proposals = SpeculativeProposals(proposal_token_ids=proposal_tokens, + proposal_probs=proposal_probs, + proposal_lens=proposal_lens, + no_proposals=maybe_sampler_output + is None) + return proposals + + def _split_by_proposal_len( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + proposal_len: int, + ) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]: + """Split sequences by two groups: + 1. Sequences with non-zero proposal length. + 2. Sequences with zero proposal length (due to disabled speculation + or exceed the maximum model length). + """ + + proposal_lens: List[int] = [] + nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = [] + nonzero_proposal_len_indices: List[int] = [] + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + # The speculative decoding for this request has either been disabled + # (e.g. due to high traffic) or this is a prompt request. + if (seq_group_metadata.is_prompt + or seq_group_metadata.num_speculative_tokens == 0): + proposal_lens.append(0) + continue + + seq_data = next(iter(seq_group_metadata.seq_data.values())) + seq_len = seq_data.get_len() + + # Currently only proposal lens of 0 or the global batch proposal len + # are supported. + # If max_proposal_len is defined, then we shall not exceed this + # quota for nonzero_proposal + new_k = 0 + if (self.max_proposal_len is None + or seq_len + proposal_len < self.max_proposal_len): + new_k = proposal_len + nonzero_proposal_len_seqs.append(seq_group_metadata) + nonzero_proposal_len_indices.append(i) + proposal_lens.append(new_k) + seq_group_metadata.num_speculative_tokens = new_k + + return ( + proposal_lens, + nonzero_proposal_len_seqs, + nonzero_proposal_len_indices, + ) + + @staticmethod + def _remove_no_proposal_seqs(proposal_lens, maybe_sampler_output, + nonzero_proposal_len_indices, transposed): + """Remove sequences from nonzero_proposal_len_indices and reset + their proposal_len to 0 the draft worker does not provide a proposal + (maybe_sampler_output=None). This can avoid scoring overheads. + """ + + # If maybe_sampler_output is None, then the draft worker did not + # provide a proposal for any sequence and thus no action needed. + # Also we do not support transposed maybe_sampler_output for now + # because it seems not straightforward for draft workers outputting + # transposed sampler outputs to handle the case of no proposal. + if maybe_sampler_output is None or transposed: + return (proposal_lens, maybe_sampler_output, + nonzero_proposal_len_indices) + + new_proposal_lens: List[int] = [] + new_nonzero_proposal_len_indices: List[int] = [] + new_maybe_sampler_output: List[SamplerOutput] = [] + nonzero_proposal_len_idx_ptr = 0 + seq_idx = 0 + while seq_idx < len( + proposal_lens) and nonzero_proposal_len_idx_ptr < len( + nonzero_proposal_len_indices): + if seq_idx < nonzero_proposal_len_indices[ + nonzero_proposal_len_idx_ptr]: + # Sequence is not in the original nonzero_proposal_len_indices, + # meaning that it has a proposal length of 0 before sending to + # the draft worker. + assert proposal_lens[seq_idx] == 0 + new_proposal_lens.append(0) + else: + # Sequence is in the original nonzero_proposal_len_indices + if maybe_sampler_output[nonzero_proposal_len_idx_ptr] is None: + # but does not have a proposal from the draft worker. + new_proposal_lens.append(0) + else: + # and has a proposal from the draft worker. Add it to the + # new nonzero proposal list and keep the sampler output. + new_proposal_lens.append(proposal_lens[seq_idx]) + new_nonzero_proposal_len_indices.append(seq_idx) + new_maybe_sampler_output.append( + maybe_sampler_output[nonzero_proposal_len_idx_ptr]) + nonzero_proposal_len_idx_ptr += 1 + seq_idx += 1 + + # The remaining sequences should have proposal length of 0. + new_proposal_lens.extend(proposal_lens[seq_idx:]) + + # We assume sampler_output will not be a list of all Nones. + # In this case this function should not be called. + assert new_maybe_sampler_output + return (new_proposal_lens, new_maybe_sampler_output, + new_nonzero_proposal_len_indices) + + def _merge_outputs( + self, + batch_size: int, + proposal_len: int, + maybe_sampler_output: Optional[List[SamplerOutput]], + proposal_lens: List[int], + nonzero_proposal_len_indices: List[int], + sampler_transposed: bool, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """After speculations are produced, merge the speculation results with + the skipped sequences. + """ + if maybe_sampler_output is None: + # If no speculative tokens, the sampler output will be None. + # In this case we return empty proposals. + proposal_tokens = torch.tensor(-1, + dtype=torch.long, + device=self._device).expand( + batch_size, proposal_len) + proposal_probs = torch.tensor(0, + dtype=torch.float32, + device=self._device).expand( + batch_size, proposal_len, + self._vocab_size) + proposal_lens_tensor = torch.tensor(0, + dtype=torch.long, + device=self._device).expand( + len(proposal_lens)) + return proposal_tokens, proposal_probs, proposal_lens_tensor + + sampler_output = maybe_sampler_output + proposal_tokens, proposal_probs, *_ = sampler_output_to_torch( + sampler_output, sampler_transposed) + + # Now, reformat the output GPU tensors such that each sequence has + # a proposal. the proposal can be empty, e.g. [-1, -1, -1] + + entire_proposal_tokens = proposal_tokens.new_full( + size=(batch_size, *proposal_tokens.shape[1:]), + fill_value=-1, + ) + entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens + entire_proposal_probs = proposal_probs.new_zeros( + batch_size, + *proposal_probs.shape[1:], + ) + entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs + + proposal_tokens, proposal_probs = ( + entire_proposal_tokens, + entire_proposal_probs, + ) + + proposal_lens_tensor = torch.zeros(batch_size, + dtype=torch.long, + device=self._device) + proposal_lens_tensor[nonzero_proposal_len_indices] = proposal_len + + return proposal_tokens, proposal_probs, proposal_lens_tensor diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py new file mode 100644 index 0000000..466269b --- /dev/null +++ b/vllm/spec_decode/util.py @@ -0,0 +1,276 @@ +# SPDX-License-Identifier: Apache-2.0 + +import time +from contextlib import contextmanager +from typing import Dict, List, Optional, Sequence, Tuple + +import torch + +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.platforms import current_platform +from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, + PromptLogprobs, SequenceGroupMetadata, + SequenceOutput) + +SeqId = int + + +def get_all_num_logprobs( + seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]: + """Given a list of SequenceGroupMetadata, create a list of all num_logprobs. + + If the sampling params do not call for any logprobs, return 0 for that + sequence. + """ + + all_num_logprobs: List[int] = [] + for seq_group_metadata in seq_group_metadata_list: + num_logprobs = seq_group_metadata.sampling_params.logprobs + if num_logprobs is None: + num_logprobs = 0 + all_num_logprobs.append(num_logprobs) + + return all_num_logprobs + + +def get_sampled_token_logprobs( + # shape [num_steps, batch_size, vocab_size] + logprob_tensor: torch.Tensor, + sampled_token_ids: torch.Tensor, # shape [num_steps, batch_size] +) -> Tuple[torch.Tensor, torch.Tensor]: + """Get the logprobs for the sampled tokens. Returns the ranks and logprobs. + """ + num_steps, batch_size, vocab_size = logprob_tensor.shape + + selected_logprobs = logprob_tensor[ + torch.arange(num_steps).unsqueeze(1), + torch.arange(batch_size), + sampled_token_ids, + ] + expanded_selected_logprobs = selected_logprobs.unsqueeze(-1).expand( + -1, -1, vocab_size) + sampled_token_ids_ranks = (logprob_tensor + > expanded_selected_logprobs).sum(-1).add_(1) + + return sampled_token_ids_ranks, selected_logprobs + + +def create_logprobs_output( + token_id: int, + token_id_logprob_rank: int, + token_id_logprob: float, + topk_token_ids: List[Optional[int]], + topk_logprobs: List[Optional[float]], +) -> Dict[int, Logprob]: + """Create a Logprob Dict for a token given the sampling results. + + Args: + token_id (int): The sampled token for the sequence. + token_id_logprob_rank (int): The logprob rank of the sampled token. + token_id_logprob (float): The logprob value of the sampled token. + topk_token_ids (List[Optional[int]]): The list of top-k token ids. + topk_logprobs (List[Optional[float]]): The list of top-k logprobs. + """ + # vLLM logprobs always include the sampled token. In addition, the user may + # request topk-logprobs (where top-k varies per user up to max_logprobs). + logprobs: Dict[int, Logprob] = { + token_id: Logprob( + logprob=token_id_logprob, + rank=token_id_logprob_rank, + ), + } + logprobs.update({ + topk_token_id: Logprob( + logprob=topk_logprob if topk_logprob is not None else 0.0, + rank=topk_index + 1, + ) + for topk_index, (topk_token_id, topk_logprob) \ + in enumerate(zip(topk_token_ids, topk_logprobs)) \ + if topk_token_id is not None + }) + + return logprobs + + +def create_sequence_group_output( + token_id: int, + token_id_logprob_rank: int, + token_id_logprob: float, + seq_id: SeqId, + topk_token_ids: List[Optional[int]], + topk_logprobs: List[Optional[float]], + prompt_logprobs: Optional[PromptLogprobs] = None, + step_index: Optional[int] = 0) -> CompletionSequenceGroupOutput: + """Create a SequenceGroupOutput given the sampling results. + + Args: + token_id (int): The sampled token for the sequence. + token_id_logprob_rank (int): The logprob rank of the sampled token. + token_id_logprob (float): The logprob value of the sampled token. + seq_id (int): The sequence id. + topk_token_ids (List[Optional[int]]): The list of top-k token ids. + topk_logprobs (List[Optional[float]]): The list of top-k logprobs. + step_index: (Optional[int]): The index of the speculative token. + """ + + logprobs = create_logprobs_output( + token_id, + token_id_logprob_rank, + token_id_logprob, + topk_token_ids, + topk_logprobs, + ) + + return CompletionSequenceGroupOutput(samples=[ + SequenceOutput(parent_seq_id=seq_id, + output_token=token_id, + logprobs=logprobs) + ], + prompt_logprobs=prompt_logprobs, + step_index=step_index) + + +def split_batch_by_proposal_len( + seq_group_metadata_list: List[SequenceGroupMetadata], + proposal_lens: List[int], +) -> Tuple[Tuple[List[SequenceGroupMetadata], List[int]], Tuple[ + List[SequenceGroupMetadata], List[int]]]: + """Utility function that splits a batch based on whether the proposal len is + zero or not. We should remove this once vLLM supports per-sequence proposal + lens in a batch. + """ + + nonzero_lists: Tuple[List[SequenceGroupMetadata], List[int]] = ([], []) + zero_lists: Tuple[List[SequenceGroupMetadata], List[int]] = ([], []) + for i, (seq_group, proposal_len) in enumerate( + zip(seq_group_metadata_list, proposal_lens)): + seq_groups, indices = nonzero_lists if proposal_len else zero_lists + seq_groups.append(seq_group) + indices.append(i) + return nonzero_lists, zero_lists + + +def sampler_output_to_torch( + sampler_output_list: Sequence[SamplerOutput], sampler_transposed: bool +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Utility function which converts a list of SamplerOutput to tensors. + + sampler_transposed here is used as the indicator for whether + we need do additional tensor transpose logic here. + + Returns: + sampled_token_ids: torch.Tensor + shape: [batch_size, len(sampler_output_list)] + + sampled_token_probs: torch.Tensor + shape: [batch_size, len(sampler_output_list), vocab_size] + """ + + # shape: [batch_size, num_sampler_output, vocab_size] + sampled_token_probs = torch.stack( + [ + sampler_output.sampled_token_probs + for sampler_output in sampler_output_list + ], + dim=0, + ) + + # shape: [batch_size, num_sampler_output, vocab_size] + sampled_token_logprobs = torch.stack( + [sampler_output.logprobs for sampler_output in sampler_output_list], + dim=0, + ) + + # shape: [batch_size, num_sampler_output] + sampled_token_ids = torch.stack( + [ + sampler_output.sampled_token_ids.flatten() + for sampler_output in sampler_output_list + ], + dim=0, + ) + + if sampler_transposed: + sampled_token_probs = sampled_token_probs.transpose(0, 1) + sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1) + sampled_token_ids = sampled_token_ids.transpose(0, 1) + + if sampler_output_list[0].hidden_states is not None: + # shape: [batch_size, num_sampler_output, hidden_dim] + sampled_hidden_states = torch.stack( + [ + sampler_output.hidden_states + for sampler_output in sampler_output_list + ], + dim=0, + ) + + if sampler_transposed: + sampled_hidden_states = sampled_hidden_states.transpose(0, 1) + else: + sampled_hidden_states = None + + return (sampled_token_ids, sampled_token_probs, sampled_token_logprobs, + sampled_hidden_states) + + +def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int, + vocab_size: int, device: str) -> None: + """Helper method which mocks out the GPU tensors in SamplerOutput with dummy + values. This will be removed in PR 7/9. + https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer + """ + values = [ + sampler_output.sampled_token_probs, sampler_output.sampled_token_ids + ] + assert all(v is None for v in values) or not any(v is None for v in values) + if not any(v is None for v in values): + # Do nothing if the tensors are already created (usually in unit tests). + return + + # Softmax to ensure valid probs. + sampler_output.sampled_token_probs = torch.nn.functional.softmax( + torch.rand(batch_size, vocab_size, dtype=torch.float32, device=device), + dim=-1) + + sampler_output.sampled_token_ids = torch.randint(low=10, + high=100, + size=(batch_size, ), + dtype=torch.long, + device=device) + + +@contextmanager +def nvtx_range(msg, *args, **kwargs): + """ + Context manager / decorator that pushes an NVTX range at the beginning + of its scope, and pops it at the end. If extra arguments are given, + they are passed as arguments to msg.format(). + + If running with cuda graphs, you must enable nsys cuda graph profiling. + + Arguments: + msg (string): message to associate with the range + """ + if current_platform.is_cuda_alike(): + torch.cuda.nvtx.range_push(msg.format(*args, **kwargs)) + try: + yield + finally: + torch.cuda.nvtx.range_pop() + else: + yield + + +class Timer: + """Basic timer context manager for measuring CPU time. + """ + + def __enter__(self): + self.start_time = time.time() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.end_time = time.time() + self.elapsed_time_s = self.end_time - self.start_time + self.elapsed_time_ms = self.elapsed_time_s * 1000 diff --git a/vllm/test_utils.py b/vllm/test_utils.py new file mode 100644 index 0000000..8611a25 --- /dev/null +++ b/vllm/test_utils.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: Apache-2.0 +MODELS_ON_S3 = [ + "adept/fuyu-8b", + "ai21labs/AI21-Jamba-1.5-Mini", + "ai21labs/Jamba-tiny-random", + "ai21labs/Jamba-tiny-reward-dev", + "allenai/Molmo-7B-D-0924", + "allenai/OLMo-1B-hf", + "allenai/OLMoE-1B-7B-0924-Instruct", + "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test", + "AMead10/Llama-3.2-1B-Instruct-AWQ", + "ArthurZ/Ilama-3.2-1B", + "BAAI/bge-base-en-v1.5", + "BAAI/bge-multilingual-gemma2", + "BAAI/bge-reranker-v2-m3", + "bigcode/starcoder2-3b", + "cross-encoder/ms-marco-MiniLM-L-6-v2", + "cross-encoder/quora-roberta-base", + "deepseek-ai/deepseek-vl2-tiny", + "distilbert/distilgpt2", + "facebook/bart-base", + "facebook/bart-large-cnn", + # "fixie-ai/ultravox-v0_5-llama-3_2-1b", + "google/gemma-1.1-2b-it", + "google/gemma-2-2b-it", + "google/paligemma-3b-pt-224", + "h2oai/h2ovl-mississippi-800m", + "HuggingFaceM4/Idefics3-8B-Llama3", + "internlm/internlm2-1_8b-reward", + "intfloat/e5-mistral-7b-instruct", + "intfloat/multilingual-e5-small", + "jason9693/Qwen2.5-1.5B-apeach", + "llava-hf/llava-1.5-7b-hf", + "llava-hf/llava-onevision-qwen2-0.5b-ov-hf", + "llava-hf/llava-v1.6-mistral-7b-hf", + "llava-hf/LLaVA-NeXT-Video-7B-hf", + # "meta-llama/Llama-2-7b-hf", + "meta-llama/Llama-3.2-11B-Vision-Instruct", + "meta-llama/Llama-3.2-1B", + "meta-llama/Llama-3.2-1B-Instruct", + "meta-llama/Meta-Llama-3-8B", + "microsoft/phi-2", + "microsoft/Phi-3-mini-4k-instruct", + "microsoft/Phi-3-small-8k-instruct", + "microsoft/Phi-3-vision-128k-instruct", + "microsoft/Phi-3.5-MoE-instruct", + "microsoft/Phi-3.5-vision-instruct", + # "mistralai/Mistral-7B-Instruct-v0.1", + "mistralai/Mixtral-8x7B-Instruct-v0.1", + "mistralai/Pixtral-12B-2409", + "mistral-community/Mixtral-8x22B-v0.1-AWQ", + "ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head", + "ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symFalse", + "ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symTrue", + "ModelCloud/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit-10-25-2024", + "neuralmagic/Llama-3.2-1B-quantized.w8a8", + "neuralmagic/Meta-Llama-3-8B-Instruct-FP8", + "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV", + "nm-testing/asym-w8w8-int8-static-per-tensor-tiny-llama", + "nm-testing/llama2.c-stories42M-pruned2.4-compressed", + "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t", + "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test", + "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing", + "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing", + "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing", + "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing", + "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Asym", + "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym", + "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym", + "nm-testing/Phi-3-mini-128k-instruct-FP8", + "nm-testing/Qwen2-0.5B-Instruct-FP8-SkipQKV", + "nm-testing/Qwen2-1.5B-Instruct-FP8-K-V", + "nm-testing/tinyllama-oneshot-w4a16-channel-v2", + "nm-testing/tinyllama-oneshot-w4a16-group128-v2", + "nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", + "nm-testing/tinyllama-oneshot-w8a16-per-channel", + "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", + "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2-asym", + "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", + "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2-asym", + "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_int8-BitM", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_fp8-BitM", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_int8-BitM", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_fp8-BitM", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_int8-BitM", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_fp8-BitM", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_int8-BitM", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Tensor-Weight-testing", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Static-testing", + "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme", + "nvidia/NVLM-D-72B", + "openai-community/gpt2", + # "openai/whisper-large-v3", + "openbmb/MiniCPM-o-2_6", + "openbmb/MiniCPM-V-2_6", + "OpenGVLab/InternVL2-1B", + "parasail-ai/GritLM-7B-vllm", + "Qwen/Qwen1.5-MoE-A2.7B-Chat", + "Qwen/Qwen2-7B-Instruct", + "Qwen/Qwen2-Audio-7B-Instruct", + "Qwen/Qwen2-VL-2B-Instruct", + "Qwen/Qwen2.5-1.5B-Instruct", + "Qwen/Qwen2.5-Math-PRM-7B", + "Qwen/Qwen2.5-Math-RM-72B", + "Qwen/Qwen2.5-VL-3B-Instruct", + "royokong/e5-v", + "sentence-transformers/all-roberta-large-v1", + "sentence-transformers/stsb-roberta-base-v2", + "shanearora/OLMo-7B-1124-hf", + "shuyuej/Llama-3.2-1B-Instruct-GPTQ", + "ssmits/Qwen2-7B-Instruct-embed-base", + "stabilityai/stablelm-3b-4e1t", + "stabilityai/stablelm-zephyr-3b", + "state-spaces/mamba-130m-hf", + "TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", + "THUDM/glm-4v-9b", + "TIGER-Lab/Mantis-8B-siglip-llama3", + "TIGER-Lab/VLM2Vec-Full", + "tiiuae/falcon-40b", + "tiiuae/falcon-mamba-7b-instruct", + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "upstage/solar-pro-preview-instruct", +] + +MODEL_WEIGHTS_S3_BUCKET = "s3://vllm-ci-model-weights" diff --git a/vllm/third_party/__init__.py b/vllm/third_party/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/third_party/pynvml.py b/vllm/third_party/pynvml.py new file mode 100644 index 0000000..0a4be23 --- /dev/null +++ b/vllm/third_party/pynvml.py @@ -0,0 +1,6139 @@ +# SPDX-License-Identifier: Apache-2.0 +# copied from https://pypi.org/project/nvidia-ml-py +# version 12.570.86 + +##### +# Copyright (c) 2011-2023, NVIDIA Corporation. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of the NVIDIA Corporation nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +##### + +## +# Python bindings for the NVML library +## +from ctypes import * +from ctypes.util import find_library +from functools import wraps +import sys +import os +import threading +import string + +## C Type mappings ## +## Enums +_nvmlEnableState_t = c_uint +NVML_FEATURE_DISABLED = 0 +NVML_FEATURE_ENABLED = 1 + +_nvmlBrandType_t = c_uint +NVML_BRAND_UNKNOWN = 0 +NVML_BRAND_QUADRO = 1 +NVML_BRAND_TESLA = 2 +NVML_BRAND_NVS = 3 +NVML_BRAND_GRID = 4 # Deprecated from API reporting. Keeping definition for backward compatibility. +NVML_BRAND_GEFORCE = 5 +NVML_BRAND_TITAN = 6 +NVML_BRAND_NVIDIA_VAPPS = 7 # NVIDIA Virtual Applications +NVML_BRAND_NVIDIA_VPC = 8 # NVIDIA Virtual PC +NVML_BRAND_NVIDIA_VCS = 9 # NVIDIA Virtual Compute Server +NVML_BRAND_NVIDIA_VWS = 10 # NVIDIA RTX Virtual Workstation +NVML_BRAND_NVIDIA_CLOUD_GAMING = 11 # NVIDIA Cloud Gaming +NVML_BRAND_NVIDIA_VGAMING = NVML_BRAND_NVIDIA_CLOUD_GAMING # Deprecated from API reporting. Keeping definition for backward compatibility. +NVML_BRAND_QUADRO_RTX = 12 +NVML_BRAND_NVIDIA_RTX = 13 +NVML_BRAND_NVIDIA = 14 +NVML_BRAND_GEFORCE_RTX = 15 # Unused +NVML_BRAND_TITAN_RTX = 16 # Unused +NVML_BRAND_COUNT = 17 + +_nvmlTemperatureThresholds_t = c_uint +NVML_TEMPERATURE_THRESHOLD_SHUTDOWN = 0 +NVML_TEMPERATURE_THRESHOLD_SLOWDOWN = 1 +NVML_TEMPERATURE_THRESHOLD_MEM_MAX = 2 +NVML_TEMPERATURE_THRESHOLD_GPU_MAX = 3 +NVML_TEMPERATURE_THRESHOLD_ACOUSTIC_MIN = 4 +NVML_TEMPERATURE_THRESHOLD_ACOUSTIC_CURR = 5 +NVML_TEMPERATURE_THRESHOLD_ACOUSTIC_MAX = 6 +NVML_TEMPERATURE_THRESHOLD_GPS_CURR = 7 +NVML_TEMPERATURE_THRESHOLD_COUNT = 8 + +_nvmlTemperatureSensors_t = c_uint +NVML_TEMPERATURE_GPU = 0 +NVML_TEMPERATURE_COUNT = 1 + + +_nvmlComputeMode_t = c_uint +NVML_COMPUTEMODE_DEFAULT = 0 +NVML_COMPUTEMODE_EXCLUSIVE_THREAD = 1 ## Support Removed +NVML_COMPUTEMODE_PROHIBITED = 2 +NVML_COMPUTEMODE_EXCLUSIVE_PROCESS = 3 +NVML_COMPUTEMODE_COUNT = 4 + +_nvmlMemoryLocation_t = c_uint +NVML_MEMORY_LOCATION_L1_CACHE = 0 +NVML_MEMORY_LOCATION_L2_CACHE = 1 +NVML_MEMORY_LOCATION_DEVICE_MEMORY = 2 +NVML_MEMORY_LOCATION_DRAM = 2 +NVML_MEMORY_LOCATION_REGISTER_FILE = 3 +NVML_MEMORY_LOCATION_TEXTURE_MEMORY = 4 +NVML_MEMORY_LOCATION_TEXTURE_SHM = 5 +NVML_MEMORY_LOCATION_CBU = 6 +NVML_MEMORY_LOCATION_SRAM = 7 +NVML_MEMORY_LOCATION_COUNT = 8 + +NVML_NVLINK_MAX_LINKS = 18 + +# For backwards compatibility, maintain the incorrectly-named "LANES" define +NVML_NVLINK_MAX_LANES = NVML_NVLINK_MAX_LINKS + +_nvmlNvLinkErrorCounter_t = c_uint +NVML_NVLINK_ERROR_DL_REPLAY = 0 +NVML_NVLINK_ERROR_DL_RECOVERY = 1 +NVML_NVLINK_ERROR_DL_CRC_FLIT = 2 +NVML_NVLINK_ERROR_DL_CRC_DATA = 3 +NVML_NVLINK_ERROR_DL_ECC_DATA = 4 +NVML_NVLINK_ERROR_COUNT = 5 + +_nvmlNvLinkEccLaneErrorCounter_t = c_uint +NVML_NVLINK_ERROR_DL_ECC_LANE0 = 0 +NVML_NVLINK_ERROR_DL_ECC_LANE1 = 1 +NVML_NVLINK_ERROR_DL_ECC_LANE2 = 2 +NVML_NVLINK_ERROR_DL_ECC_LANE3 = 3 +NVML_NVLINK_ERROR_DL_ECC_COUNT = 5 + +_nvmlNvLinkCapability_t = c_uint +NVML_NVLINK_CAP_P2P_SUPPORTED = 0 +NVML_NVLINK_CAP_SYSMEM_ACCESS = 1 +NVML_NVLINK_CAP_P2P_ATOMICS = 2 +NVML_NVLINK_CAP_SYSMEM_ATOMICS= 3 +NVML_NVLINK_CAP_SLI_BRIDGE = 4 +NVML_NVLINK_CAP_VALID = 5 +NVML_NVLINK_CAP_COUNT = 6 + +_nvmlNvLinkUtilizationCountPktTypes_t = c_uint +NVML_NVLINK_COUNTER_PKTFILTER_NOP = 0x1 +NVML_NVLINK_COUNTER_PKTFILTER_READ = 0x2 +NVML_NVLINK_COUNTER_PKTFILTER_WRITE = 0x4 +NVML_NVLINK_COUNTER_PKTFILTER_RATOM = 0x8 +NVML_NVLINK_COUNTER_PKTFILTER_NRATOM = 0x10 +NVML_NVLINK_COUNTER_PKTFILTER_FLUSH = 0x20 +NVML_NVLINK_COUNTER_PKTFILTER_RESPDATA = 0x40 +NVML_NVLINK_COUNTER_PKTFILTER_RESPNODATA = 0x80 +NVML_NVLINK_COUNTER_PKTFILTER_ALL = 0xFF + +_nvmlNvLinkUtilizationCountUnits_t = c_uint +NVML_NVLINK_COUNTER_UNIT_CYCLES = 0 +NVML_NVLINK_COUNTER_UNIT_PACKETS = 1 +NVML_NVLINK_COUNTER_UNIT_BYTES = 2 +NVML_NVLINK_COUNTER_UNIT_RESERVED = 3 +NVML_NVLINK_COUNTER_UNIT_COUNT = 4 + +_nvmlNvLinkDeviceType_t = c_uint +NVML_NVLINK_DEVICE_TYPE_GPU = 0x00 +NVML_NVLINK_DEVICE_TYPE_IBMNPU = 0x01 +NVML_NVLINK_DEVICE_TYPE_SWITCH = 0x02 +NVML_NVLINK_DEVICE_TYPE_UNKNOWN = 0xFF + +# These are deprecated, instead use _nvmlMemoryErrorType_t +_nvmlEccBitType_t = c_uint +NVML_SINGLE_BIT_ECC = 0 +NVML_DOUBLE_BIT_ECC = 1 +NVML_ECC_ERROR_TYPE_COUNT = 2 + +_nvmlEccCounterType_t = c_uint +NVML_VOLATILE_ECC = 0 +NVML_AGGREGATE_ECC = 1 +NVML_ECC_COUNTER_TYPE_COUNT = 2 + +_nvmlMemoryErrorType_t = c_uint +NVML_MEMORY_ERROR_TYPE_CORRECTED = 0 +NVML_MEMORY_ERROR_TYPE_UNCORRECTED = 1 +NVML_MEMORY_ERROR_TYPE_COUNT = 2 + +_nvmlClockType_t = c_uint +NVML_CLOCK_GRAPHICS = 0 +NVML_CLOCK_SM = 1 +NVML_CLOCK_MEM = 2 +NVML_CLOCK_VIDEO = 3 +NVML_CLOCK_COUNT = 4 + +_nvmlClockId_t = c_uint +NVML_CLOCK_ID_CURRENT = 0 +NVML_CLOCK_ID_APP_CLOCK_TARGET = 1 +NVML_CLOCK_ID_APP_CLOCK_DEFAULT = 2 +NVML_CLOCK_ID_CUSTOMER_BOOST_MAX = 3 +NVML_CLOCK_ID_COUNT = 4 + +_nvmlDriverModel_t = c_uint +NVML_DRIVER_WDDM = 0 +NVML_DRIVER_WDM = 1 +NVML_DRIVER_MCDM = 2 + +NVML_MAX_GPU_PERF_PSTATES = 16 + +_nvmlPstates_t = c_uint +NVML_PSTATE_0 = 0 +NVML_PSTATE_1 = 1 +NVML_PSTATE_2 = 2 +NVML_PSTATE_3 = 3 +NVML_PSTATE_4 = 4 +NVML_PSTATE_5 = 5 +NVML_PSTATE_6 = 6 +NVML_PSTATE_7 = 7 +NVML_PSTATE_8 = 8 +NVML_PSTATE_9 = 9 +NVML_PSTATE_10 = 10 +NVML_PSTATE_11 = 11 +NVML_PSTATE_12 = 12 +NVML_PSTATE_13 = 13 +NVML_PSTATE_14 = 14 +NVML_PSTATE_15 = 15 +NVML_PSTATE_UNKNOWN = 32 + +_nvmlInforomObject_t = c_uint +NVML_INFOROM_OEM = 0 +NVML_INFOROM_ECC = 1 +NVML_INFOROM_POWER = 2 +NVML_INFOROM_DEN = 3 +NVML_INFOROM_COUNT = 4 + +_nvmlReturn_t = c_uint +NVML_SUCCESS = 0 +NVML_ERROR_UNINITIALIZED = 1 +NVML_ERROR_INVALID_ARGUMENT = 2 +NVML_ERROR_NOT_SUPPORTED = 3 +NVML_ERROR_NO_PERMISSION = 4 +NVML_ERROR_ALREADY_INITIALIZED = 5 +NVML_ERROR_NOT_FOUND = 6 +NVML_ERROR_INSUFFICIENT_SIZE = 7 +NVML_ERROR_INSUFFICIENT_POWER = 8 +NVML_ERROR_DRIVER_NOT_LOADED = 9 +NVML_ERROR_TIMEOUT = 10 +NVML_ERROR_IRQ_ISSUE = 11 +NVML_ERROR_LIBRARY_NOT_FOUND = 12 +NVML_ERROR_FUNCTION_NOT_FOUND = 13 +NVML_ERROR_CORRUPTED_INFOROM = 14 +NVML_ERROR_GPU_IS_LOST = 15 +NVML_ERROR_RESET_REQUIRED = 16 +NVML_ERROR_OPERATING_SYSTEM = 17 +NVML_ERROR_LIB_RM_VERSION_MISMATCH = 18 +NVML_ERROR_IN_USE = 19 +NVML_ERROR_MEMORY = 20 +NVML_ERROR_NO_DATA = 21 +NVML_ERROR_VGPU_ECC_NOT_SUPPORTED = 22 +NVML_ERROR_INSUFFICIENT_RESOURCES = 23 +NVML_ERROR_FREQ_NOT_SUPPORTED = 24 +NVML_ERROR_ARGUMENT_VERSION_MISMATCH = 25 +NVML_ERROR_DEPRECATED = 26 +NVML_ERROR_NOT_READY = 27 +NVML_ERROR_GPU_NOT_FOUND = 28 +NVML_ERROR_INVALID_STATE = 29 +NVML_ERROR_UNKNOWN = 999 + +_nvmlFanState_t = c_uint +NVML_FAN_NORMAL = 0 +NVML_FAN_FAILED = 1 + +_nvmlFanControlPolicy_t = c_uint +NVML_FAN_POLICY_TEMPERATURE_CONTINOUS_SW = 0 +NVML_FAN_POLICY_MANUAL = 1 + +_nvmlLedColor_t = c_uint +NVML_LED_COLOR_GREEN = 0 +NVML_LED_COLOR_AMBER = 1 + +_nvmlGpuOperationMode_t = c_uint +NVML_GOM_ALL_ON = 0 +NVML_GOM_COMPUTE = 1 +NVML_GOM_LOW_DP = 2 + +_nvmlPageRetirementCause_t = c_uint +NVML_PAGE_RETIREMENT_CAUSE_MULTIPLE_SINGLE_BIT_ECC_ERRORS = 0 +NVML_PAGE_RETIREMENT_CAUSE_DOUBLE_BIT_ECC_ERROR = 1 +NVML_PAGE_RETIREMENT_CAUSE_COUNT = 2 + +_nvmlRestrictedAPI_t = c_uint +NVML_RESTRICTED_API_SET_APPLICATION_CLOCKS = 0 +NVML_RESTRICTED_API_SET_AUTO_BOOSTED_CLOCKS = 1 +NVML_RESTRICTED_API_COUNT = 2 + +_nvmlBridgeChipType_t = c_uint +NVML_BRIDGE_CHIP_PLX = 0 +NVML_BRIDGE_CHIP_BRO4 = 1 +NVML_MAX_PHYSICAL_BRIDGE = 128 + +_nvmlValueType_t = c_uint +NVML_VALUE_TYPE_DOUBLE = 0 +NVML_VALUE_TYPE_UNSIGNED_INT = 1 +NVML_VALUE_TYPE_UNSIGNED_LONG = 2 +NVML_VALUE_TYPE_UNSIGNED_LONG_LONG = 3 +NVML_VALUE_TYPE_SIGNED_LONG_LONG = 4 +NVML_VALUE_TYPE_SIGNED_INT = 5 +NVML_VALUE_TYPE_UNSIGNED_SHORT = 6 +NVML_VALUE_TYPE_COUNT = 7 + +_nvmlNvlinkVersion_t = c_uint +NVML_NVLINK_VERSION_INVALID = 0 +NVML_NVLINK_VERSION_1_0 = 1 +NVML_NVLINK_VERSION_2_0 = 2 +NVML_NVLINK_VERSION_2_2 = 3 +NVML_NVLINK_VERSION_3_0 = 4 +NVML_NVLINK_VERSION_3_1 = 5 +NVML_NVLINK_VERSION_4_0 = 6 +NVML_NVLINK_VERSION_5_0 = 7 + +_nvmlPerfPolicyType_t = c_uint +NVML_PERF_POLICY_POWER = 0 +NVML_PERF_POLICY_THERMAL = 1 +NVML_PERF_POLICY_SYNC_BOOST = 2 +NVML_PERF_POLICY_BOARD_LIMIT = 3 +NVML_PERF_POLICY_LOW_UTILIZATION = 4 +NVML_PERF_POLICY_RELIABILITY = 5 +NVML_PERF_POLICY_TOTAL_APP_CLOCKS = 10 +NVML_PERF_POLICY_TOTAL_BASE_CLOCKS = 11 +NVML_PERF_POLICY_COUNT = 12 + +_nvmlEncoderQueryType_t = c_uint +NVML_ENCODER_QUERY_H264 = 0 +NVML_ENCODER_QUERY_HEVC = 1 +NVML_ENCODER_QUERY_AV1 = 2 +NVML_ENCODER_QUERY_UNKNOWN = 255 + +_nvmlFBCSessionType_t = c_uint +NVML_FBC_SESSION_TYPE_UNKNOWN = 0 +NVML_FBC_SESSION_TYPE_TOSYS = 1 +NVML_FBC_SESSION_TYPE_CUDA = 2 +NVML_FBC_SESSION_TYPE_VID = 3 +NVML_FBC_SESSION_TYPE_HWENC = 4 + +_nvmlDetachGpuState_t = c_uint +NVML_DETACH_GPU_KEEP = 0 +NVML_DETACH_GPU_REMOVE = 1 + +_nvmlPcieLinkState_t = c_uint +NVML_PCIE_LINK_KEEP = 0 +NVML_PCIE_LINK_SHUT_DOWN = 1 + +_nvmlSamplingType_t = c_uint +NVML_TOTAL_POWER_SAMPLES = 0 +NVML_GPU_UTILIZATION_SAMPLES = 1 +NVML_MEMORY_UTILIZATION_SAMPLES = 2 +NVML_ENC_UTILIZATION_SAMPLES = 3 +NVML_DEC_UTILIZATION_SAMPLES = 4 +NVML_PROCESSOR_CLK_SAMPLES = 5 +NVML_MEMORY_CLK_SAMPLES = 6 +NVML_MODULE_POWER_SAMPLES = 7 +NVML_JPG_UTILIZATION_SAMPLES = 8 +NVML_OFA_UTILIZATION_SAMPLES = 9 +NVML_SAMPLINGTYPE_COUNT = 10 + +_nvmlPcieUtilCounter_t = c_uint +NVML_PCIE_UTIL_TX_BYTES = 0 +NVML_PCIE_UTIL_RX_BYTES = 1 +NVML_PCIE_UTIL_COUNT = 2 + +_nvmlGpuTopologyLevel_t = c_uint +NVML_TOPOLOGY_INTERNAL = 0 +NVML_TOPOLOGY_SINGLE = 10 +NVML_TOPOLOGY_MULTIPLE = 20 +NVML_TOPOLOGY_HOSTBRIDGE = 30 +NVML_TOPOLOGY_NODE = 40 +NVML_TOPOLOGY_CPU = NVML_TOPOLOGY_NODE +NVML_TOPOLOGY_SYSTEM = 50 + +_nvmlGpuP2PCapsIndex_t = c_uint +NVML_P2P_CAPS_INDEX_READ = 0, +NVML_P2P_CAPS_INDEX_WRITE = 1 +NVML_P2P_CAPS_INDEX_NVLINK =2 +NVML_P2P_CAPS_INDEX_ATOMICS = 3 +# +# NVML_P2P_CAPS_INDEX_PROP is deprecated. +# Use NVML_P2P_CAPS_INDEX_PCI instead. +# +NVML_P2P_CAPS_INDEX_PROP = 4 +NVML_P2P_CAPS_INDEX_PCI = 4 +NVML_P2P_CAPS_INDEX_UNKNOWN = 5 + +_nvmlGpuP2PStatus_t = c_uint +NVML_P2P_STATUS_OK = 0 +NVML_P2P_STATUS_CHIPSET_NOT_SUPPORED = 1 +NVML_P2P_STATUS_CHIPSET_NOT_SUPPORTED = NVML_P2P_STATUS_CHIPSET_NOT_SUPPORED +NVML_P2P_STATUS_GPU_NOT_SUPPORTED = 2 +NVML_P2P_STATUS_IOH_TOPOLOGY_NOT_SUPPORTED =3 +NVML_P2P_STATUS_DISABLED_BY_REGKEY =4 +NVML_P2P_STATUS_NOT_SUPPORTED =5 +NVML_P2P_STATUS_UNKNOWN =6 + +_nvmlDeviceArchitecture_t = c_uint +NVML_DEVICE_ARCH_KEPLER = 2 +NVML_DEVICE_ARCH_MAXWELL = 3 +NVML_DEVICE_ARCH_PASCAL = 4 +NVML_DEVICE_ARCH_VOLTA = 5 +NVML_DEVICE_ARCH_TURING = 6 +NVML_DEVICE_ARCH_AMPERE = 7 +NVML_DEVICE_ARCH_ADA = 8 +NVML_DEVICE_ARCH_HOPPER = 9 +NVML_DEVICE_ARCH_BLACKWELL = 10 +NVML_DEVICE_ARCH_T23X = 11 +NVML_DEVICE_ARCH_UNKNOWN = 0xffffffff + +# PCI bus Types +_nvmlBusType_t = c_uint +NVML_BUS_TYPE_UNKNOWN = 0 +NVML_BUS_TYPE_PCI = 1 +NVML_BUS_TYPE_PCIE = 2 +NVML_BUS_TYPE_FPCI = 3 +NVML_BUS_TYPE_AGP = 4 + +_nvmlPowerSource_t = c_uint +NVML_POWER_SOURCE_AC = 0x00000000 +NVML_POWER_SOURCE_BATTERY = 0x00000001 +NVML_POWER_SOURCE_UNDERSIZED = 0x00000002 + +_nvmlAdaptiveClockInfoStatus_t = c_uint +NVML_ADAPTIVE_CLOCKING_INFO_STATUS_DISABLED = 0x00000000 +NVML_ADAPTIVE_CLOCKING_INFO_STATUS_ENABLED = 0x00000001 + +_nvmlClockLimitId_t = c_uint +NVML_CLOCK_LIMIT_ID_RANGE_START = 0xffffff00 +NVML_CLOCK_LIMIT_ID_TDP = 0xffffff01 +NVML_CLOCK_LIMIT_ID_UNLIMITED = 0xffffff02 + +_nvmlPcieLinkMaxSpeed_t = c_uint +NVML_PCIE_LINK_MAX_SPEED_INVALID = 0x00000000 +NVML_PCIE_LINK_MAX_SPEED_2500MBPS = 0x00000001 +NVML_PCIE_LINK_MAX_SPEED_5000MBPS = 0x00000002 +NVML_PCIE_LINK_MAX_SPEED_8000MBPS = 0x00000003 +NVML_PCIE_LINK_MAX_SPEED_16000MBPS = 0x00000004 +NVML_PCIE_LINK_MAX_SPEED_32000MBPS = 0x00000005 +NVML_PCIE_LINK_MAX_SPEED_64000MBPS = 0x00000006 + +_nvmlPcieAtomicsCapability_t = c_uint +NVML_PCIE_ATOMICS_CAP_FETCHADD32 = 0x01 +NVML_PCIE_ATOMICS_CAP_FETCHADD64 = 0x02 +NVML_PCIE_ATOMICS_CAP_SWAP32 = 0x04 +NVML_PCIE_ATOMICS_CAP_SWAP64 = 0x08 +NVML_PCIE_ATOMICS_CAP_CAS32 = 0x10 +NVML_PCIE_ATOMICS_CAP_CAS64 = 0x20 +NVML_PCIE_ATOMICS_CAP_CAS128 = 0x40 +NVML_PCIE_ATOMICS_OPS_MAX = 7 + +_nvmlAffinityScope_t = c_uint +NVML_AFFINITY_SCOPE_NODE = 0 +NVML_AFFINITY_SCOPE_SOCKET = 1 + +_nvmlDeviceGpuRecoveryAction_t = c_uint +NVML_GPU_RECOVERY_ACTION_NONE = 0 +NVML_GPU_RECOVERY_ACTION_GPU_RESET = 1 +NVML_GPU_RECOVERY_ACTION_NODE_REBOOT = 2 +NVML_GPU_RECOVERY_ACTION_DRAIN_P2P = 3 +NVML_GPU_RECOVERY_ACTION_DRAIN_AND_RESET = 4 + +# C preprocessor defined values +nvmlFlagDefault = 0 +nvmlFlagForce = 1 +NVML_INIT_FLAG_NO_GPUS = 1 +NVML_INIT_FLAG_NO_ATTACH = 2 + +NVML_MAX_GPC_COUNT = 32 + +# buffer size +NVML_DEVICE_INFOROM_VERSION_BUFFER_SIZE = 16 +NVML_DEVICE_UUID_BUFFER_SIZE = 80 +NVML_DEVICE_UUID_V2_BUFFER_SIZE = 96 +NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE = 80 +NVML_SYSTEM_NVML_VERSION_BUFFER_SIZE = 80 +NVML_DEVICE_NAME_BUFFER_SIZE = 64 +NVML_DEVICE_NAME_V2_BUFFER_SIZE = 96 +NVML_DEVICE_SERIAL_BUFFER_SIZE = 30 +NVML_DEVICE_PART_NUMBER_BUFFER_SIZE = 80 +NVML_DEVICE_GPU_PART_NUMBER_BUFFER_SIZE = 80 +NVML_DEVICE_VBIOS_VERSION_BUFFER_SIZE = 32 +NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE = 32 +NVML_DEVICE_PCI_BUS_ID_BUFFER_V2_SIZE = 16 +NVML_GRID_LICENSE_BUFFER_SIZE = 128 +NVML_VGPU_NAME_BUFFER_SIZE = 64 +NVML_GRID_LICENSE_FEATURE_MAX_COUNT = 3 +NVML_VGPU_METADATA_OPAQUE_DATA_SIZE = sizeof(c_uint) + 256 +NVML_VGPU_PGPU_METADATA_OPAQUE_DATA_SIZE = 256 +NVML_DEVICE_GPU_FRU_PART_NUMBER_BUFFER_SIZE = 0x14 # NV2080_GPU_MAX_PRODUCT_PART_NUMBER_LENGTH +NVML_PERF_MODES_BUFFER_SIZE = 2048 + +# Format strings +NVML_DEVICE_PCI_BUS_ID_LEGACY_FMT = "%04X:%02X:%02X.0" +NVML_DEVICE_PCI_BUS_ID_FMT = "%08X:%02X:%02X.0" + +NVML_VALUE_NOT_AVAILABLE_ulonglong = c_ulonglong(-1) +NVML_VALUE_NOT_AVAILABLE_uint = c_uint(-1) + +''' + Field Identifiers. + + All Identifiers pertain to a device. Each ID is only used once and is guaranteed never to change. +''' +NVML_FI_DEV_ECC_CURRENT = 1 # Current ECC mode. 1=Active. 0=Inactive +NVML_FI_DEV_ECC_PENDING = 2 # Pending ECC mode. 1=Active. 0=Inactive + +#ECC Count Totals +NVML_FI_DEV_ECC_SBE_VOL_TOTAL = 3 # Total single bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_TOTAL = 4 # Total double bit volatile ECC errors +NVML_FI_DEV_ECC_SBE_AGG_TOTAL = 5 # Total single bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_DBE_AGG_TOTAL = 6 # Total double bit aggregate (persistent) ECC errors +#Individual ECC locations +NVML_FI_DEV_ECC_SBE_VOL_L1 = 7 # L1 cache single bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_L1 = 8 # L1 cache double bit volatile ECC errors +NVML_FI_DEV_ECC_SBE_VOL_L2 = 9 # L2 cache single bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_L2 = 10 # L2 cache double bit volatile ECC errors +NVML_FI_DEV_ECC_SBE_VOL_DEV = 11 # Device memory single bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_DEV = 12 # Device memory double bit volatile ECC errors +NVML_FI_DEV_ECC_SBE_VOL_REG = 13 # Register file single bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_REG = 14 # Register file double bit volatile ECC errors +NVML_FI_DEV_ECC_SBE_VOL_TEX = 15 # Texture memory single bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_TEX = 16 # Texture memory double bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_CBU = 17 # CBU double bit volatile ECC errors +NVML_FI_DEV_ECC_SBE_AGG_L1 = 18 # L1 cache single bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_DBE_AGG_L1 = 19 # L1 cache double bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_SBE_AGG_L2 = 20 # L2 cache single bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_DBE_AGG_L2 = 21 # L2 cache double bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_SBE_AGG_DEV = 22 # Device memory single bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_DBE_AGG_DEV = 23 # Device memory double bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_SBE_AGG_REG = 24 # Register File single bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_DBE_AGG_REG = 25 # Register File double bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_SBE_AGG_TEX = 26 # Texture memory single bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_DBE_AGG_TEX = 27 # Texture memory double bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_DBE_AGG_CBU = 28 # CBU double bit aggregate ECC errors + +# Page Retirement +NVML_FI_DEV_RETIRED_SBE = 29 # Number of retired pages because of single bit errors +NVML_FI_DEV_RETIRED_DBE = 30 # Number of retired pages because of double bit errors +NVML_FI_DEV_RETIRED_PENDING = 31 # If any pages are pending retirement. 1=yes. 0=no. + +# NvLink Flit Error Counters +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L0 = 32 # NVLink flow control CRC Error Counter for Lane 0 +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L1 = 33 # NVLink flow control CRC Error Counter for Lane 1 +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L2 = 34 # NVLink flow control CRC Error Counter for Lane 2 +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L3 = 35 # NVLink flow control CRC Error Counter for Lane 3 +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L4 = 36 # NVLink flow control CRC Error Counter for Lane 4 +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L5 = 37 # NVLink flow control CRC Error Counter for Lane 5 +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_TOTAL = 38 # NVLink flow control CRC Error Counter total for all Lanes + +# NvLink CRC Data Error Counters +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L0 = 39 # NVLink data CRC Error Counter for Lane 0 +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L1 = 40 # NVLink data CRC Error Counter for Lane 1 +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L2 = 41 # NVLink data CRC Error Counter for Lane 2 +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L3 = 42 # NVLink data CRC Error Counter for Lane 3 +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L4 = 43 # NVLink data CRC Error Counter for Lane 4 +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L5 = 44 # NVLink data CRC Error Counter for Lane 5 +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_TOTAL = 45 # NvLink data CRC Error Counter total for all Lanes + +# NvLink Replay Error Counters +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L0 = 46 # NVLink Replay Error Counter for Lane 0 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L1 = 47 # NVLink Replay Error Counter for Lane 1 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L2 = 48 # NVLink Replay Error Counter for Lane 2 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L3 = 49 # NVLink Replay Error Counter for Lane 3 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L4 = 50 # NVLink Replay Error Counter for Lane 4 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L5 = 51 # NVLink Replay Error Counter for Lane 5 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_TOTAL = 52 # NVLink Replay Error Counter total for all Lanes + +# NvLink Recovery Error Counters +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L0 = 53 # NVLink Recovery Error Counter for Lane 0 +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L1 = 54 # NVLink Recovery Error Counter for Lane 1 +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L2 = 55 # NVLink Recovery Error Counter for Lane 2 +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L3 = 56 # NVLink Recovery Error Counter for Lane 3 +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L4 = 57 # NVLink Recovery Error Counter for Lane 4 +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L5 = 58 # NVLink Recovery Error Counter for Lane 5 +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_TOTAL = 59 # NVLink Recovery Error Counter total for all Lanes + +# NvLink Bandwidth Counters +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L0 = 60 # NVLink Bandwidth Counter for Counter Set 0, Lane 0 +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L1 = 61 # NVLink Bandwidth Counter for Counter Set 0, Lane 1 +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L2 = 62 # NVLink Bandwidth Counter for Counter Set 0, Lane 2 +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L3 = 63 # NVLink Bandwidth Counter for Counter Set 0, Lane 3 +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L4 = 64 # NVLink Bandwidth Counter for Counter Set 0, Lane 4 +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L5 = 65 # NVLink Bandwidth Counter for Counter Set 0, Lane 5 +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_TOTAL = 66 # NVLink Bandwidth Counter Total for Counter Set 0, All Lanes + +# NvLink Bandwidth Counters +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L0 = 67 # NVLink Bandwidth Counter for Counter Set 1, Lane 0 +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L1 = 68 # NVLink Bandwidth Counter for Counter Set 1, Lane 1 +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L2 = 69 # NVLink Bandwidth Counter for Counter Set 1, Lane 2 +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L3 = 70 # NVLink Bandwidth Counter for Counter Set 1, Lane 3 +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L4 = 71 # NVLink Bandwidth Counter for Counter Set 1, Lane 4 +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L5 = 72 # NVLink Bandwidth Counter for Counter Set 1, Lane 5 +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_TOTAL = 73 # NVLink Bandwidth Counter Total for Counter Set 1, All Lanes + +# Perf Policy Counters +NVML_FI_DEV_PERF_POLICY_POWER = 74 # Perf Policy Counter for Power Policy +NVML_FI_DEV_PERF_POLICY_THERMAL = 75 # Perf Policy Counter for Thermal Policy +NVML_FI_DEV_PERF_POLICY_SYNC_BOOST = 76 # Perf Policy Counter for Sync boost Policy +NVML_FI_DEV_PERF_POLICY_BOARD_LIMIT = 77 # Perf Policy Counter for Board Limit +NVML_FI_DEV_PERF_POLICY_LOW_UTILIZATION = 78 # Perf Policy Counter for Low GPU Utilization Policy +NVML_FI_DEV_PERF_POLICY_RELIABILITY = 79 # Perf Policy Counter for Reliability Policy +NVML_FI_DEV_PERF_POLICY_TOTAL_APP_CLOCKS = 80 # Perf Policy Counter for Total App Clock Policy +NVML_FI_DEV_PERF_POLICY_TOTAL_BASE_CLOCKS = 81 # Perf Policy Counter for Total Base Clocks Policy + +# Memory temperatures +NVML_FI_DEV_MEMORY_TEMP = 82 # Memory temperature for the device + +# Energy Counter +NVML_FI_DEV_TOTAL_ENERGY_CONSUMPTION = 83 # Total energy consumption for the GPU in mJ since the driver was last reloaded + +# NVLink Speed +NVML_FI_DEV_NVLINK_SPEED_MBPS_L0 = 84 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L1 = 85 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L2 = 86 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L3 = 87 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L4 = 88 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L5 = 89 +NVML_FI_DEV_NVLINK_SPEED_MBPS_COMMON = 90 + +# NVLink Link Count +NVML_FI_DEV_NVLINK_LINK_COUNT = 91 + +# Page Retirement pending fields +NVML_FI_DEV_RETIRED_PENDING_SBE = 92 +NVML_FI_DEV_RETIRED_PENDING_DBE = 93 + +# PCIe replay and replay rollover counters +NVML_FI_DEV_PCIE_REPLAY_COUNTER = 94 +NVML_FI_DEV_PCIE_REPLAY_ROLLOVER_COUNTER = 95 + +# NvLink Flit Error Counters +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L6 = 96 # NVLink flow control CRC Error Counter for Lane 6 +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L7 = 97 # NVLink flow control CRC Error Counter for Lane 7 +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L8 = 98 # NVLink flow control CRC Error Counter for Lane 8 +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L9 = 99 # NVLink flow control CRC Error Counter for Lane 9 +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L10 = 100 # NVLink flow control CRC Error Counter for Lane 10 +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L11 = 101 # NVLink flow control CRC Error Counter for Lane 11 + +# NvLink CRC Data Error Counters +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L6 = 102 # NVLink data CRC Error Counter for Lane 6 +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L7 = 103 # NVLink data CRC Error Counter for Lane 7 +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L8 = 104 # NVLink data CRC Error Counter for Lane 8 +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L9 = 105 # NVLink data CRC Error Counter for Lane 9 +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L10 = 106 # NVLink data CRC Error Counter for Lane 10 +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L11 = 107 # NVLink data CRC Error Counter for Lane 11 + +# NvLink Replay Error Counters +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L6 = 108 # NVLink Replay Error Counter for Lane 6 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L7 = 109 # NVLink Replay Error Counter for Lane 7 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L8 = 110 # NVLink Replay Error Counter for Lane 8 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L9 = 111 # NVLink Replay Error Counter for Lane 9 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L10 = 112 # NVLink Replay Error Counter for Lane 10 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L11 = 113 # NVLink Replay Error Counter for Lane 11 + +# NvLink Recovery Error Counters +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L6 = 114 # NVLink Recovery Error Counter for Lane 6 +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L7 = 115 # NVLink Recovery Error Counter for Lane 7 +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L8 = 116 # NVLink Recovery Error Counter for Lane 8 +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L9 = 117 # NVLink Recovery Error Counter for Lane 9 +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L10 = 118 # NVLink Recovery Error Counter for Lane 10 +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L11 = 119 # NVLink Recovery Error Counter for Lane 11 + +# NvLink Bandwidth Counters +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L6 = 120 # NVLink Bandwidth Counter for Counter Set 0, Lane 6 +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L7 = 121 # NVLink Bandwidth Counter for Counter Set 0, Lane 7 +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L8 = 122 # NVLink Bandwidth Counter for Counter Set 0, Lane 8 +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L9 = 123 # NVLink Bandwidth Counter for Counter Set 0, Lane 9 +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L10 = 124 # NVLink Bandwidth Counter for Counter Set 0, Lane 10 +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L11 = 125 # NVLink Bandwidth Counter for Counter Set 0, Lane 11 + +# NvLink Bandwidth Counters +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L6 = 126 # NVLink Bandwidth Counter for Counter Set 1, Lane 6 +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L7 = 127 # NVLink Bandwidth Counter for Counter Set 1, Lane 7 +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L8 = 128 # NVLink Bandwidth Counter for Counter Set 1, Lane 8 +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L9 = 129 # NVLink Bandwidth Counter for Counter Set 1, Lane 9 +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L10 = 130 # NVLink Bandwidth Counter for Counter Set 1, Lane 10 +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L11 = 131 # NVLink Bandwidth Counter for Counter Set 1, Lane 11 + +# NVLink Speed +NVML_FI_DEV_NVLINK_SPEED_MBPS_L6 = 132 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L7 = 133 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L8 = 134 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L9 = 135 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L10 = 136 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L11 = 137 + +# NVLink Throughput Counters +NVML_FI_DEV_NVLINK_THROUGHPUT_DATA_TX = 138 # NVLink TX Data throughput in KiB +NVML_FI_DEV_NVLINK_THROUGHPUT_DATA_RX = 139 # NVLink RX Data throughput in KiB +NVML_FI_DEV_NVLINK_THROUGHPUT_RAW_TX = 140 # NVLink TX Data + protocol overhead in KiB +NVML_FI_DEV_NVLINK_THROUGHPUT_RAW_RX = 141 # NVLink RX Data + protocol overhead in KiB + +# Row Remapper +NVML_FI_DEV_REMAPPED_COR = 142 +NVML_FI_DEV_REMAPPED_UNC = 143 +NVML_FI_DEV_REMAPPED_PENDING = 144 +NVML_FI_DEV_REMAPPED_FAILURE = 145 + +#Remote device NVLink ID +NVML_FI_DEV_NVLINK_REMOTE_NVLINK_ID = 146 + +# Number of NVLinks connected to NVSwitch +NVML_FI_DEV_NVSWITCH_CONNECTED_LINK_COUNT = 147 + +# NvLink ECC Data Error Counters +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L0 = 148 #< NVLink data ECC Error Counter for Link 0 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L1 = 149 #< NVLink data ECC Error Counter for Link 1 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L2 = 150 #< NVLink data ECC Error Counter for Link 2 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L3 = 151 #< NVLink data ECC Error Counter for Link 3 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L4 = 152 #< NVLink data ECC Error Counter for Link 4 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L5 = 153 #< NVLink data ECC Error Counter for Link 5 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L6 = 154 #< NVLink data ECC Error Counter for Link 6 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L7 = 155 #< NVLink data ECC Error Counter for Link 7 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L8 = 156 #< NVLink data ECC Error Counter for Link 8 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L9 = 157 #< NVLink data ECC Error Counter for Link 9 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L10 = 158 #< NVLink data ECC Error Counter for Link 10 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L11 = 159 #< NVLink data ECC Error Counter for Link 11 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_TOTAL = 160 #< NvLink data ECC Error Counter total for all Links + +NVML_FI_DEV_NVLINK_ERROR_DL_REPLAY = 161 +NVML_FI_DEV_NVLINK_ERROR_DL_RECOVERY = 162 +NVML_FI_DEV_NVLINK_ERROR_DL_CRC = 163 +NVML_FI_DEV_NVLINK_GET_SPEED = 164 +NVML_FI_DEV_NVLINK_GET_STATE = 165 +NVML_FI_DEV_NVLINK_GET_VERSION = 166 + +NVML_FI_DEV_NVLINK_GET_POWER_STATE = 167 +NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD = 168 + +NVML_FI_DEV_PCIE_L0_TO_RECOVERY_COUNTER = 169 + +NVML_FI_DEV_C2C_LINK_COUNT = 170 +NVML_FI_DEV_C2C_LINK_GET_STATUS = 171 +NVML_FI_DEV_C2C_LINK_GET_MAX_BW = 172 + +NVML_FI_DEV_PCIE_COUNT_CORRECTABLE_ERRORS = 173 +NVML_FI_DEV_PCIE_COUNT_NAKS_RECEIVED = 174 +NVML_FI_DEV_PCIE_COUNT_RECEIVER_ERROR = 175 +NVML_FI_DEV_PCIE_COUNT_BAD_TLP = 176 +NVML_FI_DEV_PCIE_COUNT_NAKS_SENT = 177 +NVML_FI_DEV_PCIE_COUNT_BAD_DLLP = 178 +NVML_FI_DEV_PCIE_COUNT_NON_FATAL_ERROR = 179 +NVML_FI_DEV_PCIE_COUNT_FATAL_ERROR = 180 +NVML_FI_DEV_PCIE_COUNT_UNSUPPORTED_REQ = 181 +NVML_FI_DEV_PCIE_COUNT_LCRC_ERROR = 182 +NVML_FI_DEV_PCIE_COUNT_LANE_ERROR = 183 + +NVML_FI_DEV_IS_RESETLESS_MIG_SUPPORTED = 184 + +NVML_FI_DEV_POWER_AVERAGE = 185 +NVML_FI_DEV_POWER_INSTANT = 186 +NVML_FI_DEV_POWER_MIN_LIMIT = 187 +NVML_FI_DEV_POWER_MAX_LIMIT = 188 +NVML_FI_DEV_POWER_DEFAULT_LIMIT = 189 +NVML_FI_DEV_POWER_CURRENT_LIMIT = 190 +NVML_FI_DEV_ENERGY = 191 +NVML_FI_DEV_POWER_REQUESTED_LIMIT = 192 + +NVML_FI_DEV_TEMPERATURE_SHUTDOWN_TLIMIT = 193 +NVML_FI_DEV_TEMPERATURE_SLOWDOWN_TLIMIT = 194 +NVML_FI_DEV_TEMPERATURE_MEM_MAX_TLIMIT = 195 +NVML_FI_DEV_TEMPERATURE_GPU_MAX_TLIMIT = 196 + +NVML_FI_DEV_PCIE_COUNT_TX_BYTES = 197 +NVML_FI_DEV_PCIE_COUNT_RX_BYTES = 198 + +NVML_FI_DEV_IS_MIG_MODE_INDEPENDENT_MIG_QUERY_CAPABLE = 199 + +NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_MAX = 200 + +NVML_FI_DEV_NVLINK_COUNT_XMIT_PACKETS = 201 +NVML_FI_DEV_NVLINK_COUNT_XMIT_BYTES = 202 +NVML_FI_DEV_NVLINK_COUNT_RCV_PACKETS = 203 +NVML_FI_DEV_NVLINK_COUNT_RCV_BYTES = 204 +NVML_FI_DEV_NVLINK_COUNT_VL15_DROPPED = 205 # Deprecated, do not use +NVML_FI_DEV_NVLINK_COUNT_MALFORMED_PACKET_ERRORS = 206 +NVML_FI_DEV_NVLINK_COUNT_BUFFER_OVERRUN_ERRORS = 207 +NVML_FI_DEV_NVLINK_COUNT_RCV_ERRORS = 208 +NVML_FI_DEV_NVLINK_COUNT_RCV_REMOTE_ERRORS = 209 +NVML_FI_DEV_NVLINK_COUNT_RCV_GENERAL_ERRORS = 210 +NVML_FI_DEV_NVLINK_COUNT_LOCAL_LINK_INTEGRITY_ERRORS = 211 +NVML_FI_DEV_NVLINK_COUNT_XMIT_DISCARDS = 212 + +NVML_FI_DEV_NVLINK_COUNT_LINK_RECOVERY_SUCCESSFUL_EVENTS = 213 +NVML_FI_DEV_NVLINK_COUNT_LINK_RECOVERY_FAILED_EVENTS = 214 +NVML_FI_DEV_NVLINK_COUNT_LINK_RECOVERY_EVENTS = 215 + +NVML_FI_DEV_NVLINK_COUNT_RAW_BER_LANE0 = 216 # Deprecated, do not use +NVML_FI_DEV_NVLINK_COUNT_RAW_BER_LANE1 = 217 # Deprecated, do not use +NVML_FI_DEV_NVLINK_COUNT_RAW_BER = 218 # Deprecated, do not use +NVML_FI_DEV_NVLINK_COUNT_EFFECTIVE_ERRORS = 219 +NVML_FI_DEV_NVLINK_COUNT_EFFECTIVE_BER = 220 +NVML_FI_DEV_NVLINK_COUNT_SYMBOL_ERRORS = 221 +NVML_FI_DEV_NVLINK_COUNT_SYMBOL_BER = 222 + +NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_MIN = 223 +NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_UNITS = 224 # Values are in the form NVML_NVLINK_LOW_POWER_THRESHOLD_UNIT_* +NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_SUPPORTED = 225 + +NVML_FI_DEV_RESET_STATUS = 226 # Deprecated use NVML_FI_DEV_GET_GPU_RECOVERY_ACTION instead +NVML_FI_DEV_DRAIN_AND_RESET_STATUS = 227 # Deprecated use NVML_FI_DEV_GET_GPU_RECOVERY_ACTION instead +NVML_FI_DEV_PCIE_OUTBOUND_ATOMICS_MASK = 228 +NVML_FI_DEV_PCIE_INBOUND_ATOMICS_MASK = 229 +NVML_FI_DEV_GET_GPU_RECOVERY_ACTION = 230 + +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_0 = 235 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_1 = 236 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_2 = 237 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_3 = 238 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_4 = 239 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_5 = 240 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_6 = 241 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_7 = 242 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_8 = 243 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_9 = 244 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_10 = 245 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_11 = 246 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_12 = 247 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_13 = 248 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_14 = 249 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_15 = 250 +NVML_FI_PWR_SMOOTHING_ENABLED = 251 # Enablement (0/DISABLED or 1/ENABLED) +NVML_FI_PWR_SMOOTHING_PRIV_LVL = 252 # Current privilege level +NVML_FI_PWR_SMOOTHING_IMM_RAMP_DOWN_ENABLED = 253 # Immediate ramp down enablement (0/DISABLED or 1/ENABLED) +NVML_FI_PWR_SMOOTHING_APPLIED_TMP_CEIL = 254 # Applied TMP ceiling value +NVML_FI_PWR_SMOOTHING_APPLIED_TMP_FLOOR = 255 # Applied TMP floor value +NVML_FI_PWR_SMOOTHING_MAX_PERCENT_TMP_FLOOR_SETTING = 256 # Max % TMP Floor value +NVML_FI_PWR_SMOOTHING_MIN_PERCENT_TMP_FLOOR_SETTING = 257 # Min % TMP Floor value +NVML_FI_PWR_SMOOTHING_HW_CIRCUITRY_PERCENT_LIFETIME_REMAINING = 258 # HW Circuitry % lifetime remaining +NVML_FI_PWR_SMOOTHING_MAX_NUM_PRESET_PROFILES = 259 # Max number of preset profiles +NVML_FI_PWR_SMOOTHING_PROFILE_PERCENT_TMP_FLOOR = 260 # % TMP floor for a given profile +NVML_FI_PWR_SMOOTHING_PROFILE_RAMP_UP_RATE = 261 # Ramp up rate in mW/s for a given profile +NVML_FI_PWR_SMOOTHING_PROFILE_RAMP_DOWN_RATE = 262 # Ramp down rate in mW/s for a given profile +NVML_FI_PWR_SMOOTHING_PROFILE_RAMP_DOWN_HYST_VAL = 263 # Ramp down hysteresis value in ms for a given profile +NVML_FI_PWR_SMOOTHING_ACTIVE_PRESET_PROFILE = 264 # Active preset profile number +NVML_FI_PWR_SMOOTHING_ADMIN_OVERRIDE_PERCENT_TMP_FLOOR = 265 # % TMP floor for a given profile +NVML_FI_PWR_SMOOTHING_ADMIN_OVERRIDE_RAMP_UP_RATE = 266 # Ramp up rate in mW/s for a given profile +NVML_FI_PWR_SMOOTHING_ADMIN_OVERRIDE_RAMP_DOWN_RATE = 267 # Ramp down rate in mW/s for a given profile +NVML_FI_PWR_SMOOTHING_ADMIN_OVERRIDE_RAMP_DOWN_HYST_VAL = 268 # Ramp down hysteresis value in ms for a given profile + +NVML_FI_MAX = 269 # One greater than the largest field ID defined above + +# NVML_FI_DEV_NVLINK_GET_STATE state enums +NVML_NVLINK_STATE_INACTIVE = 0x0 +NVML_NVLINK_STATE_ACTIVE = 0x1 +NVML_NVLINK_STATE_SLEEP = 0x2 + +NVML_NVLINK_LOW_POWER_THRESHOLD_UNIT_100US = 0 # NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_UNITS +NVML_NVLINK_LOW_POWER_THRESHOLD_UNIT_50US = 1 # NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_UNITS + +## Enums needed for the method nvmlDeviceGetVirtualizationMode and nvmlDeviceSetVirtualizationMode +NVML_GPU_VIRTUALIZATION_MODE_NONE = 0 # Represents Bare Metal GPU +NVML_GPU_VIRTUALIZATION_MODE_PASSTHROUGH = 1 # Device is associated with GPU-Passthorugh +NVML_GPU_VIRTUALIZATION_MODE_VGPU = 2 # Device is associated with vGPU inside virtual machine. +NVML_GPU_VIRTUALIZATION_MODE_HOST_VGPU = 3 # Device is associated with VGX hypervisor in vGPU mode +NVML_GPU_VIRTUALIZATION_MODE_HOST_VSGA = 4 # Device is associated with VGX hypervisor in vSGA mode + +## Lib loading ## +nvmlLib = None +libLoadLock = threading.Lock() +_nvmlLib_refcount = 0 # Incremented on each nvmlInit and decremented on nvmlShutdown + +## vGPU Management +_nvmlVgpuTypeId_t = c_uint +_nvmlVgpuInstance_t = c_uint + +_nvmlVgpuVmIdType_t = c_uint +NVML_VGPU_VM_ID_DOMAIN_ID = 0 +NVML_VGPU_VM_ID_UUID = 1 + +_nvmlGridLicenseFeatureCode_t = c_uint +NVML_GRID_LICENSE_FEATURE_CODE_UNKNOWN = 0 +NVML_GRID_LICENSE_FEATURE_CODE_VGPU = 1 +NVML_GRID_LICENSE_FEATURE_CODE_NVIDIA_RTX = 2 +NVML_GRID_LICENSE_FEATURE_CODE_VWORKSTATION = 2 # deprecated, use NVML_GRID_LICENSE_FEATURE_CODE_NVIDIA_RTX. +NVML_GRID_LICENSE_FEATURE_CODE_GAMING = 3 +NVML_GRID_LICENSE_FEATURE_CODE_COMPUTE = 4 + +_nvmlGridLicenseExpiryStatus_t = c_uint8 +NVML_GRID_LICENSE_EXPIRY_NOT_AVAILABLE = 0, # Expiry information not available +NVML_GRID_LICENSE_EXPIRY_INVALID = 1, # Invalid expiry or error fetching expiry +NVML_GRID_LICENSE_EXPIRY_VALID = 2, # Valid expiry +NVML_GRID_LICENSE_EXPIRY_NOT_APPLICABLE = 3, # Expiry not applicable +NVML_GRID_LICENSE_EXPIRY_PERMANENT = 4, # Permanent expiry + +_nvmlVgpuCapability_t = c_uint +NVML_VGPU_CAP_NVLINK_P2P = 0 # vGPU P2P over NVLink is supported +NVML_VGPU_CAP_GPUDIRECT = 1 # GPUDirect capability is supported +NVML_VGPU_CAP_MULTI_VGPU_EXCLUSIVE = 2 # vGPU profile cannot be mixed with other vGPU profiles in same VM +NVML_VGPU_CAP_EXCLUSIVE_TYPE = 3 # vGPU profile cannot run on a GPU alongside other profiles of different type +NVML_VGPU_CAP_EXCLUSIVE_SIZE = 4 # vGPU profile cannot run on a GPU alongside other profiles of different size +NVML_VGPU_CAP_COUNT = 5 + +_nvmlVgpuDriverCapability_t = c_uint +NVML_VGPU_DRIVER_CAP_HETEROGENEOUS_MULTI_VGPU = 0 # Supports mixing of different vGPU profiles within one guest VM +NVML_VGPU_DRIVER_CAP_WARM_UPDATE = 1 # Supports FSR and warm update of vGPU host driver without terminating the running guest VM +NVML_VGPU_DRIVER_CAP_COUNT = 2 + +_nvmlDeviceVgpuCapability_t = c_uint +NVML_DEVICE_VGPU_CAP_FRACTIONAL_MULTI_VGPU = 0 # Query whether the fractional vGPU profiles on this GPU can be used in multi-vGPU configurations +NVML_DEVICE_VGPU_CAP_HETEROGENEOUS_TIMESLICE_PROFILES = 1 # Query whether the GPU supports concurrent execution of timesliced vGPU profiles of differing types +NVML_DEVICE_VGPU_CAP_HETEROGENEOUS_TIMESLICE_SIZES = 2 # Query whether the GPU supports concurrent execution of timesliced vGPU profiles of differing framebuffer sizes +NVML_DEVICE_VGPU_CAP_READ_DEVICE_BUFFER_BW = 3 # Query the GPU's read_device_buffer expected bandwidth capacity in megabytes per second +NVML_DEVICE_VGPU_CAP_WRITE_DEVICE_BUFFER_BW = 4 # Query the GPU's write_device_buffer expected bandwidth capacity in megabytes per second +NVML_DEVICE_VGPU_CAP_DEVICE_STREAMING = 5 # Query whether the vGPU profiles on the GPU supports migration data streaming +NVML_DEVICE_VGPU_CAP_MINI_QUARTER_GPU = 6 # Set/Get support of mini-quarter vGPU profiles +NVML_DEVICE_VGPU_CAP_COMPUTE_MEDIA_ENGINE_GPU = 7 # Set/Get support for compute media engine vGPU profiles +NVML_DEVICE_VGPU_CAP_WARM_UPDATE = 8 # Query whether the GPU supports FSR and warm update +NVML_DEVICE_VGPU_CAP_HOMOGENEOUS_PLACEMENTS = 9 # Query whether the GPU supports reporting of placements of timesliced vGPU profiles with identical framebuffer sizes +NVML_DEVICE_VGPU_CAP_COUNT = 10 + +_nvmlVgpuGuestInfoState_t = c_uint +NVML_VGPU_INSTANCE_GUEST_INFO_STATE_UNINITIALIZED = 0 +NVML_VGPU_INSTANCE_GUEST_INFO_STATE_INITIALIZED = 1 + +_nvmlVgpuVmCompatibility_t = c_uint +NVML_VGPU_VM_COMPATIBILITY_NONE = 0x0 +NVML_VGPU_VM_COMPATIBILITY_COLD = 0x1 +NVML_VGPU_VM_COMPATIBILITY_HIBERNATE = 0x2 +NVML_VGPU_VM_COMPATIBILITY_SLEEP = 0x4 +NVML_VGPU_VM_COMPATIBILITY_LIVE = 0x8 + +_nvmlVgpuPgpuCompatibilityLimitCode_t = c_uint +NVML_VGPU_COMPATIBILITY_LIMIT_NONE = 0x0 +NVML_VGPU_COMPATIBILITY_LIMIT_HOST_DRIVER = 0x1 +NVML_VGPU_COMPATIBILITY_LIMIT_GUEST_DRIVER = 0x2 +NVML_VGPU_COMPATIBILITY_LIMIT_GPU = 0x4 +NVML_VGPU_COMPATIBILITY_LIMIT_OTHER = 0x80000000 + +_nvmlHostVgpuMode_t = c_uint +NVML_HOST_VGPU_MODE_NON_SRIOV = 0 +NVML_HOST_VGPU_MODE_SRIOV = 1 + +_nvmlConfComputeGpusReadyState_t = c_uint +NVML_CC_ACCEPTING_CLIENT_REQUESTS_FALSE = 0 +NVML_CC_ACCEPTING_CLIENT_REQUESTS_TRUE = 1 + +_nvmlConfComputeGpuCaps_t = c_uint +NVML_CC_SYSTEM_GPUS_CC_NOT_CAPABLE = 0 +NVML_CC_SYSTEM_GPUS_CC_CAPABLE = 1 + +_nvmlConfComputeCpuCaps_t = c_uint +NVML_CC_SYSTEM_CPU_CAPS_NONE = 0 +NVML_CC_SYSTEM_CPU_CAPS_AMD_SEV = 1 +NVML_CC_SYSTEM_CPU_CAPS_INTEL_TDX = 2 +NVML_CC_SYSTEM_CPU_CAPS_AMD_SEV_SNP = 3 +NVML_CC_SYSTEM_CPU_CAPS_AMD_SNP_VTOM = 4 + +_nvmlConfComputeDevToolsMode_t = c_uint +NVML_CC_SYSTEM_DEVTOOLS_MODE_OFF = 0 +NVML_CC_SYSTEM_DEVTOOLS_MODE_ON = 1 + +NVML_CC_SYSTEM_MULTIGPU_NONE = 0 +NVML_CC_SYSTEM_MULTIGPU_PROTECTED_PCIE = 1 + +NVML_CC_SYSTEM_ENVIRONMENT_UNAVAILABLE = 0 +NVML_CC_SYSTEM_ENVIRONMENT_SIM = 1 +NVML_CC_SYSTEM_ENVIRONMENT_PROD = 2 + +_nvmlConfComputeCcFeature_t = c_uint +NVML_CC_SYSTEM_FEATURE_DISABLED = 0 +NVML_CC_SYSTEM_FEATURE_ENABLED = 1 + +_nvmlConfComputeCcKeyRotationThreshAttackerAdv_t = c_uint +NVML_CC_KEY_ROTATION_THRESH_ATTACKER_ADVANTAGE_MIN = 50 +NVML_CC_KEY_ROTATION_THRESH_ATTACKER_ADVANTAGE_MAX = 65 + +# GSP firmware +NVML_GSP_FIRMWARE_VERSION_BUF_SIZE = 0x40 + +class NVMLLibraryMismatchError(Exception): + pass + +## Error Checking ## +class NVMLError(Exception): + _valClassMapping = dict() + # List of currently known error codes + _errcode_to_string = { + NVML_ERROR_UNINITIALIZED: "Uninitialized", + NVML_ERROR_INVALID_ARGUMENT: "Invalid Argument", + NVML_ERROR_NOT_SUPPORTED: "Not Supported", + NVML_ERROR_NO_PERMISSION: "Insufficient Permissions", + NVML_ERROR_ALREADY_INITIALIZED: "Already Initialized", + NVML_ERROR_NOT_FOUND: "Not Found", + NVML_ERROR_INSUFFICIENT_SIZE: "Insufficient Size", + NVML_ERROR_INSUFFICIENT_POWER: "Insufficient External Power", + NVML_ERROR_DRIVER_NOT_LOADED: "Driver Not Loaded", + NVML_ERROR_TIMEOUT: "Timeout", + NVML_ERROR_IRQ_ISSUE: "Interrupt Request Issue", + NVML_ERROR_LIBRARY_NOT_FOUND: "NVML Shared Library Not Found", + NVML_ERROR_FUNCTION_NOT_FOUND: "Function Not Found", + NVML_ERROR_CORRUPTED_INFOROM: "Corrupted infoROM", + NVML_ERROR_GPU_IS_LOST: "GPU is lost", + NVML_ERROR_RESET_REQUIRED: "GPU requires restart", + NVML_ERROR_OPERATING_SYSTEM: "The operating system has blocked the request.", + NVML_ERROR_LIB_RM_VERSION_MISMATCH: "RM has detected an NVML/RM version mismatch.", + NVML_ERROR_MEMORY: "Insufficient Memory", + NVML_ERROR_UNKNOWN: "Unknown Error", + } + def __new__(typ, value): + ''' + Maps value to a proper subclass of NVMLError. + See _extractNVMLErrorsAsClasses function for more details + ''' + if typ == NVMLError: + typ = NVMLError._valClassMapping.get(value, typ) + obj = Exception.__new__(typ) + obj.value = value + return obj + def __str__(self): + try: + if self.value not in NVMLError._errcode_to_string: + NVMLError._errcode_to_string[self.value] = str(nvmlErrorString(self.value)) + return NVMLError._errcode_to_string[self.value] + except NVMLError: + return "NVML Error with code %d" % self.value + def __eq__(self, other): + return self.value == other.value + +def nvmlExceptionClass(nvmlErrorCode): + if nvmlErrorCode not in NVMLError._valClassMapping: + raise ValueError('nvmlErrorCode %s is not valid' % nvmlErrorCode) + return NVMLError._valClassMapping[nvmlErrorCode] + +def _extractNVMLErrorsAsClasses(): + ''' + Generates a hierarchy of classes on top of NVMLError class. + + Each NVML Error gets a new NVMLError subclass. This way try,except blocks can filter appropriate + exceptions more easily. + + NVMLError is a parent class. Each NVML_ERROR_* gets it's own subclass. + e.g. NVML_ERROR_ALREADY_INITIALIZED will be turned into NVMLError_AlreadyInitialized + ''' + this_module = sys.modules[__name__] + nvmlErrorsNames = [x for x in dir(this_module) if x.startswith("NVML_ERROR_")] + for err_name in nvmlErrorsNames: + # e.g. Turn NVML_ERROR_ALREADY_INITIALIZED into NVMLError_AlreadyInitialized + class_name = "NVMLError_" + string.capwords(err_name.replace("NVML_ERROR_", ""), "_").replace("_", "") + err_val = getattr(this_module, err_name) + def gen_new(val): + def new(typ): + obj = NVMLError.__new__(typ, val) + return obj + return new + new_error_class = type(class_name, (NVMLError,), {'__new__': gen_new(err_val)}) + new_error_class.__module__ = __name__ + setattr(this_module, class_name, new_error_class) + NVMLError._valClassMapping[err_val] = new_error_class +_extractNVMLErrorsAsClasses() + +def _nvmlCheckReturn(ret): + if (ret != NVML_SUCCESS): + raise NVMLError(ret) + return ret + +## Function access ## +_nvmlGetFunctionPointer_cache = dict() # function pointers are cached to prevent unnecessary libLoadLock locking +def _nvmlGetFunctionPointer(name): + global nvmlLib + + if name in _nvmlGetFunctionPointer_cache: + return _nvmlGetFunctionPointer_cache[name] + + libLoadLock.acquire() + try: + # ensure library was loaded + if (nvmlLib == None): + raise NVMLError(NVML_ERROR_UNINITIALIZED) + try: + _nvmlGetFunctionPointer_cache[name] = getattr(nvmlLib, name) + return _nvmlGetFunctionPointer_cache[name] + except AttributeError: + raise NVMLError(NVML_ERROR_FUNCTION_NOT_FOUND) + finally: + # lock is always freed + libLoadLock.release() + +## Alternative object +# Allows the object to be printed +# Allows mismatched types to be assigned +# - like None when the Structure variant requires c_uint +class nvmlFriendlyObject(object): + def __init__(self, dictionary): + for x in dictionary: + setattr(self, x, dictionary[x]) + def __str__(self): + return self.__dict__.__str__() + +def nvmlStructToFriendlyObject(struct): + d = {} + for x in struct._fields_: + key = x[0] + value = getattr(struct, key) + # only need to convert from bytes if bytes, no need to check python version. + d[key] = value.decode() if isinstance(value, bytes) else value + obj = nvmlFriendlyObject(d) + return obj + +# pack the object so it can be passed to the NVML library +def nvmlFriendlyObjectToStruct(obj, model): + for x in model._fields_: + key = x[0] + value = obj.__dict__[key] + # any c_char_p in python3 needs to be bytes, default encoding works fine. + if sys.version_info >= (3,): + setattr(model, key, value.encode()) + else: + setattr(model, key, value) + return model + +## Unit structures +class struct_c_nvmlUnit_t(Structure): + pass # opaque handle +c_nvmlUnit_t = POINTER(struct_c_nvmlUnit_t) + +class _PrintableStructure(Structure): + """ + Abstract class that produces nicer __str__ output than ctypes.Structure. + e.g. instead of: + >>> print str(obj) + + this class will print + class_name(field_name: formatted_value, field_name: formatted_value) + + _fmt_ dictionary of -> + e.g. class that has _field_ 'hex_value', c_uint could be formatted with + _fmt_ = {"hex_value" : "%08X"} + to produce nicer output. + Default fomratting string for all fields can be set with key "" like: + _fmt_ = {"" : "%d MHz"} # e.g all values are numbers in MHz. + If not set it's assumed to be just "%s" + + Exact format of returned str from this class is subject to change in the future. + """ + _fmt_ = {} + def __str__(self): + result = [] + for x in self._fields_: + key = x[0] + value = getattr(self, key) + fmt = "%s" + if key in self._fmt_: + fmt = self._fmt_[key] + elif "" in self._fmt_: + fmt = self._fmt_[""] + result.append(("%s: " + fmt) % (key, value)) + return self.__class__.__name__ + "(" + ", ".join(result) + ")" + + def __getattribute__(self, name): + res = super(_PrintableStructure, self).__getattribute__(name) + # need to convert bytes to unicode for python3 don't need to for python2 + # Python 2 strings are of both str and bytes + # Python 3 strings are not of type bytes + # ctypes should convert everything to the correct values otherwise + if isinstance(res, bytes): + if isinstance(res, str): + return res + return res.decode() + return res + + def __setattr__(self, name, value): + if isinstance(value, str): + # encoding a python2 string returns the same value, since python2 strings are bytes already + # bytes passed in python3 will be ignored. + value = value.encode() + super(_PrintableStructure, self).__setattr__(name, value) + +class c_nvmlUnitInfo_t(_PrintableStructure): + _fields_ = [ + ('name', c_char * 96), + ('id', c_char * 96), + ('serial', c_char * 96), + ('firmwareVersion', c_char * 96), + ] + +class c_nvmlC2cModeInfo_v1_t(_PrintableStructure): + _fields_ = [ + ('isC2cEnabled', c_uint) + ] + +nvmlC2cModeInfo_v1 = 0x1000008; + +class c_nvmlLedState_t(_PrintableStructure): + _fields_ = [ + ('cause', c_char * 256), + ('color', _nvmlLedColor_t), + ] + +class c_nvmlPSUInfo_t(_PrintableStructure): + _fields_ = [ + ('state', c_char * 256), + ('current', c_uint), + ('voltage', c_uint), + ('power', c_uint), + ] + +class c_nvmlUnitFanInfo_t(_PrintableStructure): + _fields_ = [ + ('speed', c_uint), + ('state', _nvmlFanState_t), + ] + +class c_nvmlUnitFanSpeeds_t(_PrintableStructure): + _fields_ = [ + ('fans', c_nvmlUnitFanInfo_t * 24), + ('count', c_uint) + ] + +## Device structures +class struct_c_nvmlDevice_t(Structure): + pass # opaque handle +c_nvmlDevice_t = POINTER(struct_c_nvmlDevice_t) + +class nvmlPciInfoExt_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('domain', c_uint), + ('bus', c_uint), + ('device', c_uint), + ('pciDeviceId', c_uint), + ('pciSubSystemId', c_uint), + ('baseClass', c_uint), + ('subClass', c_uint), + ('busId', c_char * NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE), + ] + _fmt_ = { + 'version' : "0x%04X", + 'domain' : "0x%04X", + 'bus' : "0x%02X", + 'device' : "0x%02X", + 'pciDeviceId' : "0x%08X", + 'pciSubSystemId' : "0x%08X", + 'baseClass' : "0x%01X", + 'subClass' : "0x%01X", + } + +nvmlPciInfoExt_v1 = 0x1000040 + +# Legacy pciInfo used for _v1 and _v2 +class nvmlPciInfo_v2_t(_PrintableStructure): + _fields_ = [ + ('busId', c_char * NVML_DEVICE_PCI_BUS_ID_BUFFER_V2_SIZE), + ('domain', c_uint), + ('bus', c_uint), + ('device', c_uint), + ('pciDeviceId', c_uint), + + # Added in 2.285 + ('pciSubSystemId', c_uint), + ('reserved0', c_uint), + ('reserved1', c_uint), + ('reserved2', c_uint), + ('reserved3', c_uint), + ] + _fmt_ = { + 'domain' : "0x%04X", + 'bus' : "0x%02X", + 'device' : "0x%02X", + 'pciDeviceId' : "0x%08X", + 'pciSubSystemId' : "0x%08X", + } + +class nvmlPciInfo_t(_PrintableStructure): + _fields_ = [ + # Moved to the new busId location below + ('busIdLegacy', c_char * NVML_DEVICE_PCI_BUS_ID_BUFFER_V2_SIZE), + ('domain', c_uint), + ('bus', c_uint), + ('device', c_uint), + ('pciDeviceId', c_uint), + + # Added in 2.285 + ('pciSubSystemId', c_uint), + # New busId replaced the long deprecated and reserved fields with a + # field of the same size in 9.0 + ('busId', c_char * NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE), + ] + _fmt_ = { + 'domain' : "0x%08X", + 'bus' : "0x%02X", + 'device' : "0x%02X", + 'pciDeviceId' : "0x%08X", + 'pciSubSystemId' : "0x%08X", + } + +class c_nvmlSystemDriverBranchInfo_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ("branch", c_char * NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE), + ] + +SystemDriverBranchInfo_v1 = 0x1000054 + +class c_nvmlExcludedDeviceInfo_t(_PrintableStructure): + _fields_ = [ + ('pci', nvmlPciInfo_t), + ('uuid', c_char * NVML_DEVICE_UUID_BUFFER_SIZE) + ] + +class nvmlNvLinkUtilizationControl_t(_PrintableStructure): + _fields_ = [ + ('units', _nvmlNvLinkUtilizationCountUnits_t), + ('pktfilter', _nvmlNvLinkUtilizationCountPktTypes_t), + ] + +class c_nvmlMemory_t(_PrintableStructure): + _fields_ = [ + ('total', c_ulonglong), + ('free', c_ulonglong), + ('used', c_ulonglong), + ] + _fmt_ = {'': "%d B"} + +class c_nvmlMemory_v2_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('total', c_ulonglong), + ('reserved', c_ulonglong), + ('free', c_ulonglong), + ('used', c_ulonglong), + ] + _fmt_ = {'': "%d B"} + +nvmlMemory_v2 = 0x02000028 + +class c_nvmlBAR1Memory_t(_PrintableStructure): + _fields_ = [ + ('bar1Total', c_ulonglong), + ('bar1Free', c_ulonglong), + ('bar1Used', c_ulonglong), + ] + _fmt_ = {'': "%d B"} + +class nvmlClkMonFaultInfo_t(Structure): + _fields_ = [("clkApiDomain", c_uint), + ("clkDomainFaultMask", c_uint) + ] + +MAX_CLK_DOMAINS = 32 + +class nvmlClkMonStatus_t(Structure): + _fields_ = [("bGlobalStatus", c_uint), + ("clkMonListSize", c_uint), + ("clkMonList", nvmlClkMonFaultInfo_t * MAX_CLK_DOMAINS) + ] + +# On Windows with the WDDM driver, usedGpuMemory is reported as None +# Code that processes this structure should check for None, I.E. +# +# if (info.usedGpuMemory == None): +# # TODO handle the error +# pass +# else: +# print("Using %d MiB of memory" % (info.usedGpuMemory / 1024 / 1024)) +# endif +# +# See NVML documentation for more information +class c_nvmlProcessInfo_v2_t(_PrintableStructure): + _fields_ = [ + ('pid', c_uint), + ('usedGpuMemory', c_ulonglong), + ('gpuInstanceId', c_uint), + ('computeInstanceId', c_uint), + ] + _fmt_ = {'usedGpuMemory': "%d B"} + +c_nvmlProcessInfo_v3_t = c_nvmlProcessInfo_v2_t + +c_nvmlProcessInfo_t = c_nvmlProcessInfo_v3_t + +_nvmlProcessMode_t = c_uint +NVML_PROCESS_MODE_COMPUTE = 0 +NVML_PROCESS_MODE_GRAPHICS = 1 +NVML_PROCESS_MODE_MPS = 2 + +class c_nvmlProcessDetail_v1_t(Structure): + _fields_ = [ + ('pid', c_uint), + ('usedGpuMemory', c_ulonglong), + ('gpuInstanceId', c_uint), + ('computeInstanceId', c_uint), + ('usedGpuCcProtectedMemory', c_ulonglong), + ] + +class c_nvmlProcessDetailList_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('mode', _nvmlProcessMode_t), + ('numProcArrayEntries', c_uint), + ('procArray', POINTER(c_nvmlProcessDetail_v1_t)), + ] + _fmt_ = {'numProcArrayEntries': "%d B"} + +c_nvmlProcessDetailList_t = c_nvmlProcessDetailList_v1_t + +nvmlProcessDetailList_v1 = 0x1000018 + +class c_nvmlBridgeChipInfo_t(_PrintableStructure): + _fields_ = [ + ('type', _nvmlBridgeChipType_t), + ('fwVersion', c_uint), + ] + +class c_nvmlBridgeChipHierarchy_t(_PrintableStructure): + _fields_ = [ + ('bridgeCount', c_uint), + ('bridgeChipInfo', c_nvmlBridgeChipInfo_t * 128), + ] + +class c_nvmlEccErrorCounts_t(_PrintableStructure): + _fields_ = [ + ('l1Cache', c_ulonglong), + ('l2Cache', c_ulonglong), + ('deviceMemory', c_ulonglong), + ('registerFile', c_ulonglong), + ] + +class c_nvmlUtilization_t(_PrintableStructure): + _fields_ = [ + ('gpu', c_uint), + ('memory', c_uint), + ] + _fmt_ = {'': "%d %%"} + +# Added in 2.285 +class c_nvmlHwbcEntry_t(_PrintableStructure): + _fields_ = [ + ('hwbcId', c_uint), + ('firmwareVersion', c_char * 32), + ] + +class c_nvmlValue_t(Union): + _fields_ = [ + ('dVal', c_double), + ('uiVal', c_uint), + ('ulVal', c_ulong), + ('ullVal', c_ulonglong), + ('sllVal', c_longlong), + ('siVal', c_int), + ('usVal', c_ushort), + ] + +class c_nvmlSample_t(_PrintableStructure): + _fields_ = [ + ('timeStamp', c_ulonglong), + ('sampleValue', c_nvmlValue_t), + ] + +class c_nvmlViolationTime_t(_PrintableStructure): + _fields_ = [ + ('referenceTime', c_ulonglong), + ('violationTime', c_ulonglong), + ] + +class c_nvmlFieldValue_t(_PrintableStructure): + _fields_ = [ + ('fieldId', c_uint32), + ('scopeId', c_uint32), + ('timestamp', c_int64), + ('latencyUsec', c_int64), + ('valueType', _nvmlValueType_t), + ('nvmlReturn', _nvmlReturn_t), + ('value', c_nvmlValue_t) + ] + +NVML_NVLINK_TOTAL_SUPPORTED_BW_MODES = 23 + +nvmlNvlinkSupportedBwModes_v1 = 0x100001c +class c_nvmlNvlinkSupportedBwModes_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('bwModes', c_uint8 * NVML_NVLINK_TOTAL_SUPPORTED_BW_MODES), + ('totalBwModes', c_uint8) + ] + + def __init__(self): + super(c_nvmlNvlinkSupportedBwModes_v1_t, self).__init__(version=nvmlNvlinkSupportedBwModes_v1) + +nvmlNvlinkGetBwMode_v1 = 0x100000c +class c_nvmlNvlinkGetBwMode_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('bIsBest', c_uint), + ('bwMode', c_uint8) + ] + + def __init__(self): + super(c_nvmlNvlinkGetBwMode_v1_t, self).__init__(version=nvmlNvlinkGetBwMode_v1) + +nvmlNvlinkSetBwMode_v1 = 0x100000c +class c_nvmlNvlinkSetBwMode_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('bSetBest', c_uint), + ('bwMode', c_uint8) + ] + + def __init__(self): + super(c_nvmlNvlinkSetBwMode_v1_t, self).__init__(version=nvmlNvlinkSetBwMode_v1) + +class c_nvmlVgpuHeterogeneousMode_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('mode', c_uint), + ] + +VgpuHeterogeneousMode_v1 = 0x1000008 + +class c_nvmlVgpuPlacementId_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('placementId', c_uint), + ] + +VgpuPlacementId_v1 = 0x1000008 + +class c_nvmlVgpuPlacementList_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('count', c_uint), + ('placementSize', c_uint), + ('placementIds', POINTER(c_uint)), + ] + +VgpuPlacementList_v1 = 0x1000018 + +NVML_VGPU_PGPU_HETEROGENEOUS_MODE = 0 +NVML_VGPU_PGPU_HOMOGENEOUS_MODE = 1 + +class c_nvmlVgpuPlacementList_v2_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('placementSize', c_uint), + ('count', c_uint), + ('placementIds', POINTER(c_uint)), + ('mode', c_uint), + ] + +VgpuPlacementList_v2 = 0x2000020 + +class c_nvmlVgpuTypeBar1Info_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('bar1Size', c_ulonglong), + ] + +VgpuTypeBar1Info_v1 = 0x1000010 + +class c_nvmlVgpuInstanceUtilizationSample_t(_PrintableStructure): + _fields_ = [ + ('vgpuInstance', _nvmlVgpuInstance_t), + ('timeStamp', c_ulonglong), + ('smUtil', c_nvmlValue_t), + ('memUtil', c_nvmlValue_t), + ('encUtil', c_nvmlValue_t), + ('decUtil', c_nvmlValue_t), + ] + +class c_nvmlVgpuInstanceUtilizationInfo_v1_t(_PrintableStructure): + _fields_ = [ + ('timeStamp', c_ulonglong), + ('vgpuInstance', _nvmlVgpuInstance_t), + ('smUtil', c_nvmlValue_t), + ('memUtil', c_nvmlValue_t), + ('encUtil', c_nvmlValue_t), + ('decUtil', c_nvmlValue_t), + ('jpgUtil', c_nvmlValue_t), + ('ofaUtil', c_nvmlValue_t), + ] + +class c_nvmlVgpuInstancesUtilizationInfo_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('sampleValType', _nvmlValueType_t), + ('vgpuInstanceCount', c_uint), + ('lastSeenTimeStamp', c_ulonglong), + ('vgpuUtilArray', POINTER(c_nvmlVgpuInstanceUtilizationInfo_v1_t)), + ] + +VgpuInstancesUtilizationInfo_v1 = 0x01000020 + +class c_nvmlVgpuProcessUtilizationSample_t(_PrintableStructure): + _fields_ = [ + ('vgpuInstance', _nvmlVgpuInstance_t), + ('pid', c_uint), + ('processName', c_char * NVML_VGPU_NAME_BUFFER_SIZE), + ('timeStamp', c_ulonglong), + ('smUtil', c_uint), + ('memUtil', c_uint), + ('encUtil', c_uint), + ('decUtil', c_uint), + ] + +class c_nvmlVgpuProcessUtilizationInfo_v1_t(_PrintableStructure): + _fields_ = [ + ('processName', c_char * NVML_VGPU_NAME_BUFFER_SIZE), + ('timeStamp', c_ulonglong), + ('vgpuInstance', _nvmlVgpuInstance_t), + ('pid', c_uint), + ('smUtil', c_uint), + ('memUtil', c_uint), + ('encUtil', c_uint), + ('decUtil', c_uint), + ('jpgUtil', c_uint), + ('ofaUtil', c_uint), + ] + +class c_nvmlVgpuProcessesUtilizationInfo_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('vgpuProcessCount', c_uint), + ('lastSeenTimeStamp', c_ulonglong), + ('vgpuProcUtilArray', POINTER(c_nvmlVgpuProcessUtilizationInfo_v1_t)), + ] + +VgpuProcessesUtilizationInfo_v1 = 0x01000018 + +class nvmlVgpuRuntimeState_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('size', c_ulonglong), + ] + +VgpuRuntimeState_v1 = 0x1000010 + +class c_nvmlVgpuLicenseExpiry_t(_PrintableStructure): + _fields_ = [ + ('year', c_uint32), + ('month', c_uint16), + ('day', c_uint16), + ('hour', c_uint16), + ('min', c_uint16), + ('sec', c_uint16), + ('status', c_uint8), + ] + +NVML_GRID_LICENSE_STATE_UNKNOWN = 0 +NVML_GRID_LICENSE_STATE_UNINITIALIZED = 1 +NVML_GRID_LICENSE_STATE_UNLICENSED_UNRESTRICTED = 2 +NVML_GRID_LICENSE_STATE_UNLICENSED_RESTRICTED = 3 +NVML_GRID_LICENSE_STATE_UNLICENSED = 4 +NVML_GRID_LICENSE_STATE_LICENSED = 5 + +class c_nvmlVgpuLicenseInfo_t(_PrintableStructure): + _fields_ = [ + ('isLicensed', c_uint8), + ('licenseExpiry', c_nvmlVgpuLicenseExpiry_t), + ('currentState', c_uint), + ] + +class c_nvmlEncoderSession_t(_PrintableStructure): + _fields_ = [ + ('sessionId', c_uint), + ('pid', c_uint), + ('vgpuInstance', _nvmlVgpuInstance_t), + ('codecType', c_uint), + ('hResolution', c_uint), + ('vResolution', c_uint), + ('averageFps', c_uint), + ('encodeLatency', c_uint), + ] + +class c_nvmlProcessUtilizationSample_t(_PrintableStructure): + _fields_ = [ + ('pid', c_uint), + ('timeStamp', c_ulonglong), + ('smUtil', c_uint), + ('memUtil', c_uint), + ('encUtil', c_uint), + ('decUtil', c_uint), + ] + +class c_nvmlProcessUtilizationInfo_v1_t(_PrintableStructure): + _fields_ = [ + ('timeStamp', c_ulonglong), + ('pid', c_uint), + ('smUtil', c_uint), + ('memUtil', c_uint), + ('encUtil', c_uint), + ('decUtil', c_uint), + ('jpgUtil', c_uint), + ('ofaUtil', c_uint), + ] + +class c_nvmlProcessesUtilizationInfo_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('processSamplesCount', c_uint), + ('lastSeenTimeStamp', c_ulonglong), + ('procUtilArray', POINTER(c_nvmlProcessUtilizationInfo_v1_t)), + ] + +ProcessesUtilizationInfo_v1 = 0x01000018 + +class c_nvmlGridLicenseExpiry_t(_PrintableStructure): + _fields_ = [ + ('year', c_uint32), + ('month', c_uint16), + ('day', c_uint16), + ('hour', c_uint16), + ('min', c_uint16), + ('sec', c_uint16), + ('status', c_uint8), + ] + +class c_nvmlGridLicensableFeature_v4_t(_PrintableStructure): + _fields_ = [ + ('featureCode', _nvmlGridLicenseFeatureCode_t), + ('featureState', c_uint), + ('licenseInfo', c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ('productName', c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ('featureEnabled', c_uint), + ('licenseExpiry', c_nvmlGridLicenseExpiry_t), + ] + +class c_nvmlGridLicensableFeatures_v4_t(_PrintableStructure): + _fields_ = [ + ('isGridLicenseSupported', c_int), + ('licensableFeaturesCount', c_uint), + ('gridLicensableFeatures', c_nvmlGridLicensableFeature_v4_t * NVML_GRID_LICENSE_FEATURE_MAX_COUNT), + ] + +class c_nvmlGridLicensableFeature_v3_t(_PrintableStructure): + _fields_ = [ + ('featureCode', _nvmlGridLicenseFeatureCode_t), + ('featureState', c_uint), + ('licenseInfo', c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ('productName', c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ('featureEnabled', c_uint), + ] + +class c_nvmlGridLicensableFeatures_v3_t(_PrintableStructure): + _fields_ = [ + ('isGridLicenseSupported', c_int), + ('licensableFeaturesCount', c_uint), + ('gridLicensableFeatures', c_nvmlGridLicensableFeature_v3_t * NVML_GRID_LICENSE_FEATURE_MAX_COUNT), + ] + +class c_nvmlGridLicensableFeature_v2_t(_PrintableStructure): + _fields_ = [ + ('featureCode', _nvmlGridLicenseFeatureCode_t), + ('featureState', c_uint), + ('licenseInfo', c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ('productName', c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ] + +class c_nvmlGridLicensableFeatures_v2_t(_PrintableStructure): + _fields_ = [ + ('isGridLicenseSupported', c_int), + ('licensableFeaturesCount', c_uint), + ('gridLicensableFeatures', c_nvmlGridLicensableFeature_v2_t * NVML_GRID_LICENSE_FEATURE_MAX_COUNT), + ] + +class c_nvmlGridLicensableFeature_t(_PrintableStructure): + _fields_ = [ + ('featureCode', _nvmlGridLicenseFeatureCode_t), + ('featureState', c_uint), + ('licenseInfo', c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ] + +class c_nvmlGridLicensableFeatures_t(_PrintableStructure): + _fields_ = [ + ('isGridLicenseSupported', c_int), + ('licensableFeaturesCount', c_uint), + ('gridLicensableFeatures', c_nvmlGridLicensableFeature_t * NVML_GRID_LICENSE_FEATURE_MAX_COUNT), + ] + +class c_nvmlMarginTemperature_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('marginTemperature', c_int), + ] + +nvmlMarginTemperature_v1 = 0x1000008 + +## Event structures +class struct_c_nvmlEventSet_t(Structure): + pass # opaque handle +c_nvmlEventSet_t = POINTER(struct_c_nvmlEventSet_t) + +nvmlEventTypeSingleBitEccError = 0x0000000000000001 +nvmlEventTypeDoubleBitEccError = 0x0000000000000002 +nvmlEventTypePState = 0x0000000000000004 +nvmlEventTypeXidCriticalError = 0x0000000000000008 +nvmlEventTypeClock = 0x0000000000000010 +nvmlEventTypePowerSourceChange = 0x0000000000000080 +nvmlEventMigConfigChange = 0x0000000000000100 +nvmlEventTypeSingleBitEccErrorStorm = 0x0000000000000200 +nvmlEventTypeDramRetirementEvent = 0x0000000000000400 +nvmlEventTypeDramRetirementFailure = 0x0000000000000800 +nvmlEventTypeNonFatalPoisonError = 0x0000000000001000 +nvmlEventTypeFatalPoisonError = 0x0000000000002000 +nvmlEventTypeGpuUnavailableError = 0x0000000000004000 +nvmlEventTypeGpuRecoveryAction = 0x0000000000008000 +nvmlEventTypeNone = 0x0000000000000000 +nvmlEventTypeAll = ( + nvmlEventTypeNone + | nvmlEventTypeSingleBitEccError + | nvmlEventTypeDoubleBitEccError + | nvmlEventTypePState + | nvmlEventTypeClock + | nvmlEventTypePowerSourceChange + | nvmlEventTypeXidCriticalError + | nvmlEventMigConfigChange + | nvmlEventTypeSingleBitEccErrorStorm + | nvmlEventTypeDramRetirementEvent + | nvmlEventTypeDramRetirementFailure + | nvmlEventTypeNonFatalPoisonError + | nvmlEventTypeFatalPoisonError + | nvmlEventTypeGpuUnavailableError + | nvmlEventTypeGpuRecoveryAction + ) + +## Clock Event Reasons defines +nvmlClocksEventReasonGpuIdle = 0x0000000000000001 +nvmlClocksEventReasonApplicationsClocksSetting = 0x0000000000000002 +nvmlClocksEventReasonUserDefinedClocks = nvmlClocksEventReasonApplicationsClocksSetting # deprecated, use nvmlClocksEventReasonApplicationsClocksSetting +nvmlClocksEventReasonSwPowerCap = 0x0000000000000004 +nvmlClocksEventReasonHwSlowdown = 0x0000000000000008 +nvmlClocksEventReasonSyncBoost = 0x0000000000000010 +nvmlClocksEventReasonSwThermalSlowdown = 0x0000000000000020 +nvmlClocksEventReasonHwThermalSlowdown = 0x0000000000000040 +nvmlClocksEventReasonHwPowerBrakeSlowdown = 0x0000000000000080 +nvmlClocksEventReasonDisplayClockSetting = 0x0000000000000100 +nvmlClocksEventReasonNone = 0x0000000000000000 +nvmlClocksEventReasonAll = ( + nvmlClocksEventReasonNone | + nvmlClocksEventReasonGpuIdle | + nvmlClocksEventReasonApplicationsClocksSetting | + nvmlClocksEventReasonSwPowerCap | + nvmlClocksEventReasonHwSlowdown | + nvmlClocksEventReasonSyncBoost | + nvmlClocksEventReasonSwThermalSlowdown | + nvmlClocksEventReasonHwThermalSlowdown | + nvmlClocksEventReasonHwPowerBrakeSlowdown | + nvmlClocksEventReasonDisplayClockSetting + ) + +## Following have been deprecated +nvmlClocksThrottleReasonGpuIdle = 0x0000000000000001 +nvmlClocksThrottleReasonApplicationsClocksSetting = 0x0000000000000002 +nvmlClocksThrottleReasonUserDefinedClocks = nvmlClocksThrottleReasonApplicationsClocksSetting # deprecated, use nvmlClocksThrottleReasonApplicationsClocksSetting +nvmlClocksThrottleReasonSwPowerCap = 0x0000000000000004 +nvmlClocksThrottleReasonHwSlowdown = 0x0000000000000008 +nvmlClocksThrottleReasonSyncBoost = 0x0000000000000010 +nvmlClocksThrottleReasonSwThermalSlowdown = 0x0000000000000020 +nvmlClocksThrottleReasonHwThermalSlowdown = 0x0000000000000040 +nvmlClocksThrottleReasonHwPowerBrakeSlowdown = 0x0000000000000080 +nvmlClocksThrottleReasonDisplayClockSetting = 0x0000000000000100 +nvmlClocksThrottleReasonNone = 0x0000000000000000 +nvmlClocksThrottleReasonAll = ( + nvmlClocksThrottleReasonNone | + nvmlClocksThrottleReasonGpuIdle | + nvmlClocksThrottleReasonApplicationsClocksSetting | + nvmlClocksThrottleReasonSwPowerCap | + nvmlClocksThrottleReasonHwSlowdown | + nvmlClocksThrottleReasonSyncBoost | + nvmlClocksThrottleReasonSwThermalSlowdown | + nvmlClocksThrottleReasonHwThermalSlowdown | + nvmlClocksThrottleReasonHwPowerBrakeSlowdown | + nvmlClocksThrottleReasonDisplayClockSetting + ) + +class c_nvmlEventData_t(_PrintableStructure): + _fields_ = [ + ('device', c_nvmlDevice_t), + ('eventType', c_ulonglong), + ('eventData', c_ulonglong), + ('gpuInstanceId', c_uint), + ('computeInstanceId', c_uint) + ] + _fmt_ = {'eventType': "0x%08X"} + +class c_nvmlAccountingStats_t(_PrintableStructure): + _fields_ = [ + ('gpuUtilization', c_uint), + ('memoryUtilization', c_uint), + ('maxMemoryUsage', c_ulonglong), + ('time', c_ulonglong), + ('startTime', c_ulonglong), + ('isRunning', c_uint), + ('reserved', c_uint * 5) + ] + +class c_nvmlVgpuVersion_t(Structure): + _fields_ = [("minVersion", c_uint), + ("maxVersion", c_uint) + ] + +class c_nvmlVgpuMetadata_t(_PrintableStructure): + _fields_ = [("version", c_uint), + ("revision", c_uint), + ("guestInfoState", _nvmlVgpuGuestInfoState_t), + ("guestDriverVersion", c_char * NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE), + ("hostDriverVersion", c_char * NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE), + ("reserved", c_uint * 6), + ("vgpuVirtualizationCaps", c_uint), + ("guestVgpuVersion", c_uint), + ("opaqueDataSize", c_uint), + ("opaqueData", c_char * NVML_VGPU_METADATA_OPAQUE_DATA_SIZE) + ] + +class c_nvmlVgpuPgpuMetadata_t(_PrintableStructure): + _fields_ = [("version", c_uint), + ("revision", c_uint), + ("hostDriverVersion", c_char * NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE), + ("pgpuVirtualizationCaps", c_uint), + ("reserved", c_uint * 5), + ("hostSupportedVgpuRange", c_nvmlVgpuVersion_t), + ("opaqueDataSize", c_uint), + ("opaqueData", c_char * NVML_VGPU_PGPU_METADATA_OPAQUE_DATA_SIZE) + ] + +class c_nvmlVgpuPgpuCompatibility_t(Structure): + _fields_ = [("vgpuVmCompatibility", _nvmlVgpuVmCompatibility_t), + ("compatibilityLimitCode", _nvmlVgpuPgpuCompatibilityLimitCode_t) + ] + +## vGPU scheduler policy defines +NVML_VGPU_SCHEDULER_POLICY_UNKNOWN = 0 +NVML_VGPU_SCHEDULER_POLICY_BEST_EFFORT = 1 +NVML_VGPU_SCHEDULER_POLICY_EQUAL_SHARE = 2 +NVML_VGPU_SCHEDULER_POLICY_FIXED_SHARE = 3 + +## Supported vGPU scheduler policy count +NVML_SUPPORTED_VGPU_SCHEDULER_POLICY_COUNT = 3 + +NVML_SCHEDULER_SW_MAX_LOG_ENTRIES = 200 + +NVML_VGPU_SCHEDULER_ARR_DEFAULT = 0 +NVML_VGPU_SCHEDULER_ARR_DISABLE = 1 +NVML_VGPU_SCHEDULER_ARR_ENABLE = 2 + +class c_nvmlVgpuSchedDataWithARR_t(_PrintableStructure): + _fields_ = [ + ('avgFactor', c_uint), + ('timeslice', c_uint), + ] + +class c_nvmlVgpuSchedData_t(_PrintableStructure): + _fields_ = [ + ('timeslice', c_uint), + ] + +class c_nvmlVgpuSchedulerParams_t(Union): + _fields_ = [ + ('vgpuSchedDataWithARR', c_nvmlVgpuSchedDataWithARR_t), + ('vgpuSchedData', c_nvmlVgpuSchedData_t), + ] + +class c_nvmlVgpuSchedulerLogEntry_t(_PrintableStructure): + _fields_ = [ + ('timestamp', c_ulonglong), + ('timeRunTotal', c_ulonglong), + ('timeRun', c_ulonglong), + ('swRunlistId', c_uint), + ('targetTimeSlice', c_ulonglong), + ('cumulativePreemptionTime', c_ulonglong), + ] + +class c_nvmlVgpuSchedulerLog_t(_PrintableStructure): + _fields_ = [ + ('engineId', c_uint), + ('schedulerPolicy', c_uint), + ('arrMode', c_uint), + ('schedulerParams', c_nvmlVgpuSchedulerParams_t), + ('entriesCount', c_uint), + ('logEntries', c_nvmlVgpuSchedulerLogEntry_t * NVML_SCHEDULER_SW_MAX_LOG_ENTRIES), + ] + +class c_nvmlVgpuSchedulerGetState_t(_PrintableStructure): + _fields_ = [ + ('schedulerPolicy', c_uint), + ('arrMode', c_uint), + ('schedulerParams', c_nvmlVgpuSchedulerParams_t), + ] + +class c_nvmlVgpuSchedSetDataWithARR_t(_PrintableStructure): + _fields_ = [ + ('avgFactor', c_uint), + ('frequency', c_uint), + ] + +class c_nvmlVgpuSchedSetData_t(_PrintableStructure): + _fields_ = [ + ('timeslice', c_uint), + ] + +class c_nvmlVgpuSchedulerSetParams_t(Union): + _fields_ = [ + ('vgpuSchedDataWithARR', c_nvmlVgpuSchedSetDataWithARR_t), + ('vgpuSchedData', c_nvmlVgpuSchedSetData_t), + ] + +class c_nvmlVgpuSchedulerSetState_t(_PrintableStructure): + _fields_ = [ + ('schedulerPolicy', c_uint), + ('enableARRMode', c_uint), + ('schedulerParams', c_nvmlVgpuSchedulerSetParams_t), + ] + +class c_nvmlVgpuSchedulerCapabilities_t(_PrintableStructure): + _fields_ = [ + ('supportedSchedulers', c_uint * NVML_SUPPORTED_VGPU_SCHEDULER_POLICY_COUNT), + ('maxTimeslice', c_uint), + ('minTimeslice', c_uint), + ('isArrModeSupported', c_uint), + ('maxFrequencyForARR', c_uint), + ('minFrequencyForARR', c_uint), + ('maxAvgFactorForARR', c_uint), + ('minAvgFactorForARR', c_uint), + ] + +class c_nvmlFBCStats_t(Structure): + _fields_ = [("sessionsCount", c_uint), + ("averageFPS", c_uint), + ("averageLatency", c_uint) + ] + +class c_nvmlFBCSession_t(_PrintableStructure): + _fields_ = [ + ('sessionId', c_uint), + ('pid', c_uint), + ('vgpuInstance', _nvmlVgpuInstance_t), + ('displayOrdinal', c_uint), + ('sessionType', c_uint), + ('sessionFlags', c_uint), + ('hMaxResolution', c_uint), + ('vMaxResolution', c_uint), + ('hResolution', c_uint), + ('vResolution', c_uint), + ('averageFPS', c_uint), + ('averageLatency', c_uint), + ] + +NVML_DEVICE_MIG_DISABLE = 0x0 +NVML_DEVICE_MIG_ENABLE = 0x1 + +NVML_GPU_INSTANCE_PROFILE_1_SLICE = 0x0 +NVML_GPU_INSTANCE_PROFILE_2_SLICE = 0x1 +NVML_GPU_INSTANCE_PROFILE_3_SLICE = 0x2 +NVML_GPU_INSTANCE_PROFILE_4_SLICE = 0x3 +NVML_GPU_INSTANCE_PROFILE_7_SLICE = 0x4 +NVML_GPU_INSTANCE_PROFILE_8_SLICE = 0x5 +NVML_GPU_INSTANCE_PROFILE_6_SLICE = 0x6 +NVML_GPU_INSTANCE_PROFILE_1_SLICE_REV1 = 0x7 +NVML_GPU_INSTANCE_PROFILE_2_SLICE_REV1 = 0x8 +NVML_GPU_INSTANCE_PROFILE_1_SLICE_REV2 = 0x9 +NVML_GPU_INSTANCE_PROFILE_1_SLICE_GFX = 0xA +NVML_GPU_INSTANCE_PROFILE_2_SLICE_GFX = 0xB +NVML_GPU_INSTANCE_PROFILE_4_SLICE_GFX = 0xC +NVML_GPU_INSTANCE_PROFILE_COUNT = 0xD + +class c_nvmlGpuInstancePlacement_t(Structure): + _fields_ = [("start", c_uint), + ("size", c_uint) + ] + +class c_nvmlGpuInstanceProfileInfo_t(Structure): + _fields_ = [("id", c_uint), + ("isP2pSupported", c_uint), + ("sliceCount", c_uint), + ("instanceCount", c_uint), + ("multiprocessorCount", c_uint), + ("copyEngineCount", c_uint), + ("decoderCount", c_uint), + ("encoderCount", c_uint), + ("jpegCount", c_uint), + ("ofaCount", c_uint), + ("memorySizeMB", c_ulonglong), + ] + +nvmlGpuInstanceProfileInfo_v2 = 0x02000098 + +class c_nvmlGpuInstanceProfileInfo_v2_t(_PrintableStructure): + _fields_ = [("version", c_uint), + ("id", c_uint), + ("isP2pSupported", c_uint), + ("sliceCount", c_uint), + ("instanceCount", c_uint), + ("multiprocessorCount", c_uint), + ("copyEngineCount", c_uint), + ("decoderCount", c_uint), + ("encoderCount", c_uint), + ("jpegCount", c_uint), + ("ofaCount", c_uint), + ("memorySizeMB", c_ulonglong), + ("name", c_char * NVML_DEVICE_NAME_V2_BUFFER_SIZE) + ] + + def __init__(self): + super(c_nvmlGpuInstanceProfileInfo_v2_t, self).__init__(version=nvmlGpuInstanceProfileInfo_v2) + +class c_nvmlGpuInstanceInfo_t(Structure): + _fields_ = [("device", c_nvmlDevice_t), + ("id", c_uint), + ("profileId", c_uint), + ("placement", c_nvmlGpuInstancePlacement_t) + ] + +class struct_c_nvmlGpuInstance_t(Structure): + pass # opaque handle +c_nvmlGpuInstance_t = POINTER(struct_c_nvmlGpuInstance_t) + +NVML_COMPUTE_INSTANCE_PROFILE_1_SLICE = 0x0 +NVML_COMPUTE_INSTANCE_PROFILE_2_SLICE = 0x1 +NVML_COMPUTE_INSTANCE_PROFILE_3_SLICE = 0x2 +NVML_COMPUTE_INSTANCE_PROFILE_4_SLICE = 0x3 +NVML_COMPUTE_INSTANCE_PROFILE_7_SLICE = 0x4 +NVML_COMPUTE_INSTANCE_PROFILE_8_SLICE = 0x5 +NVML_COMPUTE_INSTANCE_PROFILE_6_SLICE = 0x6 +NVML_COMPUTE_INSTANCE_PROFILE_1_SLICE_REV1 = 0x7 +NVML_COMPUTE_INSTANCE_PROFILE_COUNT = 0x8 + +NVML_COMPUTE_INSTANCE_ENGINE_PROFILE_SHARED = 0x0 +NVML_COMPUTE_INSTANCE_ENGINE_PROFILE_COUNT = 0x1 + +class c_nvmlComputeInstancePlacement_t(Structure): + _fields_ = [("start", c_uint), + ("size", c_uint) + ] + +class c_nvmlComputeInstanceProfileInfo_t(Structure): + _fields_ = [("id", c_uint), + ("sliceCount", c_uint), + ("instanceCount", c_uint), + ("multiprocessorCount", c_uint), + ("sharedCopyEngineCount", c_uint), + ("sharedDecoderCount", c_uint), + ("sharedEncoderCount", c_uint), + ("sharedJpegCount", c_uint), + ("sharedOfaCount", c_uint) + ] + +nvmlComputeInstanceProfileInfo_v2 = 0x02000088 + +class c_nvmlComputeInstanceProfileInfo_v2_t(_PrintableStructure): + _fields_ = [("version", c_uint), + ("id", c_uint), + ("sliceCount", c_uint), + ("instanceCount", c_uint), + ("multiprocessorCount", c_uint), + ("sharedCopyEngineCount", c_uint), + ("sharedDecoderCount", c_uint), + ("sharedEncoderCount", c_uint), + ("sharedJpegCount", c_uint), + ("sharedOfaCount", c_uint), + ("name", c_char * NVML_DEVICE_NAME_V2_BUFFER_SIZE) + ] + + def __init__(self): + super(c_nvmlComputeInstanceProfileInfo_v2_t, self).__init__(version=nvmlComputeInstanceProfileInfo_v2) + +class c_nvmlComputeInstanceInfo_t(Structure): + _fields_ = [("device", c_nvmlDevice_t), + ("gpuInstance", c_nvmlGpuInstance_t), + ("id", c_uint), + ("profileId", c_uint), + ("placement", c_nvmlComputeInstancePlacement_t) + ] + +NVML_MAX_GPU_UTILIZATIONS = 8 +NVML_GPU_UTILIZATION_DOMAIN_GPU = 0 +NVML_GPU_UTILIZATION_DOMAIN_FB = 1 +NVML_GPU_UTILIZATION_DOMAIN_VID = 2 +NVML_GPU_UTILIZATION_DOMAIN_BUS = 3 +class c_nvmlGpuDynamicPstatesUtilization_t(Structure): + _fields_ = [("bIsPresent", c_uint, 1), + ("percentage", c_uint), + ("incThreshold", c_uint), + ("decThreshold", c_uint)] +class c_nvmlGpuDynamicPstatesInfo_t(Structure): + _fields_ = [("flags", c_uint), + ("utilization", c_nvmlGpuDynamicPstatesUtilization_t * NVML_MAX_GPU_UTILIZATIONS)] + +NVML_MAX_THERMAL_SENSORS_PER_GPU = 3 + +NVML_THERMAL_TARGET_NONE = 0 +NVML_THERMAL_TARGET_GPU = 1 +NVML_THERMAL_TARGET_MEMORY = 2 +NVML_THERMAL_TARGET_POWER_SUPPLY = 4 +NVML_THERMAL_TARGET_BOARD = 8 +NVML_THERMAL_TARGET_VCD_BOARD = 9 +NVML_THERMAL_TARGET_VCD_INLET = 10 +NVML_THERMAL_TARGET_VCD_OUTLET = 11 +NVML_THERMAL_TARGET_ALL = 15 +NVML_THERMAL_TARGET_UNKNOWN = -1 + +NVML_THERMAL_CONTROLLER_NONE = 0 +NVML_THERMAL_CONTROLLER_GPU_INTERNAL = 1 +NVML_THERMAL_CONTROLLER_ADM1032 = 2 +NVML_THERMAL_CONTROLLER_ADT7461 = 3 +NVML_THERMAL_CONTROLLER_MAX6649 = 4 +NVML_THERMAL_CONTROLLER_MAX1617 = 5 +NVML_THERMAL_CONTROLLER_LM99 = 6 +NVML_THERMAL_CONTROLLER_LM89 = 7 +NVML_THERMAL_CONTROLLER_LM64 = 8 +NVML_THERMAL_CONTROLLER_G781 = 9 +NVML_THERMAL_CONTROLLER_ADT7473 = 10 +NVML_THERMAL_CONTROLLER_SBMAX6649 = 11 +NVML_THERMAL_CONTROLLER_VBIOSEVT = 12 +NVML_THERMAL_CONTROLLER_OS = 13 +NVML_THERMAL_CONTROLLER_NVSYSCON_CANOAS = 14 +NVML_THERMAL_CONTROLLER_NVSYSCON_E551 = 15 +NVML_THERMAL_CONTROLLER_MAX6649R = 16 +NVML_THERMAL_CONTROLLER_ADT7473S = 17 +NVML_THERMAL_CONTROLLER_UNKNOWN = -1 + +class c_nvmlGpuThermalSensor_t(Structure): + _fields_ = [("controller", c_int), + ("defaultMinTemp", c_int), + ("defaultMaxTemp", c_int), + ("currentTemp", c_int), + ("target", c_int)] +class c_nvmlGpuThermalSettings_t(Structure): + _fields_ = [("count", c_uint), + ("sensor", c_nvmlGpuThermalSensor_t * NVML_MAX_THERMAL_SENSORS_PER_GPU)] + +_nvmlCoolerControl_t = c_uint +NVML_THERMAL_COOLER_SIGNAL_NONE = 0 +NVML_THERMAL_COOLER_SIGNAL_TOGGLE = 1 +NVML_THERMAL_COOLER_SIGNAL_VARIABLE = 2 +NVML_THERMAL_COOLER_SIGNAL_COUNT = 3 + +_nvmlCoolerTarget_t = c_uint +NVML_THERMAL_COOLER_TARGET_NONE = (1 << 0) +NVML_THERMAL_COOLER_TARGET_GPU = (1 << 1) +NVML_THERMAL_COOLER_TARGET_MEMORY = (1 << 2) +NVML_THERMAL_COOLER_TARGET_POWER_SUPPLY = (1 << 3) +NVML_THERMAL_COOLER_TARGET_GPU_RELATED = (NVML_THERMAL_COOLER_TARGET_GPU | NVML_THERMAL_COOLER_TARGET_MEMORY | NVML_THERMAL_COOLER_TARGET_POWER_SUPPLY) + +class c_nvmlCoolerInfo_t(_PrintableStructure): + _fields_ = [("version", c_uint), + ("index", c_uint), + ("coolerControlType", _nvmlCoolerControl_t), + ("coolerTarget", _nvmlCoolerTarget_t) + ] + +nvmlCoolerInfo_v1 = 0x1000010 + +def nvmlDeviceGetCoolerInfo(handle): + c_coolerInfo = c_nvmlCoolerInfo_t() + c_coolerInfo.version = nvmlCoolerInfo_v1 + c_coolerInfo.index = 0 + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCoolerInfo") + ret = fn(handle, byref(c_coolerInfo)) + _nvmlCheckReturn(ret) + return [c_coolerInfo.coolerControlType, c_coolerInfo.coolerTarget] + +class struct_c_nvmlComputeInstance_t(Structure): + pass # opaque handle +c_nvmlComputeInstance_t = POINTER(struct_c_nvmlComputeInstance_t) + +class c_nvmlDeviceAttributes(Structure): + _fields_ = [("multiprocessorCount", c_uint), + ("sharedCopyEngineCount", c_uint), + ("sharedDecoderCount", c_uint), + ("sharedEncoderCount", c_uint), + ("sharedJpegCount", c_uint), + ("sharedOfaCount", c_uint), + ("gpuInstanceSliceCount", c_uint), + ("computeInstanceSliceCount", c_uint), + ("memorySizeMB", c_ulonglong), + ] + +class c_nvmlRowRemapperHistogramValues(Structure): + _fields_ = [("max", c_uint), + ("high", c_uint), + ("partial", c_uint), + ("low", c_uint), + ("none", c_uint) + ] + +NVML_GPU_CERT_CHAIN_SIZE = 0x1000 +NVML_GPU_ATTESTATION_CERT_CHAIN_SIZE = 0x1400 +NVML_CC_GPU_CEC_NONCE_SIZE = 0x20 +NVML_CC_GPU_ATTESTATION_REPORT_SIZE = 0x2000 +NVML_CC_GPU_CEC_ATTESTATION_REPORT_SIZE = 0x1000 +NVML_CC_CEC_ATTESTATION_REPORT_NOT_PRESENT = 0 +NVML_CC_CEC_ATTESTATION_REPORT_PRESENT = 1 + +class c_nvmlConfComputeSystemState_t(Structure): + _fields_ = [('environment', c_uint), + ('ccFeature', c_uint), + ('devToolsMode', c_uint), + ] + +nvmlSystemConfComputeSettings_v1 = 0x1000014 + +class c_nvmlSystemConfComputeSettings_v1_t(Structure): + _fields_ = [('version', c_uint), + ('environment', c_uint), + ('ccFeature', c_uint), + ('devToolsMode', c_uint), + ('multiGpuMode', c_uint), + ] + def __init__(self): + super(c_nvmlSystemConfComputeSettings_v1_t, self).__init__(version=nvmlSystemConfComputeSettings_v1) + +class c_nvmlConfComputeSystemCaps_t(Structure): + _fields_ = [('cpuCaps', c_uint), + ('gpusCaps', c_uint), + ] + +class c_nvmlConfComputeMemSizeInfo_t(Structure): + _fields_ = [('protectedMemSizeKib', c_ulonglong), + ('unprotectedMemSizeKib', c_ulonglong), + ] + +class c_nvmlConfComputeGpuCertificate_t(Structure): + _fields_ = [('certChainSize', c_uint), + ('attestationCertChainSize', c_uint), + ('certChain', c_uint8 * NVML_GPU_CERT_CHAIN_SIZE), + ('attestationCertChain', c_uint8 * NVML_GPU_ATTESTATION_CERT_CHAIN_SIZE), + ] + +class c_nvmlConfComputeGpuAttestationReport_t(Structure): + _fields_ = [('isCecAttestationReportPresent', c_uint), + ('attestationReportSize', c_uint), + ('cecAttestationReportSize', c_uint), + ('nonce', c_uint8 * NVML_CC_GPU_CEC_NONCE_SIZE), + ('attestationReport', c_uint8 * NVML_CC_GPU_ATTESTATION_REPORT_SIZE), + ('cecAttestationReport', c_uint8 * NVML_CC_GPU_CEC_ATTESTATION_REPORT_SIZE), + ] + +class c_nvmlConfComputeSetKeyRotationThresholdInfo_t(Structure): + _fields_ = [('version', c_uint), + ('maxAttackerAdvantage', c_ulong), + ] +ConfComputeSetKeyRotationThresholdInfo_v1 = 0x1000010 + +class c_nvmlConfComputeGetKeyRotationThresholdInfo_t(Structure): + _fields_ = [('version', c_uint), + ('attackerAdvantage', c_ulong), + ] +ConfComputeGetKeyRotationThresholdInfo_v1 = 0x1000010 + + +## string/bytes conversion for ease of use +def convertStrBytes(func): + ''' + In python 3, strings are unicode instead of bytes, and need to be converted for ctypes + Args from caller: (1, 'string', <__main__.c_nvmlDevice_t at 0xFFFFFFFF>) + Args passed to function: (1, b'string', <__main__.c_nvmlDevice_t at 0xFFFFFFFF)> + ---- + Returned from function: b'returned string' + Returned to caller: 'returned string' + ''' + @wraps(func) + def wrapper(*args, **kwargs): + # encoding a str returns bytes in python 2 and 3 + args = [arg.encode() if isinstance(arg, str) else arg for arg in args] + res = func(*args, **kwargs) + # In python 2, str and bytes are the same + # In python 3, str is unicode and should be decoded. + # Ctypes handles most conversions, this only effects c_char and char arrays. + if isinstance(res, bytes): + if isinstance(res, str): + return res + return res.decode() + return res + + if sys.version_info >= (3,): + return wrapper + return func + +def throwOnVersionMismatch(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except NVMLError_FunctionNotFound: + raise NVMLLibraryMismatchError("Unversioned function called and the " + "pyNVML version does not match the NVML lib version. " + "Either use matching pyNVML and NVML lib versions or " + "use a versioned function such as " + func.__name__ + "_v2") + return wrapper + +## C function wrappers ## +def nvmlInitWithFlags(flags): + _LoadNvmlLibrary() + + # + # Initialize the library + # + fn = _nvmlGetFunctionPointer("nvmlInitWithFlags") + ret = fn(flags) + _nvmlCheckReturn(ret) + + # Atomically update refcount + global _nvmlLib_refcount + libLoadLock.acquire() + _nvmlLib_refcount += 1 + libLoadLock.release() + return None + +def nvmlInit(): + nvmlInitWithFlags(0) + return None + +def _LoadNvmlLibrary(): + ''' + Load the library if it isn't loaded already + ''' + global nvmlLib + + if (nvmlLib == None): + # lock to ensure only one caller loads the library + libLoadLock.acquire() + + try: + # ensure the library still isn't loaded + if (nvmlLib == None): + try: + if (sys.platform[:3] == "win"): + # cdecl calling convention + try: + # Check for nvml.dll in System32 first for DCH drivers + nvmlLib = CDLL(os.path.join(os.getenv("WINDIR", "C:/Windows"), "System32/nvml.dll")) + except OSError as ose: + # If nvml.dll is not found in System32, it should be in ProgramFiles + # load nvml.dll from %ProgramFiles%/NVIDIA Corporation/NVSMI/nvml.dll + nvmlLib = CDLL(os.path.join(os.getenv("ProgramFiles", "C:/Program Files"), "NVIDIA Corporation/NVSMI/nvml.dll")) + else: + # assume linux + nvmlLib = CDLL("libnvidia-ml.so.1") + except OSError as ose: + _nvmlCheckReturn(NVML_ERROR_LIBRARY_NOT_FOUND) + if (nvmlLib == None): + _nvmlCheckReturn(NVML_ERROR_LIBRARY_NOT_FOUND) + finally: + # lock is always freed + libLoadLock.release() + +def nvmlShutdown(): + # + # Leave the library loaded, but shutdown the interface + # + fn = _nvmlGetFunctionPointer("nvmlShutdown") + ret = fn() + _nvmlCheckReturn(ret) + + # Atomically update refcount + global _nvmlLib_refcount + libLoadLock.acquire() + if (0 < _nvmlLib_refcount): + _nvmlLib_refcount -= 1 + libLoadLock.release() + return None + +# Added in 2.285 +@convertStrBytes +def nvmlErrorString(result): + fn = _nvmlGetFunctionPointer("nvmlErrorString") + fn.restype = c_char_p # otherwise return is an int + ret = fn(result) + return ret + +# Added in 2.285 +@convertStrBytes +def nvmlSystemGetNVMLVersion(): + c_version = create_string_buffer(NVML_SYSTEM_NVML_VERSION_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlSystemGetNVMLVersion") + ret = fn(c_version, c_uint(NVML_SYSTEM_NVML_VERSION_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_version.value + +def nvmlSystemGetCudaDriverVersion(): + c_cuda_version = c_int() + fn = _nvmlGetFunctionPointer("nvmlSystemGetCudaDriverVersion") + ret = fn(byref(c_cuda_version)) + _nvmlCheckReturn(ret) + return c_cuda_version.value + +def nvmlSystemGetCudaDriverVersion_v2(): + c_cuda_version = c_int() + fn = _nvmlGetFunctionPointer("nvmlSystemGetCudaDriverVersion_v2") + ret = fn(byref(c_cuda_version)) + _nvmlCheckReturn(ret) + return c_cuda_version.value + +# Added in 2.285 +@convertStrBytes +def nvmlSystemGetProcessName(pid): + c_name = create_string_buffer(1024) + fn = _nvmlGetFunctionPointer("nvmlSystemGetProcessName") + ret = fn(c_uint(pid), c_name, c_uint(1024)) + _nvmlCheckReturn(ret) + return c_name.value + +@convertStrBytes +def nvmlSystemGetDriverVersion(): + c_version = create_string_buffer(NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlSystemGetDriverVersion") + ret = fn(c_version, c_uint(NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_version.value + +# Added in 2.285 +def nvmlSystemGetHicVersion(): + c_count = c_uint(0) + hics = None + fn = _nvmlGetFunctionPointer("nvmlSystemGetHicVersion") + + # get the count + ret = fn(byref(c_count), None) + + # this should only fail with insufficient size + if ((ret != NVML_SUCCESS) and + (ret != NVML_ERROR_INSUFFICIENT_SIZE)): + raise NVMLError(ret) + + # If there are no hics + if (c_count.value == 0): + return [] + + hic_array = c_nvmlHwbcEntry_t * c_count.value + hics = hic_array() + ret = fn(byref(c_count), hics) + _nvmlCheckReturn(ret) + return hics + +def nvmlSystemGetDriverBranch(): + c_branchInfo = c_nvmlSystemDriverBranchInfo_v1_t(0) + c_branchInfo.version = SystemDriverBranchInfo_v1 + fn = _nvmlGetFunctionPointer("nvmlSystemGetDriverBranch") + ret = fn(byref(c_branchInfo), c_uint(NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_branchInfo + +## Unit get functions +def nvmlUnitGetCount(): + c_count = c_uint() + fn = _nvmlGetFunctionPointer("nvmlUnitGetCount") + ret = fn(byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + +def nvmlUnitGetHandleByIndex(index): + c_index = c_uint(index) + unit = c_nvmlUnit_t() + fn = _nvmlGetFunctionPointer("nvmlUnitGetHandleByIndex") + ret = fn(c_index, byref(unit)) + _nvmlCheckReturn(ret) + return unit + +def nvmlUnitGetUnitInfo(unit): + c_info = c_nvmlUnitInfo_t() + fn = _nvmlGetFunctionPointer("nvmlUnitGetUnitInfo") + ret = fn(unit, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + +def nvmlUnitGetLedState(unit): + c_state = c_nvmlLedState_t() + fn = _nvmlGetFunctionPointer("nvmlUnitGetLedState") + ret = fn(unit, byref(c_state)) + _nvmlCheckReturn(ret) + return c_state + +def nvmlUnitGetPsuInfo(unit): + c_info = c_nvmlPSUInfo_t() + fn = _nvmlGetFunctionPointer("nvmlUnitGetPsuInfo") + ret = fn(unit, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + +def nvmlUnitGetTemperature(unit, type): + c_temp = c_uint() + fn = _nvmlGetFunctionPointer("nvmlUnitGetTemperature") + ret = fn(unit, c_uint(type), byref(c_temp)) + _nvmlCheckReturn(ret) + return c_temp.value + +def nvmlUnitGetFanSpeedInfo(unit): + c_speeds = c_nvmlUnitFanSpeeds_t() + fn = _nvmlGetFunctionPointer("nvmlUnitGetFanSpeedInfo") + ret = fn(unit, byref(c_speeds)) + _nvmlCheckReturn(ret) + return c_speeds + +# added to API +def nvmlUnitGetDeviceCount(unit): + c_count = c_uint(0) + # query the unit to determine device count + fn = _nvmlGetFunctionPointer("nvmlUnitGetDevices") + ret = fn(unit, byref(c_count), None) + if (ret == NVML_ERROR_INSUFFICIENT_SIZE): + ret = NVML_SUCCESS + _nvmlCheckReturn(ret) + return c_count.value + +def nvmlUnitGetDevices(unit): + c_count = c_uint(nvmlUnitGetDeviceCount(unit)) + device_array = c_nvmlDevice_t * c_count.value + c_devices = device_array() + fn = _nvmlGetFunctionPointer("nvmlUnitGetDevices") + ret = fn(unit, byref(c_count), c_devices) + _nvmlCheckReturn(ret) + return c_devices + +## Device get functions +def nvmlDeviceGetCount(): + c_count = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCount_v2") + ret = fn(byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + +def nvmlDeviceGetHandleByIndex(index): + c_index = c_uint(index) + device = c_nvmlDevice_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetHandleByIndex_v2") + ret = fn(c_index, byref(device)) + _nvmlCheckReturn(ret) + return device + +@convertStrBytes +def nvmlDeviceGetHandleBySerial(serial): + c_serial = c_char_p(serial) + device = c_nvmlDevice_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetHandleBySerial") + ret = fn(c_serial, byref(device)) + _nvmlCheckReturn(ret) + return device + +@convertStrBytes +def nvmlDeviceGetHandleByUUID(uuid): + c_uuid = c_char_p(uuid) + device = c_nvmlDevice_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetHandleByUUID") + ret = fn(c_uuid, byref(device)) + _nvmlCheckReturn(ret) + return device + +@convertStrBytes +def nvmlDeviceGetHandleByPciBusId(pciBusId): + c_busId = c_char_p(pciBusId) + device = c_nvmlDevice_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetHandleByPciBusId_v2") + ret = fn(c_busId, byref(device)) + _nvmlCheckReturn(ret) + return device + +@convertStrBytes +def nvmlDeviceGetName(handle): + c_name = create_string_buffer(NVML_DEVICE_NAME_V2_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetName") + ret = fn(handle, c_name, c_uint(NVML_DEVICE_NAME_V2_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_name.value + +class c_nvmlDevicePerfModes_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('str', c_char * NVML_PERF_MODES_BUFFER_SIZE), + ] + +nvmlDevicePerfModes_v1 = 0x1000804 + +@convertStrBytes +def nvmlDeviceGetPerformanceModes(handle): + perfModes = c_nvmlDevicePerfModes_v1_t() + perfModes.version = nvmlDevicePerfModes_v1 + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPerformanceModes") + ret = fn(handle, byref(perfModes)) + _nvmlCheckReturn(ret) + return perfModes.str + +class c_nvmlDeviceCurrentClockFreqs_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('str', c_char * NVML_PERF_MODES_BUFFER_SIZE), + ] + +nvmlDeviceCurrentClockFreqs_v1 = 0x1000804 + +@convertStrBytes +def nvmlDeviceGetCurrentClockFreqs(handle): + currentClockFreqs = c_nvmlDeviceCurrentClockFreqs_v1_t() + currentClockFreqs.version = nvmlDeviceCurrentClockFreqs_v1 + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCurrentClockFreqs") + ret = fn(handle, byref(currentClockFreqs)) + _nvmlCheckReturn(ret) + return currentClockFreqs.str + +def nvmlDeviceGetBoardId(handle): + c_id = c_uint(); + fn = _nvmlGetFunctionPointer("nvmlDeviceGetBoardId") + ret = fn(handle, byref(c_id)) + _nvmlCheckReturn(ret) + return c_id.value + +def nvmlDeviceGetMultiGpuBoard(handle): + c_multiGpu = c_uint(); + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMultiGpuBoard") + ret = fn(handle, byref(c_multiGpu)) + _nvmlCheckReturn(ret) + return c_multiGpu.value + +def nvmlDeviceGetBrand(handle): + c_type = _nvmlBrandType_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetBrand") + ret = fn(handle, byref(c_type)) + _nvmlCheckReturn(ret) + return c_type.value + +def nvmlDeviceGetC2cModeInfoV1(handle): + c_info = c_nvmlC2cModeInfo_v1_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetC2cModeInfoV") + ret = fn(handle, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + +def nvmlDeviceGetC2cModeInfoV(handle): + return nvmlDeviceGetC2cModeInfoV1(handle) + +@convertStrBytes +def nvmlDeviceGetBoardPartNumber(handle): + c_part_number = create_string_buffer(NVML_DEVICE_PART_NUMBER_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetBoardPartNumber") + ret = fn(handle, c_part_number, c_uint(NVML_DEVICE_PART_NUMBER_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_part_number.value + +@convertStrBytes +def nvmlDeviceGetSerial(handle): + c_serial = create_string_buffer(NVML_DEVICE_SERIAL_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSerial") + ret = fn(handle, c_serial, c_uint(NVML_DEVICE_SERIAL_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_serial.value + +def nvmlDeviceGetModuleId(handle, moduleId=c_uint()): + isReference = type(moduleId) is not c_uint + moduleIdRef = moduleId if isReference else byref(moduleId) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetModuleId") + ret = fn(handle, moduleIdRef) + if isReference: + return ret + else: + _nvmlCheckReturn(ret) + return moduleId.value + +def nvmlDeviceGetMemoryAffinity(handle, nodeSetSize, scope): + affinity_array = c_ulonglong * nodeSetSize + c_affinity = affinity_array() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemoryAffinity") + ret = fn(handle, nodeSetSize, byref(c_affinity), _nvmlAffinityScope_t(scope)) + _nvmlCheckReturn(ret) + return c_affinity + +def nvmlDeviceGetCpuAffinityWithinScope(handle, cpuSetSize, scope): + affinity_array = c_ulonglong * cpuSetSize + c_affinity = affinity_array() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCpuAffinityWithinScope") + ret = fn(handle, cpuSetSize, byref(c_affinity), _nvmlAffinityScope_t(scope)) + _nvmlCheckReturn(ret) + return c_affinity + +def nvmlDeviceGetCpuAffinity(handle, cpuSetSize): + affinity_array = c_ulonglong * cpuSetSize + c_affinity = affinity_array() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCpuAffinity") + ret = fn(handle, cpuSetSize, byref(c_affinity)) + _nvmlCheckReturn(ret) + return c_affinity + +def nvmlDeviceSetCpuAffinity(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetCpuAffinity") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceClearCpuAffinity(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceClearCpuAffinity") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceGetNumaNodeId(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNumaNodeId") + node = c_int() + ret = fn(handle, byref(node)) + _nvmlCheckReturn(ret) + return node.value + +def nvmlDeviceGetMinorNumber(handle): + c_minor_number = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMinorNumber") + ret = fn(handle, byref(c_minor_number)) + _nvmlCheckReturn(ret) + return c_minor_number.value + +@convertStrBytes +def nvmlDeviceGetUUID(handle): + c_uuid = create_string_buffer(NVML_DEVICE_UUID_V2_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetUUID") + ret = fn(handle, c_uuid, c_uint(NVML_DEVICE_UUID_V2_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_uuid.value + +@convertStrBytes +def nvmlDeviceGetInforomVersion(handle, infoRomObject): + c_version = create_string_buffer(NVML_DEVICE_INFOROM_VERSION_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetInforomVersion") + ret = fn(handle, _nvmlInforomObject_t(infoRomObject), + c_version, c_uint(NVML_DEVICE_INFOROM_VERSION_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_version.value + +# Added in 4.304 +@convertStrBytes +def nvmlDeviceGetInforomImageVersion(handle): + c_version = create_string_buffer(NVML_DEVICE_INFOROM_VERSION_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetInforomImageVersion") + ret = fn(handle, c_version, c_uint(NVML_DEVICE_INFOROM_VERSION_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_version.value + +# Added in 4.304 +def nvmlDeviceGetInforomConfigurationChecksum(handle): + c_checksum = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetInforomConfigurationChecksum") + ret = fn(handle, byref(c_checksum)) + _nvmlCheckReturn(ret) + return c_checksum.value + +# Added in 4.304 +def nvmlDeviceValidateInforom(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceValidateInforom") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceGetLastBBXFlushTime(handle): + c_timestamp = c_ulonglong() + c_durationUs = c_ulong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetLastBBXFlushTime") + ret = fn(handle, byref(c_timestamp), byref(c_durationUs)) + _nvmlCheckReturn(ret) + return [c_timestamp.value, c_durationUs.value] + +def nvmlDeviceGetDisplayMode(handle): + c_mode = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDisplayMode") + ret = fn(handle, byref(c_mode)) + _nvmlCheckReturn(ret) + return c_mode.value + +def nvmlDeviceGetDisplayActive(handle): + c_mode = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDisplayActive") + ret = fn(handle, byref(c_mode)) + _nvmlCheckReturn(ret) + return c_mode.value + + +def nvmlDeviceGetPersistenceMode(handle): + c_state = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPersistenceMode") + ret = fn(handle, byref(c_state)) + _nvmlCheckReturn(ret) + return c_state.value + +def nvmlDeviceGetPciInfoExt(handle, c_info): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPciInfoExt") + ret = fn(handle, c_info) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceGetPciInfo_v3(handle): + c_info = nvmlPciInfo_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPciInfo_v3") + ret = fn(handle, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + +def nvmlDeviceGetPciInfo(handle): + return nvmlDeviceGetPciInfo_v3(handle) + +def nvmlDeviceGetClockInfo(handle, type): + c_clock = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetClockInfo") + ret = fn(handle, _nvmlClockType_t(type), byref(c_clock)) + _nvmlCheckReturn(ret) + return c_clock.value + +# Added in 2.285 +def nvmlDeviceGetMaxClockInfo(handle, type): + c_clock = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMaxClockInfo") + ret = fn(handle, _nvmlClockType_t(type), byref(c_clock)) + _nvmlCheckReturn(ret) + return c_clock.value + +# Added in 4.304 +def nvmlDeviceGetApplicationsClock(handle, type): + c_clock = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetApplicationsClock") + ret = fn(handle, _nvmlClockType_t(type), byref(c_clock)) + _nvmlCheckReturn(ret) + return c_clock.value + +def nvmlDeviceGetMaxCustomerBoostClock(handle, type): + c_clock = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMaxCustomerBoostClock") + ret = fn(handle, _nvmlClockType_t(type), byref(c_clock)) + _nvmlCheckReturn(ret) + return c_clock.value + +def nvmlDeviceGetClock(handle, type, id): + c_clock = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetClock") + ret = fn(handle, _nvmlClockType_t(type), _nvmlClockId_t(id), byref(c_clock)) + _nvmlCheckReturn(ret) + return c_clock.value + +# Added in 5.319 +def nvmlDeviceGetDefaultApplicationsClock(handle, type): + c_clock = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDefaultApplicationsClock") + ret = fn(handle, _nvmlClockType_t(type), byref(c_clock)) + _nvmlCheckReturn(ret) + return c_clock.value + +# Added in 4.304 +def nvmlDeviceGetSupportedMemoryClocks(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedMemoryClocks") + ret = fn(handle, byref(c_count), None) + + if (ret == NVML_SUCCESS): + # special case, no clocks + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + clocks_array = c_uint * c_count.value + c_clocks = clocks_array() + + # make the call again + ret = fn(handle, byref(c_count), c_clocks) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_count.value): + procs.append(c_clocks[i]) + + return procs + else: + # error case + raise NVMLError(ret) + +# Added in 4.304 +def nvmlDeviceGetSupportedGraphicsClocks(handle, memoryClockMHz): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedGraphicsClocks") + ret = fn(handle, c_uint(memoryClockMHz), byref(c_count), None) + + if (ret == NVML_SUCCESS): + # special case, no clocks + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + clocks_array = c_uint * c_count.value + c_clocks = clocks_array() + + # make the call again + ret = fn(handle, c_uint(memoryClockMHz), byref(c_count), c_clocks) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_count.value): + procs.append(c_clocks[i]) + + return procs + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetFanSpeed(handle): + c_speed = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFanSpeed") + ret = fn(handle, byref(c_speed)) + _nvmlCheckReturn(ret) + return c_speed.value + +def nvmlDeviceGetFanSpeed_v2(handle, fan): + c_speed = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFanSpeed_v2") + ret = fn(handle, fan, byref(c_speed)) + _nvmlCheckReturn(ret) + return c_speed.value + +class c_nvmlFanSpeedInfo_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('fan', c_uint), + ('speed', c_uint), + ] + +nvmlFanSpeedInfo_v1 = 0x100000C + +def nvmlDeviceGetFanSpeedRPM(handle): + c_fanSpeed = c_nvmlFanSpeedInfo_t() + c_fanSpeed.fan = 0 + c_fanSpeed.version = nvmlFanSpeedInfo_v1 + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFanSpeedRPM") + ret = fn(handle, byref(c_fanSpeed)) + _nvmlCheckReturn(ret) + return c_fanSpeed.speed + +def nvmlDeviceGetTargetFanSpeed(handle, fan): + c_speed = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTargetFanSpeed") + ret = fn(handle, fan, byref(c_speed)) + _nvmlCheckReturn(ret) + return c_speed.value + +def nvmlDeviceGetNumFans(device): + c_numFans = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNumFans") + ret = fn(device, byref(c_numFans)) + _nvmlCheckReturn(ret) + return c_numFans.value + +def nvmlDeviceSetDefaultFanSpeed_v2(handle, index): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetDefaultFanSpeed_v2"); + ret = fn(handle, index) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceGetMinMaxFanSpeed(handle, minSpeed=c_uint(), maxSpeed=c_uint()): + isReference = (type(minSpeed) is not c_uint) or (type(maxSpeed) is not c_uint) + minSpeedRef = minSpeed if isReference else byref(minSpeed) + maxSpeedRef = maxSpeed if isReference else byref(maxSpeed) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMinMaxFanSpeed") + ret = fn(handle, minSpeedRef, maxSpeedRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else [minSpeed.value, maxSpeed.value] + +def nvmlDeviceGetFanControlPolicy_v2(handle, fan, fanControlPolicy=c_uint()): + isReference = type(fanControlPolicy) is not c_uint + fanControlPolicyRef = fanControlPolicy if isReference else byref(fanControlPolicy) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFanControlPolicy_v2") + ret = fn(handle, fan, fanControlPolicyRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else fanControlPolicy.value + +def nvmlDeviceSetFanControlPolicy(handle, fan, fanControlPolicy): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetFanControlPolicy") + ret = fn(handle, fan, _nvmlFanControlPolicy_t(fanControlPolicy)) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +class c_nvmlTemperature_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('sensorType', _nvmlTemperatureSensors_t), + ('temperature', c_int), + ] +nvmlTemperature_v1 = 0x100000C + +def nvmlDeviceGetTemperatureV1(handle, sensor): + c_temp = c_nvmlTemperature_v1_t() + c_temp.version = nvmlTemperature_v1 + c_temp.sensorType = _nvmlTemperatureSensors_t(sensor) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTemperatureV") + ret = fn(handle, byref(c_temp)) + _nvmlCheckReturn(ret) + return c_temp.temperature + +def nvmlDeviceGetTemperatureV(handle, sensor, version=nvmlTemperature_v1): + if version == nvmlTemperature_v1: + return nvmlDeviceGetTemperatureV1(handle, sensor) + else: + raise NVMLError(NVML_ERROR_ARGUMENT_VERSION_MISMATCH) + +# DEPRECATED use nvmlDeviceGetTemperatureV instead +def nvmlDeviceGetTemperature(handle, sensor): + c_temp = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTemperature") + ret = fn(handle, _nvmlTemperatureSensors_t(sensor), byref(c_temp)) + _nvmlCheckReturn(ret) + return c_temp.value + +def nvmlDeviceGetTemperatureThreshold(handle, threshold): + c_temp = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTemperatureThreshold") + ret = fn(handle, _nvmlTemperatureThresholds_t(threshold), byref(c_temp)) + _nvmlCheckReturn(ret) + return c_temp.value + +def nvmlDeviceSetTemperatureThreshold(handle, threshold, temp): + c_temp = c_uint() + c_temp.value = temp + fn = _nvmlGetFunctionPointer("nvmlDeviceSetTemperatureThreshold") + ret = fn(handle, _nvmlTemperatureThresholds_t(threshold), byref(c_temp)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceGetMarginTemperature(handle): + c_marginTempInfo = c_nvmlMarginTemperature_v1_t() + c_marginTempInfo.version = nvmlMarginTemperature_v1 + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMarginTemperature") + ret = fn(handle, byref(c_marginTempInfo)) + _nvmlCheckReturn(ret) + return c_marginTempInfo.marginTemperature + +# DEPRECATED use nvmlDeviceGetPerformanceState +def nvmlDeviceGetPowerState(handle): + c_pstate = _nvmlPstates_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerState") + ret = fn(handle, byref(c_pstate)) + _nvmlCheckReturn(ret) + return c_pstate.value + +def nvmlDeviceGetPerformanceState(handle): + c_pstate = _nvmlPstates_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPerformanceState") + ret = fn(handle, byref(c_pstate)) + _nvmlCheckReturn(ret) + return c_pstate.value + +def nvmlDeviceGetPowerManagementMode(handle): + c_pcapMode = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerManagementMode") + ret = fn(handle, byref(c_pcapMode)) + _nvmlCheckReturn(ret) + return c_pcapMode.value + +def nvmlDeviceGetPowerManagementLimit(handle): + c_limit = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerManagementLimit") + ret = fn(handle, byref(c_limit)) + _nvmlCheckReturn(ret) + return c_limit.value + +# Added in 4.304 +def nvmlDeviceGetPowerManagementLimitConstraints(handle): + c_minLimit = c_uint() + c_maxLimit = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerManagementLimitConstraints") + ret = fn(handle, byref(c_minLimit), byref(c_maxLimit)) + _nvmlCheckReturn(ret) + return [c_minLimit.value, c_maxLimit.value] + +# Added in 4.304 +def nvmlDeviceGetPowerManagementDefaultLimit(handle): + c_limit = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerManagementDefaultLimit") + ret = fn(handle, byref(c_limit)) + _nvmlCheckReturn(ret) + return c_limit.value + + +# Added in 331 +def nvmlDeviceGetEnforcedPowerLimit(handle): + c_limit = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetEnforcedPowerLimit") + ret = fn(handle, byref(c_limit)) + _nvmlCheckReturn(ret) + return c_limit.value + +def nvmlDeviceGetPowerUsage(handle): + c_watts = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerUsage") + ret = fn(handle, byref(c_watts)) + _nvmlCheckReturn(ret) + return c_watts.value + +def nvmlDeviceGetTotalEnergyConsumption(handle): + c_millijoules = c_uint64() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTotalEnergyConsumption") + ret = fn(handle, byref(c_millijoules)) + _nvmlCheckReturn(ret) + return c_millijoules.value + +# Added in 4.304 +def nvmlDeviceGetGpuOperationMode(handle): + c_currState = _nvmlGpuOperationMode_t() + c_pendingState = _nvmlGpuOperationMode_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuOperationMode") + ret = fn(handle, byref(c_currState), byref(c_pendingState)) + _nvmlCheckReturn(ret) + return [c_currState.value, c_pendingState.value] + +# Added in 4.304 +def nvmlDeviceGetCurrentGpuOperationMode(handle): + return nvmlDeviceGetGpuOperationMode(handle)[0] + +# Added in 4.304 +def nvmlDeviceGetPendingGpuOperationMode(handle): + return nvmlDeviceGetGpuOperationMode(handle)[1] + +def nvmlDeviceGetMemoryInfo(handle, version=None): + if not version: + c_memory = c_nvmlMemory_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemoryInfo") + else: + c_memory = c_nvmlMemory_v2_t() + c_memory.version = version + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemoryInfo_v2") + ret = fn(handle, byref(c_memory)) + _nvmlCheckReturn(ret) + return c_memory + +def nvmlDeviceGetBAR1MemoryInfo(handle): + c_bar1_memory = c_nvmlBAR1Memory_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetBAR1MemoryInfo") + ret = fn(handle, byref(c_bar1_memory)) + _nvmlCheckReturn(ret) + return c_bar1_memory + +def nvmlDeviceGetComputeMode(handle): + c_mode = _nvmlComputeMode_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetComputeMode") + ret = fn(handle, byref(c_mode)) + _nvmlCheckReturn(ret) + return c_mode.value + +def nvmlDeviceGetCudaComputeCapability(handle): + c_major = c_int() + c_minor = c_int() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCudaComputeCapability") + ret = fn(handle, byref(c_major), byref(c_minor)) + _nvmlCheckReturn(ret) + return (c_major.value, c_minor.value) + +def nvmlDeviceGetEccMode(handle): + c_currState = _nvmlEnableState_t() + c_pendingState = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetEccMode") + ret = fn(handle, byref(c_currState), byref(c_pendingState)) + _nvmlCheckReturn(ret) + return [c_currState.value, c_pendingState.value] + +# added to API +def nvmlDeviceGetCurrentEccMode(handle): + return nvmlDeviceGetEccMode(handle)[0] + +# added to API +def nvmlDeviceGetPendingEccMode(handle): + return nvmlDeviceGetEccMode(handle)[1] + +def nvmlDeviceGetDefaultEccMode(handle): + c_defaultState = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDefaultEccMode") + ret = fn(handle, byref(c_defaultState)) + _nvmlCheckReturn(ret) + return [c_defaultState.value] + +def nvmlDeviceGetTotalEccErrors(handle, errorType, counterType): + c_count = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTotalEccErrors") + ret = fn(handle, _nvmlMemoryErrorType_t(errorType), + _nvmlEccCounterType_t(counterType), byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + +# This is deprecated, instead use nvmlDeviceGetMemoryErrorCounter +def nvmlDeviceGetDetailedEccErrors(handle, errorType, counterType): + c_counts = c_nvmlEccErrorCounts_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDetailedEccErrors") + ret = fn(handle, _nvmlMemoryErrorType_t(errorType), + _nvmlEccCounterType_t(counterType), byref(c_counts)) + _nvmlCheckReturn(ret) + return c_counts + +# Added in 4.304 +def nvmlDeviceGetMemoryErrorCounter(handle, errorType, counterType, locationType): + c_count = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemoryErrorCounter") + ret = fn(handle, + _nvmlMemoryErrorType_t(errorType), + _nvmlEccCounterType_t(counterType), + _nvmlMemoryLocation_t(locationType), + byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + +def nvmlDeviceGetUtilizationRates(handle): + c_util = c_nvmlUtilization_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetUtilizationRates") + ret = fn(handle, byref(c_util)) + _nvmlCheckReturn(ret) + return c_util + +def nvmlDeviceGetEncoderUtilization(handle): + c_util = c_uint() + c_samplingPeriod = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetEncoderUtilization") + ret = fn(handle, byref(c_util), byref(c_samplingPeriod)) + _nvmlCheckReturn(ret) + return [c_util.value, c_samplingPeriod.value] + +def nvmlDeviceGetDecoderUtilization(handle): + c_util = c_uint() + c_samplingPeriod = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDecoderUtilization") + ret = fn(handle, byref(c_util), byref(c_samplingPeriod)) + _nvmlCheckReturn(ret) + return [c_util.value, c_samplingPeriod.value] + +def nvmlDeviceGetJpgUtilization(handle): + c_util = c_uint() + c_samplingPeriod = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetJpgUtilization") + ret = fn(handle, byref(c_util), byref(c_samplingPeriod)) + _nvmlCheckReturn(ret) + return [c_util.value, c_samplingPeriod.value] + +def nvmlDeviceGetOfaUtilization(handle): + c_util = c_uint() + c_samplingPeriod = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetOfaUtilization") + ret = fn(handle, byref(c_util), byref(c_samplingPeriod)) + _nvmlCheckReturn(ret) + return [c_util.value, c_samplingPeriod.value] + +def nvmlDeviceGetPcieReplayCounter(handle): + c_replay = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPcieReplayCounter") + ret = fn(handle, byref(c_replay)) + _nvmlCheckReturn(ret) + return c_replay.value + +def nvmlDeviceGetDriverModel(handle): + c_currModel = _nvmlDriverModel_t() + c_pendingModel = _nvmlDriverModel_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDriverModel") + ret = fn(handle, byref(c_currModel), byref(c_pendingModel)) + _nvmlCheckReturn(ret) + return [c_currModel.value, c_pendingModel.value] + +# added to API +def nvmlDeviceGetCurrentDriverModel(handle): + return nvmlDeviceGetDriverModel(handle)[0] + +# added to API +def nvmlDeviceGetPendingDriverModel(handle): + return nvmlDeviceGetDriverModel(handle)[1] + +# Added in 2.285 +@convertStrBytes +def nvmlDeviceGetVbiosVersion(handle): + c_version = create_string_buffer(NVML_DEVICE_VBIOS_VERSION_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVbiosVersion") + ret = fn(handle, c_version, c_uint(NVML_DEVICE_VBIOS_VERSION_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_version.value + +# Added in 2.285 +def nvmlDeviceGetComputeRunningProcesses_v2(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetComputeRunningProcesses_v2") + ret = fn(handle, byref(c_count), None) + if (ret == NVML_SUCCESS): + # special case, no running processes + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + # oversize the array incase more processes are created + c_count.value = c_count.value * 2 + 5 + proc_array = c_nvmlProcessInfo_v2_t * c_count.value + c_procs = proc_array() + # make the call again + ret = fn(handle, byref(c_count), c_procs) + _nvmlCheckReturn(ret) + procs = [] + for i in range(c_count.value): + # use an alternative struct for this object + obj = nvmlStructToFriendlyObject(c_procs[i]) + if (obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value): + # special case for WDDM on Windows, see comment above + obj.usedGpuMemory = None + procs.append(obj) + return procs + else: + # error case + raise NVMLError(ret) + +# Added in 2.285 +def nvmlDeviceGetComputeRunningProcesses_v3(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetComputeRunningProcesses_v3") + ret = fn(handle, byref(c_count), None) + + if (ret == NVML_SUCCESS): + # special case, no running processes + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + # oversize the array incase more processes are created + c_count.value = c_count.value * 2 + 5 + proc_array = c_nvmlProcessInfo_v3_t * c_count.value + c_procs = proc_array() + + # make the call again + ret = fn(handle, byref(c_count), c_procs) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_count.value): + # use an alternative struct for this object + obj = nvmlStructToFriendlyObject(c_procs[i]) + if (obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value): + # special case for WDDM on Windows, see comment above + obj.usedGpuMemory = None + procs.append(obj) + + return procs + else: + # error case + raise NVMLError(ret) + +@throwOnVersionMismatch +def nvmlDeviceGetComputeRunningProcesses(handle): + return nvmlDeviceGetComputeRunningProcesses_v3(handle) + +def nvmlDeviceGetGraphicsRunningProcesses_v2(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGraphicsRunningProcesses_v2") + ret = fn(handle, byref(c_count), None) + if (ret == NVML_SUCCESS): + # special case, no running processes + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + # oversize the array incase more processes are created + c_count.value = c_count.value * 2 + 5 + proc_array = c_nvmlProcessInfo_v2_t * c_count.value + c_procs = proc_array() + # make the call again + ret = fn(handle, byref(c_count), c_procs) + _nvmlCheckReturn(ret) + procs = [] + for i in range(c_count.value): + # use an alternative struct for this object + obj = nvmlStructToFriendlyObject(c_procs[i]) + if (obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value): + # special case for WDDM on Windows, see comment above + obj.usedGpuMemory = None + procs.append(obj) + return procs + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetGraphicsRunningProcesses_v3(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGraphicsRunningProcesses_v3") + ret = fn(handle, byref(c_count), None) + + if (ret == NVML_SUCCESS): + # special case, no running processes + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + # oversize the array incase more processes are created + c_count.value = c_count.value * 2 + 5 + proc_array = c_nvmlProcessInfo_v3_t * c_count.value + c_procs = proc_array() + + # make the call again + ret = fn(handle, byref(c_count), c_procs) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_count.value): + # use an alternative struct for this object + obj = nvmlStructToFriendlyObject(c_procs[i]) + if (obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value): + # special case for WDDM on Windows, see comment above + obj.usedGpuMemory = None + procs.append(obj) + + return procs + else: + # error case + raise NVMLError(ret) + +@throwOnVersionMismatch +def nvmlDeviceGetGraphicsRunningProcesses(handle): + return nvmlDeviceGetGraphicsRunningProcesses_v3(handle) + +@throwOnVersionMismatch +def nvmlDeviceGetMPSComputeRunningProcesses(handle): + return nvmlDeviceGetMPSComputeRunningProcesses_v3(handle) + +def nvmlDeviceGetMPSComputeRunningProcesses_v2(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMPSComputeRunningProcesses_v2") + ret = fn(handle, byref(c_count), None) + + if (ret == NVML_SUCCESS): + # special case, no running processes + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + # oversize the array incase more processes are created + c_count.value = c_count.value * 2 + 5 + proc_array = c_nvmlProcessInfo_v2_t * c_count.value + c_procs = proc_array() + + # make the call again + ret = fn(handle, byref(c_count), c_procs) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_count.value): + # use an alternative struct for this object + obj = nvmlStructToFriendlyObject(c_procs[i]) + if (obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value): + # special case for WDDM on Windows, see comment above + obj.usedGpuMemory = None + procs.append(obj) + + return procs + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetMPSComputeRunningProcesses_v3(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMPSComputeRunningProcesses_v3") + ret = fn(handle, byref(c_count), None) + + if (ret == NVML_SUCCESS): + # special case, no running processes + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + # oversize the array incase more processes are created + c_count.value = c_count.value * 2 + 5 + proc_array = c_nvmlProcessInfo_v3_t * c_count.value + c_procs = proc_array() + + # make the call again + ret = fn(handle, byref(c_count), c_procs) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_count.value): + # use an alternative struct for this object + obj = nvmlStructToFriendlyObject(c_procs[i]) + if (obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value): + # special case for WDDM on Windows, see comment above + obj.usedGpuMemory = None + procs.append(obj) + + return procs + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetRunningProcessDetailList(handle, version, mode): + c_processDetailList = c_nvmlProcessDetailList_t() + c_processDetailList.version = version + c_processDetailList.mode = mode + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetRunningProcessDetailList") + + # first call to get the size + ret = fn(handle, byref(c_processDetailList)) + if (ret == NVML_SUCCESS): + # special case, no running processes + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + c_procs = c_nvmlProcessDetail_v1_t * c_processDetailList.numProcArrayEntries + c_processDetailList.procArray = cast((c_procs)(), POINTER(c_nvmlProcessDetail_v1_t)) + + # make the call again + ret = fn(handle, byref(c_processDetailList)) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_processDetailList.numProcArrayEntries): + # use an alternative struct for this object + obj = c_processDetailList.procArray[i] + if (obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value): + obj.usedGpuMemory = None + if (obj.usedGpuCcProtectedMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value): + obj.usedGpuCcProtectedMemory = None + procs.append(obj) + + return procs + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetAutoBoostedClocksEnabled(handle): + c_isEnabled = _nvmlEnableState_t() + c_defaultIsEnabled = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAutoBoostedClocksEnabled") + ret = fn(handle, byref(c_isEnabled), byref(c_defaultIsEnabled)) + _nvmlCheckReturn(ret) + return [c_isEnabled.value, c_defaultIsEnabled.value] + #Throws NVML_ERROR_NOT_SUPPORTED if hardware doesn't support setting auto boosted clocks + +## Set functions +def nvmlUnitSetLedState(unit, color): + fn = _nvmlGetFunctionPointer("nvmlUnitSetLedState") + ret = fn(unit, _nvmlLedColor_t(color)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceSetPersistenceMode(handle, mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetPersistenceMode") + ret = fn(handle, _nvmlEnableState_t(mode)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceSetComputeMode(handle, mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetComputeMode") + ret = fn(handle, _nvmlComputeMode_t(mode)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceSetEccMode(handle, mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetEccMode") + ret = fn(handle, _nvmlEnableState_t(mode)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceClearEccErrorCounts(handle, counterType): + fn = _nvmlGetFunctionPointer("nvmlDeviceClearEccErrorCounts") + ret = fn(handle, _nvmlEccCounterType_t(counterType)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceSetDriverModel(handle, model): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetDriverModel") + ret = fn(handle, _nvmlDriverModel_t(model)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceSetAutoBoostedClocksEnabled(handle, enabled): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetAutoBoostedClocksEnabled") + ret = fn(handle, _nvmlEnableState_t(enabled)) + _nvmlCheckReturn(ret) + return None + #Throws NVML_ERROR_NOT_SUPPORTED if hardware doesn't support setting auto boosted clocks + +def nvmlDeviceSetDefaultAutoBoostedClocksEnabled(handle, enabled, flags): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetDefaultAutoBoostedClocksEnabled") + ret = fn(handle, _nvmlEnableState_t(enabled), c_uint(flags)) + _nvmlCheckReturn(ret) + return None + #Throws NVML_ERROR_NOT_SUPPORTED if hardware doesn't support setting auto boosted clocks + +def nvmlDeviceSetGpuLockedClocks(handle, minGpuClockMHz, maxGpuClockMHz): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetGpuLockedClocks") + ret = fn(handle, c_uint(minGpuClockMHz), c_uint(maxGpuClockMHz)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceResetGpuLockedClocks(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceResetGpuLockedClocks") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceSetMemoryLockedClocks(handle, minMemClockMHz, maxMemClockMHz): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetMemoryLockedClocks") + ret = fn(handle, c_uint(minMemClockMHz), c_uint(maxMemClockMHz)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceResetMemoryLockedClocks(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceResetMemoryLockedClocks") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceGetClkMonStatus(handle, c_clkMonInfo=nvmlClkMonStatus_t()): + isReference = type(c_clkMonInfo) is not nvmlClkMonStatus_t + c_clkMonInfoRef = c_clkMonInfo if isReference else byref(c_clkMonInfo) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetClkMonStatus") + ret = fn(handle, c_clkMonInfoRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else c_clkMonInfo + +# Added in 4.304 +def nvmlDeviceSetApplicationsClocks(handle, maxMemClockMHz, maxGraphicsClockMHz): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetApplicationsClocks") + ret = fn(handle, c_uint(maxMemClockMHz), c_uint(maxGraphicsClockMHz)) + _nvmlCheckReturn(ret) + return None + +# Added in 4.304 +def nvmlDeviceResetApplicationsClocks(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceResetApplicationsClocks") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + +# Added in 4.304 +def nvmlDeviceSetPowerManagementLimit(handle, limit): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetPowerManagementLimit") + ret = fn(handle, c_uint(limit)) + _nvmlCheckReturn(ret) + return None + +# Added in 4.304 +def nvmlDeviceSetGpuOperationMode(handle, mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetGpuOperationMode") + ret = fn(handle, _nvmlGpuOperationMode_t(mode)) + _nvmlCheckReturn(ret) + return None + +# Added in 2.285 +def nvmlEventSetCreate(): + fn = _nvmlGetFunctionPointer("nvmlEventSetCreate") + eventSet = c_nvmlEventSet_t() + ret = fn(byref(eventSet)) + _nvmlCheckReturn(ret) + return eventSet + +# Added in 2.285 +def nvmlDeviceRegisterEvents(handle, eventTypes, eventSet): + fn = _nvmlGetFunctionPointer("nvmlDeviceRegisterEvents") + ret = fn(handle, c_ulonglong(eventTypes), eventSet) + _nvmlCheckReturn(ret) + return None + +# Added in 2.285 +def nvmlDeviceGetSupportedEventTypes(handle): + c_eventTypes = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedEventTypes") + ret = fn(handle, byref(c_eventTypes)) + _nvmlCheckReturn(ret) + return c_eventTypes.value + +# raises NVML_ERROR_TIMEOUT exception on timeout +def nvmlEventSetWait_v2(eventSet, timeoutms): + fn = _nvmlGetFunctionPointer("nvmlEventSetWait_v2") + data = c_nvmlEventData_t() + ret = fn(eventSet, byref(data), c_uint(timeoutms)) + _nvmlCheckReturn(ret) + return data + +def nvmlEventSetWait(eventSet, timeoutms): + return nvmlEventSetWait_v2(eventSet, timeoutms) + +# Added in 2.285 +def nvmlEventSetFree(eventSet): + fn = _nvmlGetFunctionPointer("nvmlEventSetFree") + ret = fn(eventSet) + _nvmlCheckReturn(ret) + return None + +# Added in 3.295 +def nvmlDeviceOnSameBoard(handle1, handle2): + fn = _nvmlGetFunctionPointer("nvmlDeviceOnSameBoard") + onSameBoard = c_int() + ret = fn(handle1, handle2, byref(onSameBoard)) + _nvmlCheckReturn(ret) + return (onSameBoard.value != 0) + +# Added in 3.295 +def nvmlDeviceGetCurrPcieLinkGeneration(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCurrPcieLinkGeneration") + gen = c_uint() + ret = fn(handle, byref(gen)) + _nvmlCheckReturn(ret) + return gen.value + +# Added in 3.295 +def nvmlDeviceGetMaxPcieLinkGeneration(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMaxPcieLinkGeneration") + gen = c_uint() + ret = fn(handle, byref(gen)) + _nvmlCheckReturn(ret) + return gen.value + +# Added in 3.295 +def nvmlDeviceGetCurrPcieLinkWidth(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCurrPcieLinkWidth") + width = c_uint() + ret = fn(handle, byref(width)) + _nvmlCheckReturn(ret) + return width.value + +# Added in 3.295 +def nvmlDeviceGetMaxPcieLinkWidth(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMaxPcieLinkWidth") + width = c_uint() + ret = fn(handle, byref(width)) + _nvmlCheckReturn(ret) + return width.value + +def nvmlDeviceGetGpuMaxPcieLinkGeneration(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuMaxPcieLinkGeneration") + gen = c_uint() + ret = fn(handle, byref(gen)) + _nvmlCheckReturn(ret) + return gen.value + +# Added in 4.304 +def nvmlDeviceGetSupportedClocksThrottleReasons(handle): + c_reasons= c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedClocksThrottleReasons") + ret = fn(handle, byref(c_reasons)) + _nvmlCheckReturn(ret) + return c_reasons.value + +def nvmlDeviceGetSupportedClocksEventReasons(handle): + c_reasons= c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedClocksEventReasons") + ret = fn(handle, byref(c_reasons)) + _nvmlCheckReturn(ret) + return c_reasons.value + +# Added in 4.304 +def nvmlDeviceGetCurrentClocksThrottleReasons(handle): + c_reasons= c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCurrentClocksThrottleReasons") + ret = fn(handle, byref(c_reasons)) + _nvmlCheckReturn(ret) + return c_reasons.value + +def nvmlDeviceGetCurrentClocksEventReasons(handle): + c_reasons= c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCurrentClocksEventReasons") + ret = fn(handle, byref(c_reasons)) + _nvmlCheckReturn(ret) + return c_reasons.value + +# Added in 5.319 +def nvmlDeviceGetIndex(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetIndex") + c_index = c_uint() + ret = fn(handle, byref(c_index)) + _nvmlCheckReturn(ret) + return c_index.value + +# Added in 5.319 +def nvmlDeviceGetAccountingMode(handle): + c_mode = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAccountingMode") + ret = fn(handle, byref(c_mode)) + _nvmlCheckReturn(ret) + return c_mode.value + +def nvmlDeviceSetAccountingMode(handle, mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetAccountingMode") + ret = fn(handle, _nvmlEnableState_t(mode)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceClearAccountingPids(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceClearAccountingPids") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceGetAccountingStats(handle, pid): + stats = c_nvmlAccountingStats_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAccountingStats") + ret = fn(handle, c_uint(pid), byref(stats)) + _nvmlCheckReturn(ret) + if (stats.maxMemoryUsage == NVML_VALUE_NOT_AVAILABLE_ulonglong.value): + # special case for WDDM on Windows, see comment above + stats.maxMemoryUsage = None + return stats + +def nvmlDeviceGetAccountingPids(handle): + count = c_uint(nvmlDeviceGetAccountingBufferSize(handle)) + pids = (c_uint * count.value)() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAccountingPids") + ret = fn(handle, byref(count), pids) + _nvmlCheckReturn(ret) + return list(map(int, pids[0:count.value])) + +def nvmlDeviceGetAccountingBufferSize(handle): + bufferSize = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAccountingBufferSize") + ret = fn(handle, byref(bufferSize)) + _nvmlCheckReturn(ret) + return int(bufferSize.value) + +def nvmlDeviceGetRetiredPages(device, sourceFilter): + c_source = _nvmlPageRetirementCause_t(sourceFilter) + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetRetiredPages") + + # First call will get the size + ret = fn(device, c_source, byref(c_count), None) + + # this should only fail with insufficient size + if ((ret != NVML_SUCCESS) and + (ret != NVML_ERROR_INSUFFICIENT_SIZE)): + raise NVMLError(ret) + + # call again with a buffer + # oversize the array for the rare cases where additional pages + # are retired between NVML calls + c_count.value = c_count.value * 2 + 5 + page_array = c_ulonglong * c_count.value + c_pages = page_array() + ret = fn(device, c_source, byref(c_count), c_pages) + _nvmlCheckReturn(ret) + return list(map(int, c_pages[0:c_count.value])) + +def nvmlDeviceGetRetiredPages_v2(device, sourceFilter): + c_source = _nvmlPageRetirementCause_t(sourceFilter) + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetRetiredPages_v2") + + # First call will get the size + ret = fn(device, c_source, byref(c_count), None) + + # this should only fail with insufficient size + if ((ret != NVML_SUCCESS) and + (ret != NVML_ERROR_INSUFFICIENT_SIZE)): + raise NVMLError(ret) + + # call again with a buffer + # oversize the array for the rare cases where additional pages + # are retired between NVML calls + c_count.value = c_count.value * 2 + 5 + page_array = c_ulonglong * c_count.value + c_pages = page_array() + times_array = c_ulonglong * c_count.value + c_times = times_array() + ret = fn(device, c_source, byref(c_count), c_pages, c_times) + _nvmlCheckReturn(ret) + return [ { 'address': int(c_pages[i]), 'timestamp': int(c_times[i]) } for i in range(c_count.value) ]; + +def nvmlDeviceGetRetiredPagesPendingStatus(device): + c_pending = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetRetiredPagesPendingStatus") + ret = fn(device, byref(c_pending)) + _nvmlCheckReturn(ret) + return int(c_pending.value) + +def nvmlDeviceGetAPIRestriction(device, apiType): + c_permission = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAPIRestriction") + ret = fn(device, _nvmlRestrictedAPI_t(apiType), byref(c_permission)) + _nvmlCheckReturn(ret) + return int(c_permission.value) + +def nvmlDeviceSetAPIRestriction(handle, apiType, isRestricted): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetAPIRestriction") + ret = fn(handle, _nvmlRestrictedAPI_t(apiType), _nvmlEnableState_t(isRestricted)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceGetBridgeChipInfo(handle): + bridgeHierarchy = c_nvmlBridgeChipHierarchy_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetBridgeChipInfo") + ret = fn(handle, byref(bridgeHierarchy)) + _nvmlCheckReturn(ret) + return bridgeHierarchy + +def nvmlDeviceGetSamples(device, sampling_type, timeStamp): + c_sampling_type = _nvmlSamplingType_t(sampling_type) + c_time_stamp = c_ulonglong(timeStamp) + c_sample_count = c_uint(0) + c_sample_value_type = _nvmlValueType_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSamples") + + ## First Call gets the size + ret = fn(device, c_sampling_type, c_time_stamp, byref(c_sample_value_type), byref(c_sample_count), None) + + # Stop if this fails + if (ret != NVML_SUCCESS): + raise NVMLError(ret) + + sampleArray = c_sample_count.value * c_nvmlSample_t + c_samples = sampleArray() + ret = fn(device, c_sampling_type, c_time_stamp, byref(c_sample_value_type), byref(c_sample_count), c_samples) + _nvmlCheckReturn(ret) + return (c_sample_value_type.value, c_samples[0:c_sample_count.value]) + +def nvmlDeviceGetViolationStatus(device, perfPolicyType): + c_perfPolicy_type = _nvmlPerfPolicyType_t(perfPolicyType) + c_violTime = c_nvmlViolationTime_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetViolationStatus") + + ## Invoke the method to get violation time + ret = fn(device, c_perfPolicy_type, byref(c_violTime)) + _nvmlCheckReturn(ret) + return c_violTime + +def nvmlDeviceGetPcieThroughput(device, counter): + c_util = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPcieThroughput") + ret = fn(device, _nvmlPcieUtilCounter_t(counter), byref(c_util)) + _nvmlCheckReturn(ret) + return c_util.value + +def nvmlSystemGetTopologyGpuSet(cpuNumber): + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlSystemGetTopologyGpuSet") + + # First call will get the size + ret = fn(cpuNumber, byref(c_count), None) + + if ret != NVML_SUCCESS: + raise NVMLError(ret) + # call again with a buffer + device_array = c_nvmlDevice_t * c_count.value + c_devices = device_array() + ret = fn(cpuNumber, byref(c_count), c_devices) + _nvmlCheckReturn(ret) + return list(c_devices[0:c_count.value]) + +def nvmlDeviceGetTopologyNearestGpus(device, level): + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTopologyNearestGpus") + + # First call will get the size + ret = fn(device, level, byref(c_count), None) + + if ret != NVML_SUCCESS: + raise NVMLError(ret) + + # call again with a buffer + device_array = c_nvmlDevice_t * c_count.value + c_devices = device_array() + ret = fn(device, level, byref(c_count), c_devices) + _nvmlCheckReturn(ret) + return list(c_devices[0:c_count.value]) + +def nvmlDeviceGetTopologyCommonAncestor(device1, device2): + c_level = _nvmlGpuTopologyLevel_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTopologyCommonAncestor") + ret = fn(device1, device2, byref(c_level)) + _nvmlCheckReturn(ret) + return c_level.value + +def nvmlDeviceGetNvLinkUtilizationCounter(device, link, counter): + c_rxcounter = c_ulonglong() + c_txcounter = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkUtilizationCounter") + ret = fn(device, link, counter, byref(c_rxcounter), byref(c_txcounter)) + _nvmlCheckReturn(ret) + return (c_rxcounter.value, c_txcounter.value) + +def nvmlDeviceFreezeNvLinkUtilizationCounter(device, link, counter, freeze): + fn = _nvmlGetFunctionPointer("nvmlDeviceFreezeNvLinkUtilizationCounter") + ret = fn(device, link, counter, freeze) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceResetNvLinkUtilizationCounter(device, link, counter): + fn = _nvmlGetFunctionPointer("nvmlDeviceResetNvLinkUtilizationCounter") + ret = fn(device, link, counter) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceSetNvLinkUtilizationControl(device, link, counter, control, reset): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetNvLinkUtilizationControl") + ret = fn(device, link, counter, byref(control), reset) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceGetNvLinkUtilizationControl(device, link, counter): + c_control = nvmlNvLinkUtilizationControl_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkUtilizationControl") + ret = fn(device, link, counter, byref(c_control)) + _nvmlCheckReturn(ret) + return c_control + +def nvmlDeviceGetNvLinkCapability(device, link, capability): + c_capResult = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkCapability") + ret = fn(device, link, capability, byref(c_capResult)) + _nvmlCheckReturn(ret) + return c_capResult.value + +def nvmlDeviceGetNvLinkErrorCounter(device, link, counter): + c_result = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkErrorCounter") + ret = fn(device, link, counter, byref(c_result)) + _nvmlCheckReturn(ret) + return c_result.value + +def nvmlDeviceResetNvLinkErrorCounters(device, link): + fn = _nvmlGetFunctionPointer("nvmlDeviceResetNvLinkErrorCounters") + ret = fn(device, link) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceGetNvLinkRemotePciInfo(device, link): + c_pci = nvmlPciInfo_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkRemotePciInfo_v2") + ret = fn(device, link, byref(c_pci)) + _nvmlCheckReturn(ret) + return c_pci + +def nvmlDeviceGetNvLinkRemoteDeviceType(handle, link): + c_type = _nvmlNvLinkDeviceType_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkRemoteDeviceType") + ret = fn(handle, link, byref(c_type)) + _nvmlCheckReturn(ret) + return c_type.value + +def nvmlDeviceGetNvLinkState(device, link): + c_isActive = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkState") + ret = fn(device, link, byref(c_isActive)) + _nvmlCheckReturn(ret) + return c_isActive.value + +def nvmlDeviceGetNvLinkVersion(device, link): + c_version = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkVersion") + ret = fn(device, link, byref(c_version)) + _nvmlCheckReturn(ret) + return c_version.value + +def nvmlDeviceModifyDrainState(pciInfo, newState): + fn = _nvmlGetFunctionPointer("nvmlDeviceModifyDrainState") + ret = fn(pointer(pciInfo), newState) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceQueryDrainState(pciInfo): + c_newState = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceQueryDrainState") + ret = fn(pointer(pciInfo), byref(c_newState)) + _nvmlCheckReturn(ret) + return c_newState.value + +def nvmlDeviceRemoveGpu(pciInfo): + fn = _nvmlGetFunctionPointer("nvmlDeviceRemoveGpu") + ret = fn(pointer(pciInfo)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceDiscoverGpus(pciInfo): + fn = _nvmlGetFunctionPointer("nvmlDeviceDiscoverGpus") + ret = fn(pointer(pciInfo)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceGetFieldValues(handle, fieldIds): + values_arr = c_nvmlFieldValue_t * len(fieldIds) + values = values_arr() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFieldValues") + + for i, fieldId in enumerate(fieldIds): + try: + (values[i].fieldId, values[i].scopeId) = fieldId + except TypeError: + values[i].fieldId = fieldId + + ret = fn(handle, c_int32(len(fieldIds)), byref(values)) + _nvmlCheckReturn(ret) + return values + +def nvmlDeviceClearFieldValues(handle, fieldIds): + values_arr = c_nvmlFieldValue_t * len(fieldIds) + values = values_arr() + fn = _nvmlGetFunctionPointer("nvmlDeviceClearFieldValues") + + for i, fieldId in enumerate(fieldIds): + try: + (values[i].fieldId, values[i].scopeId) = fieldId + except TypeError: + values[i].fieldId = fieldId + + ret = fn(handle, c_int32(len(fieldIds)), byref(values)) + _nvmlCheckReturn(ret) + return values + +def nvmlDeviceGetVirtualizationMode(handle): + c_virtualization_mode = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVirtualizationMode") + ret = fn(handle, byref(c_virtualization_mode)) + _nvmlCheckReturn(ret) + return c_virtualization_mode.value + +def nvmlDeviceSetVirtualizationMode(handle, virtualization_mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetVirtualizationMode") + return fn(handle, virtualization_mode) + +def nvmlDeviceGetVgpuHeterogeneousMode(handle): + c_vgpuHeterogeneousMode = c_nvmlVgpuHeterogeneousMode_v1_t(0) + c_vgpuHeterogeneousMode.version = VgpuHeterogeneousMode_v1 + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuHeterogeneousMode") + ret = fn(handle, byref(c_vgpuHeterogeneousMode)) + _nvmlCheckReturn(ret) + return c_vgpuHeterogeneousMode.mode + +def nvmlDeviceSetVgpuHeterogeneousMode(handle, heterogeneous_mode): + c_vgpuHeterogeneousMode = c_nvmlVgpuHeterogeneousMode_v1_t(0) + c_vgpuHeterogeneousMode.version = VgpuHeterogeneousMode_v1 + c_vgpuHeterogeneousMode.mode = heterogeneous_mode + fn = _nvmlGetFunctionPointer("nvmlDeviceSetVgpuHeterogeneousMode") + ret = fn(handle, byref(c_vgpuHeterogeneousMode)) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlVgpuInstanceGetPlacementId(vgpuInstance): + c_placement = c_nvmlVgpuPlacementId_v1_t(0) + c_placement.version = VgpuPlacementId_v1 + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetPlacementId") + ret = fn(vgpuInstance, byref(c_placement)) + _nvmlCheckReturn(ret) + return c_placement.placementId + +def nvmlDeviceGetVgpuTypeSupportedPlacements(handle, vgpuTypeId, mode=0, version=1): + c_max_instances = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetMaxInstances") + ret = fn(handle, vgpuTypeId, byref(c_max_instances)) + _nvmlCheckReturn(ret) + + if version == 2: + c_vgpu_placements = c_nvmlVgpuPlacementList_v2_t() + c_vgpu_placements.version = VgpuPlacementList_v2 + c_vgpu_placements.count = c_max_instances.value + c_vgpu_placements.mode = mode + elif version == 1: + c_vgpu_placements = c_nvmlVgpuPlacementList_v1_t() + c_vgpu_placements.version = VgpuPlacementList_v1 + else: + raise NVMLError(NVML_ERROR_ARGUMENT_VERSION_MISMATCH) + + c_placements = c_uint * c_max_instances.value + c_vgpu_placements.placementIds = c_placements() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuTypeSupportedPlacements") + ret = fn(handle, vgpuTypeId, byref(c_vgpu_placements)) + _nvmlCheckReturn(ret) + return c_vgpu_placements + +def nvmlDeviceGetVgpuTypeCreatablePlacements(handle, vgpuTypeId, version=1): + c_max_instances = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetMaxInstances") + ret = fn(handle, vgpuTypeId, byref(c_max_instances)) + _nvmlCheckReturn(ret) + + if version == 2: + c_vgpu_placements = c_nvmlVgpuPlacementList_v2_t() + c_vgpu_placements.version = VgpuPlacementList_v2 + c_vgpu_placements.count = c_max_instances.value + elif version == 1: + c_vgpu_placements = c_nvmlVgpuPlacementList_v1_t() + c_vgpu_placements.version = VgpuPlacementList_v1 + + c_placements = c_uint * c_max_instances.value + c_vgpu_placements.placementIds = c_placements() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuTypeCreatablePlacements") + ret = fn(handle, vgpuTypeId, byref(c_vgpu_placements)) + _nvmlCheckReturn(ret) + return c_vgpu_placements + +def nvmlGetVgpuDriverCapabilities(capability): + c_capResult = c_uint() + fn = _nvmlGetFunctionPointer("nvmlGetVgpuDriverCapabilities") + ret = fn(_nvmlVgpuDriverCapability_t(capability), byref(c_capResult)) + _nvmlCheckReturn(ret) + return c_capResult.value + +def nvmlDeviceGetVgpuCapabilities(handle, capability): + c_capResult = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuCapabilities") + ret = fn(handle, _nvmlDeviceVgpuCapability_t(capability), byref(c_capResult)) + _nvmlCheckReturn(ret) + return c_capResult.value + +def nvmlDeviceSetVgpuCapabilities(handle, capability, state): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetVgpuCapabilities") + ret = fn(handle, _nvmlDeviceVgpuCapability_t(capability), state) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceGetSupportedVgpus(handle): + # first call to get the size + c_vgpu_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedVgpus") + ret = fn(handle, byref(c_vgpu_count), None) + + if (ret == NVML_SUCCESS): + # special case, no supported vGPUs + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + vgpu_type_ids_array = _nvmlVgpuTypeId_t * c_vgpu_count.value + c_vgpu_type_ids = vgpu_type_ids_array() + + # make the call again + ret = fn(handle, byref(c_vgpu_count), c_vgpu_type_ids) + _nvmlCheckReturn(ret) + vgpus = [] + for i in range(c_vgpu_count.value): + vgpus.append(c_vgpu_type_ids[i]) + return vgpus + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetCreatableVgpus(handle): + # first call to get the size + c_vgpu_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCreatableVgpus") + ret = fn(handle, byref(c_vgpu_count), None) + + if (ret == NVML_SUCCESS): + # special case, no supported vGPUs + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + vgpu_type_ids_array = _nvmlVgpuTypeId_t * c_vgpu_count.value + c_vgpu_type_ids = vgpu_type_ids_array() + + # make the call again + ret = fn(handle, byref(c_vgpu_count), c_vgpu_type_ids) + _nvmlCheckReturn(ret) + vgpus = [] + for i in range(c_vgpu_count.value): + vgpus.append(c_vgpu_type_ids[i]) + return vgpus + else: + # error case + raise NVMLError(ret) + +def nvmlVgpuTypeGetGpuInstanceProfileId(vgpuTypeId): + c_profile_id = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetGpuInstanceProfileId") + ret = fn(vgpuTypeId, byref(c_profile_id)) + _nvmlCheckReturn(ret) + return (c_profile_id.value) + +@convertStrBytes +def nvmlVgpuTypeGetClass(vgpuTypeId): + c_class = create_string_buffer(NVML_DEVICE_NAME_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_DEVICE_NAME_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetClass") + ret = fn(vgpuTypeId, c_class, byref(c_buffer_size)) + _nvmlCheckReturn(ret) + return c_class.value + +@convertStrBytes +def nvmlVgpuTypeGetName(vgpuTypeId): + c_name = create_string_buffer(NVML_DEVICE_NAME_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_DEVICE_NAME_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetName") + ret = fn(vgpuTypeId, c_name, byref(c_buffer_size)) + _nvmlCheckReturn(ret) + return c_name.value + +def nvmlVgpuTypeGetDeviceID(vgpuTypeId): + c_device_id = c_ulonglong(0) + c_subsystem_id = c_ulonglong(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetDeviceID") + ret = fn(vgpuTypeId, byref(c_device_id), byref(c_subsystem_id)) + _nvmlCheckReturn(ret) + return (c_device_id.value, c_subsystem_id.value) + +def nvmlVgpuTypeGetFramebufferSize(vgpuTypeId): + c_fb_size = c_ulonglong(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetFramebufferSize") + ret = fn(vgpuTypeId, byref(c_fb_size)) + _nvmlCheckReturn(ret) + return c_fb_size.value + +def nvmlVgpuTypeGetNumDisplayHeads(vgpuTypeId): + c_num_heads = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetNumDisplayHeads") + ret = fn(vgpuTypeId, byref(c_num_heads)) + _nvmlCheckReturn(ret) + return c_num_heads.value + +def nvmlVgpuTypeGetResolution(vgpuTypeId): + c_xdim = c_uint(0) + c_ydim = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetResolution") + ret = fn(vgpuTypeId, 0, byref(c_xdim), byref(c_ydim)) + _nvmlCheckReturn(ret) + return (c_xdim.value, c_ydim.value) + +@convertStrBytes +def nvmlVgpuTypeGetLicense(vgpuTypeId): + c_license = create_string_buffer(NVML_GRID_LICENSE_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_GRID_LICENSE_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetLicense") + ret = fn(vgpuTypeId, c_license, c_buffer_size) + _nvmlCheckReturn(ret) + return c_license.value + +def nvmlVgpuTypeGetFrameRateLimit(vgpuTypeId): + c_frl_config = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetFrameRateLimit") + ret = fn(vgpuTypeId, byref(c_frl_config)) + _nvmlCheckReturn(ret) + return c_frl_config.value + +def nvmlVgpuTypeGetGspHeapSize(vgpuTypeId): + c_gsp_heap = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetGspHeapSize") + ret = fn(vgpuTypeId, byref(c_gsp_heap)) + _nvmlCheckReturn(ret) + return c_gsp_heap.value + +def nvmlVgpuTypeGetFbReservation(vgpuTypeId): + c_fb_reservation = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetFbReservation") + ret = fn(vgpuTypeId, byref(c_fb_reservation)) + _nvmlCheckReturn(ret) + return c_fb_reservation.value + +def nvmlVgpuInstanceGetRuntimeStateSize(vgpuInstance): + c_runtime_state = nvmlVgpuRuntimeState_v1_t() + c_runtime_state.version = VgpuRuntimeState_v1 + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetRuntimeStateSize") + ret = fn(vgpuInstance, byref(c_runtime_state)) + _nvmlCheckReturn(ret) + return c_runtime_state + +def nvmlVgpuTypeGetMaxInstances(handle, vgpuTypeId): + c_max_instances = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetMaxInstances") + ret = fn(handle, vgpuTypeId, byref(c_max_instances)) + _nvmlCheckReturn(ret) + return c_max_instances.value + +def nvmlVgpuTypeGetMaxInstancesPerVm(vgpuTypeId): + c_max_instances_per_vm = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetMaxInstancesPerVm") + ret = fn(vgpuTypeId, byref(c_max_instances_per_vm)) + _nvmlCheckReturn(ret) + return c_max_instances_per_vm.value + +def nvmlVgpuTypeGetBAR1Info(vgpuTypeId): + c_bar1Info = c_nvmlVgpuTypeBar1Info_v1_t(0) + c_bar1Info.version = VgpuTypeBar1Info_v1 + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetBAR1Info") + ret = fn(vgpuTypeId, byref(c_bar1Info)) + _nvmlCheckReturn(ret) + return c_bar1Info + +def nvmlDeviceGetActiveVgpus(handle): + # first call to get the size + c_vgpu_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetActiveVgpus") + ret = fn(handle, byref(c_vgpu_count), None) + + if (ret == NVML_SUCCESS): + # special case, no active vGPUs + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + vgpu_instance_array = _nvmlVgpuInstance_t * c_vgpu_count.value + c_vgpu_instances = vgpu_instance_array() + + # make the call again + ret = fn(handle, byref(c_vgpu_count), c_vgpu_instances) + _nvmlCheckReturn(ret) + vgpus = [] + for i in range(c_vgpu_count.value): + vgpus.append(c_vgpu_instances[i]) + return vgpus + else: + # error case + raise NVMLError(ret) + +@convertStrBytes +def nvmlVgpuInstanceGetVmID(vgpuInstance): + c_vm_id = create_string_buffer(NVML_DEVICE_UUID_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_GRID_LICENSE_BUFFER_SIZE) + c_vm_id_type = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetVmID") + ret = fn(vgpuInstance, byref(c_vm_id), c_buffer_size, byref(c_vm_id_type)) + _nvmlCheckReturn(ret) + return (c_vm_id.value, c_vm_id_type.value) + +@convertStrBytes +def nvmlVgpuInstanceGetUUID(vgpuInstance): + c_uuid = create_string_buffer(NVML_DEVICE_UUID_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_DEVICE_UUID_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetUUID") + ret = fn(vgpuInstance, byref(c_uuid), c_buffer_size) + _nvmlCheckReturn(ret) + return c_uuid.value + +@convertStrBytes +def nvmlVgpuInstanceGetMdevUUID(vgpuInstance): + c_uuid = create_string_buffer(NVML_DEVICE_UUID_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_DEVICE_UUID_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetMdevUUID") + ret = fn(vgpuInstance, byref(c_uuid), c_buffer_size) + _nvmlCheckReturn(ret) + return c_uuid.value + +@convertStrBytes +def nvmlVgpuInstanceGetVmDriverVersion(vgpuInstance): + c_driver_version = create_string_buffer(NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetVmDriverVersion") + ret = fn(vgpuInstance, byref(c_driver_version), c_buffer_size) + _nvmlCheckReturn(ret) + return c_driver_version.value + +def nvmlVgpuInstanceGetLicenseStatus(vgpuInstance): + c_license_status = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetLicenseStatus") + ret = fn(vgpuInstance, byref(c_license_status)) + _nvmlCheckReturn(ret) + return c_license_status.value + +def nvmlVgpuInstanceGetLicenseInfo_v2(vgpuInstance): + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetLicenseInfo_v2") + c_license_info = c_nvmlVgpuLicenseInfo_t() + ret = fn(vgpuInstance, byref(c_license_info)) + _nvmlCheckReturn(ret) + return c_license_info + +def nvmlVgpuInstanceGetLicenseInfo(vgpuInstance): + return nvmlVgpuInstanceGetLicenseInfo_v2(vgpuInstance) + +def nvmlVgpuInstanceGetFrameRateLimit(vgpuInstance): + c_frl = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetFrameRateLimit") + ret = fn(vgpuInstance, byref(c_frl)) + _nvmlCheckReturn(ret) + return c_frl.value + +def nvmlVgpuInstanceGetEccMode(vgpuInstance): + c_mode = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetEccMode") + ret = fn(vgpuInstance, byref(c_mode)) + _nvmlCheckReturn(ret) + return c_mode.value + +def nvmlVgpuInstanceGetType(vgpuInstance): + c_vgpu_type = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetType") + ret = fn(vgpuInstance, byref(c_vgpu_type)) + _nvmlCheckReturn(ret) + return c_vgpu_type.value + +def nvmlVgpuInstanceGetEncoderCapacity(vgpuInstance): + c_encoder_capacity = c_ulonglong(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetEncoderCapacity") + ret = fn(vgpuInstance, byref(c_encoder_capacity)) + _nvmlCheckReturn(ret) + return c_encoder_capacity.value + +def nvmlVgpuInstanceSetEncoderCapacity(vgpuInstance, encoder_capacity): + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceSetEncoderCapacity") + return fn(vgpuInstance, encoder_capacity) + +def nvmlVgpuInstanceGetFbUsage(vgpuInstance): + c_fb_usage = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetFbUsage") + ret = fn(vgpuInstance, byref(c_fb_usage)) + _nvmlCheckReturn(ret) + return c_fb_usage.value + +def nvmlVgpuTypeGetCapabilities(vgpuTypeId, capability): + c_cap_result = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetCapabilities") + ret = fn(vgpuTypeId, _nvmlVgpuCapability_t(capability), byref(c_cap_result)) + _nvmlCheckReturn(ret) + return (c_cap_result.value) + +def nvmlVgpuInstanceGetGpuInstanceId(vgpuInstance): + c_id = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetGpuInstanceId") + ret = fn(vgpuInstance, byref(c_id)) + _nvmlCheckReturn(ret) + return (c_id.value) + +@convertStrBytes +def nvmlVgpuInstanceGetGpuPciId(vgpuInstance): + c_vgpuPciId = create_string_buffer(NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetGpuPciId") + ret = fn(vgpuInstance, c_vgpuPciId, byref(c_uint(NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE))) + _nvmlCheckReturn(ret) + return c_vgpuPciId.value + +def nvmlDeviceGetVgpuUtilization(handle, timeStamp): + # first call to get the size + c_vgpu_count = c_uint(0) + c_time_stamp = c_ulonglong(timeStamp) + c_sample_value_type = _nvmlValueType_t() + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuUtilization") + ret = fn(handle, c_time_stamp, byref(c_sample_value_type), byref(c_vgpu_count), None) + + if (ret == NVML_SUCCESS): + # special case, no active vGPUs + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + sampleArray = c_vgpu_count.value * c_nvmlVgpuInstanceUtilizationSample_t + c_samples = sampleArray() + + # make the call again + ret = fn(handle, c_time_stamp, byref(c_sample_value_type), byref(c_vgpu_count), c_samples) + _nvmlCheckReturn(ret) + + return c_samples[0:c_vgpu_count.value] + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetVgpuInstancesUtilizationInfo(handle, timeStamp): + # first call to get the size + c_time_stamp = c_ulonglong(timeStamp) + c_vgpuUtilInfo = c_nvmlVgpuInstancesUtilizationInfo_v1_t(0) + c_vgpuUtilInfo.version = VgpuInstancesUtilizationInfo_v1 + c_vgpuUtilInfo.sampleValType = _nvmlValueType_t() + c_vgpuUtilInfo.vgpuInstanceCount = c_uint(0) + c_vgpuUtilInfo.lastSeenTimeStamp = c_time_stamp + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuInstancesUtilizationInfo") + ret = fn(handle, byref(c_vgpuUtilInfo)) + + if (ret == NVML_SUCCESS): + # special case, no active vGPUs + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + sampleArray = c_vgpuUtilInfo.vgpuInstanceCount * c_nvmlVgpuInstanceUtilizationInfo_v1_t + c_samples = sampleArray() + c_vgpuUtilInfo.vgpuUtilArray = c_samples + + # make the call again + ret = fn(handle, byref(c_vgpuUtilInfo)) + _nvmlCheckReturn(ret) + + return c_samples[0:c_vgpuUtilInfo.vgpuInstanceCount] + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetP2PStatus(device1, device2, p2pIndex): + c_p2pstatus = _nvmlGpuP2PStatus_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetP2PStatus") + ret = fn(device1, device2,p2pIndex, byref(c_p2pstatus)) + _nvmlCheckReturn(ret) + return c_p2pstatus.value + +def nvmlDeviceGetGridLicensableFeatures_v4(handle): + c_get_grid_licensable_features = c_nvmlGridLicensableFeatures_v4_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGridLicensableFeatures_v4") + ret = fn(handle, byref(c_get_grid_licensable_features)) + _nvmlCheckReturn(ret) + + return (c_get_grid_licensable_features) + +def nvmlDeviceGetGridLicensableFeatures(handle): + return nvmlDeviceGetGridLicensableFeatures_v4(handle) + +def nvmlDeviceGetGspFirmwareVersion(handle, version=None): + isUserDefined = version is not None + if not isUserDefined: + version = (c_char * NVML_GSP_FIRMWARE_VERSION_BUF_SIZE)() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGspFirmwareVersion") + ret = fn(handle, version) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isUserDefined else version.value + +def nvmlDeviceGetGspFirmwareMode(handle, isEnabled=c_uint(), defaultMode=c_uint()): + isReference = type(isEnabled) is not c_uint + isEnabledRef = isEnabled if isReference else byref(isEnabled) + defaultModeRef = defaultMode if isReference else byref(defaultMode) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGspFirmwareMode") + ret = fn(handle, isEnabledRef, defaultModeRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else [isEnabled.value, defaultMode.value] + +def nvmlDeviceGetEncoderCapacity(handle, encoderQueryType): + c_encoder_capacity = c_ulonglong(0) + c_encoderQuery_type = _nvmlEncoderQueryType_t(encoderQueryType) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetEncoderCapacity") + ret = fn(handle, c_encoderQuery_type, byref(c_encoder_capacity)) + _nvmlCheckReturn(ret) + return c_encoder_capacity.value + +def nvmlDeviceGetVgpuProcessUtilization(handle, timeStamp): + # first call to get the size + c_vgpu_count = c_uint(0) + c_time_stamp = c_ulonglong(timeStamp) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuProcessUtilization") + ret = fn(handle, c_time_stamp, byref(c_vgpu_count), None) + + if (ret == NVML_SUCCESS): + # special case, no active vGPUs + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + sampleArray = c_vgpu_count.value * c_nvmlVgpuProcessUtilizationSample_t + c_samples = sampleArray() + + # make the call again + ret = fn(handle, c_time_stamp, byref(c_vgpu_count), c_samples) + _nvmlCheckReturn(ret) + + return c_samples[0:c_vgpu_count.value] + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetVgpuProcessesUtilizationInfo(handle, timeStamp): + # first call to get the size + c_time_stamp = c_ulonglong(timeStamp) + c_vgpuProcUtilInfo = c_nvmlVgpuProcessesUtilizationInfo_v1_t(0) + c_vgpuProcUtilInfo.version = VgpuProcessesUtilizationInfo_v1 + c_vgpuProcUtilInfo.vgpuProcessCount = c_uint(0) + c_vgpuProcUtilInfo.lastSeenTimeStamp = c_time_stamp + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuProcessesUtilizationInfo") + ret = fn(handle, byref(c_vgpuProcUtilInfo)) + + if (ret == NVML_SUCCESS): + # special case, no active vGPUs + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + sampleArray = c_vgpuProcUtilInfo.vgpuProcessCount * c_nvmlVgpuProcessUtilizationInfo_v1_t + c_samples = sampleArray() + c_vgpuProcUtilInfo.vgpuProcUtilArray = c_samples + + # make the call again + ret = fn(handle, byref(c_vgpuProcUtilInfo)) + _nvmlCheckReturn(ret) + + return c_samples[0:c_vgpuProcUtilInfo.vgpuProcessCount] + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetEncoderStats(handle): + c_encoderCount = c_ulonglong(0) + c_encodeFps = c_ulonglong(0) + c_encoderLatency = c_ulonglong(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetEncoderStats") + ret = fn(handle, byref(c_encoderCount), byref(c_encodeFps), byref(c_encoderLatency)) + _nvmlCheckReturn(ret) + return (c_encoderCount.value, c_encodeFps.value, c_encoderLatency.value) + +def nvmlDeviceGetEncoderSessions(handle): + # first call to get the size + c_session_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetEncoderSessions") + ret = fn(handle, byref(c_session_count), None) + + if (ret == NVML_SUCCESS): + if (c_session_count.value != 0): + # typical case + session_array = c_nvmlEncoderSession_t * c_session_count.value + c_sessions = session_array() + + # make the call again + ret = fn(handle, byref(c_session_count), c_sessions) + _nvmlCheckReturn(ret) + sessions = [] + for i in range(c_session_count.value): + sessions.append(c_sessions[i]) + return sessions + else: + return [] # no active sessions + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetFBCStats(handle): + c_fbcStats = c_nvmlFBCStats_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFBCStats") + ret = fn(handle, byref(c_fbcStats)) + _nvmlCheckReturn(ret) + return c_fbcStats + +def nvmlDeviceGetFBCSessions(handle): + # first call to get the size + c_session_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFBCSessions") + ret = fn(handle, byref(c_session_count), None) + + if (ret == NVML_SUCCESS): + if (c_session_count.value != 0): + # typical case + session_array = c_nvmlFBCSession_t * c_session_count.value + c_sessions = session_array() + + # make the call again + ret = fn(handle, byref(c_session_count), c_sessions) + _nvmlCheckReturn(ret) + sessions = [] + for i in range(c_session_count.value): + sessions.append(c_sessions[i]) + return sessions + else: + return [] # no active sessions + else: + # error case + raise NVMLError(ret) + +def nvmlVgpuInstanceGetEncoderStats(vgpuInstance): + c_encoderCount = c_ulonglong(0) + c_encodeFps = c_ulonglong(0) + c_encoderLatency = c_ulonglong(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetEncoderStats") + ret = fn(vgpuInstance, byref(c_encoderCount), byref(c_encodeFps), byref(c_encoderLatency)) + _nvmlCheckReturn(ret) + return (c_encoderCount.value, c_encodeFps.value, c_encoderLatency.value) + +def nvmlVgpuInstanceGetEncoderSessions(vgpuInstance): + # first call to get the size + c_session_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetEncoderSessions") + ret = fn(vgpuInstance, byref(c_session_count), None) + + if (ret == NVML_SUCCESS): + if (c_session_count.value != 0): + # typical case + session_array = c_nvmlEncoderSession_t * c_session_count.value + c_sessions = session_array() + + # make the call again + ret = fn(vgpuInstance, byref(c_session_count), c_sessions) + _nvmlCheckReturn(ret) + sessions = [] + for i in range(c_session_count.value): + sessions.append(c_sessions[i]) + return sessions + else: + return [] # no active sessions + else: + # error case + raise NVMLError(ret) + +def nvmlVgpuInstanceGetFBCStats(vgpuInstance): + c_fbcStats = c_nvmlFBCStats_t() + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetFBCStats") + ret = fn(vgpuInstance, byref(c_fbcStats)) + _nvmlCheckReturn(ret) + return c_fbcStats + +def nvmlVgpuInstanceGetFBCSessions(vgpuInstance): + # first call to get the size + c_session_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetFBCSessions") + ret = fn(vgpuInstance, byref(c_session_count), None) + + if (ret == NVML_SUCCESS): + if (c_session_count.value != 0): + # typical case + session_array = c_nvmlFBCSession_t * c_session_count.value + c_sessions = session_array() + + # make the call again + ret = fn(vgpuInstance, byref(c_session_count), c_sessions) + _nvmlCheckReturn(ret) + sessions = [] + for i in range(c_session_count.value): + sessions.append(c_sessions[i]) + return sessions + else: + return [] # no active sessions + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetProcessUtilization(handle, timeStamp): + # first call to get the size + c_count = c_uint(0) + c_time_stamp = c_ulonglong(timeStamp) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetProcessUtilization") + ret = fn(handle, None, byref(c_count), c_time_stamp) + + if (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + sampleArray = c_count.value * c_nvmlProcessUtilizationSample_t + c_samples = sampleArray() + + # make the call again + ret = fn(handle, c_samples, byref(c_count), c_time_stamp) + _nvmlCheckReturn(ret) + + return c_samples[0:c_count.value] + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetProcessesUtilizationInfo(handle, timeStamp): + # first call to get the size + c_time_stamp = c_ulonglong(timeStamp) + c_processesUtilInfo = c_nvmlProcessesUtilizationInfo_v1_t(0) + c_processesUtilInfo.version = ProcessesUtilizationInfo_v1 + c_processesUtilInfo.processSamplesCount = c_uint(0) + c_processesUtilInfo.lastSeenTimeStamp = c_time_stamp + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetProcessesUtilizationInfo") + ret = fn(handle, byref(c_processesUtilInfo)) + + if (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + sampleArray = c_processesUtilInfo.processSamplesCount * c_nvmlProcessUtilizationInfo_v1_t + c_samples = sampleArray() + c_processesUtilInfo.procUtilArray = c_samples + + # make the call again + ret = fn(handle, byref(c_processesUtilInfo)) + _nvmlCheckReturn(ret) + + return c_samples[0:c_processesUtilInfo.processSamplesCount] + else: + # error case + raise NVMLError(ret) + +def nvmlVgpuInstanceGetMetadata(vgpuInstance): + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetMetadata") + c_vgpuMetadata = c_nvmlVgpuMetadata_t() + c_bufferSize = c_uint(0) + # Make the first NVML API call to get the c_bufferSize value. + # We have already allocated required buffer above. + ret = fn(vgpuInstance, byref(c_vgpuMetadata), byref(c_bufferSize)) + if (ret == NVML_ERROR_INSUFFICIENT_SIZE): + ret = fn(vgpuInstance, byref(c_vgpuMetadata), byref(c_bufferSize)) + _nvmlCheckReturn(ret) + else: + raise NVMLError(ret) + return c_vgpuMetadata + +def nvmlDeviceGetVgpuMetadata(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuMetadata") + c_vgpuPgpuMetadata = c_nvmlVgpuPgpuMetadata_t() + c_bufferSize = c_uint(0) + # Make the first NVML API call to get the c_bufferSize value. + # We have already allocated required buffer above. + ret = fn(handle, byref(c_vgpuPgpuMetadata), byref(c_bufferSize)) + if (ret == NVML_ERROR_INSUFFICIENT_SIZE): + ret = fn(handle, byref(c_vgpuPgpuMetadata), byref(c_bufferSize)) + _nvmlCheckReturn(ret) + else: + raise NVMLError(ret) + return c_vgpuPgpuMetadata + +def nvmlGetVgpuCompatibility(vgpuMetadata, pgpuMetadata): + fn = _nvmlGetFunctionPointer("nvmlGetVgpuCompatibility") + c_vgpuPgpuCompatibility = c_nvmlVgpuPgpuCompatibility_t() + ret = fn(byref(vgpuMetadata), byref(pgpuMetadata), byref(c_vgpuPgpuCompatibility)) + _nvmlCheckReturn(ret) + return c_vgpuPgpuCompatibility + +@convertStrBytes +def nvmlDeviceGetPgpuMetadataString(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPgpuMetadataString") + c_pgpuMetadata = create_string_buffer(NVML_VGPU_PGPU_METADATA_OPAQUE_DATA_SIZE) + c_bufferSize = c_uint(0) + # Make the first NVML API call to get the c_bufferSize value. + # We have already allocated required buffer above. + ret = fn(handle, byref(c_pgpuMetadata), byref(c_bufferSize)) + if (ret == NVML_ERROR_INSUFFICIENT_SIZE): + ret = fn(handle, byref(c_pgpuMetadata), byref(c_bufferSize)) + _nvmlCheckReturn(ret) + else: + raise NVMLError(ret) + return (c_pgpuMetadata.value, c_bufferSize.value) + +def nvmlDeviceGetVgpuSchedulerLog(handle): + c_vgpu_sched_log = c_nvmlVgpuSchedulerLog_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuSchedulerLog") + ret = fn(handle, byref(c_vgpu_sched_log)) + _nvmlCheckReturn(ret) + return c_vgpu_sched_log + +def nvmlDeviceGetVgpuSchedulerState(handle): + c_vgpu_sched_state = c_nvmlVgpuSchedulerGetState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuSchedulerState") + ret = fn(handle, byref(c_vgpu_sched_state)) + _nvmlCheckReturn(ret) + return c_vgpu_sched_state + +def nvmlDeviceGetVgpuSchedulerCapabilities(handle): + c_vgpu_sched_caps = c_nvmlVgpuSchedulerCapabilities_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuSchedulerCapabilities") + ret = fn(handle, byref(c_vgpu_sched_caps)) + _nvmlCheckReturn(ret) + return c_vgpu_sched_caps + +def nvmlDeviceSetVgpuSchedulerState(handle, sched_state): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetVgpuSchedulerState") + ret = fn(handle, byref(sched_state)) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlSetVgpuVersion(vgpuVersion): + fn = _nvmlGetFunctionPointer("nvmlSetVgpuVersion") + ret = fn(byref(vgpuVersion)) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlGetVgpuVersion(supported=None, current=None): + isUserDefined = (supported is not None) or (current is not None) + if not isUserDefined: + supported = c_nvmlVgpuVersion_t() + current = c_nvmlVgpuVersion_t() + fn = _nvmlGetFunctionPointer("nvmlGetVgpuVersion") + ret = fn(byref(supported), byref(current)) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isUserDefined else [(supported.minVersion, + supported.maxVersion), + (current.minVersion, + current.maxVersion)] + +def nvmlVgpuInstanceGetAccountingMode(vgpuInstance): + c_mode = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetAccountingMode") + ret = fn(vgpuInstance, byref(c_mode)) + _nvmlCheckReturn(ret) + return c_mode.value + +def nvmlVgpuInstanceGetAccountingPids(vgpuInstance): + c_pidCount = c_uint() + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetAccountingPids") + ret = fn(vgpuInstance, byref(c_pidCount), None) + if (ret == NVML_ERROR_INSUFFICIENT_SIZE): + sampleArray = c_pidCount.value * c_uint + c_pidArray = sampleArray() + ret = fn(vgpuInstance, byref(c_pidCount), byref(c_pidArray)) + _nvmlCheckReturn(ret) + else: + raise NVMLError(ret) + return (c_pidCount, c_pidArray) + +def nvmlVgpuInstanceGetAccountingStats(vgpuInstance, pid): + c_accountingStats = c_nvmlAccountingStats_t() + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetAccountingStats") + ret = fn(vgpuInstance, pid, byref(c_accountingStats)) + _nvmlCheckReturn(ret) + return c_accountingStats + +def nvmlVgpuInstanceClearAccountingPids(vgpuInstance): + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceClearAccountingPids") + ret = fn(vgpuInstance) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlGetExcludedDeviceCount(): + c_count = c_uint() + fn = _nvmlGetFunctionPointer("nvmlGetExcludedDeviceCount") + ret = fn(byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + +def nvmlGetExcludedDeviceInfoByIndex(index): + c_index = c_uint(index) + info = c_nvmlExcludedDeviceInfo_t() + fn = _nvmlGetFunctionPointer("nvmlGetExcludedDeviceInfoByIndex") + ret = fn(c_index, byref(info)) + _nvmlCheckReturn(ret) + return info + +def nvmlDeviceGetHostVgpuMode(handle): + c_host_vgpu_mode = _nvmlHostVgpuMode_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetHostVgpuMode") + ret = fn(handle, byref(c_host_vgpu_mode)) + _nvmlCheckReturn(ret) + return c_host_vgpu_mode.value + +def nvmlDeviceSetMigMode(device, mode): + c_activationStatus = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceSetMigMode") + ret = fn(device, mode, byref(c_activationStatus)) + _nvmlCheckReturn(ret) + return c_activationStatus.value + +def nvmlDeviceGetMigMode(device): + c_currentMode = c_uint() + c_pendingMode = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMigMode") + ret = fn(device, byref(c_currentMode), byref(c_pendingMode)) + _nvmlCheckReturn(ret) + return [c_currentMode.value, c_pendingMode.value] + +def nvmlDeviceGetGpuInstanceProfileInfo(device, profile, version=2): + if version == 2: + c_info = c_nvmlGpuInstanceProfileInfo_v2_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstanceProfileInfoV") + elif version == 1: + c_info = c_nvmlGpuInstanceProfileInfo_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstanceProfileInfo") + else: + raise NVMLError(NVML_ERROR_FUNCTION_NOT_FOUND) + ret = fn(device, profile, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + +# Define function alias for the API exposed by NVML +nvmlDeviceGetGpuInstanceProfileInfoV = nvmlDeviceGetGpuInstanceProfileInfo + +def nvmlDeviceGetGpuInstanceRemainingCapacity(device, profileId): + c_count = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstanceRemainingCapacity") + ret = fn(device, profileId, byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + +def nvmlDeviceGetGpuInstancePossiblePlacements(device, profileId, placementsRef, countRef): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstancePossiblePlacements_v2") + ret = fn(device, profileId, placementsRef, countRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceCreateGpuInstance(device, profileId): + c_instance = c_nvmlGpuInstance_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceCreateGpuInstance") + ret = fn(device, profileId, byref(c_instance)) + _nvmlCheckReturn(ret) + return c_instance + +def nvmlDeviceCreateGpuInstanceWithPlacement(device, profileId, placement): + c_instance = c_nvmlGpuInstance_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceCreateGpuInstanceWithPlacement") + ret = fn(device, profileId, placement, byref(c_instance)) + _nvmlCheckReturn(ret) + return c_instance + +def nvmlGpuInstanceDestroy(gpuInstance): + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceDestroy") + ret = fn(gpuInstance) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceGetGpuInstances(device, profileId, gpuInstancesRef, countRef): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstances") + ret = fn(device, profileId, gpuInstancesRef, countRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceGetGpuInstanceById(device, gpuInstanceId): + c_instance = c_nvmlGpuInstance_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstanceById") + ret = fn(device, gpuInstanceId, byref(c_instance)) + _nvmlCheckReturn(ret) + return c_instance + +def nvmlGpuInstanceGetInfo(gpuInstance): + c_info = c_nvmlGpuInstanceInfo_t() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetInfo") + ret = fn(gpuInstance, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + +def nvmlGpuInstanceGetComputeInstanceProfileInfo(device, profile, engProfile, version=2): + if version == 2: + c_info = c_nvmlComputeInstanceProfileInfo_v2_t() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetComputeInstanceProfileInfoV") + elif version == 1: + c_info = c_nvmlComputeInstanceProfileInfo_t() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetComputeInstanceProfileInfo") + else: + raise NVMLError(NVML_ERROR_FUNCTION_NOT_FOUND) + ret = fn(device, profile, engProfile, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + +# Define function alias for the API exposed by NVML +nvmlGpuInstanceGetComputeInstanceProfileInfoV = nvmlGpuInstanceGetComputeInstanceProfileInfo + +def nvmlGpuInstanceGetComputeInstanceRemainingCapacity(gpuInstance, profileId): + c_count = c_uint() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetComputeInstanceRemainingCapacity") + ret = fn(gpuInstance, profileId, byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + +def nvmlGpuInstanceGetComputeInstancePossiblePlacements(gpuInstance, profileId, placementsRef, countRef): + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetComputeInstancePossiblePlacements") + ret = fn(gpuInstance, profileId, placementsRef, countRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlGpuInstanceCreateComputeInstance(gpuInstance, profileId): + c_instance = c_nvmlComputeInstance_t() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceCreateComputeInstance") + ret = fn(gpuInstance, profileId, byref(c_instance)) + _nvmlCheckReturn(ret) + return c_instance + +def nvmlGpuInstanceCreateComputeInstanceWithPlacement(gpuInstance, profileId, placement): + c_instance = c_nvmlComputeInstance_t() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceCreateComputeInstanceWithPlacement") + ret = fn(gpuInstance, profileId, placement, byref(c_instance)) + _nvmlCheckReturn(ret) + return c_instance + +def nvmlComputeInstanceDestroy(computeInstance): + fn = _nvmlGetFunctionPointer("nvmlComputeInstanceDestroy") + ret = fn(computeInstance) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlGpuInstanceGetComputeInstances(gpuInstance, profileId, computeInstancesRef, countRef): + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetComputeInstances") + ret = fn(gpuInstance, profileId, computeInstancesRef, countRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlGpuInstanceGetComputeInstanceById(gpuInstance, computeInstanceId): + c_instance = c_nvmlComputeInstance_t() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetComputeInstanceById") + ret = fn(gpuInstance, computeInstanceId, byref(c_instance)) + _nvmlCheckReturn(ret) + return c_instance + +def nvmlComputeInstanceGetInfo_v2(computeInstance): + c_info = c_nvmlComputeInstanceInfo_t() + fn = _nvmlGetFunctionPointer("nvmlComputeInstanceGetInfo_v2") + ret = fn(computeInstance, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + +def nvmlComputeInstanceGetInfo(computeInstance): + return nvmlComputeInstanceGetInfo_v2(computeInstance) + +def nvmlDeviceIsMigDeviceHandle(device): + c_isMigDevice = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceIsMigDeviceHandle") + ret = fn(device, byref(c_isMigDevice)) + _nvmlCheckReturn(ret) + return c_isMigDevice + +def nvmlDeviceGetGpuInstanceId(device): + c_gpuInstanceId = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstanceId") + ret = fn(device, byref(c_gpuInstanceId)) + _nvmlCheckReturn(ret) + return c_gpuInstanceId.value + +def nvmlDeviceGetComputeInstanceId(device): + c_computeInstanceId = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetComputeInstanceId") + ret = fn(device, byref(c_computeInstanceId)) + _nvmlCheckReturn(ret) + return c_computeInstanceId.value + +def nvmlDeviceGetMaxMigDeviceCount(device): + c_count = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMaxMigDeviceCount") + ret = fn(device, byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + +def nvmlDeviceGetMigDeviceHandleByIndex(device, index): + c_index = c_uint(index) + migDevice = c_nvmlDevice_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMigDeviceHandleByIndex") + ret = fn(device, c_index, byref(migDevice)) + _nvmlCheckReturn(ret) + return migDevice + +def nvmlDeviceGetDeviceHandleFromMigDeviceHandle(migDevice): + device = c_nvmlDevice_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDeviceHandleFromMigDeviceHandle") + ret = fn(migDevice, byref(device)) + _nvmlCheckReturn(ret) + return device + +def nvmlDeviceGetAttributes_v2(device): + c_attrs = c_nvmlDeviceAttributes() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAttributes_v2") + ret = fn(device, byref(c_attrs)) + _nvmlCheckReturn(ret) + return c_attrs + +def nvmlDeviceGetAttributes(device): + return nvmlDeviceGetAttributes_v2(device) + +def nvmlDeviceGetRemappedRows(device): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetRemappedRows") + c_corr = c_uint() + c_unc = c_uint() + c_bpending = c_uint() + c_bfailure = c_uint() + ret = fn(device, byref(c_corr), byref(c_unc), byref(c_bpending), byref(c_bfailure)) + _nvmlCheckReturn(ret) + return (c_corr.value, c_unc.value, c_bpending.value, c_bfailure.value) + +def nvmlDeviceGetRowRemapperHistogram(device): + c_vals = c_nvmlRowRemapperHistogramValues() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetRowRemapperHistogram") + ret = fn(device, byref(c_vals)) + _nvmlCheckReturn(ret) + return c_vals + +def nvmlDeviceGetArchitecture(device): + arch = _nvmlDeviceArchitecture_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetArchitecture") + ret = fn(device, byref(arch)) + _nvmlCheckReturn(ret) + return arch.value + +def nvmlDeviceGetBusType(device): + c_busType = _nvmlBusType_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetBusType") + ret = fn(device, byref(c_busType)) + _nvmlCheckReturn(ret) + return c_busType.value + +def nvmlDeviceGetIrqNum(device): + c_irqNum = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetIrqNum") + ret = fn(device, byref(c_irqNum)) + _nvmlCheckReturn(ret) + return c_irqNum.value + +def nvmlDeviceGetNumGpuCores(device): + c_numCores = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNumGpuCores") + ret = fn(device, byref(c_numCores)) + _nvmlCheckReturn(ret) + return c_numCores.value + +def nvmlDeviceGetPowerSource(device): + c_powerSource = _nvmlPowerSource_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerSource") + ret = fn(device, byref(c_powerSource)) + _nvmlCheckReturn(ret) + return c_powerSource.value + +def nvmlDeviceGetMemoryBusWidth(device): + c_memBusWidth = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemoryBusWidth") + ret = fn(device, byref(c_memBusWidth)) + _nvmlCheckReturn(ret) + return c_memBusWidth.value + +def nvmlDeviceGetPcieLinkMaxSpeed(device): + c_speed = _nvmlPcieLinkMaxSpeed_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPcieLinkMaxSpeed") + ret = fn(device, byref(c_speed)) + _nvmlCheckReturn(ret) + return c_speed.value + +def nvmlDeviceGetAdaptiveClockInfoStatus(device): + c_adaptiveClockInfoStatus = _nvmlAdaptiveClockInfoStatus_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAdaptiveClockInfoStatus") + ret = fn(device, byref(c_adaptiveClockInfoStatus)) + _nvmlCheckReturn(ret) + return c_adaptiveClockInfoStatus.value + +def nvmlDeviceGetPcieSpeed(device): + c_speed = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPcieSpeed") + ret = fn(device, byref(c_speed)) + _nvmlCheckReturn(ret) + return c_speed.value + +def nvmlDeviceGetDynamicPstatesInfo(device, c_dynamicpstatesinfo=c_nvmlGpuDynamicPstatesInfo_t()): + isReference = type(c_dynamicpstatesinfo) is not c_nvmlGpuDynamicPstatesInfo_t + dynamicpstatesinfoRef = c_dynamicpstatesinfo if isReference else byref(c_dynamicpstatesinfo) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDynamicPstatesInfo"); + ret = fn(device, dynamicpstatesinfoRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else c_dynamicpstatesinfo + +def nvmlDeviceSetFanSpeed_v2(handle, index, speed): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetFanSpeed_v2"); + ret = fn(handle, index, speed) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceGetThermalSettings(device, sensorindex, c_thermalsettings=c_nvmlGpuThermalSettings_t()): + isReference = type(c_thermalsettings) is not c_nvmlGpuThermalSettings_t + thermalsettingsRef = c_thermalsettings if isReference else byref(c_thermalsettings) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetThermalSettings"); + ret = fn(device, sensorindex, thermalsettingsRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else c_thermalsettings.sensor[:] + +def nvmlDeviceGetMinMaxClockOfPState(device, clockType, pstate, minClockMHz=c_uint(), maxClockMHz=c_uint()): + isReference = (type(minClockMHz) is not c_uint) or (type(maxClockMHz) is not c_uint) + minClockMHzRef = minClockMHz if isReference else byref(minClockMHz) + maxClockMHzRef = maxClockMHz if isReference else byref(maxClockMHz) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMinMaxClockOfPState"); + ret = fn(device, _nvmlClockType_t(clockType), _nvmlClockType_t(pstate), minClockMHzRef, maxClockMHzRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else (minClockMHz.value, maxClockMHz.value) + +class c_nvmlClockOffset_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('type', _nvmlClockType_t), + ('pstate', _nvmlPstates_t), + ('clockOffsetMHz', c_int), + ('minClockOffsetMHz', c_int), + ('maxClockOffsetMHz', c_int), + ] + +nvmlClockOffset_v1 = 0x1000018 + +def nvmlDeviceGetClockOffsets(device, info): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetClockOffsets"); + ret = fn(device, info) + return NVML_SUCCESS + +def nvmlDeviceSetClockOffsets(device, info): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetClockOffsets"); + ret = fn(device, info) + return NVML_SUCCESS + +def nvmlDeviceGetSupportedPerformanceStates(device): + pstates = [] + c_count = c_uint(NVML_MAX_GPU_PERF_PSTATES) + c_size = sizeof(c_uint)*c_count.value + + # NOTE: use 'c_uint' to represent the size of the nvmlPstate_t enumeration. + pstates_array = _nvmlPstates_t * c_count.value + c_pstates = pstates_array() + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedPerformanceStates") + ret = fn(device, c_pstates, c_size) + _nvmlCheckReturn(ret) + + for value in c_pstates: + if value != NVML_PSTATE_UNKNOWN: + pstates.append(value) + + return pstates + +def nvmlDeviceGetGpcClkVfOffset(device): + offset = c_int32() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpcClkVfOffset") + ret = fn(device, byref(offset)) + _nvmlCheckReturn(ret) + return offset.value + +def nvmlDeviceSetGpcClkVfOffset(device, offset): + c_offset = c_int32(offset) + fn = _nvmlGetFunctionPointer("nvmlDeviceSetGpcClkVfOffset") + ret = fn(device, c_offset) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceGetGpcClkMinMaxVfOffset(device, minOffset=c_int(), maxOffset=c_int()): + isReference = (type(minOffset) is not c_int) or (type(maxOffset) is not c_int) + minOffsetRef = minOffset if isReference else byref(minOffset) + maxOffsetRef = maxOffset if isReference else byref(maxOffset) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpcClkMinMaxVfOffset") + ret = fn(device, minOffsetRef, maxOffsetRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else (minOffset.value, maxOffset.value) + +def nvmlDeviceGetMemClkVfOffset(device): + offset = c_int32() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemClkVfOffset") + ret = fn(device, byref(offset)) + _nvmlCheckReturn(ret) + return offset.value + +def nvmlDeviceSetMemClkVfOffset(device, offset): + c_offset = c_int32(offset) + fn = _nvmlGetFunctionPointer("nvmlDeviceSetMemClkVfOffset") + ret = fn(device, c_offset) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceGetMemClkMinMaxVfOffset(device, minOffset=c_int(), maxOffset=c_int()): + isReference = (type(minOffset) is not c_int) or (type(maxOffset) is not c_int) + minOffsetRef = minOffset if isReference else byref(minOffset) + maxOffsetRef = maxOffset if isReference else byref(maxOffset) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemClkMinMaxVfOffset") + ret = fn(device, minOffsetRef, maxOffsetRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else (minOffset.value, maxOffset.value) + +def nvmlSystemSetConfComputeGpusReadyState(state): + c_state = c_uint(state) + fn = _nvmlGetFunctionPointer("nvmlSystemSetConfComputeGpusReadyState") + ret = fn(c_state) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlSystemGetConfComputeGpusReadyState(): + c_state = c_uint() + fn = _nvmlGetFunctionPointer("nvmlSystemGetConfComputeGpusReadyState") + ret = fn(byref(c_state)) + _nvmlCheckReturn(ret) + return c_state.value + +def nvmlSystemGetConfComputeCapabilities(): + c_ccSysCaps = c_nvmlConfComputeSystemCaps_t() + fn = _nvmlGetFunctionPointer("nvmlSystemGetConfComputeCapabilities") + ret = fn(byref(c_ccSysCaps)) + _nvmlCheckReturn(ret) + return c_ccSysCaps + +def nvmlSystemGetConfComputeState(): + c_state = c_nvmlConfComputeSystemState_t() + fn = _nvmlGetFunctionPointer("nvmlSystemGetConfComputeState") + ret = fn(byref(c_state)) + _nvmlCheckReturn(ret) + return c_state + +def nvmlSystemGetConfComputeSettings(settings): + fn = _nvmlGetFunctionPointer("nvmlSystemGetConfComputeSettings") + return fn(settings) + +def nvmlDeviceSetConfComputeUnprotectedMemSize(device, c_ccMemSize): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetConfComputeUnprotectedMemSize") + ret = fn(device, c_ccMemSize) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceGetConfComputeMemSizeInfo(device): + c_ccMemSize = c_nvmlConfComputeMemSizeInfo_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetConfComputeMemSizeInfo") + ret = fn(device, byref(c_ccMemSize)) + _nvmlCheckReturn(ret) + return c_ccMemSize + +def nvmlDeviceGetConfComputeProtectedMemoryUsage(device): + c_memory = c_nvmlMemory_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetConfComputeProtectedMemoryUsage") + ret = fn(device, byref(c_memory)) + _nvmlCheckReturn(ret) + return c_memory + +def nvmlDeviceGetConfComputeGpuCertificate(device): + c_cert = c_nvmlConfComputeGpuCertificate_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetConfComputeGpuCertificate") + ret = fn(device, byref(c_cert)) + _nvmlCheckReturn(ret) + return c_cert + +def nvmlDeviceGetConfComputeGpuAttestationReport(device, c_nonce): + c_attestReport = c_nvmlConfComputeGpuAttestationReport_t() + c_nonce_arr = (c_uint8 * len(c_nonce))(*(c_nonce)) + setattr(c_attestReport, 'nonce', c_nonce_arr) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetConfComputeGpuAttestationReport") + ret = fn(device, byref(c_attestReport)) + _nvmlCheckReturn(ret) + return c_attestReport + +def nvmlSystemSetConfComputeKeyRotationThresholdInfo(max_atk_adv): + c_keyRotationThrInfo = c_nvmlConfComputeSetKeyRotationThresholdInfo_t(0) + c_keyRotationThrInfo.version = ConfComputeSetKeyRotationThresholdInfo_v1 + c_keyRotationThrInfo.maxAttackerAdvantage = max_atk_adv + fn = _nvmlGetFunctionPointer("nvmlSystemSetConfComputeKeyRotationThresholdInfo") + ret = fn(byref(c_keyRotationThrInfo)) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlSystemGetConfComputeKeyRotationThresholdInfo(): + c_keyRotationThrInfo = c_nvmlConfComputeGetKeyRotationThresholdInfo_t(0) + c_keyRotationThrInfo.version = ConfComputeGetKeyRotationThresholdInfo_v1 + fn = _nvmlGetFunctionPointer("nvmlSystemGetConfComputeKeyRotationThresholdInfo") + ret = fn(byref(c_keyRotationThrInfo)) + _nvmlCheckReturn(ret) + return c_keyRotationThrInfo + +## GPM ## +######### + +## Enums/defines + +#### GPM Metric Identifiers +NVML_GPM_METRIC_GRAPHICS_UTIL = 1 # Percentage of time any compute/graphics app was active on the GPU. 0.0 - 100.0 +NVML_GPM_METRIC_SM_UTIL = 2 # Percentage of SMs that were busy. 0.0 - 100.0 +NVML_GPM_METRIC_SM_OCCUPANCY = 3 # Percentage of warps that were active vs theoretical maximum. 0.0 - 100.0 +NVML_GPM_METRIC_INTEGER_UTIL = 4 # Percentage of time the GPU's SMs were doing integer operations. 0.0 - 100.0 +NVML_GPM_METRIC_ANY_TENSOR_UTIL = 5 # Percentage of time the GPU's SMs were doing ANY tensor operations. 0.0 - 100.0 +NVML_GPM_METRIC_DFMA_TENSOR_UTIL = 6 # Percentage of time the GPU's SMs were doing DFMA tensor operations. 0.0 - 100.0 +NVML_GPM_METRIC_HMMA_TENSOR_UTIL = 7 # Percentage of time the GPU's SMs were doing HMMA tensor operations. 0.0 - 100.0 +NVML_GPM_METRIC_IMMA_TENSOR_UTIL = 9 # Percentage of time the GPU's SMs were doing IMMA tensor operations. 0.0 - 100.0 +NVML_GPM_METRIC_DRAM_BW_UTIL = 10 # Percentage of DRAM bw used vs theoretical maximum. 0.0 - 100.0 +NVML_GPM_METRIC_FP64_UTIL = 11 # Percentage of time the GPU's SMs were doing non-tensor FP64 math. 0.0 - 100.0 +NVML_GPM_METRIC_FP32_UTIL = 12 # Percentage of time the GPU's SMs were doing non-tensor FP32 math. 0.0 - 100.0 +NVML_GPM_METRIC_FP16_UTIL = 13 # Percentage of time the GPU's SMs were doing non-tensor FP16 math. 0.0 - 100.0 +NVML_GPM_METRIC_PCIE_TX_PER_SEC = 20 # PCIe traffic from this GPU in MiB/sec +NVML_GPM_METRIC_PCIE_RX_PER_SEC = 21 # PCIe traffic to this GPU in MiB/sec +NVML_GPM_METRIC_NVDEC_0_UTIL = 30 # Percent utilization of NVDEC 0. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_1_UTIL = 31 # Percent utilization of NVDEC 1. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_2_UTIL = 32 # Percent utilization of NVDEC 2. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_3_UTIL = 33 # Percent utilization of NVDEC 3. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_4_UTIL = 34 # Percent utilization of NVDEC 4. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_5_UTIL = 35 # Percent utilization of NVDEC 5. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_6_UTIL = 36 # Percent utilization of NVDEC 6. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_7_UTIL = 37 # Percent utilization of NVDEC 7. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_0_UTIL = 40 # Percent utilization of NVJPG 0. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_1_UTIL = 41 # Percent utilization of NVJPG 1. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_2_UTIL = 42 # Percent utilization of NVJPG 2. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_3_UTIL = 43 # Percent utilization of NVJPG 3. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_4_UTIL = 44 # Percent utilization of NVJPG 4. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_5_UTIL = 45 # Percent utilization of NVJPG 5. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_6_UTIL = 46 # Percent utilization of NVJPG 6. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_7_UTIL = 47 # Percent utilization of NVJPG 7. 0.0 - 100.0 +NVML_GPM_METRIC_NVOFA_0_UTIL = 50 # Percent utilization of NVOFA 0. 0.0 - 100.0 +NVML_GPM_METRIC_NVOFA_1_UTIL = 51 # Percent utilization of NVOFA 1. 0.0 - 100.0 +NVML_GPM_METRIC_NVLINK_TOTAL_RX_PER_SEC = 60 # NvLink read bandwidth for all links in MiB/sec +NVML_GPM_METRIC_NVLINK_TOTAL_TX_PER_SEC = 61 # NvLink write bandwidth for all links in MiB/sec +NVML_GPM_METRIC_NVLINK_L0_RX_PER_SEC = 62 # NvLink read bandwidth for link 0 in MiB/sec +NVML_GPM_METRIC_NVLINK_L0_TX_PER_SEC = 63 # NvLink write bandwidth for link 0 in MiB/sec +NVML_GPM_METRIC_NVLINK_L1_RX_PER_SEC = 64 # NvLink read bandwidth for link 1 in MiB/sec +NVML_GPM_METRIC_NVLINK_L1_TX_PER_SEC = 65 # NvLink write bandwidth for link 1 in MiB/sec +NVML_GPM_METRIC_NVLINK_L2_RX_PER_SEC = 66 # NvLink read bandwidth for link 2 in MiB/sec +NVML_GPM_METRIC_NVLINK_L2_TX_PER_SEC = 67 # NvLink write bandwidth for link 2 in MiB/sec +NVML_GPM_METRIC_NVLINK_L3_RX_PER_SEC = 68 # NvLink read bandwidth for link 3 in MiB/sec +NVML_GPM_METRIC_NVLINK_L3_TX_PER_SEC = 69 # NvLink write bandwidth for link 3 in MiB/sec +NVML_GPM_METRIC_NVLINK_L4_RX_PER_SEC = 70 # NvLink read bandwidth for link 4 in MiB/sec +NVML_GPM_METRIC_NVLINK_L4_TX_PER_SEC = 71 # NvLink write bandwidth for link 4 in MiB/sec +NVML_GPM_METRIC_NVLINK_L5_RX_PER_SEC = 72 # NvLink read bandwidth for link 5 in MiB/sec +NVML_GPM_METRIC_NVLINK_L5_TX_PER_SEC = 73 # NvLink write bandwidth for link 5 in MiB/sec +NVML_GPM_METRIC_NVLINK_L6_RX_PER_SEC = 74 # NvLink read bandwidth for link 6 in MiB/sec +NVML_GPM_METRIC_NVLINK_L6_TX_PER_SEC = 75 # NvLink write bandwidth for link 6 in MiB/sec +NVML_GPM_METRIC_NVLINK_L7_RX_PER_SEC = 76 # NvLink read bandwidth for link 7 in MiB/sec +NVML_GPM_METRIC_NVLINK_L7_TX_PER_SEC = 77 # NvLink write bandwidth for link 7 in MiB/sec +NVML_GPM_METRIC_NVLINK_L8_RX_PER_SEC = 78 # NvLink read bandwidth for link 8 in MiB/sec +NVML_GPM_METRIC_NVLINK_L8_TX_PER_SEC = 79 # NvLink write bandwidth for link 8 in MiB/sec +NVML_GPM_METRIC_NVLINK_L9_RX_PER_SEC = 80 # NvLink read bandwidth for link 9 in MiB/sec +NVML_GPM_METRIC_NVLINK_L9_TX_PER_SEC = 81 # NvLink write bandwidth for link 9 in MiB/sec +NVML_GPM_METRIC_NVLINK_L10_RX_PER_SEC = 82 # NvLink read bandwidth for link 10 in MiB/sec +NVML_GPM_METRIC_NVLINK_L10_TX_PER_SEC = 83 # NvLink write bandwidth for link 10 in MiB/sec +NVML_GPM_METRIC_NVLINK_L11_RX_PER_SEC = 84 # NvLink read bandwidth for link 11 in MiB/sec +NVML_GPM_METRIC_NVLINK_L11_TX_PER_SEC = 85 # NvLink write bandwidth for link 11 in MiB/sec +NVML_GPM_METRIC_NVLINK_L12_RX_PER_SEC = 86 # NvLink read bandwidth for link 12 in MiB/sec +NVML_GPM_METRIC_NVLINK_L12_TX_PER_SEC = 87 # NvLink write bandwidth for link 12 in MiB/sec +NVML_GPM_METRIC_NVLINK_L13_RX_PER_SEC = 88 # NvLink read bandwidth for link 13 in MiB/sec +NVML_GPM_METRIC_NVLINK_L13_TX_PER_SEC = 89 # NvLink write bandwidth for link 13 in MiB/sec +NVML_GPM_METRIC_NVLINK_L14_RX_PER_SEC = 90 # NvLink read bandwidth for link 14 in MiB/sec +NVML_GPM_METRIC_NVLINK_L14_TX_PER_SEC = 91 # NvLink write bandwidth for link 14 in MiB/sec +NVML_GPM_METRIC_NVLINK_L15_RX_PER_SEC = 92 # NvLink read bandwidth for link 15 in MiB/sec +NVML_GPM_METRIC_NVLINK_L15_TX_PER_SEC = 93 # NvLink write bandwidth for link 15 in MiB/sec +NVML_GPM_METRIC_NVLINK_L16_RX_PER_SEC = 94 # NvLink read bandwidth for link 16 in MiB/sec +NVML_GPM_METRIC_NVLINK_L16_TX_PER_SEC = 95 # NvLink write bandwidth for link 16 in MiB/sec +NVML_GPM_METRIC_NVLINK_L17_RX_PER_SEC = 96 # NvLink read bandwidth for link 17 in MiB/sec +NVML_GPM_METRIC_NVLINK_L17_TX_PER_SEC = 97 # NvLink write bandwidth for link 17 in MiB/sec +NVML_GPM_METRIC_MAX = 98 + +## Structs + +class c_nvmlUnitInfo_t(_PrintableStructure): + _fields_ = [ + ('name', c_char * 96), + ('id', c_char * 96), + ('serial', c_char * 96), + ('firmwareVersion', c_char * 96), + ] + +class struct_c_nvmlGpmSample_t(Structure): + pass # opaque handle +c_nvmlGpmSample_t = POINTER(struct_c_nvmlGpmSample_t) + +class c_metricInfo_t(Structure): + _fields_ = [ + ("shortName", c_char_p), + ("longName", c_char_p), + ("unit", c_char_p), + ] + +class c_nvmlGpmMetric_t(_PrintableStructure): + _fields_ = [ + ('metricId', c_uint), + ('nvmlReturn', _nvmlReturn_t), + ('value', c_double), + ('metricInfo', c_metricInfo_t) + ] + +class c_nvmlGpmMetricsGet_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('numMetrics', c_uint), + ('sample1', c_nvmlGpmSample_t), + ('sample2', c_nvmlGpmSample_t), + ('metrics', c_nvmlGpmMetric_t * NVML_GPM_METRIC_MAX) + ] + +NVML_GPM_METRICS_GET_VERSION = 1 + +class c_nvmlGpmSupport_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('isSupportedDevice', c_uint), + ] + +NVML_GPM_SUPPORT_VERSION = 1 + +## Functions + +def nvmlGpmMetricsGet(metricsGet): + fn = _nvmlGetFunctionPointer("nvmlGpmMetricsGet") + ret = fn(byref(metricsGet)) + _nvmlCheckReturn(ret) + return metricsGet + +def nvmlGpmSampleFree(gpmSample): + fn = _nvmlGetFunctionPointer("nvmlGpmSampleFree") + ret = fn(gpmSample) + _nvmlCheckReturn(ret) + return + +def nvmlGpmSampleAlloc(): + gpmSample = c_nvmlGpmSample_t() + fn = _nvmlGetFunctionPointer("nvmlGpmSampleAlloc") + ret = fn(byref(gpmSample)) + _nvmlCheckReturn(ret) + return gpmSample + +def nvmlGpmSampleGet(device, gpmSample): + fn = _nvmlGetFunctionPointer("nvmlGpmSampleGet") + ret = fn(device, gpmSample) + _nvmlCheckReturn(ret) + return gpmSample + +def nvmlGpmMigSampleGet(device, gpuInstanceId, gpmSample): + fn = _nvmlGetFunctionPointer("nvmlGpmMigSampleGet") + ret = fn(device, gpuInstanceId, gpmSample) + _nvmlCheckReturn(ret) + return gpmSample + +def nvmlGpmQueryDeviceSupport(device): + gpmSupport = c_nvmlGpmSupport_t() + gpmSupport.version = NVML_GPM_SUPPORT_VERSION + fn = _nvmlGetFunctionPointer("nvmlGpmQueryDeviceSupport") + ret = fn(device, byref(gpmSupport)) + _nvmlCheckReturn(ret) + return gpmSupport + +def nvmlGpmSetStreamingEnabled(device, state): + c_state = c_uint(state) + fn = _nvmlGetFunctionPointer("nvmlGpmSetStreamingEnabled") + ret = fn(device, c_state) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlGpmQueryIfStreamingEnabled(device): + c_state = c_uint() + fn = _nvmlGetFunctionPointer("nvmlGpmQueryIfStreamingEnabled") + ret = fn(device, byref(c_state)) + _nvmlCheckReturn(ret) + return c_state.value + +# Low Power Structure and Function + +NVML_NVLINK_POWER_STATE_HIGH_SPEED = 0x0 +NVML_NVLINK_POWER_STATE_LOW = 0x1 + +NVML_NVLINK_LOW_POWER_THRESHOLD_MIN = 0x1 +NVML_NVLINK_LOW_POWER_THRESHOLD_MAX = 0x1FFF +NVML_NVLINK_LOW_POWER_THRESHOLD_RESET = 0xFFFFFFFF +NVML_NVLINK_LOW_POWER_THRESHOLD_DEFAULT = NVML_NVLINK_LOW_POWER_THRESHOLD_RESET + +class c_nvmlNvLinkPowerThres_t(Structure): + _fields_ = [ + ("lowPwrThreshold", c_uint), + ] + +def nvmlDeviceSetNvLinkDeviceLowPowerThreshold(device, l1threshold): + c_info = c_nvmlNvLinkPowerThres_t() + c_info.lowPwrThreshold = l1threshold + fn = _nvmlGetFunctionPointer("nvmlDeviceSetNvLinkDeviceLowPowerThreshold") + ret = fn(device, byref(c_info)) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +NVML_GPU_FABRIC_UUID_LEN = 16 + +_nvmlGpuFabricState_t = c_uint +NVML_GPU_FABRIC_STATE_NOT_SUPPORTED = 0 +NVML_GPU_FABRIC_STATE_NOT_STARTED = 1 +NVML_GPU_FABRIC_STATE_IN_PROGRESS = 2 +NVML_GPU_FABRIC_STATE_COMPLETED = 3 + +class c_nvmlGpuFabricInfo_t(_PrintableStructure): + _fields_ = [ + ("clusterUuid", c_char * NVML_DEVICE_UUID_BUFFER_SIZE), + ("status", _nvmlReturn_t), + ("cliqueId", c_uint32), + ("state", _nvmlGpuFabricState_t) + ] + +NVML_GPU_FABRIC_HEALTH_MASK_DEGRADED_BW_NOT_SUPPORTED = 0 +NVML_GPU_FABRIC_HEALTH_MASK_DEGRADED_BW_TRUE = 1 +NVML_GPU_FABRIC_HEALTH_MASK_DEGRADED_BW_FALSE = 2 +NVML_GPU_FABRIC_HEALTH_MASK_SHIFT_DEGRADED_BW = 0 +NVML_GPU_FABRIC_HEALTH_MASK_WIDTH_DEGRADED_BW = 0x11 + +NVML_GPU_FABRIC_HEALTH_MASK_ROUTE_RECOVERY_NOT_SUPPORTED = 0 +NVML_GPU_FABRIC_HEALTH_MASK_ROUTE_RECOVERY_TRUE = 1 +NVML_GPU_FABRIC_HEALTH_MASK_ROUTE_RECOVERY_FALSE = 2 +NVML_GPU_FABRIC_HEALTH_MASK_SHIFT_ROUTE_RECOVERY = 2 +NVML_GPU_FABRIC_HEALTH_MASK_WIDTH_ROUTE_RECOVERY = 0x11 + +NVML_GPU_FABRIC_HEALTH_MASK_ROUTE_UNHEALTHY_NOT_SUPPORTED = 0 +NVML_GPU_FABRIC_HEALTH_MASK_ROUTE_UNHEALTHY_TRUE = 1 +NVML_GPU_FABRIC_HEALTH_MASK_ROUTE_UNHEALTHY_FALSE = 2 +NVML_GPU_FABRIC_HEALTH_MASK_SHIFT_ROUTE_UNHEALTHY = 4 +NVML_GPU_FABRIC_HEALTH_MASK_WIDTH_ROUTE_UNHEALTHY = 0x11 + +NVML_GPU_FABRIC_HEALTH_MASK_ACCESS_TIMEOUT_RECOVERY_NOT_SUPPORTED = 0 +NVML_GPU_FABRIC_HEALTH_MASK_ACCESS_TIMEOUT_RECOVERY_TRUE = 1 +NVML_GPU_FABRIC_HEALTH_MASK_ACCESS_TIMEOUT_RECOVERY_FALSE = 2 +NVML_GPU_FABRIC_HEALTH_MASK_SHIFT_ACCESS_TIMEOUT_RECOVERY = 6 +NVML_GPU_FABRIC_HEALTH_MASK_WIDTH_ACCESS_TIMEOUT_RECOVERY = 0x11 + +nvmlGpuFabricInfo_v2 = 0x02000024 + +class c_nvmlGpuFabricInfoV_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("clusterUuid", c_char * NVML_GPU_FABRIC_UUID_LEN), + ("status", _nvmlReturn_t), + ("cliqueId", c_uint32), + ("state", _nvmlGpuFabricState_t), + ("healthMask", c_uint32) + ] + + def __init__(self): + super(c_nvmlGpuFabricInfoV_t, self).__init__(version=nvmlGpuFabricInfo_v2) + +def nvmlDeviceGetGpuFabricInfo(device, gpuFabricInfo): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuFabricInfo"); + ret = fn(device, gpuFabricInfo) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceGetGpuFabricInfoV(device, gpuFabricInfo): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuFabricInfoV"); + ret = fn(device, gpuFabricInfo) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +###################### +## Enums/defines +#### NVML GPU NVLINK BW MODE +NVML_GPU_NVLINK_BW_MODE_FULL = 0x0 +NVML_GPU_NVLINK_BW_MODE_OFF = 0x1 +NVML_GPU_NVLINK_BW_MODE_MIN = 0x2 +NVML_GPU_NVLINK_BW_MODE_HALF = 0x3 +NVML_GPU_NVLINK_BW_MODE_3QUARTER = 0x4 +NVML_GPU_NVLINK_BW_MODE_COUNT = 0x5 + +def nvmlSystemSetNvlinkBwMode(mode): + fn = _nvmlGetFunctionPointer("nvmlSystemSetNvlinkBwMode") + ret = fn(mode) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlSystemGetNvlinkBwMode(): + mode = c_uint() + fn = _nvmlGetFunctionPointer("nvmlSystemGetNvlinkBwMode") + ret = fn(byref(mode)) + _nvmlCheckReturn(ret) + return mode.value + +_nvmlPowerScopeType_t = c_uint +NVML_POWER_SCOPE_GPU = 0 +NVML_POWER_SCOPE_MODULE = 1 +NVML_POWER_SCOPE_MEMORY = 2 + +class c_nvmlPowerValue_v2_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('powerScope', _nvmlPowerScopeType_t), + ('powerValueMw', c_uint), + ] + _fmt_ = {'': "%d B"} + +nvmlPowerValue_v2 = 0x0200000C + +def nvmlDeviceSetPowerManagementLimit_v2(device, powerScope, powerLimit, version=nvmlPowerValue_v2): + c_powerScope = _nvmlPowerScopeType_t(powerScope) + c_powerValue = c_nvmlPowerValue_v2_t() + c_powerValue.version = c_uint(version) + c_powerValue.powerScope = c_powerScope + c_powerValue.powerValueMw = c_uint(powerLimit) + fn = _nvmlGetFunctionPointer("nvmlDeviceSetPowerManagementLimit_v2") + ret = fn(device, byref(c_powerValue)) + return NVML_SUCCESS + +class c_nvmlEccSramErrorStatus_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('aggregateUncParity', c_ulonglong), + ('aggregateUncSecDed', c_ulonglong), + ('aggregateCor', c_ulonglong), + ('volatileUncParity', c_ulonglong), + ('volatileUncSecDed', c_ulonglong), + ('volatileCor', c_ulonglong), + ('aggregateUncBucketL2', c_ulonglong), + ('aggregateUncBucketSm', c_ulonglong), + ('aggregateUncBucketPcie', c_ulonglong), + ('aggregateUncBucketMcu', c_ulonglong), + ('aggregateUncBucketOther', c_ulonglong), + ('bThresholdExceeded', c_uint) + ] + + def __init__(self): + super(c_nvmlEccSramErrorStatus_v1_t, self).__init__(version=nvmlEccSramErrorStatus_v1) + +nvmlEccSramErrorStatus_v1 = 0x1000068 +def nvmlDeviceGetSramEccErrorStatus(device, status): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSramEccErrorStatus") + ret = fn(device, status) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +NVML_DEV_CAP_EGM = (1 << 0) +nvmlDeviceCapabilities_v1 = 0x1000008 + +class c_nvmlDeviceCapabilities_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('capMask', c_uint), + ] + + def __init__(self): + super(c_nvmlDeviceCapabilities_v1_t, self).__init__(version=nvmlDeviceCapabilities_v1) + + +def nvmlDeviceGetCapabilities(device, caps): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCapabilities") + return fn(device, caps) + +class c_nvmlPlatformInfo_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('ibGuid', c_char * 16), + ('rackGuid', c_char * 16), + ('chassisPhysicalSlotNumber', c_char), + ('computeSlotIndex', c_char), + ('nodeIndex', c_char), + ('peerType', c_char), + ('moduleId', c_char) + ] + + def __init__(self): + super(c_nvmlPlatformInfo_v1_t, self).__init__(version=nvmlPlatformInfo_v1) + +nvmlPlatformInfo_v1 = 0x100002c +def nvmlDeviceGetPlatformInfo(device, platformInfo): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPlatformInfo") + ret = fn(device, platformInfo) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +class c_nvmlMask255_t(_PrintableStructure): + _fields_ = [ + ('mask', c_uint * 8), + ] + +NVML_WORKLOAD_POWER_MAX_PROFILES = 255 +NVML_POWER_PROFILE_MAX_P = 0 +NVML_POWER_PROFILE_MAX_Q = 1 +NVML_POWER_PROFILE_COMPUTE = 2 +NVML_POWER_PROFILE_MEMORY_BOUND = 3 +NVML_POWER_PROFILE_NETWORK = 4 +NVML_POWER_PROFILE_BALANCED = 5 +NVML_POWER_PROFILE_LLM_INFERENCE = 6 +NVML_POWER_PROFILE_LLM_TRAINING = 7 +NVML_POWER_PROFILE_RBM = 8 +NVML_POWER_PROFILE_DCPCIE = 9 +NVML_POWER_PROFILE_HMMA_SPARSE = 10 +NVML_POWER_PROFILE_HMMA_DENSE = 11 +NVML_POWER_PROFILE_SYNC_BALANCED = 12 +NVML_POWER_PROFILE_HPC = 13 +NVML_POWER_PROFILE_MIG = 14 +NVML_POWER_PROFILE_MAX = 15 + +nvmlWorkloadPowerProfileInfo_v1 = 0x100002c +class c_nvmlWorkloadPowerProfileInfo_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('profileId', c_uint), + ('priority', c_uint), + ('conflictingmask', c_nvmlMask255_t) + ] + + def __init__(self): + super(c_nvmlWorkloadPowerProfileInfo_v1_t, self).__init__(version=nvmlWorkloadPowerProfileInfo_v1) + +nvmlWorkloadPowerProfileProfilesInfo_v1 = 0x1002bf8 +class c_nvmlWorkloadPowerProfileProfilesInfo_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('perfProfilesMask', c_nvmlMask255_t), + ('perfProfile', c_nvmlWorkloadPowerProfileInfo_v1_t * NVML_WORKLOAD_POWER_MAX_PROFILES) + ] + + def __init__(self): + super(c_nvmlWorkloadPowerProfileProfilesInfo_v1_t, self).__init__(version=nvmlWorkloadPowerProfileProfilesInfo_v1) + +nvmlWorkloadPowerProfileCurrentProfiles_v1 = 0x1000064 +class c_nvmlWorkloadPowerProfileCurrentProfiles_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('perfProfilesMask', c_nvmlMask255_t), + ('requestedProfilesMask', c_nvmlMask255_t), + ('enforcedProfilesMask', c_nvmlMask255_t) + ] + + def __init__(self): + super(c_nvmlWorkloadPowerProfileCurrentProfiles_v1_t, self).__init__(version=nvmlWorkloadPowerProfileCurrentProfiles_v1) + +nvmlWorkloadPowerProfileRequestedProfiles_v1 = 0x1000024 +class c_nvmlWorkloadPowerProfileRequestedProfiles_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('requestedProfilesMask', c_nvmlMask255_t), + ] + + def __init__(self): + super(c_nvmlWorkloadPowerProfileRequestedProfiles_v1_t, self).__init__(version=nvmlWorkloadPowerProfileRequestedProfiles_v1) + +def nvmlDeviceWorkloadPowerProfileGetProfilesInfo(device, profilesInfo): + fn = _nvmlGetFunctionPointer("nvmlDeviceWorkloadPowerProfileGetProfilesInfo") + ret = fn(device, profilesInfo) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceWorkloadPowerProfileGetCurrentProfiles(device, currentProfiles): + fn = _nvmlGetFunctionPointer("nvmlDeviceWorkloadPowerProfileGetCurrentProfiles") + ret = fn(device, currentProfiles) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceWorkloadPowerProfileSetRequestedProfiles(device, requestedProfiles): + fn = _nvmlGetFunctionPointer("nvmlDeviceWorkloadPowerProfileSetRequestedProfiles") + ret = fn(device, requestedProfiles) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceWorkloadPowerProfileClearRequestedProfiles(device, requestedProfiles): + fn = _nvmlGetFunctionPointer("nvmlDeviceWorkloadPowerProfileClearRequestedProfiles") + ret = fn(device, requestedProfiles) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceGetNvlinkSupportedBwModes(device, supportedBwModes): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvlinkSupportedBwModes") + ret = fn(device, supportedBwModes) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceGetNvlinkBwMode(device, getBwMode): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvlinkBwMode") + ret = fn(device, getBwMode) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceSetNvlinkBwMode(device, setBwMode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetNvlinkBwMode") + ret = fn(device, setBwMode) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +nvmlDramEncryptionInfo_v1 = 0x01000008 + +class c_nvmlDramEncryptionInfo_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('encryptionState', _nvmlEnableState_t), + ] + + def __init__(self): + super(c_nvmlDramEncryptionInfo_t, self).__init__(version=nvmlDramEncryptionInfo_v1) + +def nvmlDeviceGetDramEncryptionMode(handle): + c_currState = c_nvmlDramEncryptionInfo_t() + c_pendingState = c_nvmlDramEncryptionInfo_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDramEncryptionMode") + ret = fn(handle, byref(c_currState), byref(c_pendingState)) + _nvmlCheckReturn(ret) + return [c_currState.encryptionState, c_pendingState.encryptionState] + +# added to API +def nvmlDeviceGetCurrentDramEncryptionMode(handle): + return nvmlDeviceGetDramEncryptionMode(handle)[0] + +# added to API +def nvmlDeviceGetPendingDramEncryptionMode(handle): + return nvmlDeviceGetDramEncryptionMode(handle)[1] + +def nvmlDeviceSetDramEncryptionMode(handle, mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetDramEncryptionMode") + c_dramEncryptionMode = c_nvmlDramEncryptionInfo_t() + c_dramEncryptionMode.encryptionState = mode; + ret = fn(handle, byref(c_dramEncryptionMode)) + _nvmlCheckReturn(ret) + return None + +# Power Smoothing defines +NVML_POWER_SMOOTHING_MAX_NUM_PROFILES = 5 +NVML_POWER_SMOOTHING_ADMIN_OVERRIDE_NOT_SET = 0xFFFFFFFF +NVML_POWER_SMOOTHING_PROFILE_PARAM_PERCENT_TMP_FLOOR = 0 +NVML_POWER_SMOOTHING_PROFILE_PARAM_RAMP_UP_RATE = 1 +NVML_POWER_SMOOTHING_PROFILE_PARAM_RAMP_DOWN_RATE = 2 +NVML_POWER_SMOOTHING_PROFILE_PARAM_RAMP_DOWN_HYSTERESIS = 3 + +nvmlPowerSmoothingState_v1=0x1000008 +class c_nvmlPowerSmoothingState_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('state', c_uint), + ] + + def __init__(self): + super(c_nvmlPowerSmoothingState_v1_t, self).__init__(version=nvmlPowerSmoothingState_v1) + +nvmlPowerSmoothingProfile_v1=0x1000018 +class c_nvmlPowerSmoothingProfile_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('profileId', c_uint), + ('paramId', c_uint), + ('value', c_double), + ] + + def __init__(self): + super(c_nvmlPowerSmoothingProfile_v1_t, self).__init__(version=nvmlPowerSmoothingProfile_v1) + +def nvmlDevicePowerSmoothingActivatePresetProfile(device, profile): + fn = _nvmlGetFunctionPointer("nvmlDevicePowerSmoothingActivatePresetProfile") + ret = fn(device, profile) + _nvmlCheckReturn(ret) + +def nvmlDevicePowerSmoothingUpdatePresetProfileParam(device, profile): + fn = _nvmlGetFunctionPointer("nvmlDevicePowerSmoothingUpdatePresetProfileParam") + ret = fn(device, profile) + _nvmlCheckReturn(ret) + +def nvmlDevicePowerSmoothingSetState(device, state): + fn = _nvmlGetFunctionPointer("nvmlDevicePowerSmoothingSetState") + ret = fn(device, state) + _nvmlCheckReturn(ret) + diff --git a/vllm/tracing.py b/vllm/tracing.py new file mode 100644 index 0000000..557ae40 --- /dev/null +++ b/vllm/tracing.py @@ -0,0 +1,130 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +from collections.abc import Mapping +from typing import Optional + +from vllm.logger import init_logger +from vllm.utils import run_once + +TRACE_HEADERS = ["traceparent", "tracestate"] + +logger = init_logger(__name__) + +_is_otel_imported = False +otel_import_error_traceback: Optional[str] = None +try: + from opentelemetry.context.context import Context + from opentelemetry.sdk.environment_variables import ( + OTEL_EXPORTER_OTLP_TRACES_PROTOCOL) + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import BatchSpanProcessor + from opentelemetry.trace import SpanKind, Tracer, set_tracer_provider + from opentelemetry.trace.propagation.tracecontext import ( + TraceContextTextMapPropagator) + _is_otel_imported = True +except ImportError: + # Capture and format traceback to provide detailed context for the import + # error. Only the string representation of the error is retained to avoid + # memory leaks. + # See https://github.com/vllm-project/vllm/pull/7266#discussion_r1707395458 + import traceback + otel_import_error_traceback = traceback.format_exc() + + class Context: # type: ignore + pass + + class BaseSpanAttributes: # type: ignore + pass + + class SpanKind: # type: ignore + pass + + class Tracer: # type: ignore + pass + + +def is_otel_available() -> bool: + return _is_otel_imported + + +def init_tracer(instrumenting_module_name: str, + otlp_traces_endpoint: str) -> Optional[Tracer]: + if not is_otel_available(): + raise ValueError( + "OpenTelemetry is not available. Unable to initialize " + "a tracer. Ensure OpenTelemetry packages are installed. " + f"Original error:\n{otel_import_error_traceback}") + trace_provider = TracerProvider() + + span_exporter = get_span_exporter(otlp_traces_endpoint) + trace_provider.add_span_processor(BatchSpanProcessor(span_exporter)) + set_tracer_provider(trace_provider) + + tracer = trace_provider.get_tracer(instrumenting_module_name) + return tracer + + +def get_span_exporter(endpoint): + protocol = os.environ.get(OTEL_EXPORTER_OTLP_TRACES_PROTOCOL, "grpc") + if protocol == "grpc": + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import ( + OTLPSpanExporter) + elif protocol == "http/protobuf": + from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( + OTLPSpanExporter) # type: ignore + else: + raise ValueError( + f"Unsupported OTLP protocol '{protocol}' is configured") + + return OTLPSpanExporter(endpoint=endpoint) + + +def extract_trace_context( + headers: Optional[Mapping[str, str]]) -> Optional[Context]: + if is_otel_available(): + headers = headers or {} + return TraceContextTextMapPropagator().extract(headers) + else: + return None + + +def extract_trace_headers(headers: Mapping[str, str]) -> Mapping[str, str]: + + return {h: headers[h] for h in TRACE_HEADERS if h in headers} + + +class SpanAttributes: + # Attribute names copied from here to avoid version conflicts: + # https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/gen-ai-spans.md + GEN_AI_USAGE_COMPLETION_TOKENS = "gen_ai.usage.completion_tokens" + GEN_AI_USAGE_PROMPT_TOKENS = "gen_ai.usage.prompt_tokens" + GEN_AI_REQUEST_MAX_TOKENS = "gen_ai.request.max_tokens" + GEN_AI_REQUEST_TOP_P = "gen_ai.request.top_p" + GEN_AI_REQUEST_TEMPERATURE = "gen_ai.request.temperature" + GEN_AI_RESPONSE_MODEL = "gen_ai.response.model" + # Attribute names added until they are added to the semantic conventions: + GEN_AI_REQUEST_ID = "gen_ai.request.id" + GEN_AI_REQUEST_N = "gen_ai.request.n" + GEN_AI_USAGE_NUM_SEQUENCES = "gen_ai.usage.num_sequences" + GEN_AI_LATENCY_TIME_IN_QUEUE = "gen_ai.latency.time_in_queue" + GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN = "gen_ai.latency.time_to_first_token" + GEN_AI_LATENCY_E2E = "gen_ai.latency.e2e" + GEN_AI_LATENCY_TIME_IN_SCHEDULER = "gen_ai.latency.time_in_scheduler" + # Time taken in the forward pass for this across all workers + GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD = ( + "gen_ai.latency.time_in_model_forward") + # Time taken in the model execute function. This will include model + # forward, block/sync across workers, cpu-gpu sync time and sampling time. + GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE = ( + "gen_ai.latency.time_in_model_execute") + + +def contains_trace_headers(headers: Mapping[str, str]) -> bool: + return any(h in headers for h in TRACE_HEADERS) + + +@run_once +def log_tracing_disabled_warning() -> None: + logger.warning( + "Received a request with trace context but tracing is disabled") diff --git a/vllm/transformers_utils/__init__.py b/vllm/transformers_utils/__init__.py new file mode 100644 index 0000000..01d5bb4 --- /dev/null +++ b/vllm/transformers_utils/__init__.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm.envs import VLLM_USE_MODELSCOPE + +if VLLM_USE_MODELSCOPE: + # Patch here, before each import happens + import modelscope + from packaging import version + + # patch_hub begins from modelscope>=1.18.1 + if version.parse(modelscope.__version__) <= version.parse('1.18.0'): + raise ImportError( + 'Using vLLM with ModelScope needs modelscope>=1.18.1, please ' + 'install by `pip install modelscope -U`') + + from modelscope.utils.hf_util import patch_hub + + # Patch hub to download models from modelscope to speed up. + patch_hub() diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py new file mode 100644 index 0000000..d27a126 --- /dev/null +++ b/vllm/transformers_utils/config.py @@ -0,0 +1,773 @@ +# SPDX-License-Identifier: Apache-2.0 + +import enum +import json +import os +import time +from functools import cache +from pathlib import Path +from typing import Any, Callable, Dict, Literal, Optional, Type, Union + +import huggingface_hub +from huggingface_hub import hf_hub_download +from huggingface_hub import list_repo_files as hf_list_repo_files +from huggingface_hub import try_to_load_from_cache +from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, + HFValidationError, LocalEntryNotFoundError, + RepositoryNotFoundError, + RevisionNotFoundError) +from torch import nn +from transformers import GenerationConfig, PretrainedConfig +from transformers.models.auto.image_processing_auto import ( + get_image_processor_config) +from transformers.models.auto.modeling_auto import ( + MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) +from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME + +from vllm.envs import VLLM_USE_MODELSCOPE +from vllm.logger import init_logger +# yapf conflicts with isort for this block +# yapf: disable +from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config, + DbrxConfig, DeepseekVLV2Config, + EAGLEConfig, ExaoneConfig, + H2OVLChatConfig, + InternVLChatConfig, JAISConfig, + MedusaConfig, MllamaConfig, + MLPSpeculatorConfig, MPTConfig, + NemotronConfig, NVLM_D_Config, + Olmo2Config, RWConfig, + SkyworkR1VChatConfig, SolarConfig, + Telechat2Config, UltravoxConfig) +# yapf: enable +from vllm.transformers_utils.utils import check_gguf_file +from vllm.utils import resolve_obj_by_qualname + +if VLLM_USE_MODELSCOPE: + from modelscope import AutoConfig +else: + from transformers import AutoConfig + +MISTRAL_CONFIG_NAME = "params.json" +HF_TOKEN = os.getenv('HF_TOKEN', None) + +logger = init_logger(__name__) + +_CONFIG_REGISTRY_OVERRIDE_HF: Dict[str, Type[PretrainedConfig]] = { + "mllama": MllamaConfig +} + +_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { + "chatglm": ChatGLMConfig, + "cohere2": Cohere2Config, + "dbrx": DbrxConfig, + "deepseek_vl_v2": DeepseekVLV2Config, + "mpt": MPTConfig, + "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) + "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) + "jais": JAISConfig, + "mlp_speculator": MLPSpeculatorConfig, + "medusa": MedusaConfig, + "eagle": EAGLEConfig, + "exaone": ExaoneConfig, + "h2ovl_chat": H2OVLChatConfig, + "internvl_chat": InternVLChatConfig, + "nemotron": NemotronConfig, + "NVLM_D": NVLM_D_Config, + "olmo2": Olmo2Config, + "solar": SolarConfig, + "skywork_chat": SkyworkR1VChatConfig, + "telechat": Telechat2Config, + "ultravox": UltravoxConfig, + **_CONFIG_REGISTRY_OVERRIDE_HF +} + + +class ConfigFormat(str, enum.Enum): + AUTO = "auto" + HF = "hf" + MISTRAL = "mistral" + + +def with_retry(func: Callable[[], Any], + log_msg: str, + max_retries: int = 2, + retry_delay: int = 2): + for attempt in range(max_retries): + try: + return func() + except Exception as e: + if attempt == max_retries - 1: + logger.error("%s: %s", log_msg, e) + raise + logger.error("%s: %s, retrying %d of %d", log_msg, e, attempt + 1, + max_retries) + time.sleep(retry_delay) + retry_delay *= 2 + + +# @cache doesn't cache exceptions +@cache +def list_repo_files( + repo_id: str, + *, + revision: Optional[str] = None, + repo_type: Optional[str] = None, + token: Union[str, bool, None] = None, +) -> list[str]: + + def lookup_files() -> list[str]: + # directly list files if model is local + if (local_path := Path(repo_id)).exists(): + return [ + str(file.relative_to(local_path)) + for file in local_path.rglob('*') if file.is_file() + ] + # if model is remote, use hf_hub api to list files + try: + if VLLM_USE_MODELSCOPE: + from vllm.transformers_utils.utils import ( + modelscope_list_repo_files) + return modelscope_list_repo_files(repo_id, + revision=revision, + token=token) + return hf_list_repo_files(repo_id, + revision=revision, + repo_type=repo_type, + token=token) + except huggingface_hub.errors.OfflineModeIsEnabled: + # Don't raise in offline mode, + # all we know is that we don't have this + # file cached. + return [] + + return with_retry(lookup_files, "Error retrieving file list") + + +def file_exists( + repo_id: str, + file_name: str, + *, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + token: Union[str, bool, None] = None, +) -> bool: + file_list = list_repo_files(repo_id, + repo_type=repo_type, + revision=revision, + token=token) + return file_name in file_list + + +# In offline mode the result can be a false negative +def file_or_path_exists(model: Union[str, Path], config_name: str, + revision: Optional[str]) -> bool: + if (local_path := Path(model)).exists(): + return (local_path / config_name).is_file() + + # Offline mode support: Check if config file is cached already + cached_filepath = try_to_load_from_cache(repo_id=model, + filename=config_name, + revision=revision) + if isinstance(cached_filepath, str): + # The config file exists in cache- we can continue trying to load + return True + + # NB: file_exists will only check for the existence of the config file on + # hf_hub. This will fail in offline mode. + + # Call HF to check if the file exists + return file_exists(str(model), + config_name, + revision=revision, + token=HF_TOKEN) + + +def patch_rope_scaling(config: PretrainedConfig) -> None: + """Provide backwards compatibility for RoPE.""" + text_config = getattr(config, "text_config", None) + if text_config is not None: + patch_rope_scaling(text_config) + + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None: + patch_rope_scaling_dict(rope_scaling) + + +def patch_rope_scaling_dict(rope_scaling: Dict[str, Any]) -> None: + if "rope_type" in rope_scaling and "type" in rope_scaling: + rope_type = rope_scaling["rope_type"] + rope_type_legacy = rope_scaling["type"] + if rope_type != rope_type_legacy: + raise ValueError( + f"Found conflicts between 'rope_type={rope_type}' (modern " + f"field) and 'type={rope_type_legacy}' (legacy field). " + "You should only specify one of them.") + + if "rope_type" not in rope_scaling and "type" in rope_scaling: + rope_scaling["rope_type"] = rope_scaling["type"] + logger.info("Replacing legacy 'type' key with 'rope_type'") + + if "rope_type" not in rope_scaling: + raise ValueError("rope_scaling should have a 'rope_type' key") + + if rope_scaling["rope_type"] == "su": + rope_scaling["rope_type"] = "longrope" + logger.warning("Replacing legacy rope_type 'su' with 'longrope'") + elif rope_scaling["rope_type"] == "mrope": + assert "mrope_section" in rope_scaling + rope_scaling["rope_type"] = "default" + logger.warning("Replacing legacy rope_type 'mrope' with 'default'") + + +def uses_mrope(config: PretrainedConfig) -> bool: + """Detect if the model with this config uses M-ROPE.""" + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is None: + return False + + return "mrope_section" in rope_scaling + + +def is_encoder_decoder(config: PretrainedConfig) -> bool: + """Detect if the model with this config is used as an encoder/decoder.""" + text_config = getattr(config, "text_config", None) + if text_config is not None: + return is_encoder_decoder(text_config) + + return getattr(config, "is_encoder_decoder", False) + + +def get_config( + model: Union[str, Path], + trust_remote_code: bool, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + config_format: ConfigFormat = ConfigFormat.AUTO, + **kwargs, +) -> PretrainedConfig: + # Separate model folder from file path for GGUF models + + is_gguf = check_gguf_file(model) + if is_gguf: + kwargs["gguf_file"] = Path(model).name + model = Path(model).parent + + if config_format == ConfigFormat.AUTO: + try: + if is_gguf or file_or_path_exists( + model, HF_CONFIG_NAME, revision=revision): + config_format = ConfigFormat.HF + elif file_or_path_exists(model, + MISTRAL_CONFIG_NAME, + revision=revision): + config_format = ConfigFormat.MISTRAL + else: + raise ValueError( + "Could not detect config format for no config file found. " + "Ensure your model has either config.json (HF format) " + "or params.json (Mistral format).") + + except Exception as e: + error_message = ( + "Invalid repository ID or local directory specified:" + " '{model}'.\nPlease verify the following requirements:\n" + "1. Provide a valid Hugging Face repository ID.\n" + "2. Specify a local directory that contains a recognized " + "configuration file.\n" + " - For Hugging Face models: ensure the presence of a " + "'config.json'.\n" + " - For Mistral models: ensure the presence of a " + "'params.json'.\n").format(model=model) + + raise ValueError(error_message) from e + + if config_format == ConfigFormat.HF: + config_dict, _ = PretrainedConfig.get_config_dict( + model, + revision=revision, + code_revision=code_revision, + token=HF_TOKEN, + **kwargs, + ) + + # Use custom model class if it's in our registry + model_type = config_dict.get("model_type") + if model_type in _CONFIG_REGISTRY: + config_class = _CONFIG_REGISTRY[model_type] + config = config_class.from_pretrained( + model, + revision=revision, + code_revision=code_revision, + token=HF_TOKEN, + **kwargs, + ) + else: + try: + config = AutoConfig.from_pretrained( + model, + trust_remote_code=trust_remote_code, + revision=revision, + code_revision=code_revision, + token=HF_TOKEN, + **kwargs, + ) + except ValueError as e: + if (not trust_remote_code + and "requires you to execute the configuration file" + in str(e)): + err_msg = ( + "Failed to load the model config. If the model " + "is a custom model not yet available in the " + "HuggingFace transformers library, consider setting " + "`trust_remote_code=True` in LLM or using the " + "`--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e + + elif config_format == ConfigFormat.MISTRAL: + config = load_params_config(model, revision, token=HF_TOKEN, **kwargs) + else: + supported_formats = [ + fmt.value for fmt in ConfigFormat if fmt != ConfigFormat.AUTO + ] + raise ValueError( + f"Unsupported config format: {config_format}. " + f"Supported formats are: {', '.join(supported_formats)}. " + f"Ensure your model uses one of these configuration formats " + f"or specify the correct format explicitly.") + + # Special architecture mapping check for GGUF models + if is_gguf: + if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: + raise RuntimeError( + f"Can't get gguf config for {config.model_type}.") + model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type] + config.update({"architectures": [model_type]}) + + patch_rope_scaling(config) + + if trust_remote_code: + maybe_register_config_serialize_by_value() + + return config + + +def try_get_local_file(model: Union[str, Path], + file_name: str, + revision: Optional[str] = 'main') -> Optional[Path]: + file_path = Path(model) / file_name + if file_path.is_file(): + return file_path + else: + try: + cached_filepath = try_to_load_from_cache(repo_id=model, + filename=file_name, + revision=revision) + if isinstance(cached_filepath, str): + return Path(cached_filepath) + except HFValidationError: + ... + return None + + +def get_hf_file_to_dict(file_name: str, + model: Union[str, Path], + revision: Optional[str] = 'main'): + """ + Downloads a file from the Hugging Face Hub and returns + its contents as a dictionary. + + Parameters: + - file_name (str): The name of the file to download. + - model (str): The name of the model on the Hugging Face Hub. + - revision (str): The specific version of the model. + + Returns: + - config_dict (dict): A dictionary containing + the contents of the downloaded file. + """ + + file_path = try_get_local_file(model=model, + file_name=file_name, + revision=revision) + + if file_path is None: + try: + hf_hub_file = hf_hub_download(model, file_name, revision=revision) + except huggingface_hub.errors.OfflineModeIsEnabled: + return None + except (RepositoryNotFoundError, RevisionNotFoundError, + EntryNotFoundError, LocalEntryNotFoundError) as e: + logger.debug("File or repository not found in hf_hub_download", e) + return None + except HfHubHTTPError as e: + logger.warning( + "Cannot connect to Hugging Face Hub. Skipping file " + "download for '%s':", + file_name, + exc_info=e) + return None + file_path = Path(hf_hub_file) + + if file_path is not None and file_path.is_file(): + with open(file_path) as file: + return json.load(file) + + return None + + +@cache +def get_pooling_config(model: str, revision: Optional[str] = 'main'): + """ + This function gets the pooling and normalize + config from the model - only applies to + sentence-transformers models. + + Args: + model (str): The name of the Hugging Face model. + revision (str, optional): The specific version + of the model to use. Defaults to 'main'. + + Returns: + dict: A dictionary containing the pooling + type and whether normalization is used. + """ + + modules_file_name = "modules.json" + + modules_dict = None + if file_or_path_exists(model=model, + config_name=modules_file_name, + revision=revision): + modules_dict = get_hf_file_to_dict(modules_file_name, model, revision) + + if modules_dict is None: + return None + + logger.info("Found sentence-transformers modules configuration.") + + pooling = next((item for item in modules_dict + if item["type"] == "sentence_transformers.models.Pooling"), + None) + normalize = bool( + next((item for item in modules_dict + if item["type"] == "sentence_transformers.models.Normalize"), + False)) + + if pooling: + + pooling_file_name = "{}/config.json".format(pooling["path"]) + pooling_dict = get_hf_file_to_dict(pooling_file_name, model, revision) + pooling_type_name = next( + (item for item, val in pooling_dict.items() if val is True), None) + + if pooling_type_name is not None: + pooling_type_name = get_pooling_config_name(pooling_type_name) + + logger.info("Found pooling configuration.") + return {"pooling_type": pooling_type_name, "normalize": normalize} + + return None + + +def get_pooling_config_name(pooling_name: str) -> Union[str, None]: + if "pooling_mode_" in pooling_name: + pooling_name = pooling_name.replace("pooling_mode_", "") + + if "_" in pooling_name: + pooling_name = pooling_name.split("_")[0] + + if "lasttoken" in pooling_name: + pooling_name = "last" + + supported_pooling_types = ['LAST', 'ALL', 'CLS', 'STEP', 'MEAN'] + pooling_type_name = pooling_name.upper() + + try: + if pooling_type_name in supported_pooling_types: + return pooling_type_name + except NotImplementedError as e: + logger.debug("Pooling type not supported", e) + return None + return None + + +@cache +def get_sentence_transformer_tokenizer_config(model: str, + revision: Optional[str] = 'main' + ): + """ + Returns the tokenization configuration dictionary for a + given Sentence Transformer BERT model. + + Parameters: + - model (str): The name of the Sentence Transformer + BERT model. + - revision (str, optional): The revision of the m + odel to use. Defaults to 'main'. + + Returns: + - dict: A dictionary containing the configuration parameters + for the Sentence Transformer BERT model. + """ + sentence_transformer_config_files = [ + "sentence_bert_config.json", + "sentence_roberta_config.json", + "sentence_distilbert_config.json", + "sentence_camembert_config.json", + "sentence_albert_config.json", + "sentence_xlm-roberta_config.json", + "sentence_xlnet_config.json", + ] + encoder_dict = None + + for config_file in sentence_transformer_config_files: + if try_get_local_file(model=model, + file_name=config_file, + revision=revision) is not None: + encoder_dict = get_hf_file_to_dict(config_file, model, revision) + if encoder_dict: + break + + if not encoder_dict and not model.startswith("/"): + try: + # If model is on HuggingfaceHub, get the repo files + repo_files = list_repo_files(model, + revision=revision, + token=HF_TOKEN) + except Exception: + repo_files = [] + + for config_name in sentence_transformer_config_files: + if config_name in repo_files: + encoder_dict = get_hf_file_to_dict(config_name, model, + revision) + if encoder_dict: + break + + if not encoder_dict: + return None + + logger.info("Found sentence-transformers tokenize configuration.") + + if all(k in encoder_dict for k in ("max_seq_length", "do_lower_case")): + return encoder_dict + return None + + +def maybe_register_config_serialize_by_value() -> None: + """Try to register HF model configuration class to serialize by value + + If trust_remote_code is set, and the model's config file specifies an + `AutoConfig` class, then the config class is typically an instance of + a custom class imported from the HF modules cache. + + Examples: + + >>> from transformers import AutoConfig + >>> klass = AutoConfig.from_pretrained('meta-llama/Meta-Llama-3-8B', trust_remote_code=True) + >>> klass.__class__ # transformers.models.llama.configuration_llama.LlamaConfig + >>> import transformers_modules # error, not initialized + >>> klass = AutoConfig.from_pretrained('deepseek-ai/DeepSeek-V2.5', trust_remote_code=True) + >>> import transformers_modules # success, initialized + >>> klass.__class__ # transformers_modules.deepseek-ai.DeepSeek-V2.5.98b11844770b2c3ffc18b175c758a803640f4e77.configuration_deepseek.DeepseekV2Config + + In the DeepSeek example, the config class is an instance of a custom + class that is not serializable by default. This class will not be + importable in spawned workers, and won't exist at all on + other nodes, which breaks serialization of the config. + + In this function we tell the cloudpickle serialization library to pass + instances of these generated classes by value instead of by reference, + i.e. the class definition is serialized along with its data so that the + class module does not need to be importable on the receiving end. + + See: https://github.com/cloudpipe/cloudpickle?tab=readme-ov-file#overriding-pickles-serialization-mechanism-for-importable-constructs + """ # noqa + try: + import transformers_modules + except ImportError: + # the config does not need trust_remote_code + return + + try: + import cloudpickle + cloudpickle.register_pickle_by_value(transformers_modules) + + # ray vendors its own version of cloudpickle + from vllm.executor.ray_utils import ray + if ray: + ray.cloudpickle.register_pickle_by_value(transformers_modules) + + # multiprocessing uses pickle to serialize arguments when using spawn + # Here we get pickle to use cloudpickle to serialize config objects + # that contain instances of the custom config class to avoid + # serialization problems if the generated module (and model) has a `.` + # in its name + import multiprocessing + import pickle + + from vllm.config import VllmConfig + + def _reduce_config(config: VllmConfig): + return (pickle.loads, (cloudpickle.dumps(config), )) + + multiprocessing.reducer.register(VllmConfig, _reduce_config) + + except Exception as e: + logger.warning( + "Unable to register remote classes used by" + " trust_remote_code with by-value serialization. This may" + " lead to a later error. If remote code is not needed" + " remove `--trust-remote-code`", + exc_info=e) + + +def load_params_config(model: Union[str, Path], revision: Optional[str], + **kwargs) -> PretrainedConfig: + # This function loads a params.json config which + # should be used when loading models in mistral format + + config_file_name = "params.json" + + config_dict = get_hf_file_to_dict(config_file_name, model, revision) + assert isinstance(config_dict, dict) + + config_mapping = { + "dim": "hidden_size", + "norm_eps": "rms_norm_eps", + "n_kv_heads": "num_key_value_heads", + "n_layers": "num_hidden_layers", + "n_heads": "num_attention_heads", + "hidden_dim": "intermediate_size", + } + + def recurse_elems(elem: Any): + if isinstance(elem, dict): + config_dict = {} + for key, value in elem.items(): + key = config_mapping.get(key, key) + config_dict[key] = recurse_elems(value) + + return config_dict + else: + return elem + + config_dict["model_type"] = config_dict.get("model_type", "transformer") + config_dict["hidden_act"] = config_dict.get("activation", "silu") + config_dict["tie_word_embeddings"] = config_dict.get( + "tie_embeddings", False) + config_dict["max_seq_len"] = config_dict.get("max_seq_len", 128_000) + config_dict["max_position_embeddings"] = config_dict.get( + "max_position_embeddings", 128_000) + + if config_dict.get("quantization") is not None: + quantization = config_dict.get("quantization", {}) + if quantization.get("qformat_weight") == "fp8_e4m3": + # This maps to the FP8 static per-tensor quantization scheme + quantization_config = { + "quant_method": "fp8", + "activation_scheme": "static" + } + else: + raise ValueError( + f"Found unknown quantization='{quantization}' in config") + + config_dict["quantization_config"] = quantization_config + + config_type: Literal["text", + "multimodal"] = "multimodal" if config_dict.get( + "vision_encoder") is not None else "text" + + if config_dict.get("moe") is not None: + config_dict["architectures"] = ["MixtralForCausalLM"] + else: + config_dict["architectures"] = ["MistralForCausalLM"] + + if config_type == "multimodal": + multimodal_config = config_dict.pop("vision_encoder") + + config_dict = { + "text_config": config_dict, + "vision_config": multimodal_config + } + config_dict["architectures"] = ["PixtralForConditionalGeneration"] + config_dict["model_type"] = "pixtral" + + config_dict.update(kwargs) + + config_dict = recurse_elems(config_dict) + + # transform to HF config format + if config_type == "multimodal": + config_dict["text_config"] = PretrainedConfig( + **config_dict["text_config"]) + config_dict["vision_config"] = PretrainedConfig( + **config_dict["vision_config"]) + + return PretrainedConfig(**config_dict) + + +def get_hf_image_processor_config( + model: Union[str, Path], + revision: Optional[str] = None, + **kwargs, +) -> Dict[str, Any]: + # ModelScope does not provide an interface for image_processor + if VLLM_USE_MODELSCOPE: + return dict() + # Separate model folder from file path for GGUF models + if check_gguf_file(model): + model = Path(model).parent + return get_image_processor_config(model, revision=revision, **kwargs) + + +def get_hf_text_config(config: PretrainedConfig): + """Get the "sub" config relevant to llm for multi modal models. + No op for pure text models. + """ + if hasattr(config, "text_config"): + # The code operates under the assumption that text_config should have + # `num_attention_heads` (among others). Assert here to fail early + # if transformers config doesn't align with this assumption. + assert hasattr(config.text_config, "num_attention_heads") + return config.text_config + else: + return config + + +def try_get_generation_config( + model: str, + trust_remote_code: bool, + revision: Optional[str] = None, +) -> Optional[GenerationConfig]: + try: + return GenerationConfig.from_pretrained( + model, + revision=revision, + ) + except OSError: # Not found + try: + config = get_config( + model, + trust_remote_code=trust_remote_code, + revision=revision, + ) + return GenerationConfig.from_model_config(config) + except OSError: # Not found + return None + + +def get_cross_encoder_activation_function(config: PretrainedConfig): + if (hasattr(config, "sbert_ce_default_activation_function") + and config.sbert_ce_default_activation_function is not None): + + function_name = config.sbert_ce_default_activation_function + assert function_name.startswith("torch.nn.modules."), \ + "Loading of activation functions is restricted to " \ + "torch.nn.modules for security reasons" + return resolve_obj_by_qualname(function_name)() + else: + return nn.Sigmoid() if config.num_labels == 1 else nn.Identity() diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py new file mode 100644 index 0000000..5369934 --- /dev/null +++ b/vllm/transformers_utils/configs/__init__.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm.transformers_utils.configs.chatglm import ChatGLMConfig +from vllm.transformers_utils.configs.cohere2 import Cohere2Config +from vllm.transformers_utils.configs.dbrx import DbrxConfig +from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config +from vllm.transformers_utils.configs.eagle import EAGLEConfig +from vllm.transformers_utils.configs.exaone import ExaoneConfig +# RWConfig is for the original tiiuae/falcon-40b(-instruct) and +# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the +# `FalconConfig` class from the official HuggingFace transformers library. +from vllm.transformers_utils.configs.falcon import RWConfig +from vllm.transformers_utils.configs.h2ovl import H2OVLChatConfig +from vllm.transformers_utils.configs.internvl import InternVLChatConfig +from vllm.transformers_utils.configs.jais import JAISConfig +from vllm.transformers_utils.configs.medusa import MedusaConfig +from vllm.transformers_utils.configs.mllama import MllamaConfig +from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig +from vllm.transformers_utils.configs.mpt import MPTConfig +from vllm.transformers_utils.configs.nemotron import NemotronConfig +from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config +from vllm.transformers_utils.configs.olmo2 import Olmo2Config +from vllm.transformers_utils.configs.skyworkr1v import SkyworkR1VChatConfig +from vllm.transformers_utils.configs.solar import SolarConfig +from vllm.transformers_utils.configs.telechat2 import Telechat2Config +from vllm.transformers_utils.configs.ultravox import UltravoxConfig + +__all__ = [ + "ChatGLMConfig", + "Cohere2Config", + "DbrxConfig", + "DeepseekVLV2Config", + "MPTConfig", + "RWConfig", + "H2OVLChatConfig", + "InternVLChatConfig", + "JAISConfig", + "MedusaConfig", + "EAGLEConfig", + "ExaoneConfig", + "MllamaConfig", + "MLPSpeculatorConfig", + "NemotronConfig", + "NVLM_D_Config", + "Olmo2Config", + "SkyworkR1VChatConfig", + "SolarConfig", + "Telechat2Config", + "UltravoxConfig", +] diff --git a/vllm/transformers_utils/configs/arctic.py b/vllm/transformers_utils/configs/arctic.py new file mode 100644 index 0000000..5ab70c0 --- /dev/null +++ b/vllm/transformers_utils/configs/arctic.py @@ -0,0 +1,206 @@ +# SPDX-License-Identifier: Apache-2.0 + +# yapf: disable +# ruff: noqa: E501 +# coding=utf-8 +# Copied from +# https://huggingface.co/Snowflake/snowflake-arctic-instruct/blob/main/configuration_arctic.py +""" Arctic model configuration""" + +from dataclasses import asdict, dataclass +from typing import Any, Dict + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +ARCTIC_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "arctic": "https://huggingface.co/Snowflake/snowflake-arctic-instruct/tree/main/config.json", +} + + +@dataclass +class ArcticLoRAConfig: + lora_r: int = 64 + lora_alpha: float = 16 + shard_base_weights: bool = False + + +@dataclass +class ArcticQuantizationConfig: + q_bits: int = 8 + rounding: str = "nearest" + mantissa_bits: int = 3 + group_size: int = 128 + + +class ArcticConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ArcticModel`]. It is used to instantiate an + Arctic model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the #TODO(rsamdani): add what model has the default config.. + + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Arctic model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`ArcticModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to `4096*32`): + The maximum sequence length that this model might ever be used with. Arctic's sliding window attention + allows sequence of up to 4096*32 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + sliding_window (`int`, *optional*): + Sliding window attention window size. If not specified, will default to `4096`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_experts_per_tok (`int`, *optional*, defaults to 2): + The number of experts to root per-token, can be also interpreted as the `top-p` routing + parameter + num_local_experts (`int`, *optional*, defaults to 8): + Number of experts per Sparse MLP layer. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + + ```python + >>> from transformers import ArcticModel, ArcticConfig + + >>> # Initializing a Arctic 7B style configuration TODO(rsamdani): verify which model does the default configuration correspond to. + >>> configuration = ArcticConfig() + + >>> # Initializing a model from the Arctic 7B style configuration + >>> model = ArcticModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "arctic" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=1e6, + sliding_window=None, + attention_dropout=0.0, + num_experts_per_tok=1, + num_local_experts=8, + router_aux_loss_coef=0.001, + moe_layer_frequency=2, + parallel_attn_mlp_res=False, + moe_train_capacity_factor=1, + moe_eval_capacity_factor=1, + enable_expert_tensor_parallelism=False, + moe_min_capacity=0, + moe_token_dropping=True, + quantization=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.router_aux_loss_coef = router_aux_loss_coef + self.moe_layer_frequency = moe_layer_frequency + self.moe_train_capacity_factor = moe_train_capacity_factor + self.moe_eval_capacity_factor = moe_eval_capacity_factor + self.enable_expert_tensor_parallelism = enable_expert_tensor_parallelism + self.moe_min_capacity = moe_min_capacity + self.moe_token_dropping = moe_token_dropping + self.parallel_attn_mlp_res = parallel_attn_mlp_res + if isinstance(quantization, dict): + self.quantization = ArcticQuantizationConfig(**quantization) + else: + self.quantization = quantization + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + @classmethod + def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "ArcticConfig": + result = super().from_dict(config_dict, **kwargs) + config = result[0] if isinstance(result, tuple) else result + if isinstance(config.quantization, dict): + config.quantization = ArcticQuantizationConfig(**config.quantization) + return result + + def to_dict(self) -> Dict[str, Any]: + ret = super().to_dict() + if isinstance(ret["quantization"], ArcticQuantizationConfig): + ret["quantization"] = asdict(ret["quantization"]) + return ret diff --git a/vllm/transformers_utils/configs/chatglm.py b/vllm/transformers_utils/configs/chatglm.py new file mode 100644 index 0000000..43e9503 --- /dev/null +++ b/vllm/transformers_utils/configs/chatglm.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/THUDM/ChatGLM2-6B +from transformers import PretrainedConfig + + +class ChatGLMConfig(PretrainedConfig): + model_type = "chatglm" + attribute_map = { + "num_hidden_layers": "num_layers", + "n_head_kv": "multi_query_group_num", + } + + def __init__(self, + num_layers=28, + padded_vocab_size=65024, + hidden_size=4096, + ffn_hidden_size=13696, + kv_channels=128, + num_attention_heads=32, + seq_length=2048, + hidden_dropout=0.0, + attention_dropout=0.0, + layernorm_epsilon=1e-5, + rmsnorm=True, + apply_residual_connection_post_layernorm=False, + post_layer_norm=True, + add_bias_linear=False, + add_qkv_bias=False, + interleaved_qkv=False, + bias_dropout_fusion=True, + multi_query_attention=False, + multi_query_group_num=1, + apply_query_key_layer_scaling=True, + attention_softmax_in_fp32=True, + fp32_residual_connection=False, + quantization_bit=0, + pre_seq_len=None, + prefix_projection=False, + **kwargs): + self.num_layers = num_layers + self.vocab_size = padded_vocab_size + self.padded_vocab_size = padded_vocab_size + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.kv_channels = kv_channels + self.num_attention_heads = num_attention_heads + self.seq_length = seq_length + # It is to be compatible with long lora. + self.max_position_embeddings = seq_length + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.layernorm_epsilon = layernorm_epsilon + self.rmsnorm = rmsnorm + self.apply_residual_connection_post_layernorm = ( + apply_residual_connection_post_layernorm) + self.post_layer_norm = post_layer_norm + self.add_bias_linear = add_bias_linear + self.add_qkv_bias = add_qkv_bias + self.bias_dropout_fusion = bias_dropout_fusion + self.multi_query_attention = multi_query_attention + self.multi_query_group_num = multi_query_group_num + self.apply_query_key_layer_scaling = apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.fp32_residual_connection = fp32_residual_connection + self.quantization_bit = quantization_bit + self.pre_seq_len = pre_seq_len + self.prefix_projection = prefix_projection + self.interleaved_qkv = interleaved_qkv + super().__init__(**kwargs) diff --git a/vllm/transformers_utils/configs/cohere2.py b/vllm/transformers_utils/configs/cohere2.py new file mode 100644 index 0000000..e30409b --- /dev/null +++ b/vllm/transformers_utils/configs/cohere2.py @@ -0,0 +1,194 @@ +# SPDX-License-Identifier: Apache-2.0 + +# ruff: noqa + +# Adapted from +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/cohere2/configuration_cohere2.py +from transformers import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation + + +class Cohere2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`CohereModel`]. It is used to instantiate an Cohere + model according to the specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. Instantiating a configuration + with the defaults will yield a similar configuration to that of the [CohereForAI/c4ai-command-r-v01](https://huggingface.co/CohereForAI/c4ai-command-r-v01) model. + + + Args: + vocab_size (`int`, *optional*, defaults to 256000): + Vocabulary size of the Cohere model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`CohereModel`] + hidden_size (`int`, *optional*, defaults to 8192): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 22528): + Dimension of the MLP representations. + logit_scale (`float`, *optional*, defaults to 0.0625): + The scaling factor for the output logits. + num_hidden_layers (`int`, *optional*, defaults to 40): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 64): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 8192): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 5): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 255001): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + sliding_window (`int`, *optional*, defaults to 4096): + Size of the sliding window attention context. + sliding_window_pattern (`int`, *optional*, defaults to 4): + Pattern for the sliding window attention. + cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`. + + ```python + >>> from transformers import Cohere2Model, Cohere2Config + + >>> # Initializing a Cohere Nextmodel configuration + >>> configuration = Cohere2Config() + + >>> # Initializing a model from the Cohere2 configuration + >>> model = Cohere2Model(configuration) # doctest: +SKIP + + >>> # Accessing the model configuration + >>> configuration = model.config # doctest: +SKIP + ``` + """ + + model_type = "cohere2" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=256000, + hidden_size=8192, + intermediate_size=22528, + logit_scale=0.0625, + num_hidden_layers=40, + num_attention_heads=64, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=8192, + initializer_range=0.02, + layer_norm_eps=1e-5, + use_cache=True, + pad_token_id=0, + bos_token_id=5, + eos_token_id=255001, + tie_word_embeddings=True, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + sliding_window=4096, + sliding_window_pattern=4, + cache_implementation="hybrid", + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.logit_scale = logit_scale + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.sliding_window = sliding_window + self.sliding_window_pattern = sliding_window_pattern + # Need to specify head_dim in the config so it can be used in the attention forward functions + self.head_dim = hidden_size // num_attention_heads + self.cache_implementation = cache_implementation + + # Validate the correctness of rotary position embeddings parameters + rope_config_validation(self) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["Cohere2Config"] diff --git a/vllm/transformers_utils/configs/dbrx.py b/vllm/transformers_utils/configs/dbrx.py new file mode 100644 index 0000000..8f40b2b --- /dev/null +++ b/vllm/transformers_utils/configs/dbrx.py @@ -0,0 +1,280 @@ +# SPDX-License-Identifier: Apache-2.0 + +# yapf: disable +# ruff: noqa: E501 +# coding=utf-8 +# Copied from +# https://huggingface.co/databricks/dbrx-base/blob/main/configuration_dbrx.py +"""Dbrx configuration.""" + +from typing import Any, Optional + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} # type: ignore + + +class DbrxAttentionConfig(PretrainedConfig): + """Configuration class for Dbrx Attention. + + [`DbrxAttention`] class. It is used to instantiate attention layers + according to the specified arguments, defining the layers architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + attn_pdrop (`float`, *optional*, defaults to 0.0): + The dropout probability for the attention layers. + clip_qkv (`float`, *optional*, defaults to None): + If not `None`, clip the queries, keys, and values in the attention layer to this value. + kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads. + rope_theta (float): The base frequency for rope. + """ + + def __init__( + self, + attn_pdrop: float = 0, + clip_qkv: Optional[float] = None, + kv_n_heads: int = 1, + rope_theta: float = 10000.0, + **kwargs: Any, + ): + super().__init__(**kwargs) + self.attn_pdrop = attn_pdrop + self.clip_qkv = clip_qkv + self.kv_n_heads = kv_n_heads + self.rope_theta = rope_theta + + for k in ["model_type"]: + if k in kwargs: + kwargs.pop(k) + if len(kwargs) != 0: + raise ValueError(f"Found unknown {kwargs=}") + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path: str, **kwargs: Any + ) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs + ) + + if config_dict.get("model_type") == "dbrx": + config_dict = config_dict["attn_config"] + + if ( + "model_type" in config_dict + and hasattr(cls, "model_type") + and config_dict["model_type"] != cls.model_type + ): + logger.warning( + "You are using a model of type %s to instantiate a model of " + "type %s. This is not supported for all configurations of " + "models and can yield errors.", + config_dict["model_type"], cls.model_type) + + return cls.from_dict(config_dict, **kwargs) + + +class DbrxFFNConfig(PretrainedConfig): + """Configuration class for Dbrx FFN. + + [`DbrxFFN`] class. It is used to instantiate feedforward layers according to + the specified arguments, defining the layers architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + ffn_act_fn (dict, optional): A dict specifying activation function for the FFN. + The dict should have a key 'name' with the value being the name of + the activation function along with any additional keyword arguments. + ffn_hidden_size (int, optional): The hidden size of the feedforward network. + moe_num_experts (int, optional): The number of experts in the mixture of experts layer. + moe_top_k (int, optional): The number of experts to use in the mixture of experts layer. + moe_jitter_eps (float, optional): The jitter epsilon for the mixture of experts layer. + moe_loss_weight (float, optional): The loss weight for the mixture of experts layer. + moe_normalize_expert_weights (float, optional): The normalization factor for the expert weights. + uniform_expert_assignment (bool, optional): Whether to use uniform expert assignment. + This should only be used for benchmarking purposes. + """ + + def __init__( + self, + ffn_act_fn: Optional[dict] = None, + ffn_hidden_size: int = 3584, + moe_num_experts: int = 4, + moe_top_k: int = 1, + moe_jitter_eps: Optional[float] = None, + moe_loss_weight: float = 0.01, + moe_normalize_expert_weights: Optional[float] = 1, + uniform_expert_assignment: bool = False, + **kwargs: Any, + ): + super().__init__() + if ffn_act_fn is None: + ffn_act_fn = {"name": "silu"} + self.ffn_act_fn = ffn_act_fn + self.ffn_hidden_size = ffn_hidden_size + self.moe_num_experts = moe_num_experts + self.moe_top_k = moe_top_k + self.moe_jitter_eps = moe_jitter_eps + self.moe_loss_weight = moe_loss_weight + self.moe_normalize_expert_weights = moe_normalize_expert_weights + self.uniform_expert_assignment = uniform_expert_assignment + + for k in ["model_type"]: + if k in kwargs: + kwargs.pop(k) + if len(kwargs) != 0: + raise ValueError(f"Found unknown {kwargs=}") + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path: str, **kwargs: Any + ) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs + ) + + if config_dict.get("model_type") == "dbrx": + config_dict = config_dict["ffn_config"] + + if ( + "model_type" in config_dict + and hasattr(cls, "model_type") + and config_dict["model_type"] != cls.model_type + ): + logger.warning( + "You are using a model of type %s to instantiate a model of " + "type %s. This is not supported for all " + "configurations of models and can yield errors.", config_dict["model_type"], cls.model_type) + + return cls.from_dict(config_dict, **kwargs) + + +class DbrxConfig(PretrainedConfig): + """Configuration class for Dbrx. + + [`DbrxModel`]. It is used to instantiate a Dbrx model according to the + specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + d_model (`int`, *optional*, defaults to 6144): + Dimensionality of the embeddings and hidden states. + n_heads (`int`, *optional*, defaults to 48): + Number of attention heads for each attention layer in the Transformer encoder. + n_layers (`int`, *optional*, defaults to 40): + Number of hidden layers in the Transformer encoder. + max_seq_len (`int`, *optional*, defaults to 32768): + The maximum sequence length of the model. + vocab_size (`int`, *optional*, defaults to 100352): + Vocabulary size of the Dbrx model. Defines the maximum number of different tokens that can be represented by + the `inputs_ids` passed when calling [`DbrxModel`]. + resid_pdrop (`float`, *optional*, defaults to 0.0): + The dropout probability applied to the attention output before combining with residual. + emb_pdrop (`float`, *optional*, defaults to 0.0): + The dropout probability for the embedding layer. + attn_config (`dict`, *optional*): + A dictionary used to configure the model's attention module. + ffn_config (`dict`, *optional*): + A dictionary used to configure the model's FFN module. + use_cache (`bool`, *optional*, defaults to `False`): + Whether or not the model should return the last key/values attentions (not used by all models). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabling this will also + allow the model to output the auxiliary loss. See [here]() for more details + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + + + Example: + ```python + >>> from transformers import DbrxConfig, DbrxModel + + >>> # Initializing a Dbrx configuration + >>> configuration = DbrxConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = DbrxModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "dbrx" + attribute_map = { + "num_attention_heads": "n_heads", + "hidden_size": "d_model", + "num_hidden_layers": "n_layers", + "max_position_embeddings": "max_seq_len", + } + + def __init__( + self, + d_model: int = 2048, + n_heads: int = 16, + n_layers: int = 24, + max_seq_len: int = 2048, + vocab_size: int = 32000, + resid_pdrop: float = 0.0, + emb_pdrop: float = 0.0, + attn_config: Optional[DbrxAttentionConfig] = None, + ffn_config: Optional[DbrxFFNConfig] = None, + use_cache: bool = True, + initializer_range: float = 0.02, + output_router_logits: bool = False, + router_aux_loss_coef: float = 0.05, + **kwargs: Any, + ): + if attn_config is None: + self.attn_config = DbrxAttentionConfig() + elif isinstance(attn_config, dict): + self.attn_config = DbrxAttentionConfig(**attn_config) + else: + self.attn_config = attn_config + + if ffn_config is None: + self.ffn_config = DbrxFFNConfig() + elif isinstance(ffn_config, dict): + self.ffn_config = DbrxFFNConfig(**ffn_config) + else: + self.ffn_config = ffn_config + + self.d_model = d_model + self.n_heads = n_heads + self.n_layers = n_layers + self.max_seq_len = max_seq_len + self.vocab_size = vocab_size + self.resid_pdrop = resid_pdrop + self.emb_pdrop = emb_pdrop + self.use_cache = use_cache + self.initializer_range = initializer_range + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + + tie_word_embeddings = kwargs.pop("tie_word_embeddings", False) + if tie_word_embeddings: + raise ValueError( + "tie_word_embeddings is not supported for Dbrx models." + ) + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/vllm/transformers_utils/configs/deepseek_vl2.py b/vllm/transformers_utils/configs/deepseek_vl2.py new file mode 100644 index 0000000..24d4052 --- /dev/null +++ b/vllm/transformers_utils/configs/deepseek_vl2.py @@ -0,0 +1,216 @@ +# SPDX-License-Identifier: Apache-2.0 + +# adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/faf18023f24b962b32d9f0a2d89e402a8d383a78/deepseek_vl2/models/modeling_deepseek_vl_v2.py#L115-L268 +from typing import Tuple + +from transformers.configuration_utils import PretrainedConfig + + +class VisionEncoderConfig(PretrainedConfig): + model_type: str = "vision" + + model_name: str = "vit_so400m_patch14_siglip_384.webli" + image_size: int = 384 + patch_size: int = 16 + width: int = 1024 + layers: int = 24 + heads: int = 16 + mlp_ratio: int = 4 + global_pool: str = "map" + ignore_head: bool = True + class_token: bool = False + num_classes: int = 0 + use_checkpoint: bool = False + weight_init: str = "skip" + deterministic: bool = False + num_recomputing_layers: int = 0 + + def __init__(self, + model_name: str = "vit_so400m_patch14_siglip_384.webli", + image_size: int = 384, + patch_size: int = 16, + width: int = 1024, + layers: int = 24, + heads: int = 16, + mlp_ratio: int = 4, + global_pool: str = "map", + ignore_head: bool = True, + class_token: bool = False, + num_classes: int = 0, + use_checkpoint: bool = False, + **kwargs): + self.model_name = model_name + self.image_size = image_size + self.patch_size = patch_size + self.width = width + self.layers = layers + self.heads = heads + self.mlp_ratio = mlp_ratio + self.global_pool = global_pool + self.ignore_head = ignore_head + self.class_token = class_token + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + + super().__init__(**kwargs) + + +class MlpProjectorConfig(PretrainedConfig): + model_type = "mlp_projector" + projector_type: str = "downsample_mlp_gelu" + input_dim: int = 1152 + n_embed: int = 2048 + depth: int = 2 + mlp_ratio: int = 1 + downsample_ratio: int = 2 + token_pooling: bool = False + + def __init__(self, + projector_type: str = "downsample_mlp_gelu", + input_dim: int = 1152, + n_embed: int = 2048, + depth: int = 2, + mlp_ratio: int = 1, + downsample_ratio: int = 2, + **kwargs): + self.projector_type = projector_type + self.input_dim = input_dim + self.n_embed = n_embed + self.depth = depth + self.mlp_ratio = mlp_ratio + self.downsample_ratio = downsample_ratio + + super().__init__(**kwargs) + + +class DeepseekV2Config(PretrainedConfig): + + model_type = "deepseek_v2" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=102400, + hidden_size=4096, + intermediate_size=11008, + moe_intermediate_size=1407, + num_hidden_layers=30, + num_attention_heads=32, + num_key_value_heads=32, + n_shared_experts=None, + n_routed_experts=None, + ep_size=1, + routed_scaling_factor=1.0, + kv_lora_rank=512, + q_lora_rank=1536, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + topk_method='gready', + n_group=None, + topk_group=None, + num_experts_per_tok=None, + moe_layer_freq=1, + first_k_dense_replace=0, + norm_topk_prob=False, + scoring_func='softmax', + aux_loss_alpha=0.001, + seq_aux=True, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=100000, + eos_token_id=100001, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + use_mla=True, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + self.aux_loss_alpha = aux_loss_alpha + self.seq_aux = seq_aux + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = float(rms_norm_eps) + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.use_mla = use_mla + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class DeepseekVLV2Config(PretrainedConfig): + model_type = "deepseek_vl_v2" + vision_config: VisionEncoderConfig + projector_config: MlpProjectorConfig + + tile_tag: str = "2D" + global_view_pos: str = "head" + candidate_resolutions: Tuple[Tuple[int, int]] = ((384, 384), ) + + def __init__(self, + tile_tag: str = "tile_tag", + global_view_pos: str = "head", + candidate_resolutions: Tuple[Tuple[int, + int]] = ((384, 384), ), + **kwargs): + super().__init__(**kwargs) + + vision_config = kwargs.get("vision_config", {}) + self.vision_config = VisionEncoderConfig(**vision_config) + + projector_config = kwargs.get("projector_config", {}) + self.projector_config = MlpProjectorConfig(**projector_config) + + language_config = kwargs.get("language_config", {}) + self.text_config = DeepseekV2Config(**language_config) + + self.tile_tag = tile_tag + self.global_view_pos = global_view_pos + self.candidate_resolutions = candidate_resolutions + self.vocab_size = self.text_config.vocab_size diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py new file mode 100644 index 0000000..dd80606 --- /dev/null +++ b/vllm/transformers_utils/configs/eagle.py @@ -0,0 +1,62 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +from typing import Optional, Union + +from transformers import AutoConfig, PretrainedConfig + +from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config + + +class EAGLEConfig(PretrainedConfig): + model_type = "eagle" + + def __init__(self, + model: Union[PretrainedConfig, dict, None] = None, + truncated_vocab_size: Optional[int] = None, + **kwargs): + + model_config: Union[PretrainedConfig, DeepseekV2Config, None] + if isinstance(model, dict): + archs = model.get("architectures", []) + target_archs = ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"] + if any(target_arch in archs for target_arch in target_archs): + # AutoConfig does not support DeepSeek MoE models yet + model_config = DeepseekV2Config(**model) + else: + model_config = AutoConfig.for_model(**model) + else: + model_config = model + + for k, v in kwargs.items(): + if k != "architectures" and k != "model_type" and hasattr( + model_config, k): + setattr(model_config, k, v) + + self.model = model_config + + if self.model is None: + self.truncated_vocab_size = None + else: + self.truncated_vocab_size = self.model.vocab_size if \ + truncated_vocab_size is None else truncated_vocab_size + + if "architectures" not in kwargs: + kwargs["architectures"] = ["EAGLEModel"] + + super().__init__(**kwargs) + + if self.model is not None: + for k, v in self.model.to_dict().items(): + if not hasattr(self, k): + setattr(self, k, v) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + **kwargs, + ) -> "EAGLEConfig": + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs) + return cls.from_dict(config_dict, **kwargs) diff --git a/vllm/transformers_utils/configs/exaone.py b/vllm/transformers_utils/configs/exaone.py new file mode 100644 index 0000000..3936436 --- /dev/null +++ b/vllm/transformers_utils/configs/exaone.py @@ -0,0 +1,191 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copied from +# https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/blob/main/configuration_exaone.py +# Copyright 2021 The LG AI Research EXAONE Lab. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Exaone model configuration""" + +from typing import Dict + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +EXAONE_PRETRAINED_CONFIG_ARCHIVE_MAP: Dict[str, str] = {} + + +class ExaoneConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class: + `~transformers.ExaoneModel`. It is used to instantiate a GPT Lingvo model + according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar + configuration to that of the Exaone + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` + and can be used to control the model outputs. Read the documentation from : + class:`~transformers.PretrainedConfig` for more information. + + Args: + vocab_size (:obj:`int`, `optional`, defaults to 50257): + Vocabulary size of the GPT Lingvo model. Defines the number of + different tokens that can be represented by the :obj:`inputs_ids` + passed when calling :class:`~transformers.ExaoneModel`. Vocabulary + size of the model. + Defines the different tokens that can be represented by the + `inputs_ids` passed to the forward method of :class: + `~transformers.EXAONEModel`. + hidden_size (:obj:`int`, `optional`, defaults to 2048): + Dimensionality of the encoder layers and the pooler layer. + num_layers (:obj:`int`, `optional`, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the + Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to + implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi + Head Attention (MHA), if `num_key_value_heads=1 the model will use + Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, + each group key and value head should be constructed by meanpooling + all the original heads within that group. For more details checkout + [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not + specified, will default to `num_attention_heads`. + rotary_pct (`float`, *optional*, defaults to 0.25): + percentage of hidden dimensions to allocate to rotary embeddings + intermediate_size (:obj:`int`, `optional`, defaults to 8192): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in + the Transformer encoder. + activation_function (:obj:`str` or :obj:`function`, `optional`, + defaults to :obj:`"gelu_new"`): + The non-linear activation function (function or string) in the + encoder and pooler. If string, :obj:`"gelu"`, :obj:`"relu"`, + :obj:`"selu"` and :obj:`"gelu_new"` are supported. + embed_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout probabilitiy for all fully connected layers in the + embeddings, encoder, and pooler. + attention_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout ratio for the attention probabilities. + max_position_embeddings (:obj:`int`, `optional`, defaults to 2048): + The maximum sequence length that this model might ever be used with. + Typically set this to something large just in case + (e.g., 512 or 1024 or 2048). + type_vocab_size (:obj:`int`, `optional`, defaults to 2): + The vocabulary size of the :obj:`token_type_ids` passed when calling + :class:`~transformers.EXAONEModel`. + initializer_range (:obj:`float`, `optional`, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for + initializing all weight matrices. + layer_norm_epsilon (:obj:`float`, `optional`, defaults to 1e-5): + The epsilon used by the layer normalization layers. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not the model should return the last key/values + attentions (not used by all models). + Only relevant if ``config.is_decoder=True``. + gradient_checkpointing (:obj:`bool`, `optional`, + defaults to :obj:`False`): + If True, use gradient checkpointing to save memory at the expense + of slower backward pass. + Example:: + + >>> from transformers import ExoneModel, ExaoneConfig + + >>> # Initializing a EXAONE configuration + >>> configuration = ExaoneConfig() + + >>> # Initializing a model from configuration + >>> model = ExoneModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + """ + + model_type = "exaone" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_hidden_layers": "num_layers"} + + def __init__( + self, + vocab_size=102400, + max_position_embeddings=2048, + hidden_size=2048, + num_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + intermediate_size=None, + activation_function="silu", + rotary_pct=0.25, + resid_dropout=0.0, + embed_dropout=0.0, + attention_dropout=0.0, + layer_norm_epsilon=1e-6, + initializer_range=0.02, + use_cache=True, + bos_token_id=0, + eos_token_id=2, + tie_word_embeddings=True, + **kwargs, + ): + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_layers = num_layers + self.num_attention_heads = num_attention_heads + self.num_hidden_layers = num_layers + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + if intermediate_size: + self.intermediate_size = intermediate_size + else: + self.intermediate_size = hidden_size * 4 + self.activation_function = activation_function + self.resid_dropout = resid_dropout + self.embed_dropout = embed_dropout + self.attention_dropout = attention_dropout + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.use_cache = use_cache + self.rotary_pct = rotary_pct + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + self.use_logit_cap = kwargs.pop("use_logit_cap", False) + self.ln_no_scale = kwargs.pop("ln_no_scale", False) + self.use_gated = kwargs.pop("use_gated", False) + self.use_emb_norm = kwargs.pop("use_emb_norm", False) + self.use_rotary_pos = kwargs.pop("use_rotary_pos", False) + self.rotary_type = kwargs.pop("rotary_type", None) + self.scaling_factor = kwargs.pop("scaling_factor", 1) + self.use_absolute_pos = kwargs.pop("use_absolute_pos", True) + self.use_extra_logit = kwargs.pop("use_extra_logit", True) + self.rotary_expand_length = kwargs.pop("rotary_expand_length", None) + self.rotary_base = kwargs.pop("rotary_base", 10000.0) + self.use_qkv_fuse = kwargs.pop("use_qkv_fuse", False) + self.rescale_before_lm_head = kwargs.pop("rescale_before_lm_head", + (rotary_pct == 0.25)) + if self.use_rotary_pos: + self.use_absolute_pos = False diff --git a/vllm/transformers_utils/configs/falcon.py b/vllm/transformers_utils/configs/falcon.py new file mode 100644 index 0000000..f161a06 --- /dev/null +++ b/vllm/transformers_utils/configs/falcon.py @@ -0,0 +1,89 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://huggingface.co/tiiuae/falcon-7b/blob/main/configuration_RW.py +# Copyright 2023 The vLLM team. +# Copyright 2022 the Big Science Workshop and HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Falcon configuration""" +from transformers.configuration_utils import PretrainedConfig + + +class RWConfig(PretrainedConfig): + model_type = "falcon" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "num_hidden_layers": "n_layer", + "num_attention_heads": "n_head", + "num_kv_heads": "n_head_kv", + } + + def __init__( + self, + vocab_size=250880, + hidden_size=64, + n_layer=2, + n_head=8, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + use_cache=True, + bos_token_id=1, + eos_token_id=2, + hidden_dropout=0.0, + attention_dropout=0.0, + multi_query=True, + n_head_kv=None, + alibi=False, + bias=False, + parallel_attn=False, + new_decoder_architecture=False, + **kwargs, + ) -> None: + self.vocab_size = vocab_size + # Backward compatibility with n_embed kwarg + n_embed = kwargs.pop("n_embed", None) + self.hidden_size = hidden_size if n_embed is None else n_embed + self.n_layer = n_layer + self.n_head = n_head + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.use_cache = use_cache + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.multi_query = multi_query + self.n_head_kv = 1 if n_head_kv is None else n_head_kv + self.alibi = alibi + self.bias = bias + self.parallel_attn = parallel_attn + self.new_decoder_architecture = new_decoder_architecture + + if self.hidden_size == 8192: + # Hack for falcon-40b + self.new_decoder_architecture = True + + super().__init__(bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs) + + @property + def head_dim(self): + return self.hidden_size // self.n_head + + @property + def rotary(self): + return not self.alibi diff --git a/vllm/transformers_utils/configs/h2ovl.py b/vllm/transformers_utils/configs/h2ovl.py new file mode 100644 index 0000000..48b5d79 --- /dev/null +++ b/vllm/transformers_utils/configs/h2ovl.py @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/configuration_h2ovl_chat.py +# -------------------------------------------------------- +# H2OVL-Mississippi +# Copyright (c) 2024 H2O.AI +# Licensed under Apache 2.0 License [see LICENSE for details] +# -------------------------------------------------------- + +from .internvl import InternVLChatConfig + + +class H2OVLChatConfig(InternVLChatConfig): + model_type = "h2ovl_chat" diff --git a/vllm/transformers_utils/configs/internvl.py b/vllm/transformers_utils/configs/internvl.py new file mode 100644 index 0000000..8ea6254 --- /dev/null +++ b/vllm/transformers_utils/configs/internvl.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/configuration_internvl_chat.py +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- +from transformers.configuration_utils import PretrainedConfig + + +class InternVLChatConfig(PretrainedConfig): + model_type = 'internvl_chat' + is_composition = True + + def __init__(self, + vision_config=None, + llm_config=None, + use_backbone_lora=0, + use_llm_lora=0, + select_layer=-1, + force_image_size=None, + downsample_ratio=0.5, + template=None, + dynamic_image_size=False, + use_thumbnail=False, + ps_version='v1', + min_dynamic_patch=1, + max_dynamic_patch=6, + **kwargs): + super().__init__(**kwargs) + + if vision_config is None: + vision_config = {} + + if llm_config is None: + llm_config = {} + + self.vision_config = PretrainedConfig(**vision_config) + self.text_config = PretrainedConfig(**llm_config) + + self.use_backbone_lora = use_backbone_lora + self.use_llm_lora = use_llm_lora + self.select_layer = select_layer + self.force_image_size = force_image_size + self.downsample_ratio = downsample_ratio + self.template = template + self.dynamic_image_size = dynamic_image_size + self.use_thumbnail = use_thumbnail + self.ps_version = ps_version # pixel shuffle version + self.min_dynamic_patch = min_dynamic_patch + self.max_dynamic_patch = max_dynamic_patch diff --git a/vllm/transformers_utils/configs/jais.py b/vllm/transformers_utils/configs/jais.py new file mode 100644 index 0000000..be0f3b7 --- /dev/null +++ b/vllm/transformers_utils/configs/jais.py @@ -0,0 +1,237 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2023 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# Copyright 2023 Cerebras Systems. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""JAIS configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class JAISConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a + [`JAISModel`]. It is used to instantiate a JAIS model according to the + specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used + to control the model outputs. Read the documentation from + [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50257): + Vocabulary size of the JAIS model. Defines the number of different + tokens that can be represented by the + `inputs_ids` passed when calling [`JAISModel`]. + n_positions (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used + with. Typically set this to something large just in case + (e.g., 512 or 1024 or 2048). + n_embd (`int`, *optional*, defaults to 768): + Dimensionality of the embeddings and hidden states. + n_layer (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + n_head (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the + Transformer encoder. + n_inner (`int`, *optional*, defaults to None): + Dimensionality of the inner feed-forward layers. `None` will set + it to 4 times n_embd + activation_function (`str`, *optional*, defaults to `"gelu"`): + Activation function, to be selected in the list + `["relu", "silu", "gelu", "tanh", "gelu_new", "swiglu"]`. + resid_pdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in + the embeddings, encoder, and pooler. + embd_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the embeddings. + attn_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): + The epsilon to use in the layer normalization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for + initializing all weight matrices. + scale_attn_weights (`bool`, *optional*, defaults to `True`): + Scale attention weights by dividing by sqrt(hidden_size).. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values + attentions (not used by all models). + scale_attn_by_inverse_layer_idx (`bool`, *optional*, + defaults to `False`): + Whether to additionally scale attention weights by + `1 / layer_idx + 1`. + reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): + Whether to scale keys (K) prior to computing attention + (dot-product) + and upcast attention dot-product/softmax to float() when training + with mixed precision. + position_embedding_type (`str`, *optional*, defaults to `"learned"`): + Positional embedding can be either `"alibi"` or `"learned"`. + mup_width_scale (`float`, *optional*, defaults to 1.0): + muP parameter to scale learning rate and initializers. Calculated + as (`d_model,0 / d_model`), where + `d_model` is the model's width and `d_model,0` is the proxy + model's width. + mup_embeddings_scale (`float`, *optional*, defaults to 1.0): + muP parameter to scale token and position embeddings. + mup_output_alpha (`float`, *optional*, defaults to 1.0): + muP parameter to scale output logits + (`output_logits_scale = mup_output_alpha * mup_width_scale`). + mup_scale_qk_dot_by_d (`bool`, *optional*, defaults to `False`): + Scale attention weights by dividing by hidden_size instead of + sqrt(hidden_size). Need to set scale_attn_weights to `True` as + well. + alibi_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for ALiBi + embeddings. Currently only supports linear + scaling strategy. Can specify either the scaling `factor` (must be + a float greater than 1) for fixed scaling + or `train_seq_len` for dynamic scaling on input samples with + sequence length > `train_seq_len`. The expected + formats are `{"type": strategy name, "factor": scaling factor}` or + `{"type": strategy name, + "train_seq_len": training sequence length}`. + architectures (`List`, *optional*, defaults to ['JAISLMHeadModel']): + architecture names for Jais. + + Example: + + ```python + >>> from transformers import JAISConfig, JAISModel + + >>> # Initializing a JAIS configuration + >>> configuration = JAISConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = JAISModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "jais" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "hidden_size": "n_embd", + "max_position_embeddings": "n_positions", + "num_attention_heads": "n_head", + "num_hidden_layers": "n_layer", + } + + def __init__( + self, + vocab_size=50257, + n_positions=1024, + n_embd=768, + n_layer=12, + n_head=12, + n_inner=None, + activation_function="gelu_new", + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + scale_attn_weights=True, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + scale_attn_by_inverse_layer_idx=False, + reorder_and_upcast_attn=False, + position_embedding_type="learned", + mup_width_scale=1.0, + mup_embeddings_scale=1.0, + mup_output_alpha=1.0, + mup_scale_qk_dot_by_d=False, + alibi_scaling=None, + architectures=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.n_inner = n_inner + self.activation_function = activation_function + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.scale_attn_weights = scale_attn_weights + self.use_cache = use_cache + self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx + self.reorder_and_upcast_attn = reorder_and_upcast_attn + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + self.position_embedding_type = position_embedding_type + self.mup_width_scale = mup_width_scale + self.mup_embeddings_scale = mup_embeddings_scale + self.mup_output_alpha = mup_output_alpha + self.mup_scale_qk_dot_by_d = mup_scale_qk_dot_by_d + + self.alibi_scaling = alibi_scaling + self._alibi_scaling_validation() + if architectures is None: + architectures = ["JAISLMHeadModel"] + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + architectures=architectures, + **kwargs, + ) + + def _alibi_scaling_validation(self): + """ + Validate the `alibi_scaling` configuration. + """ + if self.alibi_scaling is None: + return + + if (not isinstance(self.alibi_scaling, dict) + or len(self.alibi_scaling) != 2): + raise ValueError( + "`alibi_scaling` must be a dictionary with two fields, " + "`type` and `factor` or `type` and `train_seq_len`, " + f"got {self.alibi_scaling}") + alibi_scaling_type = self.alibi_scaling.get("type", None) + alibi_scaling_factor = self.alibi_scaling.get("factor", None) + alibi_dynamic_scaling = self.alibi_scaling.get("train_seq_len", None) + if alibi_scaling_type is None or alibi_scaling_type != "linear": + raise ValueError(f"`alibi_scaling`'s type field must be 'linear', " + f"got {alibi_scaling_type}") + if (alibi_scaling_factor is not None + and not isinstance(alibi_scaling_factor, float) + or (alibi_scaling_factor is not None + and alibi_scaling_factor <= 1.0)): + raise ValueError( + f"`alibi_scaling`'s factor field must be a float > 1.0, " + f"got {alibi_scaling_factor}") + if (alibi_dynamic_scaling is not None + and not isinstance(alibi_dynamic_scaling, int) + or (alibi_dynamic_scaling is not None + and alibi_dynamic_scaling <= 1)): + raise ValueError( + f"`alibi_scaling`'s `train_seq_len` field must be an " + f"integer > 1, got {alibi_dynamic_scaling}") diff --git a/vllm/transformers_utils/configs/medusa.py b/vllm/transformers_utils/configs/medusa.py new file mode 100644 index 0000000..885713c --- /dev/null +++ b/vllm/transformers_utils/configs/medusa.py @@ -0,0 +1,62 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +from typing import Optional, Union + +from transformers import PretrainedConfig + + +class MedusaConfig(PretrainedConfig): + model_type = "medusa" + + def __init__(self, + hidden_size: int = 4096, + vocab_size: int = 32001, + num_heads: int = 5, + num_hidden_layers: int = 1, + max_paths: int = 64, + topk: int = 10, + truncated_vocab_size: Optional[int] = None, + **kwargs): + + self.hidden_size = hidden_size + self.vocab_size = vocab_size + self.num_heads = num_heads + self.num_hidden_layers = num_hidden_layers + self.max_paths = max_paths + self.topk = topk + self.max_seq_len = int(2**20) + self.truncated_vocab_size = vocab_size if truncated_vocab_size is None\ + else truncated_vocab_size + if "architectures" not in kwargs: + kwargs["architectures"] = ["MedusaModel"] + + super().__init__(**kwargs) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + **kwargs, + ) -> "MedusaConfig": + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs) + for k in list(config_dict.keys()): + if 'num' in k: + if 'heads' in k: + config_dict["num_heads"] = config_dict.pop(k) + elif 'layers' in k: + config_dict["num_hidden_layers"] = config_dict.pop(k) + return cls.from_dict(config_dict, **kwargs) + + @property + def num_attention_heads(self): + return 0 + + @property + def num_lookahead_tokens(self): + return self.num_heads + + @num_lookahead_tokens.setter + def num_lookahead_tokens(self, num_lookahead_tokens: int): + self.num_heads = num_lookahead_tokens diff --git a/vllm/transformers_utils/configs/mllama.py b/vllm/transformers_utils/configs/mllama.py new file mode 100644 index 0000000..eb77e09 --- /dev/null +++ b/vllm/transformers_utils/configs/mllama.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: Apache-2.0 + +from transformers.models.mllama import configuration_mllama as mllama_hf_config + + +class MllamaTextConfig(mllama_hf_config.MllamaTextConfig): + ''' + Use this class to override is_encoder_decoder: + - transformers regards mllama as is_encoder_decoder=False + - vllm needs is_encoder_decoder=True to enable cross-attention + ''' + + def __init__( + self, + **kwargs, + ): + super().__init__(**kwargs) + self.is_encoder_decoder = True + + +class MllamaConfig(mllama_hf_config.MllamaConfig): + + def __init__( + self, + text_config=None, + **kwargs, + ): + if isinstance(text_config, dict): + text_config = MllamaTextConfig(**text_config) + super().__init__(text_config=text_config, **kwargs) diff --git a/vllm/transformers_utils/configs/mlp_speculator.py b/vllm/transformers_utils/configs/mlp_speculator.py new file mode 100644 index 0000000..c761f65 --- /dev/null +++ b/vllm/transformers_utils/configs/mlp_speculator.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional + +from transformers import PretrainedConfig + + +class MLPSpeculatorConfig(PretrainedConfig): + model_type = "mlp_speculator" + + attribute_map = { + "hidden_size": "emb_dim", + } + + def __init__(self, + vocab_size: int = 32000, + emb_dim: int = 4096, + inner_dim: int = 0, + n_predict: int = 3, + top_k_tokens_per_head: Optional[List[int]] = None, + n_candidates: int = 5, + tie_weights: bool = False, + scale_input: bool = False, + **kwargs): + """ + Initialize an MLPSpeculatorConfig + + Args: + vocab_size: int + the model vocab size + emb_dim: int + the model embedding dimension + inner_dim: int + the inner dimension of the model. If 0, will be the emb_dim. + n_predict: int + the number of lookaheads for the speculator + top_k_tokens_per_head: List[int] + Number of tokens to consider from each head when forming the + candidate tree. + For each candidate branch in the tree, head n produces topk[n] + additional sub-branches. + NOTE: This parameter is currently unused. + n_candidates: int + number of child candidates to create per sequence + tie_weights: bool + If true, use a single set of weights for every model + head/stage after the first. The initial projection + from the base model may have a different size, so that + stays separate. + scale_input: bool + if True, will scale the initial hidden states from + the base model. + """ + if top_k_tokens_per_head is None: + top_k_tokens_per_head = [5, 4, 3] + assert len(top_k_tokens_per_head) == n_predict + self.vocab_size = vocab_size + self.emb_dim = emb_dim + self.inner_dim = inner_dim + self.n_predict = n_predict + self.top_k_tokens_per_head = top_k_tokens_per_head + self.n_candidates = n_candidates + self.num_lookahead_tokens = n_predict + self.tie_weights = tie_weights + self.scale_input = scale_input + + super().__init__(**kwargs) diff --git a/vllm/transformers_utils/configs/mpt.py b/vllm/transformers_utils/configs/mpt.py new file mode 100644 index 0000000..9635613 --- /dev/null +++ b/vllm/transformers_utils/configs/mpt.py @@ -0,0 +1,179 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copied from +# https://huggingface.co/mosaicml/mpt-7b/blob/main/configuration_mpt.py +"""A HuggingFace-style model configuration.""" +import warnings +from typing import Any, Dict, Optional, Union + +from transformers import PretrainedConfig + +attn_config_defaults: Dict = { + 'attn_type': 'multihead_attention', + 'attn_pdrop': 0.0, + 'attn_impl': 'triton', + 'qk_ln': False, + 'clip_qkv': None, + 'softmax_scale': None, + 'prefix_lm': False, + 'attn_uses_sequence_id': False, + 'alibi': False, + 'alibi_bias_max': 8 +} +ffn_config_defaults: Dict = {'ffn_type': 'mptmlp'} +init_config_defaults: Dict = { + 'name': 'kaiming_normal_', + 'fan_mode': 'fan_in', + 'init_nonlinearity': 'relu', + 'init_div_is_residual': True, + 'emb_init_std': None, + 'emb_init_uniform_lim': None, + 'init_std': None, + 'init_gain': 0.0 +} + + +class MPTConfig(PretrainedConfig): + model_type = 'mpt' + attribute_map = { + 'num_attention_heads': 'n_heads', + 'hidden_size': 'd_model', + 'num_hidden_layers': 'n_layers', + } + + # pylint: disable=dangerous-default-value + def __init__(self, + d_model: int = 2048, + n_heads: int = 16, + n_layers: int = 24, + expansion_ratio: int = 4, + max_seq_len: int = 2048, + vocab_size: int = 50368, + resid_pdrop: float = 0.0, + emb_pdrop: float = 0.0, + learned_pos_emb: bool = True, + attn_config: Dict = attn_config_defaults, + ffn_config: Dict = ffn_config_defaults, + init_device: str = 'cpu', + logit_scale: Optional[Union[float, str]] = None, + no_bias: bool = False, + embedding_fraction: float = 1.0, + norm_type: str = 'low_precision_layernorm', + use_cache: bool = False, + init_config: Dict = init_config_defaults, + fc_type: str = 'torch', + verbose: Optional[int] = None, + **kwargs: Any): + self.d_model = d_model + self.n_heads = n_heads + self.n_layers = n_layers + self.expansion_ratio = expansion_ratio + self.max_seq_len = max_seq_len + self.vocab_size = vocab_size + self.resid_pdrop = resid_pdrop + self.emb_pdrop = emb_pdrop + self.learned_pos_emb = learned_pos_emb + self.attn_config = attn_config + self.ffn_config = ffn_config + self.init_device = init_device + self.logit_scale = logit_scale + self.no_bias = no_bias + self.embedding_fraction = embedding_fraction + self.norm_type = norm_type + self.use_cache = use_cache + self.init_config = init_config + self.fc_type = fc_type + if verbose is not None: + warnings.warn(DeprecationWarning( + 'verbose argument for MPTConfig is now ignored and ' + 'will be removed. Use python_log_level instead.'), + stacklevel=2) + if 'name' in kwargs: + del kwargs['name'] + if 'loss_fn' in kwargs: + del kwargs['loss_fn'] + if self.attn_config.get('alibi', False): + self.learned_pos_emb = False + warnings.warn( + f'alibi is turned on, setting `learned_pos_emb` ' + f'to {self.learned_pos_emb}`', + stacklevel=2) + super().__init__(**kwargs) + self._validate_config() + + def _set_config_defaults( + self, config: Dict[str, Any], + config_defaults: Dict[str, Any]) -> Dict[str, Any]: + for (k, v) in config_defaults.items(): + if k not in config: + config[k] = v + return config + + def _validate_config(self) -> None: + self.attn_config = self._set_config_defaults(self.attn_config, + attn_config_defaults) + self.ffn_config = self._set_config_defaults(self.ffn_config, + ffn_config_defaults) + self.init_config = self._set_config_defaults(self.init_config, + init_config_defaults) + if self.d_model % self.n_heads != 0: + raise ValueError('d_model must be divisible by n_heads') + if any( + prob < 0 or prob > 1 for prob in + [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop + ]): + raise ValueError( + "self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are " + "probabilities and must be between 0 and 1") + if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']: + raise ValueError( + f"Unknown attn_impl={self.attn_config['attn_impl']}") + if self.attn_config['prefix_lm'] and self.attn_config[ + 'attn_impl'] not in ['torch', 'triton']: + raise NotImplementedError( + 'prefix_lm only implemented with torch and triton attention.') + if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in [ + 'torch', 'triton' + ]: + raise NotImplementedError( + 'alibi only implemented with torch and triton attention.') + if self.attn_config['attn_uses_sequence_id'] and self.attn_config[ + 'attn_impl'] not in ['torch', 'triton']: + raise NotImplementedError( + 'attn_uses_sequence_id only implemented with torch ' + 'and triton attention.') + if self.embedding_fraction > 1 or self.embedding_fraction <= 0: + raise ValueError( + 'model.embedding_fraction must be between 0 (exclusive) ' + 'and 1 (inclusive)!') + if isinstance(self.logit_scale, + str) and self.logit_scale != 'inv_sqrt_d_model': + raise ValueError( + f"self.logit_scale={self.logit_scale!r} is not recognized as " + "an option; use numeric value or 'inv_sqrt_d_model'.") + if self.init_config.get('name', None) is None: + raise ValueError( + f"self.init_config={self.init_config!r} 'name' needs to be set." + ) + if not self.learned_pos_emb and (not self.attn_config['alibi']): + warnings.warn( + 'Positional information not being provided to the model.', + stacklevel=2) + if self.fc_type == 'te' or self.ffn_config['ffn_type'] == 'te_ln_mlp': + try: + # pylint: disable=import-outside-toplevel + import transformer_engine.pytorch as te + del te + except Exception as exc: + raise ImportError( + 'TransformerEngine import fail. `fc_type: te` requires ' + 'TransformerEngine be installed. ' + 'The required version of transformer_engine also requires ' + 'FlashAttention v1.0.6 is installed:\n' + 'pip install flash-attn==1.0.6 --no-build-isolation \n' + 'pip install git+https://github.com/NVIDIA/TransformerEngine.git@144e4888b2cdd60bd52e706d5b7a79cb9c1a7156' + ) from exc + if self.ffn_config['ffn_type'] == 'mptmlp': + self.ffn_config['fc_type'] = self.fc_type + elif self.ffn_config['ffn_type'] == 'te_ln_mlp': + self.ffn_config['bias'] = not self.no_bias diff --git a/vllm/transformers_utils/configs/nemotron.py b/vllm/transformers_utils/configs/nemotron.py new file mode 100644 index 0000000..fdf4fa2 --- /dev/null +++ b/vllm/transformers_utils/configs/nemotron.py @@ -0,0 +1,204 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Nemotron model configuration""" + +from transformers import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class NemotronConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a + [`NemotronModel`]. It is used to instantiate an Nemotron model + according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar + configuration to that of the Nemotron-8B. + + Configuration objects inherit from [`PretrainedConfig`] and can be + used to control the model outputs. Read the documentation from + [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 256000): + Vocabulary size of the Nemotron model. Defines the number of + different tokens that can be represented by the + `inputs_ids` passed when calling [`NemotronModel`] + hidden_size (`int`, *optional*, defaults to 6144): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 24576): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 48): + Number of attention heads for each attention layer in the + Transformer decoder. + head_dim (`int`, *optional*): + Projection weights dimension in multi-head attention. Set to + hidden_size // num_attention_heads if None + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to + implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use + Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention + (MQA) otherwise GQA is used. When converting a multi-head + checkpoint to a GQA checkpoint, each group key and value + head should be constructed by meanpooling all the original + heads within that group. For more details checkout + [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it + is not specified, will default to `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"relu2"`): + The non-linear activation function (function or string) in the + decoder. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used + with. + initializer_range (`float`, *optional*, defaults to 0.0134): + The standard deviation of the truncated_normal_initializer for + initializing all weight matrices. + norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values + attentions (not used by all models). Only relevant if + `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 3): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + partial_rotary_factor (`float`, *optional*, defaults to 0.5): + Percentage of the query and keys which will have rotary embedding. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output + projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj and down_proj layers in the MLP + layers. + + ```python + >>> from transformers import NemotronModel, NemotronConfig + >>> # Initializing a Nemotron nemotron-15b style configuration + >>> configuration = NemotronConfig() + >>> # Initializing a model from the nemotron-15b style configuration + >>> model = NemotronModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "nemotron" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=256000, + hidden_size=6144, + intermediate_size=24576, + num_hidden_layers=32, + num_attention_heads=48, + head_dim=None, + num_key_value_heads=None, + hidden_act="relu2", + max_position_embeddings=4096, + initializer_range=0.0134, + norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=2, + eos_token_id=3, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + partial_rotary_factor=0.5, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + head_dim = head_dim or kwargs.get("kv_channels") + self.head_dim = head_dim if head_dim is not None else ( + hidden_size // num_attention_heads) + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.norm_eps = norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + # for backward compatibility + partial_rotary_factor = kwargs.get("rope_percent") or kwargs.get( + "rope_percentage") or partial_rotary_factor + self.partial_rotary_factor = partial_rotary_factor + self._rope_scaling_validation() + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len( + self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with two fields, " + f"`type` and `factor`, got {self.rope_scaling}") + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in [ + "linear", "dynamic" + ]: + raise ValueError( + "`rope_scaling`'s type field must be one of ['linear', " + f"'dynamic'], got {rope_scaling_type}") + if rope_scaling_factor is None or not isinstance( + rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError( + "`rope_scaling`'s factor field must be a float > 1, got " + f"{rope_scaling_factor}") diff --git a/vllm/transformers_utils/configs/nvlm_d.py b/vllm/transformers_utils/configs/nvlm_d.py new file mode 100644 index 0000000..300f6e2 --- /dev/null +++ b/vllm/transformers_utils/configs/nvlm_d.py @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://huggingface.co/nvidia/NVLM-D-72B/blob/main/configuration_nvlm_d.py +# -------------------------------------------------------- +# NVLM-D +# Copyright (c) 2024 NVIDIA +# Licensed under Apache 2.0 License [see LICENSE for details] +# -------------------------------------------------------- +from .internvl import InternVLChatConfig + + +class NVLM_D_Config(InternVLChatConfig): + model_type = 'NVLM_D' diff --git a/vllm/transformers_utils/configs/olmo2.py b/vllm/transformers_utils/configs/olmo2.py new file mode 100644 index 0000000..c6e4463 --- /dev/null +++ b/vllm/transformers_utils/configs/olmo2.py @@ -0,0 +1,168 @@ +# SPDX-License-Identifier: Apache-2.0 + +# yapf: disable +# ruff: noqa: E501 +# coding=utf-8 +# Copied from +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmo2/configuration_olmo2.py +"""OLMo 2 configuration.""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class Olmo2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Olmo2Model`]. It is used to instantiate an OLMo2 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the [allenai/Olmo2-7B-1124-hf](https://huggingface.co/allenai/Olmo2-7B-1124-hf). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50304): + Vocabulary size of the Olmo2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Olmo2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 1): + Padding token id. + bos_token_id (`int`, *optional*): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 50279): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + + ```python + >>> from transformers import Olmo2Model, Olmo2Config + + >>> # Initializing a Olmo2 7B style configuration + >>> configuration = Olmo2Config() + + >>> # Initializing a model from the Olmo2 7B style configuration + >>> model = Olmo2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "olmo2" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=50304, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + use_cache=True, + pad_token_id=1, + bos_token_id=None, + eos_token_id=50279, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + rms_norm_eps=1e-5, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + self.rms_norm_eps = rms_norm_eps + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") diff --git a/vllm/transformers_utils/configs/skyworkr1v.py b/vllm/transformers_utils/configs/skyworkr1v.py new file mode 100644 index 0000000..ef5f9ba --- /dev/null +++ b/vllm/transformers_utils/configs/skyworkr1v.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/configuration_skywork_chat.py +# -------------------------------------------------------- +# SkyworkR1V +# Copyright (c) 2025 Skywork +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- +from transformers.configuration_utils import PretrainedConfig + + +class SkyworkR1VChatConfig(PretrainedConfig): + model_type = 'internvl_chat' + is_composition = True + + def __init__(self, + vision_config=None, + llm_config=None, + use_backbone_lora=0, + use_llm_lora=0, + select_layer=-1, + force_image_size=None, + downsample_ratio=0.5, + template=None, + dynamic_image_size=False, + use_thumbnail=False, + ps_version='v1', + min_dynamic_patch=1, + max_dynamic_patch=6, + **kwargs): + super().__init__(**kwargs) + + if vision_config is None: + vision_config = {} + + if llm_config is None: + llm_config = {} + + self.vision_config = PretrainedConfig(**vision_config) + self.text_config = PretrainedConfig(**llm_config) + + self.use_backbone_lora = use_backbone_lora + self.use_llm_lora = use_llm_lora + self.select_layer = select_layer + self.force_image_size = force_image_size + self.downsample_ratio = downsample_ratio + self.template = template + self.dynamic_image_size = dynamic_image_size + self.use_thumbnail = use_thumbnail + self.ps_version = ps_version # pixel shuffle version + self.min_dynamic_patch = min_dynamic_patch + self.max_dynamic_patch = max_dynamic_patch diff --git a/vllm/transformers_utils/configs/solar.py b/vllm/transformers_utils/configs/solar.py new file mode 100644 index 0000000..0d5db89 --- /dev/null +++ b/vllm/transformers_utils/configs/solar.py @@ -0,0 +1,246 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Solar model configuration""" + +from transformers import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class SolarConfig(PretrainedConfig): + r""" + This is the configuration class to store + the configuration of a [`SolarModel`]. + It is used to instantiate an LLaMA model + according to the specified arguments, + defining the model architecture. + Instantiating a configuration with the + defaults will yield a similar + configuration to that of the LLaMA-7B. + Configuration objects inherit from [`PretrainedConfig`] + and can be used to control the model outputs. + Read the documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the LLaMA model. + Defines the number of different tokens + that can be represented by the `inputs_ids` + passed when calling [`SolarModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer + in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that + should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, + the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model + will use Multi Query Attention (MQA) + otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, + each group key and value head should be constructed + by meanpooling all the original heads within that group. + For more details checkout [this paper] + (https://arxiv.org/pdf/2305.13245.pdf). + If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) + in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + Solar 1 supports up to 2048 tokens, + Solar 2 up to 4096, CodeSolar up to 16384. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of + the truncated_normal_initializer for initializing + all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return + the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank + used during pretraining. + Please refer to [this + document](https://huggingface.co/docs/ + transformers/main/ + perf_train_gpu_many#tensor-parallelism) + to understand more about it. This value is + necessary to ensure exact reproducibility + of the pretraining results. + Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for + the RoPE embeddings. + Currently supports two scaling + strategies: linear and dynamic. + Their scaling factor must be a float greater than 1. + The expected format is + `{"type": strategy name, "factor": scaling factor}`. + When using this flag, don't update + `max_position_embeddings` to the expected new maximum. + See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/ + dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking + API changes in future versions. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value + and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj + layers in the MLP layers. + sliding_window (`int`, *optional*, defaults to 2047): + Sliding window attention window size. If not specified, + will default to `2047`. + ```python + >>> from transformers import SolarModel, SolarConfig + >>> # Initializing a Solar-pro style configuration + >>> configuration = SolarConfig() + >>> # Initializing a model from the Solar-pro style configuration + >>> model = SolarModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "solar" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + sliding_window=2047, + bskcn_1=None, + bskcn_2=None, + bskcn_3=None, + bskcn_4=None, + bskcn_tv=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + self.sliding_window = sliding_window + self.bskcn_1 = bskcn_1 if bskcn_1 is not None else [12, 20, 32, 44] + self.bskcn_2 = bskcn_2 if bskcn_2 is not None else [20, 32] + self.bskcn_3 = bskcn_3 if bskcn_3 is not None else [16, 24, 36, 48] + self.bskcn_4 = bskcn_4 if bskcn_4 is not None else [28, 40] + self.bskcn_tv = bskcn_tv if bskcn_tv is not None else [0.9, 0.8] + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if (not isinstance(self.rope_scaling, dict) + or len(self.rope_scaling) != 2): + raise ValueError( + "`rope_scaling` must be a dictionary with two fields," + " `type` and `factor`, " + f"got {self.rope_scaling}") + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in [ + "linear", + "dynamic", + ]: + raise ValueError(f"`rope_scaling`'s type field must be one of " + f"['linear', 'dynamic'], got {rope_scaling_type}") + if (rope_scaling_factor is None + or not isinstance(rope_scaling_factor, float) + or rope_scaling_factor <= 1.0): + raise ValueError( + f"`rope_scaling`'s factor field must be a float > 1," + f" got {rope_scaling_factor}") diff --git a/vllm/transformers_utils/configs/telechat2.py b/vllm/transformers_utils/configs/telechat2.py new file mode 100644 index 0000000..5da6c5b --- /dev/null +++ b/vllm/transformers_utils/configs/telechat2.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 + +# adapted from https://www.modelscope.cn/models/TeleAI/TeleChat2-3B/resolve/master/configuration_telechat2.py +""" Telechat configuration compatible with LlamaConfig. """ + +from transformers.configuration_utils import PretrainedConfig + + +class Telechat2Config(PretrainedConfig): + + model_type = "telechat" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "num_hidden_layers": "n_layer", + "num_attention_heads": "n_head", + "intermediate_size": "ffn_hidden_size", + "rms_norm_eps": "layer_norm_epsilon" + } + + def __init__( + self, + vocab_size=160256, + hidden_size=4096, + n_layer=30, + n_head=32, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + use_cache=True, + bos_token_id=1, + eos_token_id=2, + apply_residual_connection_post_layernorm=False, + hidden_dropout=0.0, + attention_dropout=0.0, + ffn_hidden_size=12288, + training_seqlen=8192, + logn=True, + embed_layernorm=False, + hidden_act="silu", + **kwargs, + ): + self.vocab_size = vocab_size + n_embed = kwargs.pop("n_embed", None) + self.hidden_size = hidden_size if n_embed is None else n_embed + self.n_layer = n_layer + self.n_head = n_head + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.use_cache = use_cache + self.apply_residual_connection_post_layernorm = ( + apply_residual_connection_post_layernorm) + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.logn = logn + self.training_seqlen = training_seqlen + self.embed_layernorm = embed_layernorm + self.num_key_value_heads = kwargs.pop("num_key_value_heads", None) + self.ffn_hidden_size = ffn_hidden_size + self.hidden_act = hidden_act + super().__init__(bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs) diff --git a/vllm/transformers_utils/configs/ultravox.py b/vllm/transformers_utils/configs/ultravox.py new file mode 100644 index 0000000..6b2765d --- /dev/null +++ b/vllm/transformers_utils/configs/ultravox.py @@ -0,0 +1,107 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_config.py +from typing import Any, Dict, Optional + +import transformers + + +class UltravoxConfig(transformers.PretrainedConfig): + r""" + This is the configuration class to store the configuration of a + [`UltravoxForConditionalGeneration`]. It is used to instantiate an + Ultravox model according to the specified arguments, defining the model + architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to + control the model outputs. Read the documentation from [`PretrainedConfig`] + for more information. + + Args: + audio_config (`Union[AutoConfig, dict]`, *optional*): + Custom audio config or dict + text_config (`Union[AutoConfig, dict]`, *optional*): + The config object of the text backbone. Can be any of `LlamaConfig` + or `MistralConfig`. + ignore_index (`int`, *optional*, defaults to -100): + The ignore index for the loss function. + audio_token_index (`int`, *optional*, defaults to 32000): + The audio token index to encode the audio prompt. + stack_factor (`int`, *optional*, defaults to 8): + Audio downsampling factor for the multimodal projector. + norm_init (`float`, *optional*, defaults to 0.4): + The initialization value for the layer normalization. + projector_act (`str`, *optional*, defaults to `"swiglu"`): + The activation function used by the multimodal projector. + text_model_lora_config (`LoraConfigSimplified`, *optional*): + The LoRA configuration for finetuning the text model. + audio_model_lora_config (`LoraConfigSimplified`, *optional*): + The LoRA configuration for finetuning the audio model. + projector_ln_mid (`bool`, *optional*, defaults to `False`): + Whether to apply layer normalization at the middle of the + projector or at the end. Versions v0.4.1 and below + use `False`, but v0.5 and above use `True`. + """ + + model_type = "ultravox" + is_composition = False + + def __init__( + self, + audio_config: Optional[Dict[str, Any]] = None, + text_config: Optional[Dict[str, Any]] = None, + audio_model_id: Optional[str] = None, + text_model_id: Optional[str] = None, + ignore_index: int = -100, + audio_token_index: int = 32000, + hidden_size: int = 4096, + stack_factor: int = 8, + norm_init: float = 0.4, + projector_act: str = "swiglu", + text_model_lora_config: Optional[Dict[str, Any]] = None, + audio_model_lora_config: Optional[Dict[str, Any]] = None, + projector_ln_mid: bool = False, + **kwargs, + ): + self.ignore_index = ignore_index + + self.audio_model_id = audio_model_id + self.text_model_id = text_model_id + self.audio_token_index = audio_token_index + + self.hidden_size = hidden_size + self.stack_factor = stack_factor + self.norm_init = norm_init + self.projector_act = projector_act + self.projector_ln_mid = projector_ln_mid + + if text_model_id is not None: + # Avoid circular import + from vllm.transformers_utils.config import get_config + + self.text_config = get_config(text_model_id, + trust_remote_code=False) + else: + text_config = text_config or {} + self.text_config = transformers.CONFIG_MAPPING[text_config.get( + "model_type", "llama")](**text_config) + + if audio_model_id is not None: + # Avoid circular import + from vllm.transformers_utils.config import get_config + + self.audio_config = get_config(audio_model_id, + trust_remote_code=False) + else: + audio_config = audio_config or {} + self.audio_config = transformers.CONFIG_MAPPING[audio_config.get( + "model_type", "whisper")](**audio_config) + + self.text_model_lora_config = text_model_lora_config or {} + self.audio_model_lora_config = audio_model_lora_config or {} + + self.vocab_size = self.text_config.vocab_size + + self.initializer_range = self.text_config.initializer_range + + super().__init__(**kwargs) diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py new file mode 100644 index 0000000..9d1d4bb --- /dev/null +++ b/vllm/transformers_utils/detokenizer.py @@ -0,0 +1,167 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, List, Optional + +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams, + Sequence, SequenceGroup) + +from .detokenizer_utils import (convert_prompt_ids_to_tokens, + detokenize_incrementally) +from .tokenizer import AnyTokenizer +from .tokenizer_group import BaseTokenizerGroup + + +class Detokenizer: + """Provides methods to decode the output of a model into text.""" + + def __init__(self, tokenizer_group: BaseTokenizerGroup): + self.tokenizer_group = tokenizer_group + + def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer: + """Returns the HF tokenizer to use for a given sequence.""" + return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request) + + def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup, + prompt_logprobs: List[Optional[Dict[ + int, Logprob]]], + position_offset: int) -> None: + """Decodes the logprobs for the prompt of a sequence group. + + Args: + seq_group: The sequence group to decode. + prompt_logprobs: The logprobs to decode. + position_offset: Offset of the first index of the logprobs + relative to the start of the sequence (for chunked prefill). + + Returns: + The prompt logprobs with the decoded tokens. + """ + prms = seq_group.sampling_params + assert prms is not None + + # We can pick any sequence for the prompt. + seq = seq_group.get_seqs()[0] + # Only prompt, without the generated token. + all_token_ids = seq.get_token_ids() + prompt_token_ids = all_token_ids[:-1] + tokenizer = self.get_tokenizer_for_seq(seq) + prefix_offset = 0 + read_offset = 0 + next_iter_prefix_offset = 0 + next_iter_read_offset = 0 + next_iter_tokens: List[str] = [] + prev_tokens = None + + for token_position_in_logprob, prompt_logprobs_for_token in enumerate( + prompt_logprobs): + + # Absolute token position equals the index in the logprobs + # list plus the offset of the entire logprobs list relative + # to the start of the sequence. + token_position = token_position_in_logprob + position_offset + if not prompt_logprobs_for_token: + continue + for token_id, sample_logprob in prompt_logprobs_for_token.items(): + if (sample_logprob.decoded_token is None + and token_id != VLLM_INVALID_TOKEN_ID): + prompt_token_ids_with_token = ( + prompt_token_ids[:token_position] + [token_id]) + (new_tokens, new_text, new_prefix_offset, + new_read_offset) = detokenize_incrementally( + tokenizer=tokenizer, + all_input_ids=prompt_token_ids_with_token, + prev_tokens=prev_tokens, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=prms.skip_special_tokens, + spaces_between_special_tokens=prms. + spaces_between_special_tokens, + ) + + sample_logprob.decoded_token = new_text + + # Use the offsets & prev tokens corresponding to + # real tokens to ensure detokenization is consistent + # actual with prompt. + if token_id == all_token_ids[token_position]: + next_iter_prefix_offset = new_prefix_offset + next_iter_read_offset = new_read_offset + next_iter_tokens = new_tokens + + # Advance to the next token position. + prefix_offset = next_iter_prefix_offset + read_offset = next_iter_read_offset + if prev_tokens is None: + prev_tokens = next_iter_tokens.copy() + else: + prev_tokens.extend(next_iter_tokens) + + def decode_sequence_inplace(self, seq: Sequence, + prms: SamplingParams) -> int: + """Decodes the new token for a sequence. In-place operation. + + Args: + seq: The sequence to decode. + prms: The sampling parameters used to generate the sequence. + + Returns: + The number of characters added to the output text. + """ + all_input_ids = seq.get_token_ids() + token_id_generated_this_iteration = all_input_ids[-1] + tokenizer = self.get_tokenizer_for_seq(seq) + + # Convert prompt token IDs to tokens if necessary. + # Do it here so that we don't have to repeat this + # computation for each logprob. + if seq.tokens is None: + (seq.tokens, seq.prefix_offset, + seq.read_offset) = convert_prompt_ids_to_tokens( + tokenizer=tokenizer, + prompt_ids=all_input_ids[:-1], + skip_special_tokens=prms.skip_special_tokens, + ) + + (new_tokens, new_decoded_token_text, prefix_offset, + read_offset) = detokenize_incrementally( + tokenizer=tokenizer, + all_input_ids=all_input_ids, + prev_tokens=seq.tokens, + prefix_offset=seq.prefix_offset, + read_offset=seq.read_offset, + skip_special_tokens=prms.skip_special_tokens, + spaces_between_special_tokens=prms.spaces_between_special_tokens, + ) + + # Decode logprobs + logprobs = seq.output_logprobs[-1] + if logprobs: + previous_tokens = all_input_ids[:-1] + for token_id, sample_logprob in logprobs.items(): + # If the token was generated this iteration, + # use the provided text. + if token_id == token_id_generated_this_iteration: + sample_logprob.decoded_token = new_decoded_token_text + continue + + if (sample_logprob.decoded_token is None + and token_id != VLLM_INVALID_TOKEN_ID): + all_input_ids_with_logprob = previous_tokens + [token_id] + (_, new_text, _, _) = detokenize_incrementally( + tokenizer=tokenizer, + all_input_ids=all_input_ids_with_logprob, + prev_tokens=seq.tokens, + prefix_offset=seq.prefix_offset, + read_offset=seq.read_offset, + skip_special_tokens=prms.skip_special_tokens, + spaces_between_special_tokens=prms. + spaces_between_special_tokens, + ) + sample_logprob.decoded_token = new_text + + seq.tokens.extend(new_tokens) + seq.prefix_offset = prefix_offset + seq.read_offset = read_offset + seq.output_text += new_decoded_token_text + + return len(new_decoded_token_text) diff --git a/vllm/transformers_utils/detokenizer_utils.py b/vllm/transformers_utils/detokenizer_utils.py new file mode 100644 index 0000000..a1fa277 --- /dev/null +++ b/vllm/transformers_utils/detokenizer_utils.py @@ -0,0 +1,188 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional, Tuple + +from .tokenizer import AnyTokenizer + + +def _replace_none_with_empty(tokens: List[Optional[str]]): + for i, token in enumerate(tokens): + if token is None: + tokens[i] = "" + + +def _convert_tokens_to_string_with_added_encoders( + tokenizer: AnyTokenizer, + output_tokens: List[str], + skip_special_tokens: bool, + spaces_between_special_tokens: bool, +) -> str: + # Adapted from + # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921 + # NOTE(woosuk): The following code is slow because it runs a for loop over + # the output_tokens. In Python, running a for loop over a list can be slow + # even when the loop body is very simple. + sub_texts: List[str] = [] + current_sub_text: List[str] = [] + all_special_tokens = set(tokenizer.all_special_tokens) + for token in output_tokens: + if skip_special_tokens and token in all_special_tokens: + continue + if token in tokenizer.get_added_vocab(): + if current_sub_text: + sub_text = tokenizer.convert_tokens_to_string(current_sub_text) + sub_texts.append(sub_text) + current_sub_text = [] + sub_texts.append(token) + else: + current_sub_text.append(token) + if current_sub_text: + sub_text = tokenizer.convert_tokens_to_string(current_sub_text) + sub_texts.append(sub_text) + if spaces_between_special_tokens: + return " ".join(sub_texts) + else: + return "".join(sub_texts) + + +# 5 is an arbitrary value that should work for all +# tokenizers (bigger = more conservative). +INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5 + + +def convert_prompt_ids_to_tokens( + tokenizer: AnyTokenizer, + prompt_ids: List[int], + skip_special_tokens: bool = False, +) -> Tuple[List[str], int, int]: + """Converts the prompt ids to tokens and returns the tokens and offsets + for incremental detokenization. + + Note that not all tokens are converted to strings. Only the tokens that + are necessary for incremental detokenization are converted to strings. + """ + # We do not need to convert the whole prompt to tokens. + # Offset a little more in case we have special tokens. + new_tokens = tokenizer.convert_ids_to_tokens( + prompt_ids[-INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2:], + skip_special_tokens=skip_special_tokens) + read_offset = len(new_tokens) + prefix_offset = max( + read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0) + # This is required to guard against out-of-vocab prompt token ids + _replace_none_with_empty(new_tokens) # type: ignore[arg-type] + return new_tokens, prefix_offset, read_offset + + +def convert_ids_list_to_tokens( + tokenizer: AnyTokenizer, + token_ids: List[int], +) -> List[str]: + """Detokenize the input ids individually. + + Args: + tokenizer: tokenizer used by model under test + token_ids: convert these tokens (Python list form) + + Returns: + Python list of token string representations + + """ + token_str_lst = tokenizer.convert_ids_to_tokens(token_ids) + _replace_none_with_empty(token_str_lst) # type: ignore + return token_str_lst + + +# Based on +# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15 +# under Apache 2.0 license +def detokenize_incrementally( + tokenizer: AnyTokenizer, + all_input_ids: List[int], + prev_tokens: Optional[List[str]], + prefix_offset: int, + read_offset: int, + skip_special_tokens: bool = False, + spaces_between_special_tokens: bool = True, +) -> Tuple[List[str], str, int, int]: + """Detokenizes the input ids incrementally and returns the new tokens + and the new text. + + If `prev_tokens` is None, this function will convert the input ids to + tokens and return the tokens and the new text. Otherwise, it will return the + new tokens and the new text. + + This function will also return the new prefix offset and the new read + offset to be used in the next iteration. + + The offsets are necessary to defeat cleanup algorithms in the decode which + decide to add a space or not depending on the surrounding ids. + + Args: + tokenizer: The tokenizer to use. + all_input_ids: The input ids. The last id is the new token id. + prev_tokens: The previous tokens. If None, this function will convert + the input ids to tokens and return the tokens and the new text. + prefix_offset: The prefix offset. + read_offset: The read offset. + skip_special_tokens: Whether to skip special tokens. + spaces_between_special_tokens: Whether to add spaces between special + tokens. + """ + new_token_id = all_input_ids[-1] + # This is the first iteration for this sequence + is_first_iter = prev_tokens is None + if is_first_iter: + (prev_tokens, prefix_offset, + read_offset) = convert_prompt_ids_to_tokens( + tokenizer, + all_input_ids[:-1], + skip_special_tokens=skip_special_tokens) + assert prev_tokens is not None + + # If the new token id is out of bounds, return an empty string. + if 0 <= new_token_id < len(tokenizer): + # Put new_token_id in a list so skip_special_tokens is respected + new_tokens = tokenizer.convert_ids_to_tokens( + [new_token_id], skip_special_tokens=skip_special_tokens) + if isinstance(new_tokens, str): + new_tokens = [new_tokens] + else: + new_tokens = [""] + output_tokens = prev_tokens + new_tokens + + # If this is the first iteration, return all tokens. + if is_first_iter: + new_tokens = output_tokens + + # The prefix text is necessary only to defeat cleanup algorithms in + # the decode which decide to add a space or not depending on the + # surrounding ids. + if tokenizer.is_fast or not tokenizer.get_added_vocab(): + prefix_text = tokenizer.convert_tokens_to_string( + output_tokens[prefix_offset:read_offset]) + new_text = tokenizer.convert_tokens_to_string( + output_tokens[prefix_offset:]) + else: + prefix_text = _convert_tokens_to_string_with_added_encoders( + tokenizer, + output_tokens[prefix_offset:read_offset], + skip_special_tokens=skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + ) + new_text = _convert_tokens_to_string_with_added_encoders( + tokenizer, + output_tokens[prefix_offset:], + skip_special_tokens=skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + ) + + if len(new_text) <= len(prefix_text) or new_text.endswith("�"): + # utf-8 char at the end means it's a potential unfinished byte sequence + # from byte fallback tokenization. + # If it's in the middle, it's probably a real invalid id generated + # by the model + return new_tokens, "", prefix_offset, read_offset + + new_text = new_text[len(prefix_text):] + return new_tokens, new_text, read_offset, len(output_tokens) diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py new file mode 100644 index 0000000..1d09b99 --- /dev/null +++ b/vllm/transformers_utils/processor.py @@ -0,0 +1,196 @@ +# SPDX-License-Identifier: Apache-2.0 + +from functools import lru_cache +from typing import TYPE_CHECKING, Any, Union, cast + +from transformers.processing_utils import ProcessorMixin +from typing_extensions import TypeVar + +if TYPE_CHECKING: + from vllm.config import ModelConfig + +_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin) + + +class HashableDict(dict): + """ + A dictionary that can be hashed by lru_cache. + """ + + # NOTE: pythonic dict is not hashable, + # we override on it directly for simplicity + def __hash__(self) -> int: # type: ignore[override] + return hash(frozenset(self.items())) + + +class HashableList(list): + """ + A list that can be hashed by lru_cache. + """ + + def __hash__(self) -> int: # type: ignore[override] + return hash(tuple(self)) + + +def _merge_mm_kwargs(model_config: "ModelConfig", **kwargs): + base_kwargs = model_config.mm_processor_kwargs + if base_kwargs is None: + base_kwargs = {} + + merged_kwargs = {**base_kwargs, **kwargs} + + # NOTE: Pythonic dict is not hashable and will raise unhashable type + # error when calling `cached_get_processor`, therefore we need to + # wrap it to a hashable dict. + for key, value in merged_kwargs.items(): + if isinstance(value, dict): + merged_kwargs[key] = HashableDict(value) + if isinstance(value, list): + merged_kwargs[key] = HashableList(value) + return merged_kwargs + + +def get_processor( + processor_name: str, + *args: Any, + trust_remote_code: bool = False, + processor_cls: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin, + **kwargs: Any, +) -> _P: + """Load a processor for the given model name via HuggingFace.""" + # don't put this import at the top level + # it will call torch.cuda.device_count() + from transformers import AutoProcessor + + processor_factory = (AutoProcessor if processor_cls == ProcessorMixin or + isinstance(processor_cls, tuple) else processor_cls) + + try: + processor = processor_factory.from_pretrained( + processor_name, + *args, + trust_remote_code=trust_remote_code, + **kwargs, + ) + except ValueError as e: + # If the error pertains to the processor class not existing or not + # currently being imported, suggest using the --trust-remote-code flag. + # Unlike AutoTokenizer, AutoProcessor does not separate such errors + if not trust_remote_code: + err_msg = ( + "Failed to load the processor. If the processor is " + "a custom processor not yet available in the HuggingFace " + "transformers library, consider setting " + "`trust_remote_code=True` in LLM or using the " + "`--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e + + if not isinstance(processor, processor_cls): + raise TypeError("Invalid type of HuggingFace processor. " + f"Expected type: {processor_cls}, but " + f"found type: {type(processor)}") + + return processor + + +cached_get_processor = lru_cache(get_processor) + + +def cached_processor_from_config( + model_config: "ModelConfig", + processor_cls: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin, + **kwargs: Any, +) -> _P: + return cached_get_processor( + model_config.model, + trust_remote_code=model_config.trust_remote_code, + processor_cls=processor_cls, # type: ignore[arg-type] + **_merge_mm_kwargs(model_config, **kwargs), + ) + + +def get_image_processor( + processor_name: str, + *args: Any, + trust_remote_code: bool = False, + **kwargs: Any, +): + """Load an image processor for the given model name via HuggingFace.""" + # don't put this import at the top level + # it will call torch.cuda.device_count() + from transformers import AutoImageProcessor + from transformers.image_processing_utils import BaseImageProcessor + + try: + processor = AutoImageProcessor.from_pretrained( + processor_name, + *args, + trust_remote_code=trust_remote_code, + **kwargs) + except ValueError as e: + # If the error pertains to the processor class not existing or not + # currently being imported, suggest using the --trust-remote-code flag. + # Unlike AutoTokenizer, AutoImageProcessor does not separate such errors + if not trust_remote_code: + err_msg = ( + "Failed to load the image processor. If the image processor is " + "a custom processor not yet available in the HuggingFace " + "transformers library, consider setting " + "`trust_remote_code=True` in LLM or using the " + "`--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e + + return cast(BaseImageProcessor, processor) + + +cached_get_image_processor = lru_cache(get_image_processor) + + +def cached_image_processor_from_config( + model_config: "ModelConfig", + **kwargs: Any, +): + return cached_get_image_processor( + model_config.model, + trust_remote_code=model_config.trust_remote_code, + **_merge_mm_kwargs(model_config, **kwargs), + ) + + +def get_video_processor( + processor_name: str, + *args: Any, + trust_remote_code: bool = False, + **kwargs: Any, +): + """Load a video processor for the given model name via HuggingFace.""" + # don't put this import at the top level + # it will call torch.cuda.device_count() + from transformers.image_processing_utils import BaseImageProcessor + + processor = get_processor( + processor_name, + *args, + trust_remote_code=trust_remote_code, + **kwargs, + ) + + return cast(BaseImageProcessor, processor.video_processor) + + +cached_get_video_processor = lru_cache(get_video_processor) + + +def cached_video_processor_from_config( + model_config: "ModelConfig", + **kwargs: Any, +): + return cached_get_video_processor( + model_config.model, + trust_remote_code=model_config.trust_remote_code, + **_merge_mm_kwargs(model_config, **kwargs), + ) diff --git a/vllm/transformers_utils/processors/__init__.py b/vllm/transformers_utils/processors/__init__.py new file mode 100644 index 0000000..4696f0c --- /dev/null +++ b/vllm/transformers_utils/processors/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm.transformers_utils.processors.deepseek_vl2 import ( + DeepseekVLV2Processor) + +__all__ = ["DeepseekVLV2Processor"] diff --git a/vllm/transformers_utils/processors/deepseek_vl2.py b/vllm/transformers_utils/processors/deepseek_vl2.py new file mode 100644 index 0000000..316281f --- /dev/null +++ b/vllm/transformers_utils/processors/deepseek_vl2.py @@ -0,0 +1,363 @@ +# SPDX-License-Identifier: Apache-2.0 + +# yapf: disable +# ruff: noqa: E501 +# coding=utf-8 +# adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/ff23960c5cf9e6874b44be38af930cfb0ccbb620/deepseek_vl2/models/processing_deepseek_vl_v2.py +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import math +from typing import List, Tuple + +import torch +import torchvision.transforms as T +from PIL import Image, ImageOps +from transformers import AutoProcessor, BatchFeature, LlamaTokenizerFast +from transformers.processing_utils import ProcessorMixin + + +class ImageTransform: + + def __init__(self, + mean: Tuple[float, float, float] = (0.5, 0.5, 0.5), + std: Tuple[float, float, float] = (0.5, 0.5, 0.5), + normalize: bool = True): + self.mean = mean + self.std = std + self.normalize = normalize + + transform_pipelines = [T.ToTensor()] + + if normalize: + transform_pipelines.append(T.Normalize(mean, std)) + + self.transform = T.Compose(transform_pipelines) + + def __call__(self, pil_img: Image.Image): + x = self.transform(pil_img) + return x + + +class DeepseekVLV2Processor(ProcessorMixin): + tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") + attributes = ["tokenizer"] + + def __init__( + self, + tokenizer: LlamaTokenizerFast, + candidate_resolutions: Tuple[Tuple[int, int]], + patch_size: int, + downsample_ratio: int, + image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5), + image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5), + normalize: bool = True, + image_token: str = "", + pad_token: str = "<|▁pad▁|>", + add_special_token: bool = False, + sft_format: str = "deepseek", + mask_prompt: bool = True, + ignore_id: int = -100, + **kwargs, + ): + + self.candidate_resolutions = candidate_resolutions + self.image_size = candidate_resolutions[0][0] + self.patch_size = patch_size + self.image_mean = image_mean + self.image_std = image_std + self.normalize = normalize + self.downsample_ratio = downsample_ratio + + self.image_transform = ImageTransform(mean=image_mean, std=image_std, normalize=normalize) + self.tokenizer = tokenizer + self.tokenizer.padding_side = 'left' # must set this,padding side with make a difference in batch inference + + # add the pad_token as special token to use 'tokenizer.pad_token' and 'tokenizer.pad_token_id' + if tokenizer.pad_token is None: + self.tokenizer.add_special_tokens({'pad_token': pad_token}) + + # add image token + image_token_id = self.tokenizer.vocab.get(image_token) + if image_token_id is None: + special_tokens = [image_token] + special_tokens_dict = {"additional_special_tokens": special_tokens} + self.tokenizer.add_special_tokens(special_tokens_dict) + self.image_token_id = self.tokenizer.vocab.get(image_token) + + # add five special tokens for grounding-related tasks + # <|ref|>, <|/ref|>, <|det|>, <|/det|>, <|grounding|> + special_tokens = ['<|ref|>', '<|/ref|>', '<|det|>', '<|/det|>', '<|grounding|>'] + special_tokens_dict = {"additional_special_tokens": special_tokens} + self.tokenizer.add_special_tokens(special_tokens_dict) + + # add special tokens for SFT data + special_tokens = ["<|User|>", "<|Assistant|>"] + special_tokens_dict = {"additional_special_tokens": special_tokens} + self.tokenizer.add_special_tokens(special_tokens_dict) + + self.image_token = image_token + self.pad_token = pad_token + self.add_special_token = add_special_token + self.sft_format = sft_format + self.mask_prompt = mask_prompt + self.ignore_id = ignore_id + + super().__init__( + tokenizer, + **kwargs, + ) + + def select_best_resolution(self, image_size): + # used for cropping + original_width, original_height = image_size + best_fit = None + max_effective_resolution = 0 + min_wasted_resolution = float("inf") + + for width, height in self.candidate_resolutions: + scale = min(width / original_width, height / original_height) + downscaled_width, downscaled_height = int( + original_width * scale), int(original_height * scale) + effective_resolution = min(downscaled_width * downscaled_height, + original_width * original_height) + wasted_resolution = (width * height) - effective_resolution + + if effective_resolution > max_effective_resolution or ( + effective_resolution == max_effective_resolution + and wasted_resolution < min_wasted_resolution): + max_effective_resolution = effective_resolution + min_wasted_resolution = wasted_resolution + best_fit = (width, height) + + return best_fit + + @property + def bos_id(self): + return self.tokenizer.bos_token_id + + @property + def eos_id(self): + return self.tokenizer.eos_token_id + + @property + def pad_id(self): + return self.tokenizer.pad_token_id + + def encode(self, text: str, bos: bool = True, eos: bool = False): + t = self.tokenizer.encode(text, add_special_tokens=False) + + if bos: + t = [self.bos_id] + t + if eos: + t = t + [self.eos_id] + + return t + + def decode(self, t: List[int], **kwargs) -> str: + return self.tokenizer.decode(t, **kwargs) + + def process_one( + self, + prompt: str, + images: List[Image.Image], + inference_mode: bool = True, + **kwargs, + ): + """ + + Args: + prompt (str): the formatted prompt; + conversations (List[Dict]): conversations with a list of messages; + images (List[ImageType]): the list of images; + inference_mode (bool): if True, then remove the last eos token; + system_prompt (str): the system prompt; + **kwargs: + + Returns: + outputs (BaseProcessorOutput): the output of the processor, + - input_ids (torch.LongTensor): [N + image tokens] + - target_ids (torch.LongTensor): [N + image tokens] + - pixel_values (torch.FloatTensor): [n_patches, 3, H, W] + - image_id (int): the id of the image token + - num_image_tokens (List[int]): the number of image tokens + """ + + assert (prompt is not None and images is not None + ), "prompt and images must be used at the same time." + + sft_format = prompt + tokenized_str, images_list, images_seq_mask, images_spatial_crop, num_image_tokens = self.tokenize_with_images( + sft_format, images, bos=True, eos=True, cropping=len(images) <= 2) + masked_tokenized_str = [] + for token_index in tokenized_str: + if token_index != self.image_token_id: + masked_tokenized_str.append(token_index) + else: + masked_tokenized_str.append(self.ignore_id) + + assert len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str), \ + (f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, " + f"imags_seq_mask's length {len(images_seq_mask)}, are not equal") + + input_ids = torch.LongTensor(tokenized_str) + target_ids = torch.LongTensor(masked_tokenized_str) + images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool) + + # set input_ids < 0 | input_ids == self.image_token_id as ignore_id + target_ids[(input_ids < 0) | + (input_ids == self.image_token_id)] = self.ignore_id + input_ids[input_ids < 0] = self.pad_id + + if inference_mode: + # Remove the ending eos token + assert input_ids[-1] == self.eos_id + input_ids = input_ids[:-1] + target_ids = target_ids[:-1] + images_seq_mask = images_seq_mask[:-1] + + if len(images_list) == 0: + pixel_values = torch.zeros((1, 3, self.image_size, self.image_size)) + images_spatial_crop = torch.zeros((1, 2), dtype=torch.long) + else: + pixel_values = torch.stack(images_list, dim=0) + images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long) + + input_ids = input_ids.unsqueeze(0) + + prepare = BatchFeature( + data=dict( + input_ids=input_ids, + pixel_values=pixel_values, + images_seq_mask=images_seq_mask, + images_spatial_crop=images_spatial_crop, + num_image_tokens=num_image_tokens, + ), + tensor_type="pt", + ) + return prepare + + def __call__( + self, + *, + prompt: str, + images: List[Image.Image], + inference_mode: bool = True, + **kwargs, + ): + """ + + Args: + prompt (str): the formatted prompt; + images (List[ImageType]): the list of images; + inference_mode (bool): if True, then remove the last eos token; + **kwargs: + + Returns: + outputs (BaseProcessorOutput): the output of the processor, + - input_ids (torch.LongTensor): [N + image tokens] + - images (torch.FloatTensor): [n_images, 3, H, W] + - image_id (int): the id of the image token + - num_image_tokens (List[int]): the number of image tokens + """ + + prepare = self.process_one( + prompt=prompt, + images=images, + inference_mode=inference_mode, + ) + + return prepare + + def tokenize_with_images( + self, + conversation: str, + images: List[Image.Image], + bos: bool = True, + eos: bool = True, + cropping: bool = True, + ): + """Tokenize text with tags.""" + assert conversation.count(self.image_token) == len(images) + text_splits = conversation.split(self.image_token) + images_list, images_seq_mask, images_spatial_crop = [], [], [] + num_image_tokens = [] + tokenized_str = [] + for text_sep, image in zip(text_splits, images): + """encode text_sep""" + tokenized_sep = self.encode(text_sep, bos=False, eos=False) + tokenized_str += tokenized_sep + images_seq_mask += [False] * len(tokenized_sep) + + """select best resolution for anyres""" + if cropping: + best_width, best_height = self.select_best_resolution(image.size) + else: + best_width, best_height = self.image_size, self.image_size + + """process the global view""" + global_view = ImageOps.pad(image, (self.image_size, self.image_size), + color=tuple(int(x * 255) for x in self.image_transform.mean)) + images_list.append(self.image_transform(global_view)) + + """process the local views""" + local_view = ImageOps.pad(image, (best_width, best_height), + color=tuple(int(x * 255) for x in self.image_transform.mean)) + for i in range(0, best_height, self.image_size): + for j in range(0, best_width, self.image_size): + images_list.append( + self.image_transform(local_view.crop((j, i, j + self.image_size, i + self.image_size)))) + + """record height / width crop num""" + num_width_tiles, num_height_tiles = best_width // self.image_size, best_height // self.image_size + images_spatial_crop.append([num_width_tiles, num_height_tiles]) + + """add image tokens""" + h = w = math.ceil((self.image_size // self.patch_size) / self.downsample_ratio) + # global views tokens h * (w + 1), 1 is for line separator + tokenized_image = [self.image_token_id] * h * (w + 1) + # add a separator between global and local views + tokenized_image += [self.image_token_id] + # local views tokens, (num_height_tiles * h) * (num_width_tiles * w + 1) + tokenized_image += [self.image_token_id] * (num_height_tiles * h) * (num_width_tiles * w + 1) + + tokenized_str += tokenized_image + images_seq_mask += [True] * len(tokenized_image) + num_image_tokens.append(len(tokenized_image)) + + """process the last text split""" + tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False) + tokenized_str += tokenized_sep + images_seq_mask += [False] * len(tokenized_sep) + + """add the bos and eos tokens""" + if bos: + tokenized_str = [self.bos_id] + tokenized_str + images_seq_mask = [False] + images_seq_mask + if eos: + tokenized_str = tokenized_str + [self.eos_id] + images_seq_mask = images_seq_mask + [False] + + assert len(tokenized_str) == len( + images_seq_mask), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}" + + return tokenized_str, images_list, images_seq_mask, images_spatial_crop, num_image_tokens + + +AutoProcessor.register("DeepseekVLV2Processor", DeepseekVLV2Processor) diff --git a/vllm/transformers_utils/s3_utils.py b/vllm/transformers_utils/s3_utils.py new file mode 100644 index 0000000..1c3520b --- /dev/null +++ b/vllm/transformers_utils/s3_utils.py @@ -0,0 +1,161 @@ +# SPDX-License-Identifier: Apache-2.0 + +import fnmatch +import os +import shutil +import signal +import tempfile +from pathlib import Path +from typing import Optional + +from vllm.utils import PlaceholderModule + +try: + import boto3 +except ImportError: + boto3 = PlaceholderModule("boto3") # type: ignore[assignment] + + +def _filter_allow(paths: list[str], patterns: list[str]) -> list[str]: + return [ + path for path in paths if any( + fnmatch.fnmatch(path, pattern) for pattern in patterns) + ] + + +def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]: + return [ + path for path in paths + if not any(fnmatch.fnmatch(path, pattern) for pattern in patterns) + ] + + +def glob(s3=None, + path: str = "", + allow_pattern: Optional[list[str]] = None) -> list[str]: + """ + List full file names from S3 path and filter by allow pattern. + + Args: + s3: S3 client to use. + path: The S3 path to list from. + allow_pattern: A list of patterns of which files to pull. + + Returns: + list[str]: List of full S3 paths allowed by the pattern + """ + if s3 is None: + s3 = boto3.client("s3") + if not path.endswith("/"): + path = path + "/" + bucket_name, _, paths = list_files(s3, + path=path, + allow_pattern=allow_pattern) + return [f"s3://{bucket_name}/{path}" for path in paths] + + +def list_files( + s3, + path: str, + allow_pattern: Optional[list[str]] = None, + ignore_pattern: Optional[list[str]] = None +) -> tuple[str, str, list[str]]: + """ + List files from S3 path and filter by pattern. + + Args: + s3: S3 client to use. + path: The S3 path to list from. + allow_pattern: A list of patterns of which files to pull. + ignore_pattern: A list of patterns of which files not to pull. + + Returns: + tuple[str, str, list[str]]: A tuple where: + - The first element is the bucket name + - The second element is string represent the bucket + and the prefix as a dir like string + - The third element is a list of files allowed or + disallowed by pattern + """ + parts = path.removeprefix('s3://').split('/') + prefix = '/'.join(parts[1:]) + bucket_name = parts[0] + + objects = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix) + paths = [obj['Key'] for obj in objects.get('Contents', [])] + + paths = _filter_ignore(paths, ["*/"]) + if allow_pattern is not None: + paths = _filter_allow(paths, allow_pattern) + + if ignore_pattern is not None: + paths = _filter_ignore(paths, ignore_pattern) + + return bucket_name, prefix, paths + + +class S3Model: + """ + A class representing a S3 model mirrored into a temporary directory. + + Attributes: + s3: S3 client. + dir: The temporary created directory. + + Methods: + pull_files(): Pull model from S3 to the temporary directory. + """ + + def __init__(self) -> None: + self.s3 = boto3.client('s3') + for sig in (signal.SIGINT, signal.SIGTERM): + existing_handler = signal.getsignal(sig) + signal.signal(sig, self._close_by_signal(existing_handler)) + + self.dir = tempfile.mkdtemp() + + def __del__(self): + self._close() + + def _close(self) -> None: + if os.path.exists(self.dir): + shutil.rmtree(self.dir) + + def _close_by_signal(self, existing_handler=None): + + def new_handler(signum, frame): + self._close() + if existing_handler: + existing_handler(signum, frame) + + return new_handler + + def pull_files(self, + s3_model_path: str = "", + allow_pattern: Optional[list[str]] = None, + ignore_pattern: Optional[list[str]] = None) -> None: + """ + Pull files from S3 storage into the temporary directory. + + Args: + s3_model_path: The S3 path of the model. + allow_pattern: A list of patterns of which files to pull. + ignore_pattern: A list of patterns of which files not to pull. + + """ + if not s3_model_path.endswith("/"): + s3_model_path = s3_model_path + "/" + + bucket_name, base_dir, files = list_files(self.s3, s3_model_path, + allow_pattern, + ignore_pattern) + if len(files) == 0: + return + + for file in files: + destination_file = os.path.join( + self.dir, + file.removeprefix(base_dir).lstrip("/")) + local_dir = Path(destination_file).parent + os.makedirs(local_dir, exist_ok=True) + self.s3.download_file(bucket_name, file, destination_file) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py new file mode 100644 index 0000000..1bfb503 --- /dev/null +++ b/vllm/transformers_utils/tokenizer.py @@ -0,0 +1,288 @@ +# SPDX-License-Identifier: Apache-2.0 + +import contextlib +import os +import warnings +from functools import lru_cache +from pathlib import Path +from types import MethodType +from typing import TYPE_CHECKING, Any, Optional, Union + +import huggingface_hub +from transformers import (AutoTokenizer, PreTrainedTokenizer, + PreTrainedTokenizerFast) + +from vllm.envs import VLLM_USE_MODELSCOPE +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.transformers_utils.tokenizer_base import (TokenizerBase, + TokenizerRegistry) +from vllm.transformers_utils.tokenizers import MistralTokenizer +from vllm.transformers_utils.utils import check_gguf_file +from vllm.utils import make_async + +if TYPE_CHECKING: + from vllm.config import ModelConfig + +logger = init_logger(__name__) + +AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast, + TokenizerBase] + + +def decode_tokens( + tokenizer: AnyTokenizer, + token_ids: list[int], + *, + skip_special_tokens: Optional[bool] = None, +) -> str: + """ + Backend-agnostic equivalent of HF's + :code:`tokenizer.decode(token_ids, ...)`. + + :code:`skip_special_tokens=None` means to use the backend's default + settings. + """ + if skip_special_tokens is not None: + return tokenizer.decode(token_ids, + skip_special_tokens=skip_special_tokens) + + return tokenizer.decode(token_ids) + + +def encode_tokens( + tokenizer: AnyTokenizer, + text: str, + *, + add_special_tokens: Optional[bool] = None, +) -> list[int]: + """ + Backend-agnostic equivalent of HF's + :code:`tokenizer.encode(text, ...)`. + + :code:`add_special_tokens=None` means to use the backend's default + settings. + """ + if add_special_tokens is not None: + return tokenizer.encode(text, add_special_tokens=add_special_tokens) + + return tokenizer.encode(text) + + +def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: + """Get tokenizer with cached properties. + + This will patch the tokenizer object in place. + + By default, transformers will recompute multiple tokenizer properties + each time they are called, leading to a significant slowdown. This + function caches these properties for faster access.""" + + tokenizer_all_special_ids = set(tokenizer.all_special_ids) + tokenizer_all_special_tokens_extended = ( + tokenizer.all_special_tokens_extended) + tokenizer_all_special_tokens = set(tokenizer.all_special_tokens) + tokenizer_vocab = tokenizer.get_vocab() + tokenizer_len = len(tokenizer) + + max_token_id = max(tokenizer_vocab.values()) + # Some tokenizers (e.g., QwenTokenizer) have special tokens that + # are added and included in the implementation of the vocab_size + # property, but not in get_vocab(); if there is an implementation + # of vocab size, we should take the greater value. + if hasattr(tokenizer, "vocab_size"): + with contextlib.suppress(NotImplementedError): + max_token_id = max(max_token_id, tokenizer.vocab_size) + + class CachedTokenizer(tokenizer.__class__): # type: ignore + + @property + def all_special_ids(self): + return tokenizer_all_special_ids + + @property + def all_special_tokens(self): + return tokenizer_all_special_tokens + + @property + def all_special_tokens_extended(self): + return tokenizer_all_special_tokens_extended + + @property + def max_token_id(self): + return max_token_id + + def get_vocab(self): + return tokenizer_vocab + + def __len__(self): + return tokenizer_len + + CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}" + + tokenizer.__class__ = CachedTokenizer + return tokenizer + + +def patch_padding_side(tokenizer: PreTrainedTokenizer) -> None: + """Patch _pad method to accept `padding_side` for older tokenizers.""" + orig_pad = tokenizer._pad + + def _pad( + self: PreTrainedTokenizer, + *args, + padding_side: Optional[str] = None, + **kwargs, + ): + if padding_side is not None and padding_side != self.padding_side: + msg = ("`padding_side` argument is not supported by " + f"{type(tokenizer).__name__} and will be ignored.") + warnings.warn(msg, stacklevel=2) + + return orig_pad(*args, **kwargs) + + tokenizer._pad = MethodType(_pad, tokenizer) + + +def get_tokenizer( + tokenizer_name: Union[str, Path], + *args, + tokenizer_mode: str = "auto", + trust_remote_code: bool = False, + revision: Optional[str] = None, + download_dir: Optional[str] = None, + **kwargs, +) -> AnyTokenizer: + """Gets a tokenizer for the given model name via HuggingFace or ModelScope. + """ + if VLLM_USE_MODELSCOPE: + # download model from ModelScope hub, + # lazy import so that modelscope is not required for normal use. + # pylint: disable=C. + from modelscope.hub.snapshot_download import snapshot_download + + # avoid circuit import + from vllm.model_executor.model_loader.weight_utils import get_lock + + # Only set the tokenizer here, model will be downloaded on the workers. + if not os.path.exists(tokenizer_name): + # Use file lock to prevent multiple processes from + # downloading the same file at the same time. + with get_lock(tokenizer_name, download_dir): + tokenizer_path = snapshot_download( + model_id=tokenizer_name, + cache_dir=download_dir, + revision=revision, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + # Ignore weights - we only need the tokenizer. + ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"]) + tokenizer_name = tokenizer_path + + if tokenizer_mode == "slow": + if kwargs.get("use_fast", False): + raise ValueError( + "Cannot use the fast tokenizer in slow tokenizer mode.") + kwargs["use_fast"] = False + + if "truncation_side" not in kwargs: + kwargs["truncation_side"] = "left" + + # Separate model folder from file path for GGUF models + is_gguf = check_gguf_file(tokenizer_name) + if is_gguf: + kwargs["gguf_file"] = Path(tokenizer_name).name + tokenizer_name = Path(tokenizer_name).parent + + # if tokenizer is from official mistral org + is_from_mistral_org = str(tokenizer_name).split("/")[0] == "mistralai" + if is_from_mistral_org and tokenizer_mode != "mistral": + warnings.warn( + 'It is strongly recommended to run mistral models with ' + '`--tokenizer-mode "mistral"` to ensure correct ' + 'encoding and decoding.', + FutureWarning, + stacklevel=2) + + tokenizer: AnyTokenizer + if tokenizer_mode == "mistral": + tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name), + revision=revision) + elif tokenizer_mode == "custom": + tokenizer = TokenizerRegistry.get_tokenizer(str(tokenizer_name), + *args, + revision=revision, + download_dir=download_dir, + **kwargs) + else: + try: + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, + *args, + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs, + ) + except ValueError as e: + # If the error pertains to the tokenizer class not existing or not + # currently being imported, + # suggest using the --trust-remote-code flag. + if not trust_remote_code and ( + "does not exist or is not currently imported." in str(e) + or "requires you to execute the tokenizer file" in str(e)): + err_msg = ("Failed to load the tokenizer. If the tokenizer " + "is a custom tokenizer not yet available in the " + "HuggingFace transformers library, consider " + "setting `trust_remote_code=True` in LLM or using " + "the `--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e + + # NOTE: We can remove this after https://github.com/THUDM/ChatGLM3/issues/1324 + if type(tokenizer).__name__ in ("ChatGLMTokenizer", + "ChatGLM4Tokenizer"): + assert isinstance(tokenizer, PreTrainedTokenizer) + patch_padding_side(tokenizer) + + if not isinstance(tokenizer, PreTrainedTokenizerFast): + logger.warning( + "Using a slow tokenizer. This might cause a significant " + "slowdown. Consider using a fast tokenizer instead.") + tokenizer = get_cached_tokenizer(tokenizer) + + return tokenizer + + +cached_get_tokenizer = lru_cache(get_tokenizer) + + +def cached_tokenizer_from_config( + model_config: "ModelConfig", + **kwargs: Any, +): + return cached_get_tokenizer( + model_config.tokenizer, + tokenizer_mode=model_config.tokenizer_mode, + tokenizer_revision=model_config.tokenizer_revision, + trust_remote_code=model_config.trust_remote_code, + **kwargs, + ) + + +def get_lora_tokenizer(lora_request: LoRARequest, *args, + **kwargs) -> Optional[AnyTokenizer]: + if lora_request is None: + return None + try: + tokenizer = get_tokenizer(lora_request.lora_path, *args, **kwargs) + except Exception as e: + # No tokenizer was found in the LoRA folder, + # use base model tokenizer + logger.warning( + "No tokenizer found in %s, using base model tokenizer instead. " + "(Exception: %s)", lora_request.lora_path, e) + tokenizer = None + return tokenizer + + +get_lora_tokenizer_async = make_async(get_lora_tokenizer) diff --git a/vllm/transformers_utils/tokenizer_base.py b/vllm/transformers_utils/tokenizer_base.py new file mode 100644 index 0000000..bb5ddaf --- /dev/null +++ b/vllm/transformers_utils/tokenizer_base.py @@ -0,0 +1,146 @@ +# SPDX-License-Identifier: Apache-2.0 + +import importlib +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +if TYPE_CHECKING: + from vllm.entrypoints.chat_utils import ChatCompletionMessageParam + + +class TokenizerBase(ABC): + + @property + @abstractmethod + def all_special_tokens_extended(self) -> List[str]: + raise NotImplementedError() + + @property + @abstractmethod + def all_special_tokens(self) -> List[str]: + raise NotImplementedError() + + @property + @abstractmethod + def all_special_ids(self) -> List[int]: + raise NotImplementedError() + + @property + @abstractmethod + def bos_token_id(self) -> int: + raise NotImplementedError() + + @property + @abstractmethod + def eos_token_id(self) -> int: + raise NotImplementedError() + + @property + @abstractmethod + def sep_token(self) -> str: + raise NotImplementedError() + + @property + @abstractmethod + def pad_token(self) -> str: + raise NotImplementedError() + + @property + @abstractmethod + def is_fast(self) -> bool: + raise NotImplementedError() + + @property + @abstractmethod + def vocab_size(self) -> int: + raise NotImplementedError() + + @property + @abstractmethod + def max_token_id(self) -> int: + raise NotImplementedError() + + def __len__(self) -> int: + return self.vocab_size + + @abstractmethod + def __call__( + self, + text: Union[str, List[str], List[int]], + text_pair: Optional[str] = None, + add_special_tokens: bool = False, + truncation: bool = False, + max_length: Optional[int] = None, + ): + raise NotImplementedError() + + @abstractmethod + def get_vocab(self) -> Dict[str, int]: + raise NotImplementedError() + + @abstractmethod + def get_added_vocab(self) -> Dict[str, int]: + raise NotImplementedError() + + @abstractmethod + def encode_one( + self, + text: str, + truncation: bool = False, + max_length: Optional[int] = None, + ) -> List[int]: + raise NotImplementedError() + + @abstractmethod + def encode(self, + text: str, + add_special_tokens: Optional[bool] = None) -> List[int]: + raise NotImplementedError() + + @abstractmethod + def apply_chat_template(self, + messages: List["ChatCompletionMessageParam"], + tools: Optional[List[Dict[str, Any]]] = None, + **kwargs) -> List[int]: + raise NotImplementedError() + + @abstractmethod + def convert_tokens_to_string(self, tokens: List[str]) -> str: + raise NotImplementedError() + + @abstractmethod + def decode(self, + ids: Union[List[int], int], + skip_special_tokens: bool = True) -> str: + raise NotImplementedError() + + @abstractmethod + def convert_ids_to_tokens( + self, + ids: List[int], + skip_special_tokens: bool = True, + ) -> List[str]: + raise NotImplementedError() + + +class TokenizerRegistry: + # Tokenizer name -> (tokenizer module, tokenizer class) + REGISTRY: Dict[str, Tuple[str, str]] = {} + + @staticmethod + def register(name: str, module: str, class_name: str) -> None: + TokenizerRegistry.REGISTRY[name] = (module, class_name) + + @staticmethod + def get_tokenizer( + tokenizer_name: str, + *args, + **kwargs, + ) -> TokenizerBase: + tokenizer_cls = TokenizerRegistry.REGISTRY.get(tokenizer_name) + if tokenizer_cls is None: + raise ValueError(f"Tokenizer {tokenizer_name} not found.") + + tokenizer_module = importlib.import_module(tokenizer_cls[0]) + class_ = getattr(tokenizer_module, tokenizer_cls[1]) + return class_.from_pretrained(*args, **kwargs) diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py new file mode 100644 index 0000000..9d22095 --- /dev/null +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Type + +from vllm.config import (LoRAConfig, ModelConfig, ParallelConfig, + SchedulerConfig, TokenizerPoolConfig) +from vllm.executor.ray_utils import ray + +from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup +from .tokenizer_group import TokenizerGroup + +if ray: + from .ray_tokenizer_group import RayTokenizerGroupPool +else: + RayTokenizerGroupPool = None # type: ignore + + +def init_tokenizer_from_configs(model_config: ModelConfig, + scheduler_config: SchedulerConfig, + parallel_config: ParallelConfig, + lora_config: Optional[LoRAConfig]): + init_kwargs = dict(tokenizer_id=model_config.tokenizer, + enable_lora=bool(lora_config), + max_num_seqs=scheduler_config.max_num_seqs, + max_loras=lora_config.max_loras if lora_config else 0, + max_input_length=None, + tokenizer_mode=model_config.tokenizer_mode, + trust_remote_code=model_config.trust_remote_code, + revision=model_config.tokenizer_revision, + truncation_side=model_config.truncation_side) + + return get_tokenizer_group(parallel_config.tokenizer_pool_config, + **init_kwargs) + + +def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig], + **init_kwargs) -> BaseTokenizerGroup: + tokenizer_cls: Type[BaseTokenizerGroup] + if tokenizer_pool_config is None: + tokenizer_cls = TokenizerGroup + elif isinstance(tokenizer_pool_config.pool_type, type) and issubclass( + tokenizer_pool_config.pool_type, BaseTokenizerGroup): + tokenizer_cls = tokenizer_pool_config.pool_type + elif tokenizer_pool_config.pool_type == "ray": + if RayTokenizerGroupPool is None: + raise ImportError( + "RayTokenizerGroupPool is not available. Please install " + "the ray package to use the Ray tokenizer group pool.") + tokenizer_cls = RayTokenizerGroupPool + else: + raise ValueError( + f"Unknown pool type: {tokenizer_pool_config.pool_type}") + return tokenizer_cls.from_config(tokenizer_pool_config, **init_kwargs) + + +__all__ = ["AnyTokenizer", "get_tokenizer_group", "BaseTokenizerGroup"] diff --git a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py new file mode 100644 index 0000000..c5108a7 --- /dev/null +++ b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py @@ -0,0 +1,68 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from typing import List, Optional + +from vllm.config import TokenizerPoolConfig +from vllm.lora.request import LoRARequest +from vllm.transformers_utils.tokenizer import AnyTokenizer + + +class BaseTokenizerGroup(ABC): + """A group of tokenizers that can be used for LoRA adapters.""" + + @classmethod + @abstractmethod + def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig], + **init_kwargs) -> "BaseTokenizerGroup": + pass + + @abstractmethod + def ping(self) -> bool: + """Check if the tokenizer group is alive.""" + pass + + @abstractmethod + def get_max_input_len( + self, + lora_request: Optional[LoRARequest] = None, + ) -> Optional[int]: + """Get the maximum input length for the LoRA request.""" + pass + + @abstractmethod + def encode(self, + prompt: str, + lora_request: Optional[LoRARequest] = None, + add_special_tokens: Optional[bool] = None) -> List[int]: + """Encode a prompt using the tokenizer group.""" + pass + + @abstractmethod + async def encode_async( + self, + prompt: str, + lora_request: Optional[LoRARequest] = None, + add_special_tokens: Optional[bool] = None) -> List[int]: + """Encode a prompt using the tokenizer group.""" + pass + + @abstractmethod + def get_lora_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + """Get a tokenizer for a LoRA request.""" + pass + + @abstractmethod + async def get_lora_tokenizer_async( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + """Get a tokenizer for a LoRA request.""" + pass + + def check_health(self): + """Raise exception if the tokenizer group is unhealthy.""" + return diff --git a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py new file mode 100644 index 0000000..b048b80 --- /dev/null +++ b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py @@ -0,0 +1,244 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import os +from typing import List, Optional + +try: + from ray.exceptions import ActorDiedError # type: ignore +except ImportError: + # For older versions of Ray + from ray.exceptions import RayActorError as ActorDiedError # type: ignore +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy + +from vllm.config import TokenizerPoolConfig +from vllm.executor.ray_utils import ray +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.transformers_utils.tokenizer import AnyTokenizer + +from .base_tokenizer_group import BaseTokenizerGroup +from .tokenizer_group import TokenizerGroup + +logger = init_logger(__name__) + + +class RayTokenizerGroupPool(BaseTokenizerGroup): + """A Ray-based pool of TokenizerGroups for async tokenization.""" + + # Class to use for workers making up the pool. + _worker_cls = TokenizerGroup + + @classmethod + def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig], + **init_kwargs) -> "RayTokenizerGroupPool": + if not tokenizer_pool_config: + raise ValueError("tokenizer_pool_config must not be None.") + ray_actor_options = (tokenizer_pool_config.extra_config or { + "num_cpus": 0 + }) + ray_actor_options.setdefault( + "scheduling_strategy", + NodeAffinitySchedulingStrategy( + node_id=ray.get_runtime_context().get_node_id(), soft=True)) + + # Carry over the env vars to the actors. + # This is necessary for API keys and such. + ray_actor_options.setdefault("runtime_env", {}) + _carry_over_env_vars_to_runtime_env(ray_actor_options["runtime_env"]) + + init_kwargs["num_actors"] = tokenizer_pool_config.pool_size + init_kwargs["ray_actor_options"] = ray_actor_options + + return cls(**init_kwargs) + + def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, + max_input_length: Optional[int], num_actors: int, + ray_actor_options: dict, **tokenizer_config): + # Store a local copy of the TokenizerGroup for quick access + # to underlying HF tokenizers. + self._tokenizer_config = { + "tokenizer_id": tokenizer_id, + "enable_lora": enable_lora, + "max_num_seqs": max_num_seqs, + "max_input_length": max_input_length, + **tokenizer_config + } + self._local_tokenizer_group = self._worker_cls( + **self._tokenizer_config, ) + + self._ray_tokenizer_group_cls = ray.remote( + self._worker_cls).options(**ray_actor_options) # type: ignore + self.tokenizer_actors = [self._init_actor() for _ in range(num_actors)] + self._idle_actors: Optional[asyncio.Queue] = None + + # If set, actor is unhealthy. Will reraise on the next + # check_health call. + self._exception: Optional[ActorDiedError] = None + + def _init_actor(self) -> ray.ObjectRef: + return self._ray_tokenizer_group_cls.remote(**self._tokenizer_config) + + @property + def pool_size(self) -> int: + return len(self.tokenizer_actors) + + def ping(self): + return ray.get([ + actor.ping.remote() # type: ignore + for actor in self.tokenizer_actors + ]) + + def _ensure_queue_initialized(self): + if self._idle_actors is None: + self._idle_actors = asyncio.Queue() + for actor in self.tokenizer_actors: + self._idle_actors.put_nowait(actor) + + def _finalize_encode(self, actor: ray.ObjectRef, + original_actor: ray.ObjectRef, actor_is_alive: bool): + assert self._idle_actors is not None + # Cleanup the dead actor. + if not actor_is_alive or original_actor is not actor: + self.tokenizer_actors.remove(original_actor) + if actor_is_alive: + # Put the actor back in the queue. + # This is done in a finally block to ensure that the actor is + # always put back in the queue, even if an exception/cancellation + # is raised. + self._idle_actors.put_nowait(actor) + # Add back the new actor. + if original_actor is not actor: + self.tokenizer_actors.append(actor) + + def encode(self, + prompt: str, + lora_request: Optional[LoRARequest] = None, + add_special_tokens: Optional[bool] = None) -> List[int]: + """Encode a prompt using the tokenizer group. + + We pick an idle actor and use it to encode the prompt. + The actor is then put back in the queue for future use. + This is blocking. + """ + self.check_health() + self._ensure_queue_initialized() + assert self._idle_actors is not None + + if self._idle_actors.empty(): + raise RuntimeError("No idle actors available.") + actor = self._idle_actors.get_nowait() + actor_is_alive = True + original_actor = actor + try: + ret = ray.get( + actor.encode.remote(prompt=prompt, + lora_request=lora_request, + add_special_tokens=add_special_tokens)) + except ActorDiedError as e: + # If the actor is dead, we first try to reinitialize it. + logger.warning("%s died with ActorDiedError, reinitializing.", + actor, + exc_info=e) + actor = self._init_actor() + try: + ret = ray.get( + actor.encode.remote(prompt=prompt, + lora_request=lora_request, + add_special_tokens=add_special_tokens)) + except ActorDiedError as e: + logger.error( + "%s died for second time in a row, marking " + "RayTokenizerGroupPool as unhealthy.", actor) + actor_is_alive = False + if not self._exception: + self._exception = e + self.check_health() + finally: + self._finalize_encode(actor, original_actor, actor_is_alive) + return ret + + async def encode_async( + self, + prompt: str, + lora_request: Optional[LoRARequest] = None, + add_special_tokens: Optional[bool] = None) -> List[int]: + """Encode a prompt using the tokenizer group. + + We pick an idle actor and use it to encode the prompt. + If there are no idle actors, we wait until one becomes + available. + The actor is then put back in the queue for future use. + This is non-blocking. + """ + self.check_health() + self._ensure_queue_initialized() + assert self._idle_actors is not None + + actor = await self._idle_actors.get() + actor_is_alive = True + original_actor = actor + try: + ret = await actor.encode.remote( + prompt=prompt, + lora_request=lora_request, + add_special_tokens=add_special_tokens) + except ActorDiedError as e: + # If the actor is dead, we first try to reinitialize it. + logger.warning("%s died with ActorDiedError, reinitializing.", + actor, + exc_info=e) + actor = self._init_actor() + try: + ret = await actor.encode.remote( + prompt=prompt, + lora_request=lora_request, + add_special_tokens=add_special_tokens) + except ActorDiedError as e: + logger.error( + "%s died for second time in a row, marking " + "RayTokenizerGroupPool as unhealthy.", actor) + actor_is_alive = False + if not self._exception: + self._exception = e + self.check_health() + finally: + self._finalize_encode(actor, original_actor, actor_is_alive) + return ret + + def get_max_input_len(self, + lora_request: Optional[LoRARequest] = None + ) -> Optional[int]: + """Get the maximum input length for the LoRA request.""" + return self._local_tokenizer_group.get_max_input_len(lora_request) + + def get_lora_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + return self._local_tokenizer_group.get_lora_tokenizer(lora_request) + + async def get_lora_tokenizer_async( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + return await self._local_tokenizer_group.get_lora_tokenizer_async( + lora_request) + + def check_health(self): + if self._exception: + raise RuntimeError( + "TokenizerGroupPool is unhealthy.") from self._exception + + +def _carry_over_env_vars_to_runtime_env(runtime_env: dict) -> None: + """Copy over all current process environment variables to the runtime_env. + + The variables in runtime_env will take precedence over the current process + environment variables. + + runtime_env will be modified in place.""" + env_vars = os.environ.copy() + runtime_env.setdefault("env_vars", {}) + env_vars.update(runtime_env["env_vars"]) + runtime_env["env_vars"] = env_vars diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py new file mode 100644 index 0000000..b6e9005 --- /dev/null +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -0,0 +1,106 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional + +from vllm.config import TokenizerPoolConfig +from vllm.lora.request import LoRARequest +from vllm.transformers_utils.tokenizer import (AnyTokenizer, encode_tokens, + get_lora_tokenizer, + get_lora_tokenizer_async, + get_tokenizer) +from vllm.utils import LRUCache + +from .base_tokenizer_group import BaseTokenizerGroup + + +class TokenizerGroup(BaseTokenizerGroup): + """A group of tokenizers that can be used for LoRA adapters.""" + + def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, + max_input_length: Optional[int], **tokenizer_config): + self.tokenizer_id = tokenizer_id + self.tokenizer_config = tokenizer_config + self.enable_lora = enable_lora + self.max_input_length = max_input_length + self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) + max_loras = tokenizer_config.get("max_loras", 0) + self.lora_tokenizers = LRUCache[int, AnyTokenizer]( + capacity=max(max_loras, max_num_seqs) if enable_lora else 0) + + @classmethod + def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig], + **init_kwargs) -> "TokenizerGroup": + return cls(**init_kwargs) + + def ping(self) -> bool: + """Check if the tokenizer group is alive.""" + return True + + def get_max_input_len(self, + lora_request: Optional[LoRARequest] = None + ) -> Optional[int]: + """Get the maximum input length for the LoRA request.""" + return self.max_input_length + + def _raise_if_input_too_long(self, + encoded_tokens: List[int], + lora_request: Optional[LoRARequest] = None): + input_length = len(encoded_tokens) + if lora_request: + max_input_length = (lora_request.long_lora_max_len + or self.max_input_length) + else: + max_input_length = self.max_input_length + if max_input_length is not None and input_length > max_input_length: + raise ValueError("Input too long.", input_length, max_input_length) + + def encode(self, + prompt: str, + lora_request: Optional[LoRARequest] = None, + add_special_tokens: Optional[bool] = None) -> List[int]: + tokenizer = self.get_lora_tokenizer(lora_request) + ret = encode_tokens(tokenizer, + prompt, + add_special_tokens=add_special_tokens) + self._raise_if_input_too_long(ret, lora_request) + return ret + + async def encode_async( + self, + prompt: str, + lora_request: Optional[LoRARequest] = None, + add_special_tokens: Optional[bool] = None) -> List[int]: + tokenizer = await self.get_lora_tokenizer_async(lora_request) + ret = encode_tokens(tokenizer, + prompt, + add_special_tokens=add_special_tokens) + self._raise_if_input_too_long(ret, lora_request) + return ret + + def get_lora_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + if not lora_request or not self.enable_lora: + return self.tokenizer + if lora_request.lora_int_id not in self.lora_tokenizers: + tokenizer = (get_lora_tokenizer( + lora_request, **self.tokenizer_config) or self.tokenizer) + self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) + return tokenizer + else: + return self.lora_tokenizers[lora_request.lora_int_id] + + async def get_lora_tokenizer_async( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + if not lora_request or not self.enable_lora: + return self.tokenizer + if lora_request.lora_int_id not in self.lora_tokenizers: + tokenizer = (await get_lora_tokenizer_async( + lora_request, **self.tokenizer_config) or self.tokenizer) + self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) + return tokenizer + else: + return self.lora_tokenizers[lora_request.lora_int_id] diff --git a/vllm/transformers_utils/tokenizers/__init__.py b/vllm/transformers_utils/tokenizers/__init__.py new file mode 100644 index 0000000..c12388d --- /dev/null +++ b/vllm/transformers_utils/tokenizers/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: Apache-2.0 + +from .mistral import (MistralTokenizer, maybe_serialize_tool_calls, + truncate_tool_call_ids) + +__all__ = [ + "MistralTokenizer", "maybe_serialize_tool_calls", "truncate_tool_call_ids" +] diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py new file mode 100644 index 0000000..d893431 --- /dev/null +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -0,0 +1,476 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +import re +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast + +import huggingface_hub +from huggingface_hub import HfApi, hf_hub_download + +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer_base import TokenizerBase +from vllm.utils import is_list_of + +if TYPE_CHECKING: + # make sure `mistral_common` is lazy imported, + # so that users who only use non-mistral models + # will not be bothered by the dependency. + from mistral_common.protocol.instruct.request import ChatCompletionRequest + from mistral_common.tokens.tokenizers.mistral import ( + MistralTokenizer as PublicMistralTokenizer) + + from vllm.entrypoints.chat_utils import ChatCompletionMessageParam + +logger = init_logger(__name__) + + +@dataclass +class Encoding: + input_ids: Union[List[int], List[List[int]]] + + +def maybe_serialize_tool_calls(request: "ChatCompletionRequest"): + # SEE: https://github.com/vllm-project/vllm/pull/9951 + # Credits go to: @gcalmettes + # NOTE: There is currently a bug in pydantic where attributes + # declared as iterables are replaced in in the instances by + # pydantic-core ValidatorIterator instance. In particular, this + # affects tool_calls defined in ChatCompletionAssistantMessageParam + # model: + # see: + # - https://github.com/pydantic/pydantic/issues/9467 + # As a result, tool_calls from assistant messages are never + # deserialized in the request object if the tool_calls iterator is + # not consumed. This affect messages passed to the MistralTokenizer + # since no chat template is applied and therefore the tools_calls + # iterator is not directly consumed. + # Issue is tracked on Pydantic side, with resolution planned for + # v2.11 release. In the meantime, the official workaround is to + # consume the iterator so the tool_calls are correctly deserialized + # in the OpenAI ChatCompletionAssistantMessageParam object + # https://github.com/pydantic/pydantic/issues/9467#issuecomment-2442097291 # noqa: E501 + # Official Pydantic Issues: + # - https://github.com/pydantic/pydantic/issues/9541 + # TODO: remove when pydantic v2.11 is released + for i, message in enumerate(request.messages): + if message.get("role") == 'assistant': + tool_calls_validator = message.get("tool_calls", ().__iter__()) + validated_tool_calls = [] + while True: + try: + tool_call = next(tool_calls_validator) # type: ignore + validated_tool_calls.append(tool_call) + except StopIteration: + break + + request.messages[i]["tool_calls"] = validated_tool_calls + + +def truncate_tool_call_ids(request: "ChatCompletionRequest"): + """Truncates tool call IDs for Mistral's ID requirements.""" + for i, message in enumerate(request.messages): + if message.get("role") == 'assistant': + tool_calls = message.get("tool_calls", []) + for tool_call in tool_calls: + if len(tool_call["id"]) > 9: + logger.warning( + "Truncating tool call ID: %s to %s", + tool_call["id"], + tool_call["id"][-9:], + ) + tool_call["id"] = tool_call["id"][-9:] + + request.messages[i]["tool_calls"] = tool_calls + + elif message.get("role") in {"tool_results", "tool"}: + if "tool_call_id" in message: + tool_call_id = message["tool_call_id"] + + if len(tool_call_id) > 9: + logger.warning( + "Truncating tool_call_id: %s to %s", + tool_call_id, + tool_call_id[-9:], + ) + tool_call_id = tool_call_id[-9:] + request.messages[i]["tool_call_id"] = tool_call_id + + +def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]: + repo_cache = os.path.join( + huggingface_hub.constants.HF_HUB_CACHE, + huggingface_hub.constants.REPO_ID_SEPARATOR.join( + ["models", *repo_id.split("/")])) + + if revision is None: + revision_file = os.path.join(repo_cache, "refs", "main") + if os.path.isfile(revision_file): + with open(revision_file) as file: + revision = file.read() + + if revision: + revision_dir = os.path.join(repo_cache, "snapshots", revision) + if os.path.isdir(revision_dir): + return os.listdir(revision_dir) + + return [] + + +def find_tokenizer_file(files: List[str]): + file_pattern = re.compile( + r"^tokenizer\.model\.v.*$|^tekken\.json$|^tokenizer\.mm\.model\.v.*$") + + matched_files = [file for file in files if file_pattern.match(file)] + if len(matched_files) > 1: + raise OSError( + f"Found {len(matched_files)} files matching the " + f"pattern: `{file_pattern.pattern}`. Make sure only one Mistral " + f"tokenizer is present in {files}.") + elif len(matched_files) == 0: + raise OSError( + f"Found {len(matched_files)} files matching the " + f"pattern: `{file_pattern.pattern}`. Make sure that a Mistral " + f"tokenizer is present in {files}.") + + return matched_files[0] + + +def make_mistral_chat_completion_request( + messages: List["ChatCompletionMessageParam"], + tools: Optional[List[Dict[str, + Any]]] = None) -> "ChatCompletionRequest": + last_message = cast(Dict[str, Any], messages[-1]) + if last_message["role"] == "assistant": + last_message["prefix"] = True + + # mistral-common requires AssistantMessage content to be string [1]. + # + # [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80 + for message in messages: + if message.get("role") == "assistant": + content = message.get("content") + if isinstance(content, list): + content = "\n".join(chunk.get("text") for chunk in content) + message["content"] = content + + # The Mistral client, in comparison to the OpenAI client, requires the + # "parameters" dict to be present, even if it's empty. + if tools: + for function in [ + tool["function"] for tool in tools + if tool["type"] == "function" + ]: + if function.get("parameters") is None: + function["parameters"] = {} + + from mistral_common.protocol.instruct.request import ChatCompletionRequest + return ChatCompletionRequest(messages=messages, + tools=tools) # type: ignore[type-var] + + +class MistralTokenizer(TokenizerBase): + + def __init__(self, tokenizer: "PublicMistralTokenizer") -> None: + self.mistral = tokenizer + self.instruct = tokenizer.instruct_tokenizer + + tokenizer_ = tokenizer.instruct_tokenizer.tokenizer + from mistral_common.tokens.tokenizers.tekken import ( + SpecialTokenPolicy, Tekkenizer) + self.is_tekken = isinstance(tokenizer_, Tekkenizer) + from mistral_common.tokens.tokenizers.sentencepiece import ( + SentencePieceTokenizer) + self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer) + if self.is_tekken: + # Make sure special tokens will not raise + tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE + elif self.is_spm: + pass + else: + raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}") + + self._vocab = tokenizer_.vocab() + # Convert to a Dict[str, int] to match protocol, but this is a lossy + # conversion. There may be multiple token ids that decode to the same + # string due to partial UTF-8 byte sequences being converted to � + self._vocab_dict = { + token: idx + for idx, token in enumerate(self._vocab) + } + self.tokenizer = tokenizer_ + self._max_token_id = self.vocab_size - 1 + + @classmethod + def from_pretrained(cls, + path_or_repo_id: str, + *, + revision: Optional[str] = None) -> "MistralTokenizer": + if not Path(path_or_repo_id).exists(): + assert len(path_or_repo_id.split("/")) == 2, ( + "You have either provided a non-existent path: " + "{path_or_repo_id} or an invalid HF Hub repo id.") + tokenizer_file = cls._download_mistral_tokenizer_from_hf( + path_or_repo_id, revision) + elif Path(path_or_repo_id).is_dir(): + tokenizer_file_name = find_tokenizer_file( + os.listdir(path_or_repo_id)) + tokenizer_file = str(Path(path_or_repo_id) / tokenizer_file_name) + else: + assert Path( + path_or_repo_id).is_file(), f"Invalid path: {path_or_repo_id}" + + from mistral_common.tokens.tokenizers.mistral import ( + MistralTokenizer as PublicMistralTokenizer) + mistral_tokenizer = PublicMistralTokenizer.from_file(tokenizer_file) + return cls(mistral_tokenizer) + + @staticmethod + def _download_mistral_tokenizer_from_hf(tokenizer_name: str, + revision: Optional[str]) -> str: + try: + hf_api = HfApi() + files = hf_api.list_repo_files(repo_id=tokenizer_name, + revision=revision) + except ConnectionError as exc: + files = list_local_repo_files(repo_id=tokenizer_name, + revision=revision) + + if len(files) == 0: + raise exc + + filename = find_tokenizer_file(files) + + tokenizer_file = hf_hub_download(tokenizer_name, + filename=filename, + revision=revision) + return tokenizer_file + + # the following attributes are set to fit vLLM's design and are used + # by the guided structured output backends. + @property + def all_special_tokens_extended(self) -> List[str]: + from mistral_common.tokens.tokenizers.base import SpecialTokens + + # tekken defines its own extended special tokens list + if hasattr(self.tokenizer, "SPECIAL_TOKENS"): + special_tokens = self.tokenizer.SPECIAL_TOKENS + else: + special_tokens = list(SpecialTokens) + return [ + s.value if isinstance(s, SpecialTokens) else s + for s in special_tokens + ] + + @property + def all_special_tokens(self) -> List[str]: + return self.all_special_tokens_extended + + @property + def all_special_ids(self) -> List[int]: + return [ + self.all_special_tokens.index(t) for t in self.all_special_tokens + ] + + @property + def bos_token_id(self) -> int: + return self.tokenizer.bos_id + + @property + def eos_token_id(self) -> int: + return self.tokenizer.eos_id + + @property + def sep_token(self) -> str: + raise NotImplementedError() + + @property + def pad_token(self) -> str: + raise NotImplementedError() + + @property + def is_fast(self) -> bool: + return True + + @property + def vocab_size(self) -> int: + return len(self._vocab) + + @property + def max_token_id(self) -> int: + return self._max_token_id + + def __len__(self) -> int: + return self.vocab_size + + def __call__( + self, + text: Union[str, List[str], List[int]], + text_pair: Optional[str] = None, + add_special_tokens: bool = False, + truncation: bool = False, + max_length: Optional[int] = None, + ): + input_ids: Union[List[int], List[List[int]]] + # For List[str], original prompt text + if is_list_of(text, str): + input_ids_: List[List[int]] = [] + for p in text: + each_input_ids = self.encode_one(p, truncation, max_length) + input_ids_.append(each_input_ids) + input_ids = input_ids_ + # For List[int], apply chat template output, already tokens. + elif is_list_of(text, int): + input_ids = text + # For str, single prompt text + else: + input_ids = self.encode_one(text, truncation, max_length) + return Encoding(input_ids=input_ids) + + def get_vocab(self) -> Dict[str, int]: + # NB: the dictionary form of the vocabulary collapses token ids that map + # to the same string but have different bytes + return self._vocab_dict + + def get_added_vocab(self) -> Dict[str, int]: + # Mistral tokenizers have no added vocabulary + return {} + + def encode_one( + self, + text: str, + truncation: bool = False, + max_length: Optional[int] = None, + ) -> List[int]: + # Mistral Tokenizers should not add special tokens + input_ids = self.encode(text) + + if truncation: + input_ids = input_ids[:max_length] + return input_ids + + def encode(self, + text: str, + add_special_tokens: Optional[bool] = None) -> List[int]: + # `encode` should only be used for prompt completion + # it should never be used for chat_completion. + # For chat completion use `apply_chat_template` + if add_special_tokens is not None: + return self.tokenizer.encode(text, + bos=add_special_tokens, + eos=add_special_tokens) + else: + return self.tokenizer.encode(text, bos=True, eos=False) + + def apply_chat_template(self, + messages: List["ChatCompletionMessageParam"], + tools: Optional[List[Dict[str, Any]]] = None, + **kwargs) -> List[int]: + + request = make_mistral_chat_completion_request(messages, tools) + encoded = self.mistral.encode_chat_completion(request) + + # encode-decode to get clean prompt + return encoded.tokens + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + from mistral_common.tokens.tokenizers.base import SpecialTokens + if self.is_tekken: + tokens = [ + t for t in tokens + if (t is SpecialTokens.tool_calls + or t not in self.tokenizer._all_special_tokens) + ] + + if any(isinstance(t, bytes) for t in tokens): + # we need to encode and decode all tokens again + shift = self.tokenizer.num_special_tokens + + def _token_to_id(t: str): + t_bytes = t.encode("utf-8") \ + if not isinstance(t, bytes) else t + try: + return shift + \ + self.tokenizer._tekken_token2id_nospecial[t_bytes] + except KeyError: + logger.warning( + "Failed to convert token %s to id," + " replacing with ", t_bytes) + return self.tokenizer.unk_id + + ids = [_token_to_id(t) for t in tokens] + decoded = self.tokenizer.decode(ids) + else: + decoded = "".join(tokens) + else: + # make sure certain special tokens like Tool calls are + # not decoded + special_tokens = {SpecialTokens.tool_calls} + regular_tokens: List[str] = [] + decoded_list = [] + + for token in tokens: + if token in special_tokens: + if regular_tokens: + decoded_list.append( + self.tokenizer.decode(regular_tokens)) + regular_tokens = [] + decoded_list.append(token) + else: + regular_tokens.append(token) + + if regular_tokens: + decoded_list.append( + self.tokenizer.decode(regular_tokens)) # type: ignore + + decoded = ''.join(decoded_list) + + return decoded + + # WARN: Outlines logits processors can overwrite this method. + # See: guided_decoding/outlines_logits_processors.py::_adapt_tokenizer + # for more. + def decode(self, + ids: Union[List[int], int], + skip_special_tokens: bool = True) -> str: + assert ( + skip_special_tokens + ), "skip_special_tokens=False is not supported for Mistral tokenizers." + + if isinstance(ids, int): + ids = [ids] + return self.tokenizer.decode(ids) + + def convert_ids_to_tokens( + self, + ids: List[int], + skip_special_tokens: bool = True, + ) -> List[str]: + from mistral_common.tokens.tokenizers.base import SpecialTokens + + # TODO(Patrick) - potentially allow special tokens to not be skipped + assert ( + skip_special_tokens + ), "skip_special_tokens=False is not supported for Mistral tokenizers." + + assert self.is_tekken or self.is_spm, type(self.tokenizer) + + if self.is_tekken: + # skip special tokens except tool call + ids = [ + i for i in ids if i > self.tokenizer.num_special_tokens or i == + self.tokenizer.get_control_token(SpecialTokens.tool_calls) + ] + + tokens = [self.tokenizer.id_to_piece(id) for id in ids] + + if any("�" in t for t in tokens) and self.is_tekken: + # if a decoded token contains the replacement character, then the + # token has an incomplete UTF-8 character so we must use bytes + # See: https://github.com/vllm-project/vllm/pull/8640 + # https://github.com/vllm-project/vllm/pull/9625 + # if underlying tokenizeir is sentencepiece, we just add "�" + tokens = [self.tokenizer.id_to_byte_piece(id) for id in ids] + + return tokens diff --git a/vllm/transformers_utils/utils.py b/vllm/transformers_utils/utils.py new file mode 100644 index 0000000..564c0f8 --- /dev/null +++ b/vllm/transformers_utils/utils.py @@ -0,0 +1,83 @@ +# SPDX-License-Identifier: Apache-2.0 + +from functools import cache +from os import PathLike +from pathlib import Path +from typing import List, Optional, Union + +from vllm.envs import VLLM_MODEL_REDIRECT_PATH +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def is_s3(model_or_path: str) -> bool: + return model_or_path.lower().startswith('s3://') + + +def check_gguf_file(model: Union[str, PathLike]) -> bool: + """Check if the file is a GGUF model.""" + model = Path(model) + if not model.is_file(): + return False + elif model.suffix == ".gguf": + return True + + try: + with model.open("rb") as f: + header = f.read(4) + + return header == b"GGUF" + except Exception as e: + logger.debug("Error reading file %s: %s", model, e) + return False + + +def modelscope_list_repo_files( + repo_id: str, + revision: Optional[str] = None, + token: Union[str, bool, None] = None, +) -> List[str]: + """List files in a modelscope repo.""" + from modelscope.hub.api import HubApi + api = HubApi() + api.login(token) + # same as huggingface_hub.list_repo_files + files = [ + file['Path'] for file in api.get_model_files( + model_id=repo_id, revision=revision, recursive=True) + if file['Type'] == 'blob' + ] + return files + + +@cache +def maybe_model_redirect(model: str) -> str: + """ + Use model_redirect to redirect the model name to a local folder. + + :param model: hf model name + :return: maybe redirect to a local folder + """ + + model_redirect_path = VLLM_MODEL_REDIRECT_PATH + + if not model_redirect_path: + return model + + if not Path(model_redirect_path).exists(): + return model + + with open(model_redirect_path) as f: + for line in f.readlines(): + try: + model_name, redirect_name = line.split("\t") + if model == model_name: + redirect_name = redirect_name.strip() + logger.info("model redirect: [ %s ] -> [ %s ]", model, + redirect_name) + return redirect_name + except Exception: + pass + + return model diff --git a/vllm/triton_utils/__init__.py b/vllm/triton_utils/__init__.py new file mode 100644 index 0000000..2103138 --- /dev/null +++ b/vllm/triton_utils/__init__.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm.triton_utils.importing import HAS_TRITON + +__all__ = ["HAS_TRITON"] + +if HAS_TRITON: + from vllm.triton_utils.custom_cache_manager import ( + maybe_set_triton_cache_manager) + + __all__ += ["maybe_set_triton_cache_manager"] + +else: + def maybe_set_triton_cache_manager() -> None: + return None + __all__ += ["maybe_set_triton_cache_manager"] diff --git a/vllm/triton_utils/importing.py b/vllm/triton_utils/importing.py new file mode 100644 index 0000000..8574f1e --- /dev/null +++ b/vllm/triton_utils/importing.py @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: Apache-2.0 + +from importlib.util import find_spec + +from vllm.logger import init_logger +from vllm.platforms import current_platform + +logger = init_logger(__name__) + +HAS_TRITON = False diff --git a/vllm/usage/__init__.py b/vllm/usage/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py new file mode 100644 index 0000000..2ee3f91 --- /dev/null +++ b/vllm/usage/usage_lib.py @@ -0,0 +1,246 @@ +# SPDX-License-Identifier: Apache-2.0 + +import datetime +import json +import logging +import os +import platform +import time +from enum import Enum +from pathlib import Path +from threading import Thread +from typing import Any, Optional, Union +from uuid import uuid4 + +import cpuinfo +import psutil +import requests +import torch + +import vllm.envs as envs +from vllm.connections import global_http_connection +from vllm.version import __version__ as VLLM_VERSION + +_config_home = envs.VLLM_CONFIG_ROOT +_USAGE_STATS_JSON_PATH = os.path.join(_config_home, "usage_stats.json") +_USAGE_STATS_DO_NOT_TRACK_PATH = os.path.join(_config_home, "do_not_track") +_USAGE_STATS_ENABLED = None +_USAGE_STATS_SERVER = envs.VLLM_USAGE_STATS_SERVER + +_GLOBAL_RUNTIME_DATA = dict[str, Union[str, int, bool]]() + +_USAGE_ENV_VARS_TO_COLLECT = [ + "VLLM_USE_MODELSCOPE", + "VLLM_USE_TRITON_FLASH_ATTN", + "VLLM_ATTENTION_BACKEND", + "VLLM_USE_FLASHINFER_SAMPLER", + "VLLM_PP_LAYER_PARTITION", + "VLLM_USE_TRITON_AWQ", + "VLLM_USE_V1", + "VLLM_ENABLE_V1_MULTIPROCESSING", +] + + +def set_runtime_usage_data(key: str, value: Union[str, int, bool]) -> None: + """Set global usage data that will be sent with every usage heartbeat.""" + _GLOBAL_RUNTIME_DATA[key] = value + + +def is_usage_stats_enabled(): + """Determine whether or not we can send usage stats to the server. + The logic is as follows: + - By default, it should be enabled. + - Three environment variables can disable it: + - VLLM_DO_NOT_TRACK=1 + - DO_NOT_TRACK=1 + - VLLM_NO_USAGE_STATS=1 + - A file in the home directory can disable it if it exists: + - $HOME/.config/vllm/do_not_track + """ + global _USAGE_STATS_ENABLED + if _USAGE_STATS_ENABLED is None: + do_not_track = envs.VLLM_DO_NOT_TRACK + no_usage_stats = envs.VLLM_NO_USAGE_STATS + do_not_track_file = os.path.exists(_USAGE_STATS_DO_NOT_TRACK_PATH) + + _USAGE_STATS_ENABLED = not (do_not_track or no_usage_stats + or do_not_track_file) + return _USAGE_STATS_ENABLED + + +def _get_current_timestamp_ns() -> int: + return int(datetime.datetime.now(datetime.timezone.utc).timestamp() * 1e9) + + +def _detect_cloud_provider() -> str: + # Try detecting through vendor file + vendor_files = [ + "/sys/class/dmi/id/product_version", "/sys/class/dmi/id/bios_vendor", + "/sys/class/dmi/id/product_name", + "/sys/class/dmi/id/chassis_asset_tag", "/sys/class/dmi/id/sys_vendor" + ] + # Mapping of identifiable strings to cloud providers + cloud_identifiers = { + "amazon": "AWS", + "microsoft corporation": "AZURE", + "google": "GCP", + "oraclecloud": "OCI", + } + + for vendor_file in vendor_files: + path = Path(vendor_file) + if path.is_file(): + file_content = path.read_text().lower() + for identifier, provider in cloud_identifiers.items(): + if identifier in file_content: + return provider + + # Try detecting through environment variables + env_to_cloud_provider = { + "RUNPOD_DC_ID": "RUNPOD", + } + for env_var, provider in env_to_cloud_provider.items(): + if os.environ.get(env_var): + return provider + + return "UNKNOWN" + + +class UsageContext(str, Enum): + UNKNOWN_CONTEXT = "UNKNOWN_CONTEXT" + LLM_CLASS = "LLM_CLASS" + API_SERVER = "API_SERVER" + OPENAI_API_SERVER = "OPENAI_API_SERVER" + OPENAI_BATCH_RUNNER = "OPENAI_BATCH_RUNNER" + ENGINE_CONTEXT = "ENGINE_CONTEXT" + + +class UsageMessage: + """Collect platform information and send it to the usage stats server.""" + + def __init__(self) -> None: + # NOTE: vLLM's server _only_ support flat KV pair. + # Do not use nested fields. + + self.uuid = str(uuid4()) + + # Environment Information + self.provider: Optional[str] = None + self.num_cpu: Optional[int] = None + self.cpu_type: Optional[str] = None + self.cpu_family_model_stepping: Optional[str] = None + self.total_memory: Optional[int] = None + self.architecture: Optional[str] = None + self.platform: Optional[str] = None + self.cuda_runtime: Optional[str] = None + self.gpu_count: Optional[int] = None + self.gpu_type: Optional[str] = None + self.gpu_memory_per_device: Optional[int] = None + self.env_var_json: Optional[str] = None + + # vLLM Information + self.model_architecture: Optional[str] = None + self.vllm_version: Optional[str] = None + self.context: Optional[str] = None + + # Metadata + self.log_time: Optional[int] = None + self.source: Optional[str] = None + + def report_usage(self, + model_architecture: str, + usage_context: UsageContext, + extra_kvs: Optional[dict[str, Any]] = None) -> None: + t = Thread(target=self._report_usage_worker, + args=(model_architecture, usage_context, extra_kvs or {}), + daemon=True) + t.start() + + def _report_usage_worker(self, model_architecture: str, + usage_context: UsageContext, + extra_kvs: dict[str, Any]) -> None: + self._report_usage_once(model_architecture, usage_context, extra_kvs) + self._report_continous_usage() + + def _report_usage_once(self, model_architecture: str, + usage_context: UsageContext, + extra_kvs: dict[str, Any]) -> None: + # Platform information + from vllm.platforms import current_platform + if current_platform.is_cuda_alike(): + device_property = torch.cuda.get_device_properties(0) + self.gpu_count = torch.cuda.device_count() + self.gpu_type = device_property.name + self.gpu_memory_per_device = device_property.total_memory + if current_platform.is_cuda(): + self.cuda_runtime = torch.version.cuda + self.provider = _detect_cloud_provider() + self.architecture = platform.machine() + self.platform = platform.platform() + self.total_memory = psutil.virtual_memory().total + + info = cpuinfo.get_cpu_info() + self.num_cpu = info.get("count", None) + self.cpu_type = info.get("brand_raw", "") + self.cpu_family_model_stepping = ",".join([ + str(info.get("family", "")), + str(info.get("model", "")), + str(info.get("stepping", "")) + ]) + + # vLLM information + self.context = usage_context.value + self.vllm_version = VLLM_VERSION + self.model_architecture = model_architecture + + # Environment variables + self.env_var_json = json.dumps({ + env_var: getattr(envs, env_var) + for env_var in _USAGE_ENV_VARS_TO_COLLECT + }) + + # Metadata + self.log_time = _get_current_timestamp_ns() + self.source = envs.VLLM_USAGE_SOURCE + + data = vars(self) + if extra_kvs: + data.update(extra_kvs) + + self._write_to_file(data) + self._send_to_server(data) + + def _report_continous_usage(self): + """Report usage every 10 minutes. + + This helps us to collect more data points for uptime of vLLM usages. + This function can also help send over performance metrics over time. + """ + while True: + time.sleep(600) + data = { + "uuid": self.uuid, + "log_time": _get_current_timestamp_ns(), + } + data.update(_GLOBAL_RUNTIME_DATA) + + self._write_to_file(data) + self._send_to_server(data) + + def _send_to_server(self, data: dict[str, Any]) -> None: + try: + global_http_client = global_http_connection.get_sync_client() + global_http_client.post(_USAGE_STATS_SERVER, json=data) + except requests.exceptions.RequestException: + # silently ignore unless we are using debug log + logging.debug("Failed to send usage data to server") + + def _write_to_file(self, data: dict[str, Any]) -> None: + os.makedirs(os.path.dirname(_USAGE_STATS_JSON_PATH), exist_ok=True) + Path(_USAGE_STATS_JSON_PATH).touch(exist_ok=True) + with open(_USAGE_STATS_JSON_PATH, "a") as f: + json.dump(data, f) + f.write("\n") + + +usage_message = UsageMessage() diff --git a/vllm/utils.py b/vllm/utils.py new file mode 100644 index 0000000..88bc08a --- /dev/null +++ b/vllm/utils.py @@ -0,0 +1,2604 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import argparse +import asyncio +import concurrent +import contextlib +import datetime +import enum +import gc +import getpass +import hashlib +import importlib +import importlib.metadata +import importlib.util +import inspect +import ipaddress +import multiprocessing +import os +import pickle +import re +import signal +import socket +import subprocess +import sys +import tempfile +import threading +import time +import traceback +import types +import uuid +import warnings +import weakref +from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task +from collections import UserDict, defaultdict +from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable, + Iterable, Iterator, KeysView, Mapping) +from dataclasses import dataclass, field +from functools import cache, lru_cache, partial, wraps +from types import MappingProxyType +from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple, + Optional, Type, TypeVar, Union, cast, overload) +from uuid import uuid4 + +import cachetools +import cloudpickle +import numpy as np +import numpy.typing as npt +import psutil +import torch +import torch.types +import yaml +import zmq +import zmq.asyncio +from packaging.version import Version +from torch.library import Library +from typing_extensions import Never, ParamSpec, TypeIs, assert_never + +import vllm.envs as envs +from vllm.logger import enable_trace_function_call, init_logger +import ixformer.inference.functions as ops + +if TYPE_CHECKING: + from vllm.config import ModelConfig, VllmConfig + +logger = init_logger(__name__) + +# Exception strings for non-implemented encoder/decoder scenarios + +# Reminder: Please update docs/source/features/compatibility_matrix.md +# If the feature combo become valid + +STR_NOT_IMPL_ENC_DEC_SWA = \ + "Sliding window attention for encoder/decoder models " + \ + "is not currently supported." + +STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \ + "Prefix caching for encoder/decoder models " + \ + "is not currently supported." + +STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \ + "Chunked prefill for encoder/decoder models " + \ + "is not currently supported." + +STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP = ( + "Models with logits_soft_cap " + "require FlashInfer backend, which is " + "currently not supported for encoder/decoder " + "models.") + +STR_NOT_IMPL_ENC_DEC_LORA = ("LoRA is currently not currently " + "supported with encoder/decoder " + "models.") + +STR_NOT_IMPL_ENC_DEC_PP = ("Pipeline parallelism is not " + "currently supported with " + "encoder/decoder models.") + +STR_NOT_IMPL_ENC_DEC_MM = ("Multimodal is not currently " + "supported with encoder/decoder " + "models.") + +STR_NOT_IMPL_ENC_DEC_SPEC_DEC = ("Speculative decoding is not " + "currently supported with encoder/" + "decoder models.") + +STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers and Flash-Attention are the only " + "backends currently supported with encoder/" + "decoder models.") + +STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not " + "currently supported with encoder/" + "decoder models.") + +# Efficiently import all enc/dec error strings +# rather than having to import all of the above +STR_NOT_IMPL_ENC_DEC_ERR_STRS = { + "STR_NOT_IMPL_ENC_DEC_SWA": STR_NOT_IMPL_ENC_DEC_SWA, + "STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE": STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, + "STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL": + STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, + "STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP": STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP, + "STR_NOT_IMPL_ENC_DEC_LORA": STR_NOT_IMPL_ENC_DEC_LORA, + "STR_NOT_IMPL_ENC_DEC_PP": STR_NOT_IMPL_ENC_DEC_PP, + "STR_NOT_IMPL_ENC_DEC_MM": STR_NOT_IMPL_ENC_DEC_MM, + "STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC, + "STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND, + "STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER, +} + +# Constants related to forcing the attention backend selection + +# String name of register which may be set in order to +# force auto-selection of attention backend by Attention +# wrapper +STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND" + +# Possible string values of STR_BACKEND_ENV_VAR +# register, corresponding to possible backends +STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER" +STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA" +STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH" +STR_XFORMERS_ATTN_VAL: str = "XFORMERS" +STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" +STR_INVALID_VAL: str = "INVALID" + +GB_bytes = 1_000_000_000 +"""The number of bytes in one gigabyte (GB).""" + +GiB_bytes = 1 << 30 +"""The number of bytes in one gibibyte (GiB).""" + +STR_DTYPE_TO_TORCH_DTYPE = { + "half": torch.half, + "bfloat16": torch.bfloat16, + "float": torch.float, + "fp8": torch.uint8, + "fp8_e4m3": torch.uint8, + "fp8_e5m2": torch.uint8, + "int8": torch.int8, +} + +TORCH_DTYPE_TO_NUMPY_DTYPE = { + torch.float16: np.float16, + torch.float32: np.float32, + torch.float64: np.float64, + torch.uint8: np.uint8, + torch.int32: np.int32, + torch.int64: np.int64, +} + +P = ParamSpec('P') +T = TypeVar("T") +U = TypeVar("U") + +_K = TypeVar("_K", bound=Hashable) +_V = TypeVar("_V") +_T = TypeVar("_T") + + +class _Sentinel: + ... + + +ALL_PINNED_SENTINEL = _Sentinel() + + +class Device(enum.Enum): + GPU = enum.auto() + CPU = enum.auto() + + +class LayerBlockType(enum.Enum): + attention = "attention" + mamba = "mamba" + + +class Counter: + + def __init__(self, start: int = 0) -> None: + self.counter = start + + def __next__(self) -> int: + i = self.counter + self.counter += 1 + return i + + def reset(self) -> None: + self.counter = 0 + + +class _MappingOrderCacheView(UserDict[_K, _V]): + + def __init__(self, data: Mapping[_K, _V], ordered_keys: Mapping[_K, None]): + super().__init__(data) + self.ordered_keys = ordered_keys + + def __iter__(self) -> Iterator[_K]: + return iter(self.ordered_keys) + + def keys(self) -> KeysView[_K]: + return KeysView(self.ordered_keys) + + +class CacheInfo(NamedTuple): + hits: int + total: int + + @property + def hit_ratio(self) -> float: + if self.total == 0: + return 0 + + return self.hits / self.total + + +class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): + + def __init__(self, + capacity: float, + getsizeof: Optional[Callable[[_V], float]] = None): + super().__init__(capacity, getsizeof) + self.pinned_items = set[_K]() + self.capacity = capacity + + self._hits = 0 + self._total = 0 + + def __delitem__(self, key: _K) -> None: + run_on_remove = key in self + value = self.__getitem__(key) + super().__delitem__(key) + if key in self.pinned_items: + # Todo: add warning to inform that del pinned item + self._unpin(key) + if run_on_remove: + self._on_remove(key, value) + + @property + def cache(self) -> Mapping[_K, _V]: + """Return the internal cache dictionary in order (read-only).""" + return _MappingOrderCacheView( + self._Cache__data, # type: ignore + self.order) + + @property + def order(self) -> Mapping[_K, None]: + """Return the internal order dictionary (read-only).""" + return MappingProxyType(self._LRUCache__order) # type: ignore + + def stat(self) -> CacheInfo: + return CacheInfo(hits=self._hits, total=self._total) + + def touch(self, key: _K) -> None: + try: + self._LRUCache__order.move_to_end(key) # type: ignore + except KeyError: + self._LRUCache__order[key] = None + + @overload + def get(self, key: _K, /) -> Optional[_V]: + ... + + @overload + def get(self, key: _K, /, default: Union[_V, _T]) -> Union[_V, _T]: + ... + + def get(self, + key: _K, + /, + default: Optional[Union[_V, + _T]] = None) -> Optional[Union[_V, _T]]: + value: Optional[Union[_V, _T]] + if key in self: + value = self.__getitem__(key) + + self._hits += 1 + else: + value = default + + self._total += 1 + return value + + @overload + def pop(self, key: _K) -> _V: + ... + + @overload + def pop(self, key: _K, default: Union[_V, _T]) -> Union[_V, _T]: + ... + + def pop(self, + key: _K, + default: Optional[Union[_V, + _T]] = None) -> Optional[Union[_V, _T]]: + value: Optional[Union[_V, _T]] + if key not in self: + return default + + value = self[key] + del self[key] + return value + + def put(self, key: _K, value: _V) -> None: + self.__setitem__(key, value) + + def pin(self, key: _K) -> None: + """ + Pins a key in the cache preventing it from being + evicted in the LRU order. + """ + if key not in self: + raise ValueError(f"Cannot pin key: {key} not in cache.") + self.pinned_items.add(key) + + def _unpin(self, key: _K) -> None: + """ + Unpins a key in the cache allowing it to be + evicted in the LRU order. + """ + self.pinned_items.remove(key) + + def _on_remove(self, key: _K, value: Optional[_V]) -> None: + pass + + def remove_oldest(self, *, remove_pinned: bool = False) -> None: + if len(self) == 0: + return + + self.popitem(remove_pinned=remove_pinned) + + def _remove_old_if_needed(self) -> None: + while self.currsize > self.capacity: + self.remove_oldest() + + def clear(self) -> None: + while len(self) > 0: + self.remove_oldest(remove_pinned=True) + + def popitem(self, remove_pinned: bool = False): + """Remove and return the `(key, value)` pair least recently used.""" + if not remove_pinned: + # pop the oldest item in the cache that is not pinned + lru_key = next( + (key for key in self.order if key not in self.pinned_items), + ALL_PINNED_SENTINEL) + if lru_key is ALL_PINNED_SENTINEL: + raise RuntimeError("All items are pinned, " + "cannot remove oldest from the cache.") + else: + lru_key = next(iter(self.order)) + value = self.pop(cast(_K, lru_key)) + return (lru_key, value) + + +class PyObjectCache: + """Used to cache python objects to avoid object allocations + across scheduler iterations. + """ + + def __init__(self, obj_builder): + self._obj_builder = obj_builder + self._index = 0 + + self._obj_cache = [] + for _ in range(128): + self._obj_cache.append(self._obj_builder()) + + def _grow_cache(self): + # Double the size of the cache + num_objs = len(self._obj_cache) + for _ in range(num_objs): + self._obj_cache.append(self._obj_builder()) + + def get_object(self): + """Returns a pre-allocated cached object. If there is not enough + objects, then the cache size will double. + """ + if self._index >= len(self._obj_cache): + self._grow_cache() + assert self._index < len(self._obj_cache) + + obj = self._obj_cache[self._index] + self._index += 1 + + return obj + + def reset(self): + """Makes all cached-objects available for the next scheduler iteration. + """ + self._index = 0 + + +@cache +def get_max_shared_memory_bytes(gpu: int = 0) -> int: + """Returns the maximum shared memory per thread block in bytes.""" + from vllm import _custom_ops as ops + max_shared_mem = ( + ops.get_max_shared_memory_per_block_device_attribute(gpu)) + # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py + # will fail + assert max_shared_mem > 0, "max_shared_mem can not be zero" + return int(max_shared_mem) + + +def get_cpu_memory() -> int: + """Returns the total CPU memory of the node in bytes.""" + return psutil.virtual_memory().total + + +def random_uuid() -> str: + return str(uuid.uuid4().hex) + + +def make_async( + func: Callable[P, T], + executor: Optional[concurrent.futures.Executor] = None +) -> Callable[P, Awaitable[T]]: + """Take a blocking function, and run it on in an executor thread. + + This function prevents the blocking function from blocking the + asyncio event loop. + The code in this function needs to be thread safe. + """ + + def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future: + loop = asyncio.get_event_loop() + p_func = partial(func, *args, **kwargs) + return loop.run_in_executor(executor=executor, func=p_func) + + return _async_wrapper + + +def _next_task(iterator: AsyncGenerator[T, None], + loop: AbstractEventLoop) -> Task: + # Can use anext() in python >= 3.10 + return loop.create_task(iterator.__anext__()) # type: ignore[arg-type] + + +async def merge_async_iterators( + *iterators: AsyncGenerator[T, + None], ) -> AsyncGenerator[tuple[int, T], None]: + """Merge multiple asynchronous iterators into a single iterator. + + This method handle the case where some iterators finish before others. + When it yields, it yields a tuple (i, item) where i is the index of the + iterator that yields the item. + """ + if len(iterators) == 1: + # Fast-path single iterator case. + async for item in iterators[0]: + yield 0, item + return + + loop = asyncio.get_running_loop() + + awaits = {_next_task(pair[1], loop): pair for pair in enumerate(iterators)} + try: + while awaits: + done, _ = await asyncio.wait(awaits.keys(), + return_when=FIRST_COMPLETED) + for d in done: + pair = awaits.pop(d) + try: + item = await d + i, it = pair + awaits[_next_task(it, loop)] = pair + yield i, item + except StopAsyncIteration: + pass + finally: + # Cancel any remaining iterators + for f, (_, it) in awaits.items(): + with contextlib.suppress(BaseException): + f.cancel() + await it.aclose() + + +async def collect_from_async_generator( + iterator: AsyncGenerator[T, None]) -> list[T]: + """Collect all items from an async generator into a list.""" + items = [] + async for item in iterator: + items.append(item) + return items + + +def get_ip() -> str: + host_ip = envs.VLLM_HOST_IP + if "HOST_IP" in os.environ and "VLLM_HOST_IP" not in os.environ: + logger.warning( + "The environment variable HOST_IP is deprecated and ignored, as" + " it is often used by Docker and other software to" + " interact with the container's network stack. Please " + "use VLLM_HOST_IP instead to set the IP address for vLLM processes" + " to communicate with each other.") + if host_ip: + return host_ip + + # IP is not set, try to get it from the network interface + + # try ipv4 + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable + return s.getsockname()[0] + except Exception: + pass + + # try ipv6 + try: + s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) + # Google's public DNS server, see + # https://developers.google.com/speed/public-dns/docs/using#addresses + s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable + return s.getsockname()[0] + except Exception: + pass + + warnings.warn( + "Failed to get the IP address, using 0.0.0.0 by default." + "The value can be set by the environment variable" + " VLLM_HOST_IP or HOST_IP.", + stacklevel=2) + return "0.0.0.0" + + +def is_valid_ipv6_address(address: str) -> bool: + try: + ipaddress.IPv6Address(address) + return True + except ValueError: + return False + + +def get_distributed_init_method(ip: str, port: int) -> str: + # Brackets are not permitted in ipv4 addresses, + # see https://github.com/python/cpython/issues/103848 + return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}" + + +def get_open_zmq_ipc_path() -> str: + base_rpc_path = envs.VLLM_RPC_BASE_PATH + return f"ipc://{base_rpc_path}/{uuid4()}" + + +def get_open_zmq_inproc_path() -> str: + return f"inproc://{uuid4()}" + + +def get_open_port() -> int: + """ + Get an open port for the vLLM process to listen on. + An edge case to handle, is when we run data parallel, + we need to avoid ports that are potentially used by + the data parallel master process. + Right now we reserve 10 ports for the data parallel master + process. Currently it uses 2 ports. + """ + if "VLLM_DP_MASTER_PORT" in os.environ: + dp_port = envs.VLLM_DP_MASTER_PORT + while True: + port = _get_open_port() + if dp_port <= port < dp_port + 10: + continue + return port + return _get_open_port() + + +def _get_open_port() -> int: + port = envs.VLLM_PORT + if port is not None: + while True: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", port)) + return port + except OSError: + port += 1 # Increment port number if already in use + logger.info("Port %d is already in use, trying port %d", + port - 1, port) + # try ipv4 + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + except OSError: + # try ipv6 + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def find_process_using_port(port: int) -> Optional[psutil.Process]: + # TODO: We can not check for running processes with network + # port on macOS. Therefore, we can not have a full graceful shutdown + # of vLLM. For now, let's not look for processes in this case. + # Ref: https://www.florianreinhard.de/accessdenied-in-psutil/ + if sys.platform.startswith("darwin"): + return None + + for conn in psutil.net_connections(): + if conn.laddr.port == port: + try: + return psutil.Process(conn.pid) + except psutil.NoSuchProcess: + return None + return None + + +def update_environment_variables(envs: dict[str, str]): + for k, v in envs.items(): + if k in os.environ and os.environ[k] != v: + logger.warning( + "Overwriting environment variable %s " + "from '%s' to '%s'", k, os.environ[k], v) + os.environ[k] = v + + +def chunk_list(lst: list[T], chunk_size: int): + """Yield successive chunk_size chunks from lst.""" + for i in range(0, len(lst), chunk_size): + yield lst[i:i + chunk_size] + + +def cdiv(a: int, b: int) -> int: + """Ceiling division.""" + return -(a // -b) + + +def round_up(x: int, y: int) -> int: + return ((x + y - 1) // y) * y + + +def round_down(x: int, y: int) -> int: + return (x // y) * y + + +def _generate_random_fp8( + tensor: torch.Tensor, + low: float, + high: float, +) -> None: + # NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type, + # it may occur Inf or NaN if we directly use torch.randint + # to generate random data for fp8 data. + # For example, s.11111.00 in fp8e5m2 format represents Inf. + # | E4M3 | E5M2 + #-----|-------------|------------------- + # Inf | N/A | s.11111.00 + # NaN | s.1111.111 | s.11111.{01,10,11} + from vllm import _custom_ops as ops + tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) + tensor_tmp.uniform_(low, high) + ops.convert_fp8(tensor, tensor_tmp) + del tensor_tmp + + +def get_kv_cache_torch_dtype( + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype: + if isinstance(cache_dtype, str): + if cache_dtype == "auto": + if isinstance(model_dtype, str): + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype] + elif isinstance(model_dtype, torch.dtype): + torch_dtype = model_dtype + else: + raise ValueError(f"Invalid model dtype: {model_dtype}") + elif cache_dtype in ["half", "bfloat16", "float"]: + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] + elif cache_dtype == "fp8": + torch_dtype = torch.uint8 + else: + raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") + elif isinstance(cache_dtype, torch.dtype): + torch_dtype = cache_dtype + else: + raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") + return torch_dtype + + +def create_kv_caches_with_random_flash( + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None, + seed: Optional[int] = None, + device: Optional[str] = "cuda", +) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + from vllm.platforms import current_platform + current_platform.seed_everything(seed) + + torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) + key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) + scale = head_size**-0.5 + + key_caches: list[torch.Tensor] = [] + value_caches: list[torch.Tensor] = [] + + for _ in range(num_layers): + key_value_cache = torch.empty(size=key_value_cache_shape, + dtype=torch_dtype, + device=device) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + key_value_cache.uniform_(-scale, scale) + elif cache_dtype == 'fp8': + _generate_random_fp8(key_value_cache, -scale, scale) + else: + raise ValueError( + f"Does not support key cache of type {cache_dtype}") + key_caches.append(key_value_cache[:, 0]) + value_caches.append(key_value_cache[:, 1]) + return key_caches, value_caches + + +def create_kv_caches_with_random( + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None, + seed: Optional[int] = None, + device: Optional[str] = "cuda", +) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + + if cache_dtype == "fp8" and head_size % 16: + raise ValueError( + f"Does not support key cache of type fp8 with head_size {head_size}" + ) + from vllm.platforms import current_platform + current_platform.seed_everything(seed) + + torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) + + scale = head_size**-0.5 + x = 16 // torch.tensor([], dtype=torch_dtype).element_size() + key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) + key_caches: list[torch.Tensor] = [] + for _ in range(num_layers): + key_cache = torch.empty(size=key_cache_shape, + dtype=torch_dtype, + device=device) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + key_cache.uniform_(-scale, scale) + elif cache_dtype == 'fp8': + _generate_random_fp8(key_cache, -scale, scale) + else: + raise ValueError( + f"Does not support key cache of type {cache_dtype}") + key_caches.append(key_cache) + + value_cache_shape = (num_blocks, num_heads, head_size, block_size) + value_caches: list[torch.Tensor] = [] + for _ in range(num_layers): + value_cache = torch.empty(size=value_cache_shape, + dtype=torch_dtype, + device=device) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + value_cache.uniform_(-scale, scale) + elif cache_dtype == 'fp8': + _generate_random_fp8(value_cache, -scale, scale) + else: + raise ValueError( + f"Does not support value cache of type {cache_dtype}") + value_caches.append(value_cache) + return key_caches, value_caches + + +@cache +def is_pin_memory_available() -> bool: + from vllm.platforms import current_platform + return current_platform.is_pin_memory_available() + + +@cache +def is_uva_available() -> bool: + """Check if Unified Virtual Addressing (UVA) is available.""" + # UVA requires pinned memory. + # TODO: Add more requirements for UVA if needed. + return is_pin_memory_available() + + +class DeviceMemoryProfiler: + + def __init__(self, device: Optional[torch.types.Device] = None): + self.device = device + + def current_memory_usage(self) -> float: + # Return the memory usage in bytes. + from vllm.platforms import current_platform + return current_platform.get_current_memory_usage(self.device) + + def __enter__(self): + self.initial_memory = self.current_memory_usage() + # This allows us to call methods of the context manager if needed + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.final_memory = self.current_memory_usage() + self.consumed_memory = self.final_memory - self.initial_memory + + # Force garbage collection + gc.collect() + + +def make_ndarray_with_pad( + x: list[list[T]], + pad: T, + dtype: npt.DTypeLike, + *, + max_len: Optional[int] = None, +) -> npt.NDArray: + """ + Make a padded array from 2D inputs. + + The padding is applied to the end of each inner list until it reaches + `max_len`. + """ + if max_len is None: + # Unlike for most functions, map is faster than a genexpr over `len` + max_len = max(map(len, x), default=0) + + padded_x = np.full((len(x), max_len), pad, dtype=dtype) + for ind, blocktb in enumerate(x): + assert len(blocktb) <= max_len + padded_x[ind, :len(blocktb)] = blocktb + + return padded_x + + +def make_tensor_with_pad( + x: list[list[T]], + pad: T, + dtype: torch.dtype, + *, + max_len: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + pin_memory: bool = False, +) -> torch.Tensor: + """ + Make a padded tensor from 2D inputs. + + The padding is applied to the end of each inner list until it reaches + `max_len`. + """ + np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype] + padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len) + + tensor = torch.from_numpy(padded_x).to(device) + if pin_memory: + tensor = tensor.pin_memory() + + return tensor + + +def async_tensor_h2d( + data: list, + dtype: torch.dtype, + target_device: Union[str, torch.device], + pin_memory: bool, +) -> torch.Tensor: + """Asynchronously create a tensor and copy it from host to device.""" + t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu") + return t.to(device=target_device, non_blocking=True) + + +def get_dtype_size(dtype: torch.dtype) -> int: + """Get the size of the data type in bytes.""" + return torch.tensor([], dtype=dtype).element_size() + + +# `collections` helpers +def is_list_of( + value: object, + typ: Union[type[T], tuple[type[T], ...]], + *, + check: Literal["first", "all"] = "first", +) -> TypeIs[list[T]]: + if not isinstance(value, list): + return False + + if check == "first": + return len(value) == 0 or isinstance(value[0], typ) + elif check == "all": + return all(isinstance(v, typ) for v in value) + + assert_never(check) + + +def flatten_2d_lists(lists: Iterable[Iterable[T]]) -> list[T]: + """Flatten a list of lists to a single list.""" + return [item for sublist in lists for item in sublist] + + +def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]): + """ + Unlike :class:`itertools.groupby`, groups are not broken by + non-contiguous data. + """ + groups = defaultdict[_K, list[_V]](list) + + for value in values: + groups[key(value)].append(value) + + return groups.items() + + +# TODO: This function can be removed if transformer_modules classes are +# serialized by value when communicating between processes +def init_cached_hf_modules() -> None: + """ + Lazy initialization of the Hugging Face modules. + """ + from transformers.dynamic_module_utils import init_hf_modules + init_hf_modules() + + +@cache +def find_library(lib_name: str) -> str: + """ + Find the library file in the system. + `lib_name` is full filename, with both prefix and suffix. + This function resolves `lib_name` to the full path of the library. + """ + # Adapted from https://github.com/openai/triton/blob/main/third_party/nvidia/backend/driver.py#L19 # noqa + # According to https://en.wikipedia.org/wiki/Filesystem_Hierarchy_Standard + # `/sbin/ldconfig` should exist in all Linux systems. + # `/sbin/ldconfig` searches the library in the system + libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode() + # each line looks like the following: + # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1 + locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line] + # `LD_LIBRARY_PATH` searches the library in the user-defined paths + env_ld_library_path = envs.LD_LIBRARY_PATH + if not locs and env_ld_library_path: + locs = [ + os.path.join(dir, lib_name) + for dir in env_ld_library_path.split(":") + if os.path.exists(os.path.join(dir, lib_name)) + ] + if not locs: + raise ValueError(f"Cannot find {lib_name} in the system.") + return locs[0] + + +def find_nccl_library() -> str: + """ + We either use the library file specified by the `VLLM_NCCL_SO_PATH` + environment variable, or we find the library file brought by PyTorch. + After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be + found by `ctypes` automatically. + """ + so_file = envs.VLLM_NCCL_SO_PATH + + # manually load the nccl library + if so_file: + logger.info( + "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s", + so_file) + else: + if torch.version.cuda is not None: + so_file = "libnccl.so.2" + elif torch.version.hip is not None: + so_file = "librccl.so.1" + else: + raise ValueError("NCCL only supports CUDA and ROCm backends.") + logger.info("Found nccl from library %s", so_file) + return so_file + + +prev_set_stream = torch.cuda.set_stream + +_current_stream = None + + +def _patched_set_stream(stream: torch.cuda.Stream) -> None: + global _current_stream + _current_stream = stream + prev_set_stream(stream) + + +torch.cuda.set_stream = _patched_set_stream + + +def current_stream() -> torch.cuda.Stream: + """ + replace `torch.cuda.current_stream()` with `vllm.utils.current_stream()`. + it turns out that `torch.cuda.current_stream()` is quite expensive, + as it will construct a new stream object at each call. + here we patch `torch.cuda.set_stream` to keep track of the current stream + directly, so that we can avoid calling `torch.cuda.current_stream()`. + + the underlying hypothesis is that we do not call `torch._C._cuda_setStream` + from C/C++ code. + """ + from vllm.platforms import current_platform + global _current_stream + if _current_stream is None: + # when this function is called before any stream is set, + # we return the default stream. + # On ROCm using the default 0 stream in combination with RCCL + # is hurting performance. Therefore creating a dedicated stream + # per process + _current_stream = torch.cuda.Stream() if current_platform.is_rocm( + ) else torch.cuda.current_stream() + return _current_stream + + +def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None: + """Set up function tracing for the current thread, + if enabled via the VLLM_TRACE_FUNCTION environment variable + """ + + if envs.VLLM_TRACE_FUNCTION: + tmp_dir = tempfile.gettempdir() + # add username to tmp_dir to avoid permission issues + tmp_dir = os.path.join(tmp_dir, getpass.getuser()) + filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}" + f"_thread_{threading.get_ident()}_" + f"at_{datetime.datetime.now()}.log").replace(" ", "_") + log_path = os.path.join(tmp_dir, "vllm", + f"vllm-instance-{vllm_config.instance_id}", + filename) + os.makedirs(os.path.dirname(log_path), exist_ok=True) + enable_trace_function_call(log_path) + + +# `functools` helpers +def identity(value: T, **kwargs) -> T: + """Returns the first provided value.""" + return value + + +F = TypeVar('F', bound=Callable[..., Any]) + + +def deprecate_args( + start_index: int, + is_deprecated: Union[bool, Callable[[], bool]] = True, + additional_message: Optional[str] = None, +) -> Callable[[F], F]: + + if not callable(is_deprecated): + is_deprecated = partial(identity, is_deprecated) + + def wrapper(fn: F) -> F: + + params = inspect.signature(fn).parameters + pos_types = ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + pos_kws = [ + kw for kw, param in params.items() if param.kind in pos_types + ] + + @wraps(fn) + def inner(*args, **kwargs): + if is_deprecated(): + deprecated_args = pos_kws[start_index:len(args)] + if deprecated_args: + msg = ( + f"The positional arguments {deprecated_args} are " + "deprecated and will be removed in a future update.") + if additional_message is not None: + msg += f" {additional_message}" + + warnings.warn( + DeprecationWarning(msg), + stacklevel=3, # The inner function takes up one level + ) + + return fn(*args, **kwargs) + + return inner # type: ignore + + return wrapper + + +def deprecate_kwargs( + *kws: str, + is_deprecated: Union[bool, Callable[[], bool]] = True, + additional_message: Optional[str] = None, +) -> Callable[[F], F]: + deprecated_kws = set(kws) + + if not callable(is_deprecated): + is_deprecated = partial(identity, is_deprecated) + + def wrapper(fn: F) -> F: + + @wraps(fn) + def inner(*args, **kwargs): + if is_deprecated(): + deprecated_kwargs = kwargs.keys() & deprecated_kws + if deprecated_kwargs: + msg = ( + f"The keyword arguments {deprecated_kwargs} are " + "deprecated and will be removed in a future update.") + if additional_message is not None: + msg += f" {additional_message}" + + warnings.warn( + DeprecationWarning(msg), + stacklevel=3, # The inner function takes up one level + ) + + return fn(*args, **kwargs) + + return inner # type: ignore + + return wrapper + + +@lru_cache(maxsize=8) +def _cuda_device_count_stateless( + cuda_visible_devices: Optional[str] = None) -> int: + # Note: cuda_visible_devices is not used, but we keep it as an argument for + # LRU Cache purposes. + + # Code below is based on + # https://github.com/pytorch/pytorch/blob/ + # c1cd946818442aca8c7f812b16d187ce1586c3bc/ + # torch/cuda/__init__.py#L831C1-L831C17 + import torch.cuda + import torch.version + + from vllm.platforms import current_platform + if not torch.cuda._is_compiled(): + return 0 + if current_platform.is_rocm(): + # ROCm uses amdsmi instead of nvml for stateless device count + # This requires a sufficiently modern version of Torch 2.4.0 + raw_count = torch.cuda._device_count_amdsmi() if (hasattr( + torch.cuda, "_device_count_amdsmi")) else -1 + else: + raw_count = torch.cuda._device_count_nvml() + r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count + return r + + +def cuda_device_count_stateless() -> int: + """Get number of CUDA devices, caching based on the value of + CUDA_VISIBLE_DEVICES at the time of call. + + This should be used instead of torch.cuda.device_count() + unless CUDA_VISIBLE_DEVICES has already been set to the desired + value.""" + + # This can be removed and simply replaced with torch.cuda.get_device_count + # after https://github.com/pytorch/pytorch/pull/122815 is released. + return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES) + + +def cuda_is_initialized() -> bool: + """Check if CUDA is initialized.""" + if not torch.cuda._is_compiled(): + return False + return torch.cuda.is_initialized() + + +def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]: + """Make an instance method that weakly references + its associated instance and no-ops once that + instance is collected.""" + ref = weakref.ref(bound_method.__self__) # type: ignore[attr-defined] + unbound = bound_method.__func__ # type: ignore[attr-defined] + + def weak_bound(*args, **kwargs) -> None: + if inst := ref(): + unbound(inst, *args, **kwargs) + + return weak_bound + + +#From: https://stackoverflow.com/a/4104188/2749989 +def run_once(f: Callable[P, None]) -> Callable[P, None]: + + def wrapper(*args: P.args, **kwargs: P.kwargs) -> None: + if not wrapper.has_run: # type: ignore[attr-defined] + wrapper.has_run = True # type: ignore[attr-defined] + return f(*args, **kwargs) + + wrapper.has_run = False # type: ignore[attr-defined] + return wrapper + + +class StoreBoolean(argparse.Action): + + def __call__(self, parser, namespace, values, option_string=None): + if values.lower() == "true": + setattr(namespace, self.dest, True) + elif values.lower() == "false": + setattr(namespace, self.dest, False) + else: + raise ValueError(f"Invalid boolean value: {values}. " + "Expected 'true' or 'false'.") + + +class SortedHelpFormatter(argparse.ArgumentDefaultsHelpFormatter): + """SortedHelpFormatter that sorts arguments by their option strings.""" + + def add_arguments(self, actions): + actions = sorted(actions, key=lambda x: x.option_strings) + super().add_arguments(actions) + + +class FlexibleArgumentParser(argparse.ArgumentParser): + """ArgumentParser that allows both underscore and dash in names.""" + + def __init__(self, *args, **kwargs): + # Set the default 'formatter_class' to SortedHelpFormatter + if 'formatter_class' not in kwargs: + kwargs['formatter_class'] = SortedHelpFormatter + super().__init__(*args, **kwargs) + + def parse_args(self, args=None, namespace=None): + if args is None: + args = sys.argv[1:] + + # Check for --model in command line arguments first + if args and args[0] == "serve": + model_in_cli_args = any(arg == '--model' for arg in args) + + if model_in_cli_args: + raise ValueError( + "With `vllm serve`, you should provide the model as a " + "positional argument or in a config file instead of via " + "the `--model` option.") + + if '--config' in args: + args = self._pull_args_from_config(args) + + # Convert underscores to dashes and vice versa in argument names + processed_args = [] + for arg in args: + if arg.startswith('--'): + if '=' in arg: + key, value = arg.split('=', 1) + key = '--' + key[len('--'):].replace('_', '-') + processed_args.append(f'{key}={value}') + else: + processed_args.append('--' + + arg[len('--'):].replace('_', '-')) + elif arg.startswith('-O') and arg != '-O' and len(arg) == 2: + # allow -O flag to be used without space, e.g. -O3 + processed_args.append('-O') + processed_args.append(arg[2:]) + else: + processed_args.append(arg) + + return super().parse_args(processed_args, namespace) + + def check_port(self, value): + try: + value = int(value) + except ValueError: + msg = "Port must be an integer" + raise argparse.ArgumentTypeError(msg) from None + + if not (1024 <= value <= 65535): + raise argparse.ArgumentTypeError( + "Port must be between 1024 and 65535") + + return value + + def _pull_args_from_config(self, args: list[str]) -> list[str]: + """Method to pull arguments specified in the config file + into the command-line args variable. + + The arguments in config file will be inserted between + the argument list. + + example: + ```yaml + port: 12323 + tensor-parallel-size: 4 + ``` + ```python + $: vllm {serve,chat,complete} "facebook/opt-12B" \ + --config config.yaml -tp 2 + $: args = [ + "serve,chat,complete", + "facebook/opt-12B", + '--config', 'config.yaml', + '-tp', '2' + ] + $: args = [ + "serve,chat,complete", + "facebook/opt-12B", + '--port', '12323', + '--tensor-parallel-size', '4', + '-tp', '2' + ] + ``` + + Please note how the config args are inserted after the sub command. + this way the order of priorities is maintained when these are args + parsed by super(). + """ + assert args.count( + '--config') <= 1, "More than one config file specified!" + + index = args.index('--config') + if index == len(args) - 1: + raise ValueError("No config file specified! \ + Please check your command-line arguments.") + + file_path = args[index + 1] + + config_args = self._load_config_file(file_path) + + # 0th index is for {serve,chat,complete} + # optionally followed by model_tag (only for serve) + # followed by config args + # followed by rest of cli args. + # maintaining this order will enforce the precedence + # of cli > config > defaults + if args[0] == "serve": + model_in_cli = len(args) > 1 and not args[1].startswith('-') + model_in_config = any(arg == '--model' for arg in config_args) + + if not model_in_cli and not model_in_config: + raise ValueError( + "No model specified! Please specify model either " + "as a positional argument or in a config file.") + + if model_in_cli: + # Model specified as positional arg, keep CLI version + args = [args[0]] + [ + args[1] + ] + config_args + args[2:index] + args[index + 2:] + else: + # No model in CLI, use config if available + args = [args[0] + ] + config_args + args[1:index] + args[index + 2:] + else: + args = [args[0]] + config_args + args[1:index] + args[index + 2:] + + return args + + def _load_config_file(self, file_path: str) -> list[str]: + """Loads a yaml file and returns the key value pairs as a + flattened list with argparse like pattern + ```yaml + port: 12323 + tensor-parallel-size: 4 + ``` + returns: + processed_args: list[str] = [ + '--port': '12323', + '--tensor-parallel-size': '4' + ] + """ + extension: str = file_path.split('.')[-1] + if extension not in ('yaml', 'yml'): + raise ValueError( + "Config file must be of a yaml/yml type.\ + %s supplied", extension) + + # only expecting a flat dictionary of atomic types + processed_args: list[str] = [] + + config: dict[str, Union[int, str]] = {} + try: + with open(file_path) as config_file: + config = yaml.safe_load(config_file) + except Exception as ex: + logger.error( + "Unable to read the config file at %s. \ + Make sure path is correct", file_path) + raise ex + + store_boolean_arguments = [ + action.dest for action in self._actions + if isinstance(action, StoreBoolean) + ] + + for key, value in config.items(): + if isinstance(value, bool) and key not in store_boolean_arguments: + if value: + processed_args.append('--' + key) + else: + processed_args.append('--' + key) + processed_args.append(str(value)) + + return processed_args + + +async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, + **kwargs): + """Utility function to run async task in a lock""" + async with lock: + return await task(*args, **kwargs) + + +def supports_kw( + callable: Callable[..., object], + kw_name: str, + *, + requires_kw_only: bool = False, + allow_var_kwargs: bool = True, +) -> bool: + """Check if a keyword is a valid kwarg for a callable; if requires_kw_only + disallows kwargs names that can also be positional arguments. + """ + params = inspect.signature(callable).parameters + if not params: + return False + + param_val = params.get(kw_name) + + # Types where the it may be valid, i.e., explicitly defined & nonvariadic + passable_kw_types = set((inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY)) + + if param_val: + is_sig_param = param_val.kind in passable_kw_types + # We want kwargs only, but this is passable as a positional arg + if (requires_kw_only and is_sig_param + and param_val.kind != inspect.Parameter.KEYWORD_ONLY): + return False + if ((requires_kw_only + and param_val.kind == inspect.Parameter.KEYWORD_ONLY) + or (not requires_kw_only and is_sig_param)): + return True + + # If we're okay with var-kwargs, it's supported as long as + # the kw_name isn't something like *args, **kwargs + if allow_var_kwargs: + # Get the last param; type is ignored here because params is a proxy + # mapping, but it wraps an ordered dict, and they appear in order. + # Ref: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters + last_param = params[next(reversed(params))] # type: ignore + return (last_param.kind == inspect.Parameter.VAR_KEYWORD + and last_param.name != kw_name) + return False + + +def resolve_mm_processor_kwargs( + init_kwargs: Optional[Mapping[str, object]], + inference_kwargs: Optional[Mapping[str, object]], + callable: Callable[..., object], + *, + requires_kw_only: bool = True, + allow_var_kwargs: bool = False, +) -> dict[str, Any]: + """Applies filtering to eliminate invalid mm_processor_kwargs, i.e., + those who are not explicit keywords to the given callable (of one is + given; otherwise no filtering is done), then merges the kwarg dicts, + giving priority to inference_kwargs if there are any collisions. + + In the case that no kwarg overrides are provided, returns an empty + dict so that it can still be kwarg expanded into the callable later on. + + If allow_var_kwargs=True, allows for things that can be expanded into + kwargs as long as they aren't naming collision for var_kwargs or potential + positional arguments. + """ + # Filter inference time multimodal processor kwargs provided + runtime_mm_kwargs = get_allowed_kwarg_only_overrides( + callable, + overrides=inference_kwargs, + requires_kw_only=requires_kw_only, + allow_var_kwargs=allow_var_kwargs, + ) + + # Filter init time multimodal processor kwargs provided + init_mm_kwargs = get_allowed_kwarg_only_overrides( + callable, + overrides=init_kwargs, + requires_kw_only=requires_kw_only, + allow_var_kwargs=allow_var_kwargs, + ) + + # Merge the final processor kwargs, prioritizing inference + # time values over the initialization time values. + mm_processor_kwargs = {**init_mm_kwargs, **runtime_mm_kwargs} + return mm_processor_kwargs + + +def get_allowed_kwarg_only_overrides( + callable: Callable[..., object], + overrides: Optional[Mapping[str, object]], + *, + requires_kw_only: bool = True, + allow_var_kwargs: bool = False, +) -> dict[str, Any]: + """ + Given a callable which has one or more keyword only params and a dict + mapping param names to values, drop values that can be not be kwarg + expanded to overwrite one or more keyword-only args. This is used in a + few places to handle custom processor overrides for multimodal models, + e.g., for profiling when processor options provided by the user + may affect the number of mm tokens per instance. + + Args: + callable: Callable which takes 0 or more keyword only arguments. + If None is provided, all overrides names are allowed. + overrides: Potential overrides to be used when invoking the callable. + allow_var_kwargs: Allows overrides that are expandable for var kwargs. + + Returns: + Dictionary containing the kwargs to be leveraged which may be used + to overwrite one or more keyword only arguments when invoking the + callable. + """ + if not overrides: + return {} + + # Drop any mm_processor_kwargs provided by the user that + # are not kwargs, unless it can fit it var_kwargs param + filtered_overrides = { + kwarg_name: val + for kwarg_name, val in overrides.items() + if supports_kw(callable, + kwarg_name, + requires_kw_only=requires_kw_only, + allow_var_kwargs=allow_var_kwargs) + } + + # If anything is dropped, log a warning + dropped_keys = overrides.keys() - filtered_overrides.keys() + if dropped_keys: + if requires_kw_only: + logger.warning( + "The following intended overrides are not keyword-only args " + "and will be dropped: %s", dropped_keys) + else: + logger.warning( + "The following intended overrides are not keyword args " + "and will be dropped: %s", dropped_keys) + + return filtered_overrides + + +# Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0. +# In particular, the FakeScalarType is not supported for earlier versions of +# PyTorch which breaks dynamo for any ops registered using ScalarType. +def supports_dynamo() -> bool: + base_torch_version = Version(Version(torch.__version__).base_version) + return base_torch_version >= Version("2.4.0") + + +# Some backends use pytorch version < 2.4.0 which doesn't +# support `torch.library.custom_op`. +def supports_custom_op() -> bool: + return False + return hasattr(torch.library, "custom_op") + + +class AtomicCounter: + """An atomic, thread-safe counter""" + + def __init__(self, initial=0): + """Initialize a new atomic counter to given initial value""" + self._value = initial + self._lock = threading.Lock() + + def inc(self, num=1): + """Atomically increment the counter by num and return the new value""" + with self._lock: + self._value += num + return self._value + + def dec(self, num=1): + """Atomically decrement the counter by num and return the new value""" + with self._lock: + self._value -= num + return self._value + + @property + def value(self): + return self._value + + +# Adapted from: https://stackoverflow.com/a/47212782/5082708 +class LazyDict(Mapping[str, T], Generic[T]): + + def __init__(self, factory: dict[str, Callable[[], T]]): + self._factory = factory + self._dict: dict[str, T] = {} + + def __getitem__(self, key: str) -> T: + if key not in self._dict: + if key not in self._factory: + raise KeyError(key) + self._dict[key] = self._factory[key]() + return self._dict[key] + + def __setitem__(self, key: str, value: Callable[[], T]): + self._factory[key] = value + + def __iter__(self): + return iter(self._factory) + + def __len__(self): + return len(self._factory) + + +class ClassRegistry(UserDict[Type[T], _V]): + + def __getitem__(self, key: Type[T]) -> _V: + for cls in key.mro(): + if cls in self.data: + return self.data[cls] + + raise KeyError(key) + + def __contains__(self, key: object) -> bool: + return self.contains(key) + + def contains(self, key: object, *, strict: bool = False) -> bool: + if not isinstance(key, type): + return False + + if strict: + return key in self.data + + return any(cls in self.data for cls in key.mro()) + + +def weak_ref_tensor(tensor: Any) -> Any: + """ + Create a weak reference to a tensor. + The new tensor will share the same data as the original tensor, + but will not keep the original tensor alive. + """ + if isinstance(tensor, torch.Tensor): + return ops.weak_ref_tensor(tensor) + else: + return tensor + + +def weak_ref_tensors( + tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]] +) -> Union[torch.Tensor, list[Any], tuple[Any], Any]: + """ + Convenience function to create weak references to tensors, + for single tensor, list of tensors or tuple of tensors. + """ + if isinstance(tensors, torch.Tensor): + return weak_ref_tensor(tensors) + if isinstance(tensors, list): + return [weak_ref_tensor(t) for t in tensors] + if isinstance(tensors, tuple): + return tuple(weak_ref_tensor(t) for t in tensors) + raise ValueError("Invalid type for tensors") + + +def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor: + """ + Get a CUDA view of a CPU tensor using Unified Virtual Addressing (UVA). + """ + assert cpu_tensor.is_pinned(), "CPU tensor must be pinned" + return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor) + + +def is_in_doc_build() -> bool: + try: + from sphinx.ext.autodoc.mock import _MockModule + return isinstance(torch, _MockModule) + except ModuleNotFoundError: + return False + + +def import_from_path(module_name: str, file_path: Union[str, os.PathLike]): + """ + Import a Python file according to its file path. + + Based on the official recipe: + https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly + """ + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ModuleNotFoundError(f"No module named '{module_name}'") + + assert spec.loader is not None + + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +@cache +def get_vllm_optional_dependencies(): + metadata = importlib.metadata.metadata("vllm") + requirements = metadata.get_all("Requires-Dist", []) + extras = metadata.get_all("Provides-Extra", []) + + return { + extra: [ + re.split(r";|>=|<=|==", req)[0] for req in requirements + if req.endswith(f'extra == "{extra}"') + ] + for extra in extras + } + + +class _PlaceholderBase: + """ + Disallows downstream usage of placeholder modules. + + We need to explicitly override each dunder method because + :meth:`__getattr__` is not called when they are accessed. + + See also: + [Special method lookup](https://docs.python.org/3/reference/datamodel.html#special-lookup) + """ + + def __getattr__(self, key: str) -> Never: + """ + The main class should implement this to throw an error + for attribute accesses representing downstream usage. + """ + raise NotImplementedError + + # [Basic customization] + + def __lt__(self, other: object): + return self.__getattr__("__lt__") + + def __le__(self, other: object): + return self.__getattr__("__le__") + + def __eq__(self, other: object): + return self.__getattr__("__eq__") + + def __ne__(self, other: object): + return self.__getattr__("__ne__") + + def __gt__(self, other: object): + return self.__getattr__("__gt__") + + def __ge__(self, other: object): + return self.__getattr__("__ge__") + + def __hash__(self): + return self.__getattr__("__hash__") + + def __bool__(self): + return self.__getattr__("__bool__") + + # [Callable objects] + + def __call__(self, *args: object, **kwargs: object): + return self.__getattr__("__call__") + + # [Container types] + + def __len__(self): + return self.__getattr__("__len__") + + def __getitem__(self, key: object): + return self.__getattr__("__getitem__") + + def __setitem__(self, key: object, value: object): + return self.__getattr__("__setitem__") + + def __delitem__(self, key: object): + return self.__getattr__("__delitem__") + + # __missing__ is optional according to __getitem__ specification, + # so it is skipped + + # __iter__ and __reversed__ have a default implementation + # based on __len__ and __getitem__, so they are skipped. + + # [Numeric Types] + + def __add__(self, other: object): + return self.__getattr__("__add__") + + def __sub__(self, other: object): + return self.__getattr__("__sub__") + + def __mul__(self, other: object): + return self.__getattr__("__mul__") + + def __matmul__(self, other: object): + return self.__getattr__("__matmul__") + + def __truediv__(self, other: object): + return self.__getattr__("__truediv__") + + def __floordiv__(self, other: object): + return self.__getattr__("__floordiv__") + + def __mod__(self, other: object): + return self.__getattr__("__mod__") + + def __divmod__(self, other: object): + return self.__getattr__("__divmod__") + + def __pow__(self, other: object, modulo: object = ...): + return self.__getattr__("__pow__") + + def __lshift__(self, other: object): + return self.__getattr__("__lshift__") + + def __rshift__(self, other: object): + return self.__getattr__("__rshift__") + + def __and__(self, other: object): + return self.__getattr__("__and__") + + def __xor__(self, other: object): + return self.__getattr__("__xor__") + + def __or__(self, other: object): + return self.__getattr__("__or__") + + # r* and i* methods have lower priority than + # the methods for left operand so they are skipped + + def __neg__(self): + return self.__getattr__("__neg__") + + def __pos__(self): + return self.__getattr__("__pos__") + + def __abs__(self): + return self.__getattr__("__abs__") + + def __invert__(self): + return self.__getattr__("__invert__") + + # __complex__, __int__ and __float__ have a default implementation + # based on __index__, so they are skipped. + + def __index__(self): + return self.__getattr__("__index__") + + def __round__(self, ndigits: object = ...): + return self.__getattr__("__round__") + + def __trunc__(self): + return self.__getattr__("__trunc__") + + def __floor__(self): + return self.__getattr__("__floor__") + + def __ceil__(self): + return self.__getattr__("__ceil__") + + # [Context managers] + + def __enter__(self): + return self.__getattr__("__enter__") + + def __exit__(self, *args: object, **kwargs: object): + return self.__getattr__("__exit__") + + +class PlaceholderModule(_PlaceholderBase): + """ + A placeholder object to use when a module does not exist. + + This enables more informative errors when trying to access attributes + of a module that does not exists. + """ + + def __init__(self, name: str) -> None: + super().__init__() + + # Apply name mangling to avoid conflicting with module attributes + self.__name = name + + def placeholder_attr(self, attr_path: str): + return _PlaceholderModuleAttr(self, attr_path) + + def __getattr__(self, key: str): + name = self.__name + + try: + importlib.import_module(name) + except ImportError as exc: + for extra, names in get_vllm_optional_dependencies().items(): + if name in names: + msg = f"Please install vllm[{extra}] for {extra} support" + raise ImportError(msg) from exc + + raise exc + + raise AssertionError("PlaceholderModule should not be used " + "when the original module can be imported") + + +class _PlaceholderModuleAttr(_PlaceholderBase): + + def __init__(self, module: PlaceholderModule, attr_path: str) -> None: + super().__init__() + + # Apply name mangling to avoid conflicting with module attributes + self.__module = module + self.__attr_path = attr_path + + def placeholder_attr(self, attr_path: str): + return _PlaceholderModuleAttr(self.__module, + f"{self.__attr_path}.{attr_path}") + + def __getattr__(self, key: str): + getattr(self.__module, f"{self.__attr_path}.{key}") + + raise AssertionError("PlaceholderModule should not be used " + "when the original module can be imported") + + +# create a library to hold the custom op +vllm_lib = Library("vllm", "FRAGMENT") # noqa + + +def direct_register_custom_op( + op_name: str, + op_func: Callable, + mutates_args: list[str], + fake_impl: Optional[Callable] = None, + target_lib: Optional[Library] = None, + dispatch_key: str = "CUDA", +): + """ + `torch.library.custom_op` can have significant overhead because it + needs to consider complicated dispatching logic. This function + directly registers a custom op and dispatches it to the CUDA backend. + See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5 + for more details. + + By default, the custom op is registered to the vLLM library. If you + want to register it to a different library, you can pass the library + object to the `target_lib` argument. + + IMPORTANT: the lifetime of the operator is tied to the lifetime of the + library object. If you want to bind the operator to a different library, + make sure the library object is alive when the operator is used. + """ + if is_in_doc_build(): + return + + if not supports_custom_op(): + # TODO fix this.. + return + assert not current_platform.is_cuda_alike(), ( + "cuda platform needs torch>=2.4 to support custom op, " + "chances are you are using an old version of pytorch " + "or a custom build of pytorch. It is recommended to " + "use vLLM in a fresh new environment and let it install " + "the required dependencies.") + return + + import torch.library + if hasattr(torch.library, "infer_schema"): + schema_str = torch.library.infer_schema(op_func, + mutates_args=mutates_args) + else: + # for pytorch 2.4 + import torch._custom_op.impl + schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args) + my_lib = target_lib or vllm_lib + my_lib.define(op_name + schema_str) + my_lib.impl(op_name, op_func, dispatch_key=dispatch_key) + if fake_impl is not None: + my_lib._register_fake(op_name, fake_impl) + + +def resolve_obj_by_qualname(qualname: str) -> Any: + """ + Resolve an object by its fully qualified name. + """ + module_name, obj_name = qualname.rsplit(".", 1) + module = importlib.import_module(module_name) + return getattr(module, obj_name) + + +def kill_process_tree(pid: int): + """ + Kills all descendant processes of the given pid by sending SIGKILL. + + Args: + pid (int): Process ID of the parent process + """ + try: + parent = psutil.Process(pid) + except psutil.NoSuchProcess: + return + + # Get all children recursively + children = parent.children(recursive=True) + + # Send SIGKILL to all children first + for child in children: + with contextlib.suppress(ProcessLookupError): + os.kill(child.pid, signal.SIGKILL) + + # Finally kill the parent + with contextlib.suppress(ProcessLookupError): + os.kill(pid, signal.SIGKILL) + + +@dataclass +class MemorySnapshot: + """Memory snapshot.""" + torch_peak: int = 0 + cuda_memory: int = 0 + torch_memory: int = 0 + non_torch_memory: int = 0 + timestamp: float = 0.0 + auto_measure: bool = True + + def __post_init__(self): + if self.auto_measure: + self.measure() + + def measure(self): + # we measure the torch peak memory usage via allocated_bytes, + # rather than `torch.cuda.memory_reserved()` . + # After `torch.cuda.reset_peak_memory_stats()`, + # `torch.cuda.memory_reserved()` will keep growing, and only shrink + # when we call `torch.cuda.empty_cache()` or OOM happens. + self.torch_peak = torch.cuda.memory_stats().get( + "allocated_bytes.all.peak", 0) + + self.cuda_memory = torch.cuda.mem_get_info( + )[1] - torch.cuda.mem_get_info()[0] + + # torch.cuda.memory_reserved() is how many bytes + # PyTorch gets from cuda (by calling cudaMalloc, etc.) + # this is used to measure the non-torch memory usage + self.torch_memory = torch.cuda.memory_reserved() + + self.non_torch_memory = self.cuda_memory - self.torch_memory + self.timestamp = time.time() + + def __sub__(self, other: MemorySnapshot) -> MemorySnapshot: + return MemorySnapshot( + torch_peak=self.torch_peak - other.torch_peak, + cuda_memory=self.cuda_memory - other.cuda_memory, + torch_memory=self.torch_memory - other.torch_memory, + non_torch_memory=self.non_torch_memory - other.non_torch_memory, + timestamp=self.timestamp - other.timestamp, + auto_measure=False, + ) + + +@dataclass +class MemoryProfilingResult: + """Memory profiling result. All numbers are in bytes. + """ + non_kv_cache_memory: int = 0 + torch_peak_increase: int = 0 + non_torch_increase: int = 0 + weights_memory: float = 0 + before_create: MemorySnapshot = field(default_factory=MemorySnapshot) + before_profile: MemorySnapshot = field(default_factory=MemorySnapshot) + after_profile: MemorySnapshot = field(default_factory=MemorySnapshot) + profile_time: float = 0.0 + + +@contextlib.contextmanager +def memory_profiling( + baseline_snapshot: MemorySnapshot, + weights_memory: int) -> Generator[MemoryProfilingResult, None, None]: + """Memory profiling context manager. + baseline_snapshot: the memory snapshot before the current vLLM instance. + weights_memory: memory used by PyTorch when loading the model weights. + Note that, before loading the model weights, we also initialize the device + and distributed environment, which may consume some memory. This part is not + included in the weights_memory because PyTorch does not control it. + + The memory in one GPU can be classified into 3 categories: + 1. memory used by anything other than the current vLLM instance. + 2. memory used by torch in the current vLLM instance. + 3. memory used in the current vLLM instance, but not by torch. + + A quantitive example: + + Before creating the current vLLM instance: + category 1: 1 GiB + category 2: 0 GiB + category 3: 0 GiB + + After creating the current vLLM instance and loading the model, + (i.e. before profiling): + category 1: 1 GiB + category 2: 2 GiB (model weights take 2 GiB) + category 3: 0.5 GiB (memory used by NCCL) + + During profiling (peak): + category 1: 1 GiB + category 2: 4 GiB (peak activation tensors take 2 GiB) + category 3: 1 GiB (memory used by NCCL + buffers for some attention backends) + + After profiling: + category 1: 1 GiB + category 2: 3 GiB (after garbage-collecting activation tensors) + category 3: 1 GiB (memory used by NCCL + buffers for some attention backends) + + In this case, non-kv cache takes 5 GiB in total, including: + a. 2 GiB used by the model weights (category 2) + b. 2 GiB reserved for the peak activation tensors (category 2) + c. 1 GiB used by non-torch components (category 3) + + The memory used for loading weights (a.) is directly given from the argument `weights_memory`. + + The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.). + + The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.). + """ # noqa + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + result = MemoryProfilingResult() + + result.before_create = baseline_snapshot + # the part of memory used for holding the model weights + result.weights_memory = weights_memory + + result.before_profile.measure() + + yield result + + gc.collect() + torch.cuda.empty_cache() + + result.after_profile.measure() + + diff_profile = result.after_profile - result.before_profile + diff_from_create = result.after_profile - result.before_create + result.torch_peak_increase = diff_profile.torch_peak + result.non_torch_increase = diff_from_create.non_torch_memory + result.profile_time = diff_profile.timestamp + result.non_kv_cache_memory = result.non_torch_increase + result.torch_peak_increase + result.weights_memory # noqa + + +# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501 +def set_ulimit(target_soft_limit=65535): + if sys.platform.startswith('win'): + logger.info("Windows detected, skipping ulimit adjustment.") + return + + import resource + resource_type = resource.RLIMIT_NOFILE + current_soft, current_hard = resource.getrlimit(resource_type) + + if current_soft < target_soft_limit: + try: + resource.setrlimit(resource_type, + (target_soft_limit, current_hard)) + except ValueError as e: + logger.warning( + "Found ulimit of %s and failed to automatically increase " + "with error %s. This can cause fd limit errors like " + "`OSError: [Errno 24] Too many open files`. Consider " + "increasing with ulimit -n", current_soft, e) + + +# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/utils.py#L28 # noqa: E501 +def get_exception_traceback(): + etype, value, tb = sys.exc_info() + err_str = "".join(traceback.format_exception(etype, value, tb)) + return err_str + + +# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L783 # noqa: E501 +def make_zmq_socket( + ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined] + path: str, + socket_type: Any, +) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined] + """Make a ZMQ socket with the proper bind/connect semantics.""" + + mem = psutil.virtual_memory() + socket = ctx.socket(socket_type) + + # Calculate buffer size based on system memory + total_mem = mem.total / 1024**3 + available_mem = mem.available / 1024**3 + # For systems with substantial memory (>32GB total, >16GB available): + # - Set a large 0.5GB buffer to improve throughput + # For systems with less memory: + # - Use system default (-1) to avoid excessive memory consumption + if total_mem > 32 and available_mem > 16: + buf_size = int(0.5 * 1024**3) # 0.5GB in bytes + else: + buf_size = -1 # Use system default buffer size + + if socket_type == zmq.constants.PULL: + socket.setsockopt(zmq.constants.RCVHWM, 0) + socket.setsockopt(zmq.constants.RCVBUF, buf_size) + socket.bind(path) + elif socket_type == zmq.constants.PUSH: + socket.setsockopt(zmq.constants.SNDHWM, 0) + socket.setsockopt(zmq.constants.SNDBUF, buf_size) + socket.connect(path) + else: + raise ValueError(f"Unknown Socket Type: {socket_type}") + + return socket + + +@contextlib.contextmanager +def zmq_socket_ctx( + path: str, + socket_type: Any, + linger: int = 0, +) -> Iterator[zmq.Socket]: + """Context manager for a ZMQ socket""" + + ctx = zmq.Context() # type: ignore[attr-defined] + try: + yield make_zmq_socket(ctx, path, socket_type) + + except KeyboardInterrupt: + logger.debug("Got Keyboard Interrupt.") + + finally: + ctx.destroy(linger=linger) + + +def is_in_ray_actor(): + """Check if we are in a Ray actor.""" + + try: + import ray + return (ray.is_initialized() + and ray.get_runtime_context().get_actor_id() is not None) + except ImportError: + return False + + +def _maybe_force_spawn(): + """Check if we need to force the use of the `spawn` multiprocessing start + method. + """ + if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") == "spawn": + return + + reason = None + if cuda_is_initialized(): + reason = "CUDA is initialized" + elif is_in_ray_actor(): + # even if we choose to spawn, we need to pass the ray address + # to the subprocess so that it knows how to connect to the ray cluster. + # env vars are inherited by subprocesses, even if we use spawn. + import ray + os.environ["RAY_ADDRESS"] = ray.get_runtime_context().gcs_address + reason = "In a Ray actor and can only be spawned" + + if reason is not None: + logger.warning( + "We must use the `spawn` multiprocessing start method. " + "Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. " + "See https://docs.vllm.ai/en/latest/getting_started/" + "troubleshooting.html#python-multiprocessing " + "for more information. Reason: %s", reason) + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + + +def get_mp_context(): + """Get a multiprocessing context with a particular method (spawn or fork). + By default we follow the value of the VLLM_WORKER_MULTIPROC_METHOD to + determine the multiprocessing method (default is fork). However, under + certain conditions, we may enforce spawn and override the value of + VLLM_WORKER_MULTIPROC_METHOD. + """ + _maybe_force_spawn() + mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD + return multiprocessing.get_context(mp_method) + + +def bind_kv_cache( + ctx: dict[str, Any], + kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index] +) -> None: + # Bind the kv_cache tensor to Attention modules, similar to + # ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)] + # Special things handled here: + # 1. Some models have non-attention layers, e.g., Jamba + # 2. Pipeline parallelism, each rank only has a subset of layers + # 3. Encoder attention has no kv cache + # 4. Encoder-decoder models, encoder-decoder attention and decoder-only + # attention of the same layer (e.g., bart's decoder.layers.1.self_attn + # and decoder.layers.1.encoder_attn) is mapped to the same kv cache + # tensor + from vllm.attention import AttentionType + from vllm.model_executor.models.utils import extract_layer_index + layer_need_kv_cache = [ + layer_name for layer_name in ctx + if (hasattr(ctx[layer_name], 'attn_type') and ctx[layer_name].attn_type + in (AttentionType.DECODER, AttentionType.ENCODER_DECODER)) + ] + layer_index_sorted = sorted( + set( + extract_layer_index(layer_name) + for layer_name in layer_need_kv_cache)) + for layer_name in layer_need_kv_cache: + kv_cache_idx = layer_index_sorted.index( + extract_layer_index(layer_name)) + forward_ctx = ctx[layer_name] + assert len(forward_ctx.kv_cache) == len(kv_cache) + for ve, ve_kv_cache in enumerate(kv_cache): + forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx] + +def bind_kv_cache_scale( + ctx: dict[str, Any], + kv_cache_scale: list[list[torch.Tensor]], # [virtual_engine][layer_index] +) -> None: + # Bind the kv_cache tensor to Attention modules, similar to + # ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)] + # Special things handled here: + # 1. Some models have non-attention layers, e.g., Jamba + # 2. Pipeline parallelism, each rank only has a subset of layers + # 3. Encoder attention has no kv cache + # 4. Encoder-decoder models, encoder-decoder attention and decoder-only + # attention of the same layer (e.g., bart's decoder.layers.1.self_attn + # and decoder.layers.1.encoder_attn) is mapped to the same kv cache + # tensor + from vllm.attention import AttentionType + from vllm.model_executor.models.utils import extract_layer_index + layer_need_kv_cache_scale = [ + layer_name for layer_name in ctx + if (hasattr(ctx[layer_name], 'attn_type') and ctx[layer_name].attn_type + in (AttentionType.DECODER, AttentionType.ENCODER_DECODER)) + ] + layer_index_sorted = sorted( + set( + extract_layer_index(layer_name) + for layer_name in layer_need_kv_cache_scale)) + for layer_name in layer_need_kv_cache_scale: + kv_cache_idx = layer_index_sorted.index( + extract_layer_index(layer_name)) + forward_ctx = ctx[layer_name] + assert len(forward_ctx.kv_cache_scale) == len(kv_cache_scale) + for ve, ve_kv_cache_scale in enumerate(kv_cache_scale): + forward_ctx.kv_cache_scale[ve] = ve_kv_cache_scale[kv_cache_idx] + +def run_method(obj: Any, method: Union[str, bytes, Callable], args: tuple[Any], + kwargs: dict[str, Any]) -> Any: + """ + Run a method of an object with the given arguments and keyword arguments. + If the method is string, it will be converted to a method using getattr. + If the method is serialized bytes and will be deserialized using + cloudpickle. + If the method is a callable, it will be called directly. + """ + if isinstance(method, bytes): + func = partial(cloudpickle.loads(method), obj) + elif isinstance(method, str): + try: + func = getattr(obj, method) + except AttributeError: + raise NotImplementedError(f"Method {method!r} is not" + " implemented.") from None + else: + func = partial(method, obj) # type: ignore + return func(*args, **kwargs) + + +def import_pynvml(): + """ + Historical comments: + + libnvml.so is the library behind nvidia-smi, and + pynvml is a Python wrapper around it. We use it to get GPU + status without initializing CUDA context in the current process. + Historically, there are two packages that provide pynvml: + - `nvidia-ml-py` (https://pypi.org/project/nvidia-ml-py/): The official + wrapper. It is a dependency of vLLM, and is installed when users + install vLLM. It provides a Python module named `pynvml`. + - `pynvml` (https://pypi.org/project/pynvml/): An unofficial wrapper. + Prior to version 12.0, it also provides a Python module `pynvml`, + and therefore conflicts with the official one. What's worse, + the module is a Python package, and has higher priority than + the official one which is a standalone Python file. + This causes errors when both of them are installed. + Starting from version 12.0, it migrates to a new module + named `pynvml_utils` to avoid the conflict. + It is so confusing that many packages in the community use the + unofficial one by mistake, and we have to handle this case. + For example, `nvcr.io/nvidia/pytorch:24.12-py3` uses the unofficial + one, and it will cause errors, see the issue + https://github.com/vllm-project/vllm/issues/12847 for example. + After all the troubles, we decide to copy the official `pynvml` + module to our codebase, and use it directly. + """ + import vllm.third_party.pynvml as pynvml + return pynvml + + +def warn_for_unimplemented_methods(cls: type[T]) -> type[T]: + """ + A replacement for `abc.ABC`. + When we use `abc.ABC`, subclasses will fail to instantiate + if they do not implement all abstract methods. + Here, we only require `raise NotImplementedError` in the + base class, and log a warning if the method is not implemented + in the subclass. + """ + + original_init = cls.__init__ + + def find_unimplemented_methods(self: object): + unimplemented_methods = [] + for attr_name in dir(self): + # bypass inner method + if attr_name.startswith('_'): + continue + + try: + attr = getattr(self, attr_name) + # get the func of callable method + if callable(attr): + attr_func = attr.__func__ + except AttributeError: + continue + src = inspect.getsource(attr_func) + if "NotImplementedError" in src: + unimplemented_methods.append(attr_name) + if unimplemented_methods: + method_names = ','.join(unimplemented_methods) + msg = (f"Methods {method_names} not implemented in {self}") + logger.warning(msg) + + @wraps(original_init) + def wrapped_init(self, *args, **kwargs) -> None: + original_init(self, *args, **kwargs) + find_unimplemented_methods(self) + + type.__setattr__(cls, '__init__', wrapped_init) + return cls + + +class LazyLoader(types.ModuleType): + """ + LazyLoader module borrowed from Tensorflow + https://github.com/tensorflow/tensorflow/blob/main/tensorflow/python/util/lazy_loader.py + with a addition of "module caching". + + Lazily import a module, mainly to avoid pulling in large dependencies. + Modules such as `xgrammar` might do additional side effects, so we + only want to use this when it is needed, delaying all eager effects + """ + + def __init__( + self, + local_name: str, + parent_module_globals: dict[str, Any], + name: str, + ): + self._local_name = local_name + self._parent_module_globals = parent_module_globals + self._module: types.ModuleType | None = None + + super().__init__(str(name)) + + def _load(self) -> types.ModuleType: + # Import the target module and insert it into the parent's namespace + try: + module = importlib.import_module(self.__name__) + self._parent_module_globals[self._local_name] = module + # The additional add to sys.modules + # ensures library is actually loaded. + sys.modules[self._local_name] = module + except ModuleNotFoundError as err: + raise err from None + + # Update this object's dict so that if someone keeps a + # reference to the LazyLoader, lookups are efficient + # (__getattr__ is only called on lookups that fail). + self.__dict__.update(module.__dict__) + return module + + def __getattr__(self, item: Any) -> Any: + if self._module is None: + self._module = self._load() + return getattr(self._module, item) + + def __dir__(self) -> list[str]: + if self._module is None: + self._module = self._load() + return dir(self._module) + + +def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None: + """ + Helper function to swap values for two keys + """ + v1 = obj.get(key1) + v2 = obj.get(key2) + if v1 is not None: + obj[key2] = v1 + else: + obj.pop(key2, None) + if v2 is not None: + obj[key1] = v2 + else: + obj.pop(key1, None) + + +@contextlib.contextmanager +def cprofile_context(save_file: Optional[str] = None): + """Run a cprofile + + Args: + save_file: path to save the profile result. "1" or + None will result in printing to stdout. + """ + import cProfile + + prof = cProfile.Profile() + prof.enable() + + try: + yield + finally: + prof.disable() + if save_file and save_file != "1": + prof.dump_stats(save_file) + else: + prof.print_stats(sort="cumtime") + + +def cprofile(save_file: Optional[str] = None, enabled: bool = True): + """Decorator to profile a Python method using cProfile. + + Args: + save_file: Path to save the profile result. + If "1", None, or "", results will be printed to stdout. + enabled: Set to false to turn this into a no-op + """ + + def decorator(func: Callable): + + @wraps(func) + def wrapper(*args, **kwargs): + if not enabled: + # If profiling is disabled, just call the function directly. + return func(*args, **kwargs) + + with cprofile_context(save_file): + return func(*args, **kwargs) + + return wrapper + + return decorator + + +# Only relevant for models using ALiBi (e.g, MPT) +def check_use_alibi(model_config: ModelConfig) -> bool: + return (getattr(model_config.hf_text_config, "alibi", False) # Falcon + or ("BloomForCausalLM" in getattr(model_config.hf_config, + "architectures", [])) # Bloom + or getattr(model_config.hf_text_config, "position_encoding_type", + "") == "alibi" # codellm_1b_alibi + or + (hasattr(model_config.hf_text_config, "attn_config") # MPT + and model_config.hf_text_config.attn_config.get("alibi", False))) + + +def sha256(input) -> int: + """Hash any picklable Python object using SHA-256. + + The input is serialized using pickle before hashing, which allows + arbitrary Python objects to be used. Note that this function does + not use a hash seed—if you need one, prepend it explicitly to the input. + + Args: + input: Any picklable Python object. + + Returns: + An integer representing the SHA-256 hash of the serialized input. + """ + input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) + return int.from_bytes(hashlib.sha256(input_bytes).digest(), + byteorder="big") diff --git a/vllm/v1/__init__.py b/vllm/v1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/v1/attention/__init__.py b/vllm/v1/attention/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/v1/attention/backends/__init__.py b/vllm/v1/attention/backends/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py new file mode 100644 index 0000000..b812ae1 --- /dev/null +++ b/vllm/v1/attention/backends/flash_attn.py @@ -0,0 +1,739 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Attention layer with FlashAttention.""" +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional + +import numpy as np +import torch + +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, AttentionType, + is_quantized_kv_cache) +# from vllm.attention.ops.triton_merge_attn_states import merge_attn_states +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import cdiv +from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8, + get_flash_attn_version) + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.worker.gpu_input_batch import InputBatch + from vllm.v1.worker.gpu_model_runner import GPUModelRunner + +# if current_platform.is_cuda(): +# from vllm.vllm_flash_attn import flash_attn_varlen_func +from ixformer.contrib.vllm_flash_attn import flash_attn_varlen_func, merge_attn_states + +logger = init_logger(__name__) + + +class FlashAttentionBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> list[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_name() -> str: + return "FLASH_ATTN_VLLM_V1" + + @staticmethod + def get_impl_cls() -> type["FlashAttentionImpl"]: + return FlashAttentionImpl + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + return FlashAttentionMetadata + + @staticmethod + def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]: + return FlashAttentionMetadataBuilder + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, num_kv_heads, block_size, head_size) + + @staticmethod + def use_cascade_attention(*args, **kwargs) -> bool: + return use_cascade_attention(*args, **kwargs) + + +@dataclass +class FlashAttentionMetadata: + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + num_actual_tokens: int # Number of tokens excluding padding. + max_query_len: int + query_start_loc: torch.Tensor + key_start_loc: torch.Tensor + max_seq_len: int + seq_lens: torch.Tensor + block_table: torch.Tensor + slot_mapping: torch.Tensor + + # For cascade attention. + use_cascade: bool + common_prefix_len: int + cu_prefix_query_lens: Optional[torch.Tensor] + prefix_kv_lens: Optional[torch.Tensor] + suffix_kv_lens: Optional[torch.Tensor] + cu_prefix_kv_lens: Optional[torch.Tensor] + cu_suffix_kv_lens: Optional[torch.Tensor] + + # For logging. + num_input_tokens: int = 0 # Number of tokens including padding. + + # for local attention + @dataclass + class LocalAttentionMetadata: + local_query_start_loc: torch.Tensor + local_k_start_loc: torch.Tensor + local_seqused_k: torch.Tensor + local_block_table: torch.Tensor + local_max_query_len: int + local_max_seq_len: int + + local_attn_metadata: Optional[LocalAttentionMetadata] = None + + +# +# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into +# local attention blocks, where each block is passed to the attention kernel +# as an independent local ("virtual") batch item. +# +# For example, if are performing a chunked prefill a batch of 3 sequences: +# q_seqlens = [4, 10, 5] +# kv_seqlens = [6, 17, 9] +# Then normally for regular attention we would compute with an attention mask +# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like: +# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6) +# k_toks > 0 1 2 3 4 5 +# q_toks v _____________ +# 0 | 1 1 1 +# 1 | 1 1 1 1 +# 2 | 1 1 1 1 1 +# 3 | 1 1 1 1 1 1 +# +# for local attention (with attn_chunk_size = 4) we would compute with an +# attention mask like: +# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4) +# k_toks > 0 1 2 3 4 5 +# q_toks v _____________ +# 0 | 1 1 1 +# 1 | 1 1 1 1 +# 2 | 1 +# 3 | 1 1 +# +# We can simulate this mask using standard flash-attention by breaking the +# sequences into local ("virtual") batches, where each local batch item is a +# local attention block, so in this case batch idx 0 would be broken up into: +# +# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0) +# k_toks > 0 1 2 3 +# q_toks v _____________ +# 0 | 1 1 1 +# 1 | 1 1 1 1 +# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0) +# k_toks > 4 5 +# q_toks v _____________ +# 2 | 1 +# 3 | 1 1 +# +# e.g. if we have: +# attn_chunk_size = 4 +# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5]) +# Then this function would return: +# __b0__ ______b1______ __b2__ < orig batch indices +# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1] +# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24] +# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1] +# block_table_local : shape[local_virtual_batches, pages_per_local_batch] +def make_local_attention_virtual_batches( + attn_chunk_size: int, + query_start_loc_np: np.ndarray, + seq_lens_np: np.ndarray, + block_table: torch.tensor, + page_size: int = 0, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.tensor]: + q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1] + actual_batch_size = seq_lens_np.shape[0] + + # Handle if we are starting in the middle of a local attention block, + # we assume q_seqlens > 0 (for all elements), for each batch idx we compute + # the number of tokens that are not in the first local attention block and + # then we can simply use a cdiv for the rest. + # For example if we have: + # attn_chunk_size = 4 + # q_seqlens = [4, 10, 5] + # k_seqlens = [6, 17, 9] + # Then we would get: + # new_tokens_in_first_block = [2, 1, 4] + # local_blocks = [2, 4, 2] + q_tokens_in_first_block = np.minimum( + attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), + q_seqlens).astype(np.int32) + tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size) + local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, + attn_chunk_size) + + # Once we know the number of local blocks we can compute the request spans + # for each batch idx, we can figure out the number of "virtual" requests we + # have to make, + # For the above example we would get: + # seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1] + # + # First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1]) + # (TODO: max a utility to share this code with _prepare_inputs) + # arange step 1. [2, 4, 2] -> [2, 6, 8] + cu_num_blocks = np.cumsum(local_blocks) + virtual_batches = cu_num_blocks[-1] + # arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6] + block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks) + # arange step 3. [0, 1, 0, 1, 2, 3, 0, 1] + arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets + # also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0]) + rarange = np.repeat(local_blocks, local_blocks) - arange - 1 + # Then we can compute the seqlens_q_local, handling the fact that the + # first and last blocks could be partial + seqlens_q_local = \ + np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks) + # set the first block since this may be a partial block + seqlens_q_local[arange == 0] = q_tokens_in_first_block + # set the remaining blocks + seqlens_q_local[arange > 0] = np.minimum( + seqlens_q_local - attn_chunk_size * (arange - 1), + attn_chunk_size)[arange > 0] + + # convert from q_seqlens to cu_seqlens_q + cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0))\ + .astype(np.int32) + + # compute the seqlens_k_local, + # basically a full local attention block for all but the last block in each + # batch + # For our example this will be: + # seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1] + seqlens_k_local = np.full(cu_num_blocks[-1], + attn_chunk_size, + dtype=np.int32) + seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block + # convert from q_seqlens to cu_seqlens_q + cu_seqlens_k_local = np.pad(np.cumsum(seqlens_k_local), (1, 0))\ + .astype(np.int32) + + k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \ + (rarange * attn_chunk_size + \ + np.repeat(tokens_in_last_block, local_blocks)) + # For the example the local attention blocks start at: + # _b0_ _____b1_____ _b2_ + # k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8] + block_starts = k_seqstarts_absolute // page_size + assert attn_chunk_size % page_size == 0, \ + f"attn_chunk_size {attn_chunk_size} is not " \ + f"divisible by page_size {page_size}" + pages_per_local_batch = attn_chunk_size // page_size + + # Create a block_table for the local attention blocks + # For out example if we have a block-table like (assuming page_size=2): + # block_table = [ + # [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0 + # [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1 + # [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2 + # ] + # Then for the local batches we would want a block-table like + # block_table_local = [ + # [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0]) + # [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4]) + # [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4]) + # [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8]) + # [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12]) + # [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16]) + # [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4]) + # [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8]) + # ] + block_indices= np.broadcast_to( + np.arange(pages_per_local_batch, dtype=np.int32), + (virtual_batches, pages_per_local_batch)) \ + + np.expand_dims(block_starts, axis=1) + block_indices = block_indices.flatten() + batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32), + local_blocks * pages_per_local_batch) + block_table_local = block_table[batch_indices, block_indices]\ + .view(virtual_batches, -1) + + return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, cu_seqlens_k_local, \ + block_table_local + + +class FlashAttentionMetadataBuilder: + + def __init__(self, runner: "GPUModelRunner"): + self.runner = runner + + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: + return False + + def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, + common_prefix_len: int): + max_seq_len = self.runner.seq_lens_np[:num_reqs].max() + query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1] + query_start_loc = query_start_loc_cpu.to(self.runner.device, + non_blocking=True) + seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs] + seq_lens = seq_lens_cpu.to(self.runner.device, non_blocking=True) + key_start_loc = torch.zeros([seq_lens_cpu.shape[0]+1]) + key_start_loc[1:] = seq_lens_cpu + key_start_loc = key_start_loc.cumsum(dim=0).to(seq_lens.dtype) + key_start_loc = key_start_loc.to(self.runner.device, non_blocking=True) + + block_table = ( + self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) + slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( + self.runner.device, non_blocking=True).long() + + # for local attention + local_attn_metadata = None + if self.runner.attention_chunk_size is not None: + seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, virt_k_cu_seqlens_np, \ + virt_block_table = make_local_attention_virtual_batches( + self.runner.attention_chunk_size, + self.runner.query_start_loc_np[:num_reqs + 1], + self.runner.seq_lens_np[:num_reqs], + block_table, + self.runner.block_size, + ) + local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata( + local_query_start_loc=torch.from_numpy( + virt_q_cu_seqlens_np).to(self.runner.device, + non_blocking=True), + local_seqused_k=torch.from_numpy(virt_k_seqlens_np).to( + self.runner.device, non_blocking=True), + local_block_table=virt_block_table, + local_max_query_len=seqlens_q_local_np.max(), + local_max_seq_len=virt_k_seqlens_np.max(), + local_k_start_loc=torch.from_numpy( + virt_k_cu_seqlens_np).to(self.runner.device, + non_blocking=True), + ) + + use_cascade = common_prefix_len > 0 + if use_cascade: + cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], + dtype=torch.int32, + device=self.runner.device) + prefix_kv_lens = torch.tensor([common_prefix_len], + dtype=torch.int32, + device=self.runner.device) + cu_prefix_kv_lens = torch.tensor([0, common_prefix_len], + dtype=torch.int32, + device=self.runner.device) + suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] - + common_prefix_len) + suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to( + self.runner.device) + cu_suffix_kv_lens = suffix_kv_lens.new_zeros([suffix_kv_lens.shape[0]+1]) + cu_suffix_kv_lens[1:] = suffix_kv_lens + cu_suffix_kv_lens = cu_suffix_kv_lens.cumsum(dim=0).int() + else: + cu_prefix_query_lens = None + prefix_kv_lens = None + suffix_kv_lens = None + cu_prefix_kv_lens = None + cu_suffix_kv_lens = None + + attn_metadata = FlashAttentionMetadata( + num_actual_tokens=num_actual_tokens, + max_query_len=max_query_len, + query_start_loc=query_start_loc, + key_start_loc=key_start_loc, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + block_table=block_table, + slot_mapping=slot_mapping, + use_cascade=use_cascade, + common_prefix_len=common_prefix_len, + cu_prefix_query_lens=cu_prefix_query_lens, + prefix_kv_lens=prefix_kv_lens, + suffix_kv_lens=suffix_kv_lens, + local_attn_metadata=local_attn_metadata, + cu_prefix_kv_lens=cu_prefix_kv_lens, + cu_suffix_kv_lens=cu_suffix_kv_lens, + ) + + return attn_metadata + + +class FlashAttentionImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: AttentionType = AttentionType.DECODER, + use_irope: bool = False, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "FlashAttention does not support block-sparse attention.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + if sliding_window is None: + self.sliding_window = (-1, -1) + else: + self.sliding_window = (sliding_window - 1, 0) + self.kv_cache_dtype = kv_cache_dtype + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + support_head_sizes = FlashAttentionBackend.get_supported_head_sizes() + if head_size not in support_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by FlashAttention. " + f"Supported head sizes are: {support_head_sizes}. " + "Set VLLM_USE_V1=0 to use another attention backend.") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashAttentionImpl") + self.use_irope = use_irope + self.vllm_flash_attn_version = get_flash_attn_version() + if is_quantized_kv_cache(self.kv_cache_dtype) \ + and not flash_attn_supports_fp8(): + raise NotImplementedError( + "FlashAttention does not support fp8 kv-cache on this device.") + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + kv_cache_scale: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with FlashAttention. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + NOTE: FP8 quantization, flash-attn expect the size of + {q,k,v}_descale to be (num_sequences, num_kv_heads). + We use torch's .expand() to avoid duplicating values + """ + assert output is not None, "Output tensor must be provided." + + if attn_metadata is None: + # Profiling run. + + return output.view(-1, self.num_heads * self.head_size) + + # IMPORTANT! + # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in + # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead + # in this method. For example, `view` and `slice` (or `[:n]`) operations + # are surprisingly slow even in the case they do not invoke any GPU ops. + # Minimize the PyTorch ops in this method as much as possible. + # Whenever making a change in this method, please benchmark the + # performance to make sure it does not introduce any overhead. + + num_actual_tokens = attn_metadata.num_actual_tokens + # Reshape the input keys and values and store them in the cache. + # NOTE(woosuk): Here, key and value are padded while slot_mapping is + # not padded. However, we don't need to do key[:num_actual_tokens] and + # value[:num_actual_tokens] because the reshape_and_cache_flash op uses + # the slot_mapping's shape to determine the number of actual tokens. + key_cache, value_cache = kv_cache.unbind(0) + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + if self.kv_cache_dtype.startswith("fp8"): + key_cache = key_cache.view(torch.float8_e4m3fn) + value_cache = value_cache.view(torch.float8_e4m3fn) + num_tokens, num_heads, head_size = query.shape + query, _ = ops.scaled_fp8_quant( + query.reshape( + (num_tokens, num_heads * head_size)).contiguous(), + layer._q_scale) + query = query.reshape((num_tokens, num_heads, head_size)) + + # Compute attention and update output up to `num_actual_tokens`. + use_local_attn = \ + (self.use_irope and attn_metadata.local_attn_metadata is not None) + + if not attn_metadata.use_cascade or use_local_attn: + if use_local_attn: + assert attn_metadata.local_attn_metadata is not None + local_metadata = attn_metadata.local_attn_metadata + cu_seqlens_q = local_metadata.local_query_start_loc + seqused_k = local_metadata.local_seqused_k + max_seqlen_q = local_metadata.local_max_query_len + max_seqlen_k = local_metadata.local_max_seq_len + block_table = local_metadata.local_block_table + cu_seqlens_k = local_metadata.local_k_start_loc + else: + cu_seqlens_q = attn_metadata.query_start_loc + seqused_k = attn_metadata.seq_lens + max_seqlen_q = attn_metadata.max_query_len + max_seqlen_k = attn_metadata.max_seq_len + block_table = attn_metadata.block_table + cu_seqlens_k = attn_metadata.key_start_loc + + descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) + flash_attn_varlen_func( # noqa + q=query[:num_actual_tokens], + k=key_cache, + v=value_cache, + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + block_table=block_table, + softcap=self.logits_soft_cap, + sqrt_alibi=False, + out=output[:num_actual_tokens], + ) + return output.view(-1, self.num_heads * self.head_size) + + assert not use_local_attn, ( + "Cascade attention does not support local attention.") + # Cascade attention (rare case). + cascade_attention( + output[:num_actual_tokens], + query[:num_actual_tokens], + key_cache, + value_cache, + cu_query_lens=attn_metadata.query_start_loc, + max_query_len=attn_metadata.max_query_len, + cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens, + cu_prefix_kv_lens=attn_metadata.cu_prefix_kv_lens, + cu_suffix_kv_lens=attn_metadata.cu_suffix_kv_lens, + max_kv_len=attn_metadata.max_seq_len, + softmax_scale=self.scale, + alibi_slopes=self.alibi_slopes, + sliding_window=self.sliding_window, + logits_soft_cap=self.logits_soft_cap, + block_table=attn_metadata.block_table, + common_prefix_len=attn_metadata.common_prefix_len, + fa_version=self.vllm_flash_attn_version, + q_descale=layer._q_scale, + k_descale=layer._k_scale, + v_descale=layer._v_scale, + ) + return output.view(-1, self.num_heads * self.head_size) + + +def use_cascade_attention( + common_prefix_len: int, + query_lens: np.ndarray, + num_query_heads: int, + num_kv_heads: int, + use_alibi: bool, + use_sliding_window: bool, + num_sms: int, +) -> bool: + """Decide whether to use cascade attention. + + This function 1) checks whether cascade attention is supported with the + given configuration, and 2) heuristically decides whether using cascade + attention can improve performance. + """ + # Too short common prefix. Probably not worth using cascade attention. + # We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold. + # NOTE(woosuk): This is the common case. We should return False as soon as + # possible to avoid any unnecessary computation. + if common_prefix_len < 256: + return False + # Cascade attention is currently not supported with these variants. + if use_alibi or use_sliding_window: + return False + # Too few queries. Probably not worth using cascade attention. + # We use an arbitrary threshold of 8 queries. TODO: Tune this threshold. + num_reqs = len(query_lens) + if num_reqs < 8: + return False + + # Heuristics to decide whether using cascade attention is beneficial. + # 1. When FlashDecoding is not used for normal attention, cascade attention + # is likely to be faster since it saves memory bandwidth. + num_queries_per_kv = num_query_heads // num_kv_heads + # The criteria for using FlashDecoding can be found in the following link: + # https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535 + use_flash_decoding = (num_queries_per_kv > 1 and not use_sliding_window + and not use_alibi and np.all(query_lens == 1)) + if not use_flash_decoding: + # Use cascade attention. + return True + else: + # flash_decoding not supported now! + return False + + # 2. When FlashDecoding is used for normal attention, it is not clear + # whether cascade attention is beneficial, because FlashDecoding can + # launch more CTAs than cascade attention. + # We use a simple performance model to compare the two methods. + # NOTE(woosuk): The performance model is very rough and may not be + # accurate. + num_tokens = num_reqs + # NOTE(woosuk): These are default tile sizes. flash-attn might use + # different tile sizes (e.g., 64 or 256) depending on the configuration. + q_tile_size = 128 + kv_tile_size = 128 + num_prefix_tiles = cdiv(common_prefix_len, kv_tile_size) + + cascade_ctas = num_query_heads * cdiv(num_tokens, q_tile_size) + cascade_waves = cdiv(cascade_ctas, num_sms) + cascade_time = cascade_waves * num_prefix_tiles + + flash_decoding_ctas = (num_reqs * num_kv_heads * + cdiv(num_queries_per_kv, q_tile_size)) + flash_decoding_ctas *= num_prefix_tiles + flash_decoding_time = cdiv(flash_decoding_ctas, num_sms) + + # Use cascade attention if it is faster than FlashDecoding. + return cascade_time < flash_decoding_time + + +def cascade_attention( + output: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + cu_query_lens: torch.Tensor, + max_query_len: int, + cu_prefix_query_lens: torch.Tensor, + cu_prefix_kv_lens: torch.Tensor, + cu_suffix_kv_lens: torch.Tensor, + max_kv_len: int, + softmax_scale: float, + alibi_slopes: Optional[torch.Tensor], + sliding_window: tuple[int, int], + logits_soft_cap: float, + block_table: torch.Tensor, + common_prefix_len: int, + fa_version: int, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + assert alibi_slopes is None, ("Cascade attention does not support ALiBi.") + # TODO: Support sliding window. + assert sliding_window == (-1, -1), ( + "Cascade attention does not support sliding window.") + + num_tokens = query.shape[0] + block_size = key_cache.shape[-2] + assert common_prefix_len % block_size == 0 + num_common_kv_blocks = common_prefix_len // block_size + assert num_common_kv_blocks > 0 + assert q_descale is None or q_descale==1, f"q_descale is not None, q_descale: {q_descale}" + assert k_descale is None or k_descale==1, f"k_descale is not None, k_descale: {k_descale}" + assert v_descale is None or v_descale==1, f"v_descale is not None, v_descale: {v_descale}" + + # Process shared prefix. + prefix_output, prefix_lse = flash_attn_varlen_func( + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=cu_prefix_query_lens, + cu_seqlens_k=cu_prefix_kv_lens, + max_seqlen_q=num_tokens, + max_seqlen_k=common_prefix_len, + softmax_scale=softmax_scale, + causal=False, + window_size=sliding_window, + block_table=block_table[:1], + softcap=logits_soft_cap, + return_softmax_lse=True, + ) + + # Process suffix per query. + suffix_output, suffix_lse = flash_attn_varlen_func( + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=cu_query_lens, + cu_seqlens_k=cu_suffix_kv_lens, + max_seqlen_q=max_query_len, + max_seqlen_k=max_kv_len - common_prefix_len, + softmax_scale=softmax_scale, + causal=True, + window_size=sliding_window, + block_table=block_table[:, num_common_kv_blocks:], + softcap=logits_soft_cap, + return_softmax_lse=True, + ) + + def ref_merge_state(out_1, lse_1, out_2, lse_2): + num_heads, seq_len = lse_1.shape + + lse_2 = lse_2.transpose(0,1).view(seq_len, num_heads, 1) + lse_1 = lse_1.transpose(0,1).view(seq_len, num_heads, 1) + + s_max = torch.maximum(lse_1, lse_2) + + d = torch.exp2(lse_1-s_max) + torch.exp2(lse_2-s_max) + v_merged = out_1 * torch.exp2(lse_1-s_max) + out_2 * torch.exp2(lse_2-s_max) + v_merged = v_merged / d + return v_merged, (torch.log2(d) + s_max).view(seq_len, num_heads) + + # Merge prefix and suffix outputs, and store the result in output. + # merge_attn_states(output, prefix_output, prefix_lse, suffix_output, + # suffix_lse) + merge_attn_states(prefix_output, prefix_lse, suffix_output, suffix_lse, output) + diff --git a/vllm/v1/attention/backends/mla/__init__.py b/vllm/v1/attention/backends/mla/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py new file mode 100644 index 0000000..09275bd --- /dev/null +++ b/vllm/v1/attention/backends/mla/common.py @@ -0,0 +1,1194 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This file implements common components for MLA implementations. + +First we define: + +Sq as Q sequence length +Skv as KV sequence length + +MLA has two possible ways of computing, a data-movement friendly approach and a +compute friendly approach, we generally want to use the compute friendly +approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1) +and the data-movement friendly approach for "decode" (i.e. the ratio +Sq / Skv is "large"). + +NOTE what we deem small and large is currently determined by if its labelled +prefill or decode by the scheduler, but this is something we should probably +tune. + +Main reference: DeepseekV2 paper, and FlashInfer Implementation +(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). + +Deepseek's MLA attention works the following way: +* Use a single latent vector to represent the per-token entry of the KV cache. +* For decode (i.e. the memory friendly approach) the attention "simulates" a +multi-head attention, while the compute is similar to multi-query attention. + +Below is example of both paths assuming batchsize = 1 + +## More Extent Definitions: + +C Context length, `Skv - Sq` +H hidden size +N number of attention heads +Lq latent dimension for Q 1536 in DSV3 +Lkv latent dimension for K/V 512 in DSV3 +P nope dimension, no rope. 128 in DSV3 +R rope dimension, goes through rope. 64 in DSV3 +V V head dim. 128 in DSV3 + +## Vector/Matrix Definitions + +h_t hidden states (input to attention) shape [Sq, H] +q_c latent/compressed Q shape [Sq, Lq] +q_nope uncompressed Q (no-rope) shape [Sq, N, P] +q_pe uncompressed Q (rope) shape [Sq, N, R] +kv_c latent/compressed KV shape [Skv, Lkv] +k_pe decoupled k position embeddings shape [Skv, R] +new_kv_c new kv_c from current iter shape [Sq, Lkv] +new_k_pe new k_pe from current iter shape [Sq, R] +cache_kv_c cached k_c from previous iters shape [C, Lkv] +cache_k_pe cached k_pe from previous iters shape [C, R] +W_DQ project h_t to q_c shape [H, Lq] +W_UQ project q_c to q_nope shape [Lq, N * P] +W_QR project q_c to q_pe shape [Lq, N * R] +W_DKV project h_t to kv_c shape [H, Lkv] +W_UK project kv_c to k_nope shape [Lkv, N, P] +W_KR project h_t to k_pe shape [H, R] +W_UV project kv_c to v shape [Lkv, N, V] +W_O project v to h_t shape [N * V, H] + + +## Compute Friendly Approach (i.e. "_forward_prefill"): + +q_c = h_t @ W_DQ +q_nope = (q_c @ W_UQ).view(Sq, N, P) +q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) +new_kv_c = h_t @ W_DKV +new_k_pe = RoPE(h_t @ W_KR) +kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) +k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) +k_nope = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P) +v = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V) + +// MHA with QK headdim = P + R +// V headdim = V +// spda_o shape [Sq, N, V] +spda_o = scaled_dot_product_attention( + torch.cat([q_nope, q_pe], dim=-1), + torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), + v +) +return spda_o @ W_O + +NOTE: in the actual code, + `kv_b_proj` is [W_UK; W_UV] concatnated per head + `q_b_proj` is [W_UQ; W_QR] concatnated per head + `out_proj` is W_O + + +## Data-Movement Friendly Approach (i.e. "_forward_decode"): + +Runtime +q_c = h_t @ W_DQ +q_nope = (q_c @ W_UQ).view(-1, N, P) +ql_nope = einsum("snh,lnh->snl", q, W_UK) +q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) +new_kv_c = h_t @ W_DKV +new_k_pe = RoPE(h_t @ W_KR) +kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) +k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) + +// MQA with QK headdim = Lkv + R +// V headdim = Lkv +// spda_o shape [Sq, N, Lkv] +// NOTE: this is less compute-friendly since Lkv > P +// but is more data-movement friendly since its MQA vs MHA +spda_o = scaled_dot_product_attention( + torch.cat([ql_nope, q_pe], dim=-1), + torch.cat([kv_c, k_pe], dim=-1), + kv_c +) + +o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV) +return o.view(-1, N * V) @ self.num_heads @ W_O + + +## Chunked Prefill + +For chunked prefill we want to use the compute friendly algorithm. We are +assuming sufficiently large Sq / Skv ratio, in the future may want to switch to +the data-movement friendly approach if the chunk (i.e. `Sq`) is small. + +However, the compute-friendly approach can potentially run out of memory if Skv +is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)` + +To mitigate this, we chunk the computation of attention with respect to the +current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a +fixed workspace size. + +The chunked prefill approach is as follows: + +MCC Max chunk of context to process per iter, computed dynamically, + used to bound the memory usage + +q_c = h_t @ W_DQ +q_nope = (q_c @ W_UQ).view(Sq, N, P) +q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) +new_kv_c = h_t @ W_DKV +new_k_pe = RoPE(h_t @ W_KR) +new_k_nope = (new_kv_c @ W_UK.view(Lkv, N * P)).view(Sq, N, P) +new_v = (new_kv_c @ W_UV.view(Lkv, N * V)).view(Sq, N, V) + +// MHA between queries and new KV +// with QK headdim = P + R +// V headdim = V +// curr_o shape [Sq, N, V] +// curr_lse shape [N, Sq], this is just order FA returns +curr_o, curr_lse = scaled_dot_product_attention( + torch.cat([q_nope, q_pe], dim=-1), + torch.cat([new_k_nope, new_k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), + new_v, + casual=True, + return_softmax_lse=True +) + +// Compute attention with the already existing context +for chunk_idx in range(cdiv(C, MCC)): + chunk_start = chunk_idx * MCC + chunk_end = min(chunk_start + MCC, C) + Sc = chunk_end - chunk_start + cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end] + cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end] + cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P) + cache_v_chunk = (cache_kv_c_chunk @ W_UV).view(-1, N, V) + + chunk_o, chunk_lse = scaled_dot_product_attention( + torch.cat([q_nope, q_pe], dim=-1), + torch.cat([cache_k_nope_chunk, + cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)], + dim=-1), + cache_v_chunk, + casual=False, + return_softmax_lse=True + ) + + curr_o, curr_lse = merge_attn_states( + suffix_output=curr_o, + suffix_lse=curr_lse, + prefix_output=chunk_o, + prefix_lse=chunk_lse, + ) + +return curr_o @ W_O +""" + +import functools +from abc import abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar + +import torch + +from vllm import _custom_ops as ops +import vllm.envs as envs +from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, + AttentionMetadata, + MLAAttentionImpl) +from vllm.attention.ops.triton_merge_attn_states import merge_attn_states +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearBase, RowParallelLinear, + UnquantizedLinearMethod) +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +from vllm.platforms import current_platform +from vllm.utils import cdiv, round_down +from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version + +# try: +# from vllm.vllm_flash_attn import flash_attn_varlen_func +# except ImportError: +# # For rocm use upstream flash attention +# from flash_attn import flash_attn_varlen_func +from ixformer.contrib.vllm_flash_attn import flash_attn_varlen_func, merge_attn_states + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.worker.gpu_input_batch import InputBatch + from vllm.v1.worker.gpu_model_runner import GPUModelRunner + +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + scaled_quantize, scaled_dequantize) +import ixformer.inference.functions as ixf_ops + +logger = init_logger(__name__) + + +class MLACommonBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_name() -> str: + return "TRITON_MLA_VLLM_V1" + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + return MLACommonMetadata + + @staticmethod + def get_builder_cls() -> type["MLACommonMetadataBuilder"]: + return MLACommonMetadataBuilder + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, # assumed to be 1 for MLA + head_size: int, + ) -> tuple[int, ...]: + return (num_blocks, block_size, head_size) + + @staticmethod + def get_supported_head_sizes() -> list[int]: + return [576] + + @staticmethod + def use_cascade_attention(*args, **kwargs) -> bool: + return False + + +@dataclass +class MLACommonPrefillMetadata: + """ Prefill Specific Metadata """ + + @dataclass + class ChunkedContextMetadata: + # New for MLA (compared to FlashAttention) + # For handling chunked prefill + cu_seq_lens: torch.Tensor + starts: torch.Tensor + seq_tot: list[int] + max_seq_lens: list[int] + workspace: torch.Tensor + + # Input positions for rotrary embeddings since for MLA the rotary + # position embeddings are applied inside the attention backend + input_positions: torch.Tensor + block_table: torch.Tensor + query_start_loc: torch.Tensor + max_query_len: int + chunked_context: Optional[ChunkedContextMetadata] = None + + +@dataclass +class MLACommonDecodeMetadata: + # Input positions for rotrary embeddings since for MLA the rotary + # position embeddings are applied inside the attention backend + input_positions: torch.Tensor + block_table: torch.Tensor + seq_lens: torch.Tensor + max_decode_seq_len: int + use_cuda_graph: bool + + +D = TypeVar("D", bound=MLACommonDecodeMetadata) + + +@dataclass +class MLACommonMetadata(Generic[D]): + """Metadata for MLACommon. + + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + num_actual_tokens: int # Number of tokens excluding padding. + query_start_loc: torch.Tensor + slot_mapping: torch.Tensor + + # New for MLA (compared to FlashAttention) + # For handling prefill decode split + num_decodes: int + num_decode_tokens: int + num_prefills: int + + # For logging. + num_input_tokens: int = 0 # Number of tokens including padding. + + # The dimension of the attention heads + head_dim: Optional[int] = None + + decode: Optional[D] = None + prefill: Optional[MLACommonPrefillMetadata] = None + + def __post_init__(self): + supported_head_sizes = MLACommonBackend.get_supported_head_sizes() + if self.head_dim is not None and self.head_dim \ + not in supported_head_sizes: + raise ValueError( + f"Only {supported_head_sizes} are supported for head_dim,", + f"received {self.head_dim}.") + + +M = TypeVar("M", bound=MLACommonMetadata) + + +class MLACommonMetadataBuilder(Generic[M]): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + def __init__(self, + runner: "GPUModelRunner", + metadata_cls: Optional[type[M]] = None): + self.metadata_cls = metadata_cls \ + if metadata_cls is not None else MLACommonMetadata + self.runner = runner + scheduler_config = runner.scheduler_config + model_config = runner.model_config + cache_config = runner.cache_config + self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled + + if self.chunked_prefill_enabled: + self.chunked_prefill_workspace_size = min( + # Max sure there is enough for 8 full length request or at least + # 4 pages of cache per request + max( + 8 * model_config.max_model_len, 4 * + scheduler_config.max_num_seqs * cache_config.block_size), + # For long-context models try not to over-allocate limiting + # kv-cache space, limiting it to 64k tokens, + # which would result in the workspace being: + # 2*(576)*(64*1024) = 144mb + # (assuming 576 MLA head dim, and fp16) + # which would result in up-projected context being + # 2*(192*128)*(64*1024) = 3gb + # (assuming 192 QK head dim, 128 heads, and fp16) + 128 * 1024) + assert self.chunked_prefill_workspace_size >= \ + scheduler_config.max_num_seqs * cache_config.block_size + self.chunked_prefill_workspace = torch.empty( + (self.chunked_prefill_workspace_size, + model_config.get_head_size()), + dtype=model_config.dtype, + device=runner.device, + ) + self.page_size = self.runner.block_size + + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: + # We now want to reorder the batch so that the "decode" requests are and + # the front and the "prefill" requests are at the using the least amount + # swaps possible. (NOTE for now we loosely use "decode" to mean requests + # where attention is likely memory-bound and "prefill" to mean requests + # where attention is likely compute-bound, TODO(lucas): figure out a + # better naming here) + decodes = [] + prefills = [] + num_decode_tokens = 0 + num_prefill_tokens = 0 + + for i, req_id in enumerate(input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + # for now treat 1 scheduled token as "decode" even if its not, + # we should update this to something like < 8 in the future but + # currently the TritonMLA._forward_decode only supports + # num_tokens = 1 + if num_tokens == 1: + decodes.append(i) + num_decode_tokens += num_tokens + else: + prefills.append(i) + num_prefill_tokens += num_tokens + + # We hope that this is fairly minimal since decodes + # should be around for a number of iterations so hopefully they are + # relatively stationary (and new request are generally appended to the + # persistent batch so already should be at the back) + # To achieve this we loop over the decodes in descending order and + # the prefills in ascending order. We swap decodes from the "back" + # i.e. past where the last decode should be in the reodorered with + # prefills from the front of the batch. + # `decodes` and `prefills` are already in ascending order just based on + # the above loop + num_decodes = len(decodes) + num_prefills = len(prefills) + first_prefill = 0 + modified_batch = False + + for i in range(1, min(num_decodes, num_prefills) + 1): + # If the decode is at the "back" of the batch, i, we can swap it + # with the prefill closest to the front of the batch + if decodes[num_decodes - i] >= num_decodes: + input_batch.swap_states(prefills[first_prefill], + decodes[num_decodes - i]) + first_prefill += 1 + modified_batch = True + else: + break + + # Save for next `build` call + # TODO(lucas): this is a bit of a hack, we should probably have a + # better way of doing this + self._num_decodes = num_decodes + self._num_prefills = num_prefills + self._num_decode_tokens = num_decode_tokens + self._num_prefill_tokens = num_prefill_tokens + + return modified_batch + + def _build_decode(self, input_positions: torch.Tensor, + block_table: torch.Tensor, seq_lens: torch.Tensor, + max_decode_seq_len: int, use_cuda_graph: bool): + return MLACommonDecodeMetadata( + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + max_decode_seq_len=max_decode_seq_len, + use_cuda_graph=use_cuda_graph, + ) + + def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, + common_prefix_len: int) -> M: + assert self._num_decodes + self._num_prefills == num_reqs + + # Note(simon): be careful about the CPU <> GPU memory movement in this + # function. We should avoid GPU -> CPU sync as much as possible because + # it blocks on all previous kernels. + device = self.runner.device + block_table = ( + self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) + query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to( + device, non_blocking=True) + slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( + device, non_blocking=True).long() + input_positions = self.runner.positions_cpu[:num_actual_tokens].to( + device, non_blocking=True).long() + + seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs] + seq_lens = seq_lens_cpu.to(device, non_blocking=True) + max_query_len = seq_lens_cpu.max().item() + + prefill_metadata = None + if self._num_prefills > 0: + reqs_start = self._num_decodes # prefill_start + tokens_start = self._num_decode_tokens + + context_lens_cpu = self.runner.input_batch.\ + num_computed_tokens_cpu_tensor[reqs_start:num_reqs] + max_context_len_cpu = context_lens_cpu.max().item() + num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() + + chunked_context_metadata = None + if self.chunked_prefill_enabled and self._num_prefills > 0 \ + and max_context_len_cpu > 0: + # NOTE: it is recommend you read the `Chunked Prefill` section + # in the comment at the top of the file before trying to + # understand the following code + + # currently we allocate an equal amount of workspace for each + # prefill in the batch, we could probably use a more advanced + # algorithm here and allocate more workspace to prefills with + # longer context lengths + max_context_chunk = (self.chunked_prefill_workspace_size // + num_prefills_with_context_cpu) + + # align max_context_chunk to page_size by rounding down, + # currently the `gather_cache` kernel cannot handle + # `context_chunk_starts` that are not aligned to page_size + max_context_chunk = round_down(max_context_chunk, + self.page_size) + + assert max_context_chunk > 0 + num_chunks = cdiv(max_context_len_cpu, max_context_chunk) + + # if `max_context_chunk = 256`, `num_chunks = 3`, and + # `num_prefills_with_context = 4`, create a tensor that looks + # like + # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] + # Note(simon): this is done in CPU because of downstream's + # of `to_list`. + chunk_starts = \ + torch.arange(num_chunks, dtype=torch.int32) \ + .unsqueeze(1).expand(-1, self._num_prefills) \ + * max_context_chunk + chunk_ends = torch.min(context_lens_cpu.unsqueeze(0), + chunk_starts + max_context_chunk) + chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) + + cu_seq_lens_cpu = torch.zeros(num_chunks, + self._num_prefills + 1, + dtype=torch.int32, + pin_memory=True) + torch.cumsum(chunk_seq_lens, + dim=1, + out=cu_seq_lens_cpu[:, 1:], + dtype=torch.int32) + + chunked_context_metadata = \ + MLACommonPrefillMetadata.ChunkedContextMetadata( + cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), + starts=chunk_starts.to(device, non_blocking=True), + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + workspace=self.chunked_prefill_workspace, + ) + + assert max(chunked_context_metadata.max_seq_lens) <= \ + self.chunked_prefill_workspace_size + + prefill_metadata = MLACommonPrefillMetadata( + input_positions=input_positions[tokens_start:], + block_table=block_table[reqs_start:, ...], + query_start_loc=query_start_loc[reqs_start:] - + query_start_loc[reqs_start], + max_query_len=max_query_len, + chunked_context=chunked_context_metadata, + ) + + decode_metadata = None + if self._num_decodes > 0: + decode_metadata = self._build_decode( + input_positions=input_positions[:self._num_decode_tokens], + block_table=block_table[:self._num_decodes, ...], + seq_lens=seq_lens[:self._num_decodes], + max_decode_seq_len=max_query_len, + use_cuda_graph=False, + ) + + return self.metadata_cls( + num_actual_tokens=num_actual_tokens, + query_start_loc=query_start_loc, + slot_mapping=slot_mapping, + head_dim=self.runner.model_config.get_head_size(), + # MLACommonMetadata Chunk prefill specific + num_decodes=self._num_decodes, + num_decode_tokens=self._num_decode_tokens, + num_prefills=self._num_prefills, + prefill=prefill_metadata, + decode=decode_metadata, + ) + + +class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + q_lora_rank: Optional[int], + kv_lora_rank: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + qk_head_dim: int, + v_head_dim: int, + rotary_emb: RotaryEmbedding, + # q_proj should be q_b_proj if q_lora_rank is not None, but from an + # attention backend perspective we rely on the layer to pass in the + # correct matrix + q_proj: ColumnParallelLinear, + kv_b_proj: ColumnParallelLinear, + o_proj: RowParallelLinear, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_head_dim + self.v_head_dim = v_head_dim + + # Hack for V1 for now to avoid torch library overhead (since we are + # already inside an attention custom op), pull out the forward + # method from the rotary embedding and call it directly + # TODO(lucas): we should probably find a cleaner way to do this + self.rotary_emb = rotary_emb + # self.rotary_emb = rotary_emb.forward_native + # if current_platform.is_cuda(): + # self.rotary_emb = rotary_emb.forward_cuda + + self.q_proj = q_proj + self.kv_b_proj = kv_b_proj + self.o_proj = o_proj + self.vllm_flash_attn_version = get_flash_attn_version() + + # Handle the differences between the flash_attn_varlen from flash_attn + # and the one from vllm_flash_attn. The former is used on RoCM and the + # latter has an additional parameter to control FA2 vs FA3 + self.flash_attn_varlen_func = flash_attn_varlen_func + if self.vllm_flash_attn_version is not None: + self.flash_attn_varlen_func = \ + functools.partial(flash_attn_varlen_func, + fa_version=self.vllm_flash_attn_version) + + def _v_up_proj_and_o_proj(self, x): + # # Convert from (B, N, L) to (N, B, L) + # x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + # # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + # x = torch.bmm(x, self.W_UV) + # # Convert from (N, B, V) to (B, N * V) + # x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + # return self.o_proj(x)[0] + x = torch.einsum("bnl,lnv->bnv", x, self.W_UV) + return self.o_proj(x.reshape(-1, self.num_heads * self.v_head_dim))[0] + + # Return `ql_nope`, `q_pe` + def _q_proj_and_k_up_proj(self, x): + q_nope, q_pe = self.q_proj(x)[0]\ + .view(-1, self.num_heads, self.qk_head_dim)\ + .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + # # Convert from (B, N, P) to (N, B, P) + # q_nope = q_nope.transpose(0, 1) + # # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + # ql_nope = torch.bmm(q_nope, self.W_UK_T) + # # Convert from (N, B, L) to (B, N, L) + # return ql_nope.transpose(0, 1), q_pe + return torch.einsum("bnp,lnp->bnl", q_nope, self.W_UK).view(-1, self.num_heads, self.kv_lora_rank), q_pe + + def process_weights_after_loading(self, act_dtype: torch.dtype): + + def get_layer_weight(layer): + WEIGHT_NAMES = ("weight", "qweight", "weight_packed") + for attr in WEIGHT_NAMES: + if hasattr(layer, attr): + return getattr(layer, attr) + raise AttributeError( + f"Layer '{layer}' has no recognized weight attribute:" + f" {WEIGHT_NAMES}.") + + def get_and_maybe_dequant_weights(layer: LinearBase): + if not isinstance(layer.quant_method, UnquantizedLinearMethod): + # we already prepare a dequant weight for GGUF, skip it. + if layer.quant_method.__class__.__name__ == "GGUFLinearMethod": + weight = layer.weight.clone() + del layer.weight + return weight + if layer.quant_method.__class__.__name__ == "AWQMarlinLinearMethod": + from ixformer.inference.functions import ref_wui4a16 + return ref_wui4a16(None, layer.qweight, layer.scales, layer.qzeros, None, layer.quant_method.quant_config.group_size, only_return_weight=True) + # for W8A8, we directly dequantize it here to avoiding quantization errors + if hasattr(layer, "scheme") and layer.scheme.__class__.__name__ == "CompressedTensorsW8A8Int8" and not layer.scheme.is_static_input_scheme: + quant_weight = layer.weight.T # output, input + scales = layer.weight_scale + return scaled_dequantize(quant_weight, scales, (1, -1), act_dtype) + # NOTE: This should only be used offline, since it's O(N^3) + eye = torch.eye(layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device) + dequant_weights = layer.quant_method.apply(layer, + eye, + bias=None) + del eye + # standardize to (output, input) + return dequant_weights.T + return layer.weight + back_to_vllm = False + weight_dtype = get_layer_weight(self.kv_b_proj).dtype + assert get_layer_weight(self.q_proj).dtype == weight_dtype + + # when use customize forward, we do not care the specific data types and shape, just reset the + # _v_up_proj_and_o_proj and _q_proj_and_k_up_proj funs to get the correct results. + if envs.VLLM_MLA_CUSTOMIZE: + layer = self.kv_b_proj + quant_method = layer.quant_method.__class__.__name__ + if hasattr(layer, "scheme") and layer.scheme.__class__.__name__ == "CompressedTensorsW8A8Int8": + # must be as same as kv_b_proj + assert hasattr(self.q_proj, "scheme") and self.q_proj.scheme.__class__.__name__ == "CompressedTensorsW8A8Int8" + assert hasattr(self.o_proj, "scheme") and self.o_proj.scheme.__class__.__name__ == "CompressedTensorsW8A8Int8" + + kv_b_proj_weight = self.kv_b_proj.weight # [i,o] + kv_b_proj_weight_scale = self.kv_b_proj.weight_scale # [o, 1] + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + W_UK, W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + kv_b_proj_weight_scale = kv_b_proj_weight_scale.view( + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + 1, + ) + W_UK_S, W_UV_S = kv_b_proj_weight_scale.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=1) + + # self.W_UK = W_UK # [kv_lora_rank, n, qk_nope] + self.W_UK = ixf_ops.marlin_w8_weight_repack( + weights=W_UK.permute(1, 2, 0).contiguous(), + weight_format="int8", + reformat="k16n16_grouped_n", + ) + self.W_UK_S = W_UK_S.contiguous() # [n, qk_nope, 1] + # self.W_UV = W_UV # [kv_lora_rank, n, v_head_dim] + self.W_UV = ixf_ops.marlin_w8_weight_repack( + weights=W_UV.permute(1, 0, 2).contiguous(), + weight_format="int8", + reformat="k16n16", + ) + self.W_UV_S = W_UV_S.contiguous() # [n, v_head_dim, 1] + + def _v_up_proj_and_o_proj_w8a8(self, x): + # x: b, n, kv_lora + # W_UV = (self.W_UV * self.W_UV_S.view(1, self.num_heads, self.v_head_dim)).to(x.dtype) + # x = torch.einsum("bnl,lnv->bnv", x, W_UV) + res = ixf_ops.marlin_w8a16( + x, + self.W_UV, + self.W_UV_S, + group_size=-1, + format="k16n16", + batch_first=False, + ) + return self.o_proj(res.reshape(-1, self.num_heads * self.v_head_dim))[0] + + def _q_proj_and_k_up_proj_w8a8(self, x): + # x: b, q_lora + x = self.q_proj(x)[0].view(-1, self.num_heads, self.qk_head_dim) + x_nope, x_pe = x.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + # W_UK = (self.W_UK * self.W_UK_S.view(1, self.num_heads, self.qk_nope_head_dim)).to(x.dtype) + # return torch.einsum("bnp,lnp->bnl", x_nope, W_UK).view(-1, self.num_heads, self.kv_lora_rank), x_pe + res = ixf_ops.marlin_w8a16( + x_nope, + self.W_UK, + self.W_UK_S, + group_size=-1, + format="k16n16_grouped_n", + batch_first=False, + ) + return res, x_pe + + self._v_up_proj_and_o_proj = functools.partial(_v_up_proj_and_o_proj_w8a8, self) + self._q_proj_and_k_up_proj = functools.partial(_q_proj_and_k_up_proj_w8a8, self) + elif quant_method == "AWQMarlinLinearMethod": + # ===== W_UK & W_UV use W4A16 ===== + def split_4bit_matrix(qweight, scales, qzeros, num_heads, head_dim_1, head_dim_2): + assert head_dim_1 % 8 == 0 + assert head_dim_2 % 8 == 0 + qweight = qweight.view(-1, num_heads, (head_dim_1 + head_dim_2)//8) + w1, w2 = qweight.split([head_dim_1//8, head_dim_2//8], dim=-1) + scales = scales.view(-1, num_heads, head_dim_1 + head_dim_2) + s1, s2 = scales.split([head_dim_1, head_dim_2], dim=-1) + qzeros = qzeros.view(-1, num_heads, (head_dim_1 + head_dim_2)//8) + z1, z2 = qzeros.split([head_dim_1//8, head_dim_2//8], dim=-1) + return (w1.contiguous().view(w1.shape[0], -1), + s1.contiguous().view(s1.shape[0], -1), + z1.contiguous().view(z1.shape[0], -1), + w2.contiguous().view(w2.shape[0], -1), + s2.contiguous().view(s2.shape[0], -1), + z2.contiguous().view(z2.shape[0], -1)) + # split W_UK and W_UV + ( + W_UK, W_UK_S, W_UK_Z, + W_UV, W_UV_S, W_UV_Z + ) = split_4bit_matrix( + self.kv_b_proj.qweight, self.kv_b_proj.scales, self.kv_b_proj.qzeros, + self.num_heads, self.qk_nope_head_dim, self.v_head_dim + ) + # repack W_UK + W_UK = W_UK.reshape(self.kv_lora_rank, self.num_heads, -1).permute(1, 2, 0).contiguous() + W_UK_S = W_UK_S.reshape(W_UK_S.shape[0], self.num_heads, -1).permute(1, 0, 2).contiguous() + W_UK_Z = W_UK_Z.reshape(W_UK_Z.shape[0], self.num_heads, -1).permute(1, 0, 2).contiguous() + ( + self.W_UK, + self.W_UK_S, + self.W_UK_Z + ) = ixf_ops.marlin_w4_weight_repack( + weights = W_UK, + scales = W_UK_S, + zeros = W_UK_Z, + weight_format = "gptq_grouped_n", + reformat = "k16n32_grouped_n", + ) + # repack W_UV + W_UV = W_UV.reshape(self.kv_lora_rank, self.num_heads, -1).permute(1, 0, 2).contiguous() + W_UV_S = W_UV_S.reshape(W_UV_S.shape[0], self.num_heads, -1).permute(1, 0, 2).contiguous() + W_UV_Z = W_UV_Z.reshape(W_UV_Z.shape[0], self.num_heads, -1).permute(1, 0, 2).contiguous() + ( + self.W_UV, + self.W_UV_S, + self.W_UV_Z + ) = ixf_ops.marlin_w4_weight_repack( + weights = W_UV, + scales = W_UV_S, + zeros = W_UV_Z, + weight_format = "awq", + reformat = "k16n32", + ) + self.w4_group_size = self.q_proj.quant_method.quant_config.group_size + + # ===== W_UK & W_UV use W16A16 ===== + # kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + # kv_b_proj_weight = kv_b_proj_weight.view( + # self.kv_lora_rank, + # self.num_heads, + # self.qk_nope_head_dim + self.v_head_dim, + # ) + # self.W_UK_fp16, self.W_UV_fp16 = kv_b_proj_weight.split( + # [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + def _v_up_proj_and_o_proj_w4a16(self, x): + # W16A16 + # res = torch.einsum("bnl,lnv->bnv", x, self.W_UV_fp16) + + # W4A16 + res = ixf_ops.marlin_w4a16( + x, self.W_UV, self.W_UV_S, self.W_UV_Z, + group_size=self.w4_group_size, + format="k16n32", + batch_first=False) + + return self.o_proj(res.view(-1, self.num_heads * self.v_head_dim))[0] + + def _q_proj_and_k_up_proj_w4a16(self, x): + x = self.q_proj(x)[0].view(-1, self.num_heads, self.qk_head_dim) + x_nope, x_pe = x.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + # W16A16 + # return torch.einsum("bnp,lnp->bnl", x_nope, self.W_UK_fp16).view(-1, self.num_heads, self.kv_lora_rank), x_pe + + # W4A16 + res = ixf_ops.marlin_w4a16( + x_nope, self.W_UK, self.W_UK_S, self.W_UK_Z, + group_size=self.w4_group_size, + format="k16n32_grouped_n", + batch_first=False, + ) + return res, x_pe + + self._v_up_proj_and_o_proj = functools.partial(_v_up_proj_and_o_proj_w4a16, self) + self._q_proj_and_k_up_proj = functools.partial(_q_proj_and_k_up_proj_w4a16, self) + else: + print(f"Custom MLA for quant method: {layer.quant_method.__class__.__name__} is not supported, will use the vllm official impl.") + back_to_vllm = True + else: + back_to_vllm = True + + if back_to_vllm: + # we currently do not have quantized bmm's which are needed for + # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform + # the bmm's in 16-bit, the extra memory overhead of this is fairly low + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + assert kv_b_proj_weight.shape == ( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}") + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + W_UK, W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + # # Convert from (L, N, V) to (N, L, V) + # self.W_UV = W_UV.transpose(0, 1) + # # Convert from (L, N, P) to (N, P, L) + # self.W_UK_T = W_UK.permute(1, 2, 0) + + self.W_UV = W_UV + self.W_UK = W_UK + + def _compute_prefill_context( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + kv_c_and_k_pe_cache_scale: torch.Tensor, + attn_metadata: MLACommonMetadata, + ): + assert attn_metadata.prefill is not None + prefill_metadata = attn_metadata.prefill + assert prefill_metadata.chunked_context is not None + + output = None + iters = len(prefill_metadata.chunked_context.seq_tot) + workspace = prefill_metadata.chunked_context.workspace + + for i in range(iters): + toks = prefill_metadata.chunked_context.seq_tot[i] + if envs.VLLM_USE_INT8_MLA: + ops.gather_cache_int8( + src_cache=kv_c_and_k_pe_cache, + src_cache_scale=kv_c_and_k_pe_cache_scale, + kv_lora_rank = self.kv_lora_rank, + dst=workspace, + block_table=prefill_metadata.block_table, + cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i], + batch_size=attn_metadata.num_prefills, + seq_starts=prefill_metadata.chunked_context.starts[i], + ) + else: + ops.gather_cache( + src_cache=kv_c_and_k_pe_cache, + dst=workspace, + block_table=prefill_metadata.block_table, + cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i], + batch_size=attn_metadata.num_prefills, + seq_starts=prefill_metadata.chunked_context.starts[i], + ) + + kv_c_normed = workspace[:toks]\ + [..., :self.kv_lora_rank].contiguous() + k_pe = workspace[:toks]\ + [..., self.kv_lora_rank:].unsqueeze(1) + + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), + dim=-1) + attn_output = torch.empty(q.shape[0], self.num_heads, self.v_head_dim, dtype=q.dtype, device=q.device) + + # For MLA the v head dim is smaller than qk head dim so we pad + # out v with 0s to match the qk head dim + # v_padded = torch.nn.functional.pad(v, + # [0, q.shape[-1] - v.shape[-1]], + # value=0) + + attn_output, attn_softmax_lse = self.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i], + max_seqlen_q=prefill_metadata.max_query_len, + max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i], + softmax_scale=self.scale, + causal=False, # Context is unmasked + return_softmax_lse=True, + out=attn_output, + ) + + if output is None: + output = attn_output + output_lse = attn_softmax_lse + else: + output_tmp = torch.empty_like(output) + output_lse_tmp = torch.empty_like(output_lse) + merge_attn_states( + output=output_tmp, + output_lse=output_lse_tmp, + prefix_output=output, + prefix_lse=output_lse, + suffix_output=attn_output, + suffix_lse=attn_softmax_lse, + ) + output = output_tmp + output_lse = output_lse_tmp + + return output, output_lse + + def _forward_prefill( + self, + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + kv_c_and_k_pe_cache_scale: torch.Tensor, + attn_metadata: MLACommonMetadata, + ) -> torch.Tensor: + assert attn_metadata.prefill is not None + + has_context = attn_metadata.prefill.chunked_context is not None + kv_nope = self.kv_b_proj(kv_c_normed)[0].view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v_nope = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + v = v_nope + k[...,:self.qk_nope_head_dim] = k_nope + attn_output = torch.empty(q.shape[0], self.num_heads, self.v_head_dim, dtype=q.dtype, device=q.device) + + output = self.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=attn_metadata.prefill.query_start_loc, + cu_seqlens_k=attn_metadata.prefill.query_start_loc, + max_seqlen_q=attn_metadata.prefill.max_query_len, + max_seqlen_k=attn_metadata.prefill.max_query_len, + softmax_scale=self.scale, + causal=True, + return_softmax_lse=has_context, + out=attn_output, + ) + + if has_context: + suffix_output, suffix_lse = output + context_output, context_lse = self._compute_prefill_context( \ + q, kv_c_and_k_pe_cache, kv_c_and_k_pe_cache_scale,attn_metadata) + + output = torch.empty_like(suffix_output) + output = merge_attn_states( + output=output, + prefix_output=context_output, + prefix_lse=context_lse, + suffix_output=suffix_output, + suffix_lse=suffix_lse, + ) + + # slice by `:v.shape[-1]` in order to remove v headdim padding + output = output.view(-1, self.num_heads * self.v_head_dim) + + return self.o_proj(output)[0] + + @abstractmethod + def _forward_decode( + self, + ql_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: M, + ) -> torch.Tensor: + raise NotImplementedError + + def forward( + self, + layer: AttentionLayer, + hidden_states_or_q_c: torch.Tensor, # query in unified attn + k_c_normed: torch.Tensor, # key in unified attn + k_pe: torch.Tensor, # value in unified attn + kv_cache: torch.Tensor, + kv_cache_scale:torch.Tensor, + attn_metadata: M, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + assert output is not None, "Output tensor must be provided." + + if attn_metadata is None: + # Profiling run. + output = torch.empty(output.shape[0], self.o_proj.output_size,device=hidden_states_or_q_c.device, + dtype=hidden_states_or_q_c.dtype) + return output + # num_actual_toks = attn_metadata.num_actual_tokens + + # # Inputs and outputs may be padded for CUDA graphs + # output_padded = output + # output = output[:num_actual_toks, ...] + # hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...] + # k_c_normed = k_c_normed[:num_actual_toks, ...] + # k_pe = k_pe[:num_actual_toks, ...] + + # Restore head dim (for rotary embedding) + k_pe = k_pe.unsqueeze(1) + + assert attn_metadata.num_decodes is not None and \ + attn_metadata.num_prefills is not None and \ + attn_metadata.num_decode_tokens is not None + + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + num_decode_tokens = attn_metadata.num_decode_tokens + + decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] + decode_k_pe = k_pe[:num_decode_tokens] + + prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:] + prefill_k_pe = k_pe[num_decode_tokens:] + prefill_k_c_normed = k_c_normed[num_decode_tokens:] + + if has_decode: + assert attn_metadata.decode is not None + decode_ql_nope, decode_q_pe = self._q_proj_and_k_up_proj(decode_hs_or_q_c) + decode_q_pe, decode_k_pe = self.rotary_emb(attn_metadata.decode.input_positions, decode_q_pe, decode_k_pe) + + if has_prefill: + assert attn_metadata.prefill is not None + prefill_q = self.q_proj(prefill_hs_or_q_c)[0].view(-1, self.num_heads, self.qk_head_dim) + prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] + prefill_k = torch.empty_like(prefill_q) + ixf_ops.mla_rope(attn_metadata.prefill.input_positions, prefill_q_pe, prefill_k_pe.squeeze(1), prefill_k[...,self.qk_nope_head_dim:], self.rotary_emb.cos_sin_cache) + # write the latent and rope to kv cache + write_kv_cache = (None, None) + if kv_cache.numel() > 0: + # TODO.. + if has_decode: + if envs.VLLM_USE_INT8_MLA: + write_kv_cache = (None, None) + ops.concat_and_cache_mla_int8( + k_c_normed[:num_decode_tokens], + decode_k_pe, + kv_cache, + kv_cache_scale, + attn_metadata.slot_mapping.flatten()[:num_decode_tokens], + kv_cache_dtype=self.kv_cache_dtype, + scale=layer._k_scale, + ) + else: + write_kv_cache = (k_c_normed[:num_decode_tokens], decode_k_pe) + # ops.concat_and_cache_mla( + # k_c_normed[:num_decode_tokens], + # decode_k_pe, + # kv_cache, + # attn_metadata.slot_mapping.flatten()[:num_decode_tokens], + # kv_cache_dtype=self.kv_cache_dtype, + # scale=layer._k_scale, + # ) + if has_prefill: + if envs.VLLM_USE_INT8_MLA: + ops.concat_and_cache_mla_int8( + prefill_k_c_normed, + prefill_k[...,self.qk_nope_head_dim:], + kv_cache, + kv_cache_scale, + attn_metadata.slot_mapping.flatten()[num_decode_tokens:], + kv_cache_dtype=self.kv_cache_dtype, + scale=layer._k_scale, + ) + else: + ops.concat_and_cache_mla( + prefill_k_c_normed, + prefill_k[...,self.qk_nope_head_dim:], + kv_cache, + attn_metadata.slot_mapping.flatten()[num_decode_tokens:], + kv_cache_dtype=self.kv_cache_dtype, + scale=layer._k_scale, + ) + output = torch.empty(output.shape[0], + self.o_proj.output_size, + device=hidden_states_or_q_c.device, + dtype=hidden_states_or_q_c.dtype) + + if has_prefill: + output[num_decode_tokens:] = self._forward_prefill( + prefill_q, prefill_k_c_normed, prefill_k, kv_cache,kv_cache_scale, + attn_metadata) + + if has_decode: + output[:num_decode_tokens] = self._forward_decode( + decode_ql_nope, decode_q_pe, kv_cache, kv_cache_scale,attn_metadata, *write_kv_cache) + + return output diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py new file mode 100644 index 0000000..143bfe3 --- /dev/null +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -0,0 +1,149 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import Any, Optional + +import torch + +from vllm.attention.backends.abstract import (AttentionType, + is_quantized_kv_cache) +from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, + get_mla_metadata, + is_flashmla_supported) +from vllm.logger import init_logger +from vllm.v1.attention.backends.mla.common import (MLACommonBackend, + MLACommonDecodeMetadata, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder) + +logger = init_logger(__name__) + + +class FlashMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "FLASHMLA_VLLM_V1" + + @staticmethod + def get_metadata_cls() -> type["FlashMLAMetadata"]: + return FlashMLAMetadata + + @staticmethod + def get_builder_cls() -> type["FlashMLAMetadataBuilder"]: + return FlashMLAMetadataBuilder + + @staticmethod + def get_impl_cls() -> type["FlashMLAImpl"]: + return FlashMLAImpl + + +@dataclass +class FlashMLADecodeMetadata(MLACommonDecodeMetadata): + tile_scheduler_metadata: tuple[torch.Tensor, torch.Tensor] + num_splits: torch.Tensor + + +@dataclass +class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): + pass + + +class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): + + def __init__(self, runner): + super().__init__(runner) + + self.num_q_heads = self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config) + + def _build_decode(self, input_positions: torch.Tensor, + block_table: torch.Tensor, + seq_lens: torch.Tensor) -> FlashMLADecodeMetadata: + tile_scheduler_metadata, num_splits = \ + get_mla_metadata( + seq_lens, + self.num_q_heads, + 1, # MQA for the decode path + ) + + return FlashMLADecodeMetadata( + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + tile_scheduler_metadata=tile_scheduler_metadata, + num_splits=num_splits, + ) + + +class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + **mla_args) + + assert is_flashmla_supported(), \ + "FlashMLA is not supported on this device" + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "FlashMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashMLAImpl") + + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "FlashMLA V1 with FP8 KV cache not yet supported") + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: FlashMLAMetadata, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + assert attn_metadata.decode is not None + + q = torch.cat([q_nope, q_pe], dim=-1)\ + .unsqueeze(1) # Add seqlen dim of 1 (decode) + + o, _ = flash_mla_with_kvcache( + q=q, + k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 + block_table=attn_metadata.decode.block_table, + cache_seqlens=attn_metadata.decode.seq_lens, + head_dim_v=self.kv_lora_rank, + tile_scheduler_metadata=attn_metadata.decode. + tile_scheduler_metadata, + num_splits=attn_metadata.decode.num_splits, + softmax_scale=self.scale, + causal=True, + ) + + return self._v_up_proj_and_o_proj(o) diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py new file mode 100644 index 0000000..6d33faa --- /dev/null +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -0,0 +1,154 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional + +import torch + +from vllm.attention.backends.abstract import (AttentionType, + is_quantized_kv_cache) +from vllm.attention.ops.triton_decode_attention import decode_attention_fwd +from vllm.logger import init_logger +from vllm.v1.attention.backends.mla.common import (MLACommonBackend, + MLACommonImpl, + MLACommonMetadata) +import ixformer.inference.functions as ixf_ops +import vllm.envs as envs +from vllm import _custom_ops as ops + +logger = init_logger(__name__) + + +class TritonMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "TRITON_MLA_VLLM_V1" + + @staticmethod + def get_impl_cls() -> type["TritonMLAImpl"]: + return TritonMLAImpl + + +class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + **mla_args) + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "TritonMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "TritonMLAImpl") + + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "TritonMLA V1 with FP8 KV cache not yet supported") + self._k_scale = torch.tensor(1.0, dtype=torch.float32) + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + kv_c_and_k_pe_cache_scale: torch.Tensor, + attn_metadata: MLACommonMetadata, + k_c_normed: torch.Tensor=None, + k_pe: torch.Tensor=None, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + assert attn_metadata.decode is not None + + if self.kv_cache_dtype.startswith("fp8"): + raise NotImplementedError("FP8 Triton MLA not yet supported") + + B = q_nope.shape[0] + q = torch.cat([q_nope, q_pe], dim=-1) + + o = torch.empty(B, + self.num_heads, + self.kv_lora_rank, + dtype=q_nope.dtype, + device=q_nope.device) + + # num_kv_splits = 4 # TODO: heuristic + + # # TODO(lucas) Allocate ahead of time + # attn_logits = torch.empty( + # ( + # B, + # self.num_heads, + # num_kv_splits, + # # NOTE(lucas) idk why the +1 is here but sglang has it so we + # # just mirror that + # self.kv_lora_rank + 1, + # ), + # dtype=torch.float32, + # device=q.device, + # ) + + # # Add a head dim of 1 + # kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2) + # kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] + # PAGE_SIZE = kv_c_and_k_pe_cache.size(1) + + # # Run MQA + # decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, + # attn_metadata.decode.block_table, + # attn_metadata.decode.seq_lens, attn_logits, + # num_kv_splits, self.scale, PAGE_SIZE) + if envs.VLLM_USE_INT8_MLA: + q_int8, q_scale = ops.quant_kv(q) + ixf_ops.vllm_paged_attention_mla_int8( + o, + q_int8, + q_scale, + kv_c_and_k_pe_cache, + kv_c_and_k_pe_cache_scale, + self.scale, + attn_metadata.decode.block_table, + attn_metadata.decode.seq_lens, + attn_metadata.decode.max_decode_seq_len, + attn_metadata.decode.use_cuda_graph + ) + + else: + # fused q concat & cache write + ixf_ops.vllm_paged_attention_mla_fused( + output=o, + q_nope=q_nope, + q_pe=q_pe.contiguous(), + kv_cache=kv_c_and_k_pe_cache, + scale=self.scale, + block_tables=attn_metadata.decode.block_table, + context_lens=attn_metadata.decode.seq_lens, + max_context_len=attn_metadata.decode.max_decode_seq_len, + k_c_normed=k_c_normed, + k_pe=k_pe, + use_cuda_graph=attn_metadata.decode.use_cuda_graph + ) + return self._v_up_proj_and_o_proj(o) diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py new file mode 100644 index 0000000..af729ee --- /dev/null +++ b/vllm/v1/attention/backends/pallas.py @@ -0,0 +1,198 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import Any, Optional + +import torch +# Required to register custom ops. +import torch_xla.experimental.custom_kernel # noqa: F401 + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, AttentionType) +from vllm.attention.backends.utils import CommonAttentionState + + +class PallasAttentionBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "PALLAS_VLLM_V1" + + @staticmethod + def get_impl_cls() -> type["PallasAttentionBackendImpl"]: + return PallasAttentionBackendImpl + + @staticmethod + def get_metadata_cls() -> type["PallasMetadata"]: + return PallasMetadata + + @staticmethod + def get_state_cls() -> type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> tuple[int, ...]: + return (num_blocks, block_size, num_kv_heads * 2, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + raise RuntimeError("swap_blocks is not used for the TPU backend.") + + +@dataclass +class PallasMetadata: + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + # Used in the PallasAttentionBackendImpl + slot_mapping: torch.Tensor + block_tables: torch.Tensor + context_lens: torch.Tensor + query_start_loc: torch.Tensor + num_seqs: int + + +class PallasAttentionBackendImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + ) -> None: + if blocksparse_params is not None: + raise ValueError("Paged attention Pallas kernel does " + "not support block-sparse attention.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.sliding_window = sliding_window + self.logits_soft_cap = logits_soft_cap + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + if head_size % 128 != 0: + raise NotImplementedError("Head size must be a multiple of 128.") + if alibi_slopes is not None: + raise NotImplementedError("Alibi slopes is not supported.") + if kv_cache_dtype != "auto": + raise NotImplementedError("FP8 KV cache dtype is not supported.") + if blocksparse_params is not None: + raise NotImplementedError("Blocksparse is not supported.") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "PallasAttentionBackendImpl") + + tpu_version = torch_xla.tpu.version() + if tpu_version < 4: + raise NotImplementedError("TPU version must be 4 or higher.") + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: PallasMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with Pallas attention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + # For determine_available_memory case. + if kv_cache.numel() == 0: + if output is None: + output = torch.ones_like(query) + return output + + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 + num_tokens, hidden_size = query.shape + query = query.view(num_tokens, self.num_heads, self.head_size) + + if kv_cache.numel() > 0: + slot_mapping = attn_metadata.slot_mapping + write_to_kv_cache(key, value, kv_cache, slot_mapping) + + output = torch.ops.xla.ragged_paged_attention( + query, + kv_cache, + attn_metadata.context_lens, + attn_metadata.block_tables, + attn_metadata.query_start_loc, + attn_metadata.num_seqs, + # By default, the system utilizes optimized block size and + # vmem_limit_bytes parameters from the kernel repository. However, + # these can be manually adjusted for debugging if necessary. + num_kv_pages_per_block=None, + num_queries_per_block=None, + vmem_limit_bytes=None, + use_kernel=True, + sm_scale=self.scale, + sliding_window=self.sliding_window, + soft_cap=self.logits_soft_cap, + ) + + return output.reshape(num_tokens, hidden_size) + + +def write_to_kv_cache( + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, +) -> None: + """ Write the key and values to the KV cache. + + Args: + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size] + + """ + _, _, num_combined_kv_heads, head_size = kv_cache.shape + num_kv_heads = num_combined_kv_heads // 2 + + key = key.view(-1, num_kv_heads, head_size) + value = value.view(-1, num_kv_heads, head_size) + + kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads, + head_size) + + torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True) + + kv_cache = kv_cache.flatten(0, 1) + kv_cache.index_copy_(0, slot_mapping, kv) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py new file mode 100644 index 0000000..5f96104 --- /dev/null +++ b/vllm/v1/attention/backends/triton_attn.py @@ -0,0 +1,198 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Attention layer with PagedAttention and Triton prefix prefill.""" +from typing import Any, Optional + +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, AttentionType) +from vllm.attention.ops.chunked_prefill_paged_decode import ( + chunked_prefill_paged_decode) +from vllm.attention.ops.paged_attn import PagedAttention +from vllm.logger import init_logger +from vllm.v1.attention.backends.flash_attn import ( + FlashAttentionMetadata, FlashAttentionMetadataBuilder) + +logger = init_logger(__name__) + + +class TritonAttentionBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> list[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_name() -> str: + return "TRITON_ATTN_VLLM_V1" + + @staticmethod + def get_impl_cls() -> type["TritonAttentionImpl"]: + return TritonAttentionImpl + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + return FlashAttentionMetadata + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def use_cascade_attention(*args, **kwargs) -> bool: + return False + + @staticmethod + def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]: + return FlashAttentionMetadataBuilder + + +class TritonAttentionImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: AttentionType = AttentionType.DECODER, + use_irope: bool = False, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "TritonAttention does not support block-sparse attention.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + if sliding_window is None: + self.sliding_window = (-1, -1) + else: + self.sliding_window = (sliding_window - 1, 0) + self.kv_cache_dtype = kv_cache_dtype + self.use_irope = use_irope + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + support_head_sizes = TritonAttentionBackend.get_supported_head_sizes() + if head_size not in support_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by TritonAttention. " + f"Supported head sizes are: {support_head_sizes}.") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "TritonAttentionImpl") + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with FlashAttention. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + assert output is not None, "Output tensor must be provided." + + if attn_metadata is None: + # Profiling run. + return output + + assert attn_metadata.use_cascade is False + + # IMPORTANT! + # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in + # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead + # in this method. For example, `view` and `slice` (or `[:n]`) operations + # are surprisingly slow even in the case they do not invoke any GPU ops. + # Minimize the PyTorch ops in this method as much as possible. + # Whenever making a change in this method, please benchmark the + # performance to make sure it does not introduce any overhead. + + num_actual_tokens = attn_metadata.num_actual_tokens + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + # Reshape the input keys and values and store them in the cache. + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + use_local_attn = \ + (self.use_irope and attn_metadata.local_attn_metadata is not None) + + if use_local_attn: + assert attn_metadata.local_attn_metadata is not None + local_metadata = attn_metadata.local_attn_metadata + cu_seqlens_q = local_metadata.local_query_start_loc + sequesd_k = local_metadata.local_seqused_k + max_seqlen_q = local_metadata.local_max_query_len + max_seqlen_k = local_metadata.local_max_seq_len + block_table = local_metadata.local_block_table + else: + cu_seqlens_q = attn_metadata.query_start_loc + sequesd_k = attn_metadata.seq_lens + max_seqlen_q = attn_metadata.max_query_len + max_seqlen_k = attn_metadata.max_seq_len + block_table = attn_metadata.block_table + + # Compute attention and update output up to `num_actual_tokens`. + chunked_prefill_paged_decode(query=query[:num_actual_tokens], + key=key[:num_actual_tokens], + value=value[:num_actual_tokens], + output=output[:num_actual_tokens], + kv_cache_dtype=self.kv_cache_dtype, + key_cache=key_cache, + value_cache=value_cache, + block_table=block_table, + query_start_loc=cu_seqlens_q, + seq_lens=sequesd_k, + max_seq_len=max_seqlen_k, + max_query_len=max_seqlen_q, + k_scale=layer._k_scale, + v_scale=layer._v_scale, + alibi_slopes=self.alibi_slopes, + sliding_window=self.sliding_window[0], + sm_scale=self.scale) + + return output diff --git a/vllm/v1/core/__init__.py b/vllm/v1/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py new file mode 100644 index 0000000..43f30f7 --- /dev/null +++ b/vllm/v1/core/block_pool.py @@ -0,0 +1,281 @@ +# SPDX-License-Identifier: Apache-2.0 +from collections import defaultdict +from collections.abc import Iterable +from typing import Callable, Optional + +from vllm.logger import init_logger +from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, + KVCacheBlock, + generate_block_hash_extra_keys, + hash_block_tokens) +from vllm.v1.request import Request + +logger = init_logger(__name__) + + +class BlockPool: + """BlockPool that manages KVCacheBlocks. + It provides methods to allocate, free and cache the kv cache blocks. The + free_block_queue stores the free blocks in eviction order to enable + allocation, free, and cache eviction. The cached_block_hash_to_block + maps between block hash and cached block to support finding cached blocks + by their block hash. + + Args: + num_gpu_blocks: The number of blocks in the pool. + enable_caching: Whether to enable prefix caching. + """ + + def __init__(self, num_gpu_blocks: int, enable_caching: bool): + assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0 + self.num_gpu_blocks = num_gpu_blocks + self.enable_caching = enable_caching + # All kv-cache blocks. + self.blocks: list[KVCacheBlock] = [ + KVCacheBlock(idx) for idx in range(num_gpu_blocks) + ] + # Free block queue that constructs and manipulates a doubly linked + # list of free blocks (including eviction candidates when caching is + # enabled). + self.free_block_queue = FreeKVCacheBlockQueue(self.blocks) + + # {block_hash: {block ID: block}}. A cached block is + # a full block with a block hash that can be used for prefix caching. + # The cached block may be used by running requests or in the + # free_block_queue that could potentially be evicted. + # NOTE: We currently don't de-duplicate the blocks in the cache, + # meaning that if a block becomes full and is cached, we don't check + # if there is already an identical block in the cache. This is because + # we want to make sure the allocated block IDs won't change so that + # block tables are append-only. + self.cached_block_hash_to_block: dict[BlockHashType, dict[ + int, KVCacheBlock]] = defaultdict(dict) + + # To represent a placeholder block with block_id=0. + # The ref_cnt of null_block is not maintained, needs special care to + # avoid freeing it. + self.null_block = self.free_block_queue.popleft() + + def get_cached_block(self, + block_hash: BlockHashType) -> Optional[KVCacheBlock]: + """Get a cached block by the block hash, or None if cache miss. + If there are duplicated blocks, we return the first block in the cache. + + Args: + block_hash: The hash value of the block. + + Returns: + The cached block if it exists, or None. + """ + if block_hash in self.cached_block_hash_to_block: + first_block_id = list( + self.cached_block_hash_to_block[block_hash].keys())[0] + return self.cached_block_hash_to_block[block_hash][first_block_id] + return None + + def cache_full_blocks( + self, + request: Request, + blocks: list[KVCacheBlock], + block_hashes: list[BlockHashType], + num_cached_blocks: int, + num_full_blocks: int, + block_size: int, + hash_fn: Callable, + ) -> None: + """Cache a list of full blocks for prefix caching. + This function takes a list of blocks that will have their block hash + metadata to be updated and cached. Given a request, it computes the + block hashes for the blocks starting from `num_cached_blocks` to + `num_full_blocks`, updating the metadata for each block + and caching them in the `cached_block_hash_to_block`. + + Args: + request: The request to cache the blocks. + blocks: All blocks in the request. + block_hashes: Block hashes of the blocks in the request. Note that + this list may be shorter than the blocks list. In this case the + missed block hash will be computed in this function. + num_cached_blocks: The number of blocks that are already cached. + num_full_blocks: The number of blocks that are full and should + be cached after this function. + block_size: Number of tokens in each block. + hash_fn: The hash function to use for block hashes. + """ + if num_cached_blocks == num_full_blocks: + return + new_full_blocks = blocks[num_cached_blocks:num_full_blocks] + assert len(block_hashes) >= num_cached_blocks + new_block_hashes = block_hashes[num_cached_blocks:] + + # Update the new blocks with the block hashes through the chain. + if num_cached_blocks == 0: + prev_block_hash_value = None + else: + prev_block = blocks[num_cached_blocks - 1] + assert prev_block.block_hash is not None + prev_block_hash_value = prev_block.block_hash.hash_value + + for i, blk in enumerate(new_full_blocks): + assert blk.block_hash is None + + if i < len(new_block_hashes): + # The block hash may already be computed in + # "get_computed_blocks" if the tokens are not generated by + # this request (either the prompt tokens or the previously + # generated tokens with preemption). In this case we simply + # reuse the block hash. + block_hash = new_block_hashes[i] + else: + # Otherwise compute the block hash and cache it in the request + # in case it will be preempted in the future. + blk_idx = num_cached_blocks + i + start_token_idx = blk_idx * block_size + end_token_idx = (blk_idx + 1) * block_size + block_tokens = request.all_token_ids[ + start_token_idx:end_token_idx] + assert len(block_tokens) == block_size, ( + f"Expected {block_size} tokens, got " + f"{len(block_tokens)} at {blk_idx}th block for request " + f"{request.request_id}({request})") + + # Generate extra keys for multi-modal inputs. Note that since + # we reach to this branch only when the block is completed with + # generated tokens, we only need to consider the last mm input. + extra_keys, _ = generate_block_hash_extra_keys( + request, start_token_idx, end_token_idx, -1) + + # Compute the hash of the current block. + block_hash = hash_block_tokens(hash_fn, prev_block_hash_value, + block_tokens, extra_keys) + block_hashes.append(block_hash) + + # Update and added the full block to the cache. + blk.block_hash = block_hash + self.cached_block_hash_to_block[block_hash][blk.block_id] = blk + prev_block_hash_value = block_hash.hash_value + + def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]: + """Get new blocks from the free block pool. + + Note that we do not check block cache in this function. + + Args: + num_blocks: The number of blocks to allocate. + + Returns: + A list of new block. + """ + if num_blocks > self.get_num_free_blocks(): + raise ValueError( + f"Cannot get {num_blocks} free blocks from the pool") + + ret: list[KVCacheBlock] = [] + idx = 0 + while idx < num_blocks: + # First allocate blocks. + curr_block = self.free_block_queue.popleft() + assert curr_block.ref_cnt == 0 + + # If the block is cached, evict it. + if self.enable_caching: + self._maybe_evict_cached_block(curr_block) + + curr_block.incr_ref() + ret.append(curr_block) + idx += 1 + + return ret + + def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: + """ + If a block is cached in `cached_block_hash_to_block`, we reset its hash + metadata and evict it from the cache. + + Args: + block: The block to evict. + + Returns: + True if the block is evicted, False otherwise. + """ + block_hash = block.block_hash + if block_hash and block_hash in self.cached_block_hash_to_block: + block.reset_hash() + del self.cached_block_hash_to_block[block_hash][block.block_id] + + if len(self.cached_block_hash_to_block[block_hash]) == 0: + del self.cached_block_hash_to_block[block_hash] + + return True + return False + + def touch(self, blocks: list[KVCacheBlock]) -> None: + """Touch a block increases its reference count by 1, and may remove + the block from the free queue. This is used when a block is hit by + another request with the same prefix. + + Args: + blocks: A list of blocks to touch. + """ + for block in blocks: + # ref_cnt=0 means this block is in the free list (i.e. eviction + # candidate), so remove it. + if block.ref_cnt == 0 and block != self.null_block: + self.free_block_queue.remove(block) + block.incr_ref() + + def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: + """Free a list of blocks. The blocks should be ordered by their + eviction priority, where the first block will be evicted first. + + Args: + ordered_blocks: A list of blocks to free ordered by their eviction + priority. + """ + for block in ordered_blocks: + block.decr_ref() + # null_block should not be added to the free list. + if block.ref_cnt == 0 and block != self.null_block: + self.free_block_queue.append(block) + + def reset_prefix_cache(self) -> bool: + """Reset prefix cache. This function may be used in RLHF + flows to invalid prefix caching after the weights are updated, + or used for resetting prefix caching status for benchmarking. + + Returns: + bool: True if the prefix cache is successfully reset, + False otherwise. + """ + num_used_blocks = (self.num_gpu_blocks - self.get_num_free_blocks()) + if num_used_blocks != 1: # The null block is always marked as used + logger.warning( + "Failed to reset prefix cache because some " + "blocks (%d) are not freed yet", num_used_blocks - 1) + return False + + # Remove all hashes so that no new blocks will hit. + self.cached_block_hash_to_block = defaultdict(dict) + + # Remove all hashes from all blocks. + for block in self.blocks: + block.reset_hash() + + logger.info("Successfully reset prefix cache") + return True + + def get_num_free_blocks(self) -> int: + """Get the number of free blocks in the pool. + + Returns: + The number of free blocks. + """ + return self.free_block_queue.num_free_blocks + + def get_usage(self) -> float: + """Get the KV cache usage. + + Returns: + The KV cache usage (between 0.0 and 1.0). + """ + return 1.0 - (self.get_num_free_blocks() / self.num_gpu_blocks) diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py new file mode 100644 index 0000000..dc76df2 --- /dev/null +++ b/vllm/v1/core/encoder_cache_manager.py @@ -0,0 +1,141 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import TYPE_CHECKING + +from vllm.logger import init_logger +from vllm.multimodal import MultiModalRegistry +from vllm.v1.request import Request + +if TYPE_CHECKING: + from vllm.config import ModelConfig, SchedulerConfig + +logger = init_logger(__name__) + + +class EncoderCacheManager: + + def __init__(self, cache_size: int): + self.cache_size = cache_size + self.num_free_slots = cache_size + # req_id -> cached input ids + self.cached: dict[str, set[int]] = {} + # list of [req_id, input_id] + self.freed: list[tuple[str, int]] = [] + + def has_cache(self, request: Request, input_id: int) -> bool: + req_id = request.request_id + return req_id in self.cached and input_id in self.cached[req_id] + + def can_allocate(self, request: Request, input_id: int) -> bool: + num_tokens = request.get_num_encoder_tokens(input_id) + return num_tokens <= self.num_free_slots + + def allocate(self, request: Request, input_id: int) -> None: + req_id = request.request_id + if req_id not in self.cached: + self.cached[req_id] = set() + self.cached[req_id].add(input_id) + self.num_free_slots -= request.get_num_encoder_tokens(input_id) + + def get_cached_input_ids(self, request: Request) -> set[int]: + return self.cached.get(request.request_id, set()) + + def free_encoder_input(self, request: Request, input_id: int) -> None: + """Free a single encoder input id for the request.""" + req_id = request.request_id + if req_id not in self.cached: + return + + self.cached[req_id].discard(input_id) + if len(self.cached[req_id]) == 0: + del self.cached[req_id] + self.num_free_slots += request.get_num_encoder_tokens(input_id) + self.freed.append((req_id, input_id)) + + def free(self, request: Request) -> None: + """Free all cached input ids for the request.""" + input_ids = self.get_cached_input_ids(request).copy() + for input_id in input_ids: + self.free_encoder_input(request, input_id) + + def get_freed_ids(self) -> list[tuple[str, int]]: + freed = self.freed + self.freed = [] + return freed + + +def compute_encoder_budget( + model_config: "ModelConfig", + scheduler_config: "SchedulerConfig", + mm_registry: MultiModalRegistry, +) -> tuple[int, int]: + """Compute the encoder cache budget based on the model and scheduler + configurations. + + Args: + model_config: Model configuration. + scheduler_config: Scheduler configuration. + mm_registry: Provides information about the token cost. + + Returns: + - Compute budget for encoder execution, in unit of number of tokens + in the input sequence. + - Space budget for encoder cache size, in unit of number of tokens + in the input sequence. + """ + + if not model_config.is_multimodal_model: + return 0, 0 + + # TODO: handle encoder-decoder models once we support them. + ( + encoder_compute_budget, + encoder_cache_size, + ) = _compute_encoder_budget_multimodal( + model_config, + scheduler_config, + mm_registry, + ) + + return encoder_compute_budget, encoder_cache_size + + +def _compute_encoder_budget_multimodal( + model_config: "ModelConfig", + scheduler_config: "SchedulerConfig", + mm_registry: MultiModalRegistry, +) -> tuple[int, int]: + """Compute the encoder cache budget based on the model and scheduler + configurations for a multimodal model. + + Args: + model_config: Model configuration. + scheduler_config: Scheduler configuration. + mm_registry: Provides information about the token cost. + + Returns: + - Compute budget for encoder execution, in unit of number of tokens + in the input sequence. + - Space budget for encoder cache size, in unit of number of tokens + in the input sequence. + """ + + max_tokens_by_modality_dict = mm_registry \ + .get_max_tokens_per_item_by_nonzero_modality(model_config) + + if not max_tokens_by_modality_dict: + logger.warning( + "All non-text modalities supported by the model have been " + "explicitly disabled via limit_mm_per_prompt. Encoder cache will " + "not be initialized.") + return 0, 0 + + _, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(), + key=lambda item: item[1]) + + encoder_compute_budget = max(scheduler_config.max_num_encoder_input_tokens, + max_tokens_per_mm_item) + encoder_cache_size = max(scheduler_config.encoder_cache_size, + max_tokens_per_mm_item) + + return encoder_compute_budget, encoder_cache_size diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py new file mode 100644 index 0000000..c0f7715 --- /dev/null +++ b/vllm/v1/core/kv_cache_manager.py @@ -0,0 +1,376 @@ +# SPDX-License-Identifier: Apache-2.0 + +from collections import defaultdict +from collections.abc import Iterable +from typing import Optional + +from vllm.logger import init_logger +from vllm.utils import cdiv, sha256 +from vllm.v1.core.block_pool import BlockPool +from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, + hash_request_tokens) +from vllm.v1.core.specialized_manager import get_specialized_manager +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.metrics.stats import PrefixCacheStats +from vllm.v1.request import Request, RequestStatus + +logger = init_logger(__name__) + + +class KVCacheManager: + + def __init__( + self, + kv_cache_config: KVCacheConfig, + max_model_len: int, + enable_caching: bool = True, + caching_hash_algo: str = "builtin", + num_preallocate_tokens: int = 64, + log_stats: bool = False, + ) -> None: + assert len(kv_cache_config.kv_cache_groups) == 1, ( + "KVCacheManager does not support hybrid models with more than 1 " + "kv cache group") + kv_cache_spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec + self.block_size = kv_cache_spec.block_size + self.num_gpu_blocks = kv_cache_config.num_blocks + self.max_model_len = max_model_len + self.max_num_blocks_per_req = cdiv(max_model_len, self.block_size) + + self.enable_caching = enable_caching + self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash + # FIXME: make prefix cache stats conditional on log_stats + self.log_stats = log_stats + # NOTE(woosuk): To avoid frequent block allocation, we preallocate some + # blocks for each request. For example, when a request reaches the end + # of its block table, we preallocate N blocks in advance. This way, we + # reduce the overhead of updating free_block_ids and ref_cnts for each + # request every step (at the cost of some memory waste). + # NOTE(woosuk): This is different from the "lookahead" slots since this + # does not guarantee that the request always has N empty blocks. After + # the request gets N empty blocks, it starts to use the blocks without + # further allocation. When it uses up all the N empty blocks, it gets + # N new empty blocks. + self.num_preallocate_tokens = num_preallocate_tokens + self.num_preallocate_blocks = cdiv(num_preallocate_tokens, + self.block_size) + + self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching) + + self.specialized_manager = get_specialized_manager( + kv_cache_spec=kv_cache_spec, + block_pool=self.block_pool, + ) + + # Mapping from request ID to blocks to track the blocks allocated + # for each request, so that we can free the blocks when the request + # is finished. + self.req_to_blocks: defaultdict[str, + list[KVCacheBlock]] = defaultdict(list) + + # Mapping from request ID to kv block hashes. + # This is to avoid recomputing the block hashes for each call of + # `get_computed_blocks` or `allocate_slots`. + self.req_to_block_hashes: defaultdict[ + str, list[BlockHashType]] = defaultdict(list) + + # {req_id: The number of cached blocks for this given request} + # This is used to track the number of cached blocks for each request. + # This is only used to track the RUNNING requests, we do not track the + # data for reempted ones. + self.num_cached_block: dict[str, int] = {} + self.prefix_cache_stats = PrefixCacheStats() + + @property + def usage(self) -> float: + """Get the KV cache usage. + + Returns: + The KV cache usage (between 0.0 and 1.0). + """ + return self.block_pool.get_usage() + + def make_prefix_cache_stats(self) -> PrefixCacheStats: + """Get (and reset) the prefix cache stats. + + Returns: + The current prefix caching stats. + """ + stats = self.prefix_cache_stats + self.prefix_cache_stats = PrefixCacheStats() + return stats + + def get_computed_blocks( + self, request: Request) -> tuple[list[KVCacheBlock], int]: + """Get the computed (cached) blocks for the request. + Note that the computed blocks must be full. + + Args: + request: The request to get the computed blocks. + + Returns: + A tuple containing: + - A list of blocks that are computed for the request. + - The number of computed tokens. + """ + if not self.enable_caching: + # Prefix caching is disabled. + return [], 0 + + # The block hashes for the request may already be computed + # if the scheduler has tried to schedule the request before. + block_hashes = self.req_to_block_hashes[request.request_id] + if not block_hashes: + block_hashes = hash_request_tokens(self.caching_hash_fn, + self.block_size, request) + self.req_to_block_hashes[request.request_id] = block_hashes + + self.prefix_cache_stats.requests += 1 + if request.sampling_params.prompt_logprobs is None: + if len(block_hashes) * self.block_size == request.num_tokens: + # When prompt length is divisible by the block size and all + # blocks are cached, we need to recompute the last token. This + # have to be achieved by re-computing an entire block because + # allocate_slots() assumes num_computed_tokens is always a + # multiple of the block size. To achieve this, remove the last + # block hash from the block_hashes for find_longest_cache_hit + # This limitation can potentially be removed in the future to + # slightly improve the performance. + last_block_hash = block_hashes.pop() + else: + last_block_hash = None + + computed_blocks = ( + self.specialized_manager.find_longest_cache_hit(block_hashes)) + + if last_block_hash is not None: + # Add back the last block hash if it was removed. + block_hashes.append(last_block_hash) + + self.prefix_cache_stats.queries += len(block_hashes) + self.prefix_cache_stats.hits += len(computed_blocks) + + # NOTE(woosuk): Since incomplete blocks are not eligible for + # sharing, `num_computed_tokens` is always a multiple of + # `block_size`. + num_computed_tokens = len(computed_blocks) * self.block_size + return computed_blocks, num_computed_tokens + else: + # Skip cache hits for prompt logprobs + return [], 0 + + def allocate_slots( + self, + request: Request, + num_tokens: int, + new_computed_blocks: Optional[list[KVCacheBlock]] = None + ) -> Optional[list[KVCacheBlock]]: + """Add slots for a request with new tokens to append. + + Args: + request: The request to allocate slots. + num_tokens: The number of tokens to allocate. Note that this does + not include the tokens that have already been computed. + new_computed_blocks: A list of new computed blocks just hitting the + prefix caching. + + Blocks layout: + ----------------------------------------------------------------------- + | < computed > | < new computed > | < new > | < pre-allocated > | + ----------------------------------------------------------------------- + | < required > | + -------------------------------------------------- + | < full > | + ------------------------------------------------ + | | + -------------- + The following *_blocks are illustrated in this layout. + + Returns: + A list of new allocated blocks. + """ + if num_tokens == 0: + raise ValueError("num_tokens must be greater than 0") + + new_computed_blocks = new_computed_blocks or [] + + req_blocks = self.req_to_blocks[request.request_id] + + # Free the blocks that are skipped during the attention computation + # (e.g., tokens outside the sliding window). + # We can do this even if we cannot schedule this request due to + # insufficient free blocks. + # Should call this function before allocating new blocks to reduce + # the number of evicted blocks. + removed_blocks = self.specialized_manager.remove_skipped_blocks( + req_blocks, request.num_computed_tokens) + self.block_pool.free_blocks(removed_blocks) + + # The number of computed tokens is the number of computed tokens plus + # the new prefix caching hits + num_computed_tokens = (request.num_computed_tokens + + len(new_computed_blocks) * self.block_size) + num_required_blocks = cdiv(num_computed_tokens + num_tokens, + self.block_size) + num_new_blocks = (num_required_blocks - len(req_blocks) - + len(new_computed_blocks)) + + # If a computed block of a request is an eviction candidate (in the + # free queue and ref_cnt == 0), it cannot be counted as a free block + # when allocating this request. + num_evictable_computed_blocks = sum(1 for blk in new_computed_blocks + if blk.ref_cnt == 0) + if (num_new_blocks > self.block_pool.get_num_free_blocks() - + num_evictable_computed_blocks): + # Cannot allocate new blocks + return None + + # Touch the computed blocks to make sure they won't be evicted. + if self.enable_caching: + self.block_pool.touch(new_computed_blocks) + else: + assert not new_computed_blocks, ( + "Computed blocks should be empty when " + "prefix caching is disabled") + + # Append the new computed blocks to the request blocks until now to + # avoid the case where the new blocks cannot be allocated. + req_blocks.extend(new_computed_blocks) + + # Start to handle new blocks + + if num_new_blocks <= 0: + # No new block is needed. + new_blocks = [] + else: + # Get new blocks from the free block pool considering + # preallocated blocks. + num_new_blocks = min( + num_new_blocks + self.num_preallocate_blocks, + self.block_pool.get_num_free_blocks(), + # Should not exceed the maximum number of blocks per request. + # This is especially because the block table has the shape + # [..., max_num_blocks_per_req]. + self.max_num_blocks_per_req - len(req_blocks), + ) + assert num_new_blocks > 0 + + # Concatenate the computed block IDs and the new block IDs. + new_blocks = self.block_pool.get_new_blocks(num_new_blocks) + req_blocks.extend(new_blocks) + + if not self.enable_caching: + return new_blocks + + # Use `new_computed_blocks` for a new request, and `num_cached_block` + # for a running request. + num_cached_blocks = self.num_cached_block.get(request.request_id, + len(new_computed_blocks)) + # Speculated tokens might be rejected in the future, so we does + # not cache any speculated tokens. We only cache blocks with + # generated (accepted) tokens. + num_full_blocks_after_append = (num_computed_tokens + num_tokens - len( + request.spec_token_ids)) // self.block_size + + self.block_pool.cache_full_blocks( + request=request, + blocks=req_blocks, + block_hashes=self.req_to_block_hashes[request.request_id], + num_cached_blocks=num_cached_blocks, + num_full_blocks=num_full_blocks_after_append, + block_size=self.block_size, + hash_fn=self.caching_hash_fn, + ) + + self.num_cached_block[ + request.request_id] = num_full_blocks_after_append + return new_blocks + + def free(self, request: Request) -> None: + """Free the blocks allocated for the request. + When caching is enabled, we free the blocks in reverse order so that + the tail blocks are evicted first. + + Args: + request: The request to free the blocks. + """ + # Default to [] in case a request is freed (aborted) before alloc. + blocks = self.req_to_blocks.pop(request.request_id, []) + ordered_blocks: Iterable[KVCacheBlock] = blocks + if self.enable_caching: + # Free blocks in reverse order so that the tail blocks are + # freed first. + ordered_blocks = reversed(blocks) + + self.block_pool.free_blocks(ordered_blocks) + self.num_cached_block.pop(request.request_id, None) + + def reset_prefix_cache(self) -> bool: + """Reset prefix cache. This function may be used in RLHF + flows to invalid prefix caching after the weights are updated, + or used for resetting prefix caching status for benchmarking. + + Returns: + bool: True if the prefix cache is successfully reset, + False otherwise. + """ + if self.block_pool.reset_prefix_cache(): + self.prefix_cache_stats.reset = True + return True + return False + + def get_num_common_prefix_blocks( + self, + request: Request, + num_running_requests: int, + ) -> int: + """Calculate the number of common prefix blocks shared by all requests + in the RUNNING state. + + The function determines this by selecting any request and iterating + through its blocks. A block is considered a common prefix block if its + `ref_cnt` equals the total number of requests in the RUNNING state. + + NOTE(woosuk): The number of requests in the RUNNING state is **greater + than or equal to** the number of requests scheduled in the current step. + This is because the RUNNING state only indicates that: + 1. The request has not yet finished, and + 2. The request holds its blocks unfreed. + + While all scheduled requests must be in the RUNNING state, the inverse + is not necessarily true. There may be RUNNING requests that are not + scheduled in the current step. + + This can result in an edge case where the number of common prefix blocks + is 0, even though all scheduled requests share a common prefix. This + occurs because there may be unscheduled RUNNING requests that do not + share the common prefix. Currently, this case cannot be easily detected, + so the function returns 0 in such cases. + + Args: + request: Any request in the RUNNING state, used to identify the + common prefix blocks. + num_running_requests: The total number of requests in the RUNNING + state. This can be different from the number of scheduled + requests in the current step. + + Returns: + int: The number of common prefix blocks. + """ + assert request.status == RequestStatus.RUNNING + blocks = self.req_to_blocks[request.request_id] + num_common_blocks = 0 + for block in blocks: + if block.ref_cnt == num_running_requests: + num_common_blocks += 1 + else: + break + return num_common_blocks + + def free_block_hashes(self, request: Request) -> None: + """Discard the block hashes for the request. + + NOTE: Unlike `free`, this method should be called only when the request + is finished, not when it is preempted. + """ + self.req_to_block_hashes.pop(request.request_id, None) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py new file mode 100644 index 0000000..d3575fb --- /dev/null +++ b/vllm/v1/core/kv_cache_utils.py @@ -0,0 +1,690 @@ +# SPDX-License-Identifier: Apache-2.0 +"""KV-Cache Utilities.""" +import os +from collections import deque +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any, Callable, NamedTuple, Optional + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.utils import sha256 +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheGroupSpec, KVCacheSpec, + KVCacheTensor, SlidingWindowSpec) +from vllm.v1.metrics.stats import PrefixCacheStats +from vllm.v1.request import Request + +logger = init_logger(__name__) + + +class BlockHashType(NamedTuple): + """Hash value of a block (int), the token IDs in the block, and extra keys. + We keep a tuple of token IDs and extra keys to reduce the likelihood of + hash collisions when the hash value is the same. By using SHA256 however, + hash collisions are practically impossible. + """ + # Hash value of the block in an integer. + hash_value: int + # Token IDs in the block. + token_ids: tuple[int, ...] + # Extra keys for the block. + extra_keys: Optional[Any] = None + + +# The hash seed for the first block of the prefix block sequence. +# +# Even if the hash function is the builtin hash(), we use sha256 to generate +# the initial hash to simplify the code. This is not performance critical +# as it is done one per process. +# +# We use a random value to avoid hash collisions or PYTHONHASHSEED environment +# variable if set such that processes can share the seed if needed. +# This aligns with the behavior of Python's hash() function, which also uses +# a random seed if PYTHONHASHSEED is not set. +NONE_HASH = int.from_bytes(os.urandom(32), byteorder="big") if os.getenv( + 'PYTHONHASHSEED') is not None else sha256(os.getenv('PYTHONHASHSEED')) + + +class PrefixCachingMetrics: + """Metrics for prefix caching with a hit rate of the most recent N requests. + + Args: + interval: The number of the most recent requests to aggregate. + Defaults to 1000. + """ + + def __init__(self, interval: int = 1000): + self.interval = interval + # The current aggregated values. + self.aggregated_requests = 0 + self.aggregated_query_total = 0 + self.aggregated_query_hit = 0 + # A deque of (requests, queries, hits) for the most recent requests. + self.query_queue: deque[tuple[int, int, int]] = deque() + + def observe(self, stats: PrefixCacheStats): + """Observe the prefix caching for a set of requests. + + This function is called with information gathered when new requests + are being scheduled and are looking for computed blocks. + + When there are more than `interval` requests, the oldest set of + requestsare removed from the metrics. + + Args: + stats: The prefix cache stats. + """ + # reset_prefix_cache was invoked before the current update. + # Reset the metrics before aggregating the current stats. + if stats.reset: + self.reset() + + # Update the metrics. + self.query_queue.append((stats.requests, stats.queries, stats.hits)) + self.aggregated_requests += stats.requests + self.aggregated_query_total += stats.queries + self.aggregated_query_hit += stats.hits + + # Remove the oldest stats if the number of requests exceeds. + if self.aggregated_requests > self.interval: + old_requests, old_queries, old_hits = self.query_queue.popleft() + self.aggregated_requests -= old_requests + self.aggregated_query_total -= old_queries + self.aggregated_query_hit -= old_hits + + def reset(self): + """Reset the metrics.""" + self.aggregated_requests = 0 + self.aggregated_query_total = 0 + self.aggregated_query_hit = 0 + self.query_queue.clear() + + @property + def hit_rate(self) -> float: + """Calculate the hit rate for the past N requests.""" + if self.aggregated_query_total == 0: + return 0.0 + return self.aggregated_query_hit / self.aggregated_query_total + + +@dataclass +class KVCacheBlock: + """KV-cache block metadata.""" + # Block ID, ranging from 0 to num_gpu_blocks - 1. + block_id: int + # Reference count. + ref_cnt: int = 0 + # The hash of the block composed of (block hash, tuple of token IDs). + # It is only available when the block is full. + _block_hash: Optional[BlockHashType] = None + + # Used to construct a doubly linked list for free blocks. + # These two attributes should only be manipulated by FreeKVCacheBlockQueue. + prev_free_block: Optional["KVCacheBlock"] = None + next_free_block: Optional["KVCacheBlock"] = None + + def incr_ref(self): + self.ref_cnt += 1 + + def decr_ref(self): + self.ref_cnt -= 1 + + @property + def block_hash(self) -> Optional[BlockHashType]: + return self._block_hash + + @block_hash.setter + def block_hash(self, block_hash: BlockHashType): + assert self.block_hash is None, ( + "The block already has a hash. This should not happen.") + self._block_hash = block_hash + + def reset_hash(self): + """Reset the block hash when the block is evicted.""" + self._block_hash = None + + def __repr__(self) -> str: + # Use block_id instead of KVCacheBlock object to avoid calling __repr__ + # on KVCacheBlock object recursively. + prev_block_id = self.prev_free_block.block_id \ + if self.prev_free_block else None + next_block_id = self.next_free_block.block_id \ + if self.next_free_block else None + return (f"KVCacheBlock(block_id={self.block_id}, " + f"ref_cnt={self.ref_cnt}, " + f"_block_hash={self._block_hash}, " + f"prev_free_block={prev_block_id}, " + f"next_free_block={next_block_id})") + + +class FreeKVCacheBlockQueue: + """This class organizes a list of KVCacheBlock objects to a doubly linked + list of free blocks. We implement this class instead of using Python + builtin deque to support removing a block in the middle of the queue + in O(1) time. To close the performance gap to the builtin deque which is + implemented in C++, this class does not allocate any Python objects when + manipulating the linked list. Instead, this class manipulates the + prev_free_block and next_free_block attributes of the given blocks. + + The queue is ordered by block ID in the beginning. When a block is allocated + and then freed, it will be appended back with the eviction order: + 1. The least recent used block is at the front (LRU). + 2. If two blocks have the same last accessed time (allocated by the + same sequence), the one with more hash tokens (the tail of a block + chain) is at the front. + Note that we maintain this order by reversing the block order when free + blocks of a request. This operation is outside of this class. + + Args: + blocks: A list of KVCacheBlock objects. + """ + + def __init__(self, blocks: list[KVCacheBlock]) -> None: + self.num_free_blocks = len(blocks) + + # Initialize the doubly linked list of free blocks. + self.free_list_head: Optional[KVCacheBlock] = blocks[0] + self.free_list_tail: Optional[KVCacheBlock] = blocks[-1] + for i in range(self.num_free_blocks): + if i > 0: + blocks[i].prev_free_block = blocks[i - 1] + if i < self.num_free_blocks - 1: + blocks[i].next_free_block = blocks[i + 1] + + def popleft(self) -> KVCacheBlock: + """Pop the first free block and reduce num_free_blocks by 1. + + Returns: + The first free block. + """ + if not self.free_list_head: + raise ValueError("No free blocks available") + + block = self.free_list_head + self.remove(block) + return block + + def remove(self, block: KVCacheBlock) -> None: + """Remove a block in the free list and reduce num_free_blocks by 1. + + Args: + block: The block to remove. + """ + if block.prev_free_block is not None: + # Link the previous block to the next block. + block.prev_free_block.next_free_block = block.next_free_block + if block.next_free_block is not None: + # Link the next block to the previous block. + block.next_free_block.prev_free_block = block.prev_free_block + + if block == self.free_list_head: + # Update the head if the block is the head. + self.free_list_head = block.next_free_block + if block == self.free_list_tail: + # Update the tail if the block is the tail. + self.free_list_tail = block.prev_free_block + + # Remove the block from the linked list. + block.prev_free_block = block.next_free_block = None + self.num_free_blocks -= 1 + + def append(self, block: KVCacheBlock) -> None: + """Put a block back into the free list and increase + num_free_blocks by 1. + + Args: + block: The block to append. + """ + if self.free_list_tail is not None: + # Link the last block to the new block. + self.free_list_tail.next_free_block = block + block.prev_free_block = self.free_list_tail + self.free_list_tail = block + else: + # The free list is empty. + assert self.free_list_head is None + self.free_list_head = self.free_list_tail = block + + block.next_free_block = None + self.num_free_blocks += 1 + + def get_all_free_blocks(self) -> list[KVCacheBlock]: + """Get all free blocks in the free list. Mainly used for testing. + + Returns: + A list of free blocks. + """ + ret = [] + curr_block = self.free_list_head + while curr_block is not None: + ret.append(curr_block) + curr_block = curr_block.next_free_block + return ret + + +def need_extra_keys(request: Request) -> bool: + """Check whether the blocks allocated to this request need extra hash keys. + + Args: + request (Request): The request. + + Returns: + bool: Whether blocks allocated to this request need extra hash keys. + """ + + # Multimodal requests need to include the MM hash. + # LoRA requests need to include the LoRA ID. + return bool(request.mm_positions) or (request.lora_request is not None) + + +def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, + end_token_idx: int, + start_mm_idx: int) -> tuple[list[Any], int]: + """Generate extra keys related to MultiModal request for block hash + computation. For multi-modal inputs, the extra keys are + (mm_hash, start_offset) that indicate a mm input contained in the + block and its starting offset in the block tokens. + + Args: + request: The request object. + start_token_idx: The start token index of the block. + end_token_idx: The end token index of the block. + start_mm_idx: The start multi-modal index of the block. + + Returns: + A tuple of extra keys and the next multi-modal index. + """ + extra_keys: list[Any] = [] + + mm_positions, mm_hashes = request.mm_positions, request.mm_hashes + if not mm_positions: + return extra_keys, start_mm_idx + + if mm_positions and len(mm_positions) != len(mm_hashes): + raise ValueError( + "The number of multi-modal positions and hashes must match. This " + "is likely because you do not enable MM preprocessor hashing. " + "Please set disable_mm_preprocessor_cache=False.") + + # Note that we assume mm_positions is sorted by offset. + # We do not need to check all mm inputs if the start token index is out of + # range. This usually happens in the late prefill phase and decoding phase. + if mm_positions[-1]["offset"] + mm_positions[-1][ + "length"] < start_token_idx: + return extra_keys, start_mm_idx + + # Support start_mm_idx == -1 to indicate the last mm input. + if start_mm_idx < 0: + assert -start_mm_idx <= len(mm_positions) + start_mm_idx = len(mm_positions) + start_mm_idx + + curr_mm_idx = start_mm_idx + while mm_positions and curr_mm_idx < len(mm_positions): + assert mm_hashes[curr_mm_idx] is not None + offset = mm_positions[curr_mm_idx]["offset"] + length = mm_positions[curr_mm_idx]["length"] + if end_token_idx > offset: + if start_token_idx > offset + length: + # This block has passed the current mm input. + curr_mm_idx += 1 + continue + + # The block contains the current mm input. + extra_keys.append(mm_hashes[curr_mm_idx]) + + if end_token_idx >= offset + length: + # If this block contains the end of the current mm input, + # move to the next mm input as this block may also contain + # the next mm input. + curr_mm_idx += 1 + else: + # Otherwise this block is done with mm inputs. + break + else: + # This block has not reached the current mm input. + break + return extra_keys, curr_mm_idx + + +def _gen_lora_extra_hash_keys(request: Request) -> list[int]: + """Generate extra keys related to LoRA for block hash computation. + + Args: + request: The request object. + + Returns: + Return LoRA id of the request if it is a LoRA request. Return empty + list otherwise. + """ + if not request.lora_request: + return [] + return [request.lora_request.lora_int_id] + + +def generate_block_hash_extra_keys( + request: Request, start_token_idx: int, end_token_idx: int, + start_mm_idx: int) -> tuple[Optional[tuple[Any, ...]], int]: + """Generate extra keys for the block hash. The extra keys can come from + the multi-modal inputs and request specific metadata (e.g., LoRA ID). + + Args: + request: The request object. + start_token_idx: The start token index of the block. + end_token_idx: The end token index of the block. + start_mm_idx: The start multi-modal index of the block. + + Returns: + A tuple of extra keys and the next multi-modal index. + """ + mm_extra_keys: list[Any] + mm_extra_keys, new_start_mm_idx = _gen_mm_extra_hash_keys( + request, start_token_idx, end_token_idx, start_mm_idx) + lora_extra_keys: list[int] = _gen_lora_extra_hash_keys(request) + + extra_keys: list[Any] = lora_extra_keys + mm_extra_keys + + if not extra_keys: + return None, new_start_mm_idx + + return tuple(extra_keys), new_start_mm_idx + + +def hash_block_tokens( + hash_function: Callable, + parent_block_hash: Optional[int], + curr_block_token_ids: Sequence[int], + extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHashType: + """Computes a hash value corresponding to the contents of a block and + the contents of the preceding block(s). The hash value is used for + prefix caching. We use LRU cache for this function to avoid recomputing + hash values for the same block contents. + + Args: + parent_block_hash: The hash of the parent block. None + if this is the first block. + curr_block_token_ids: A list of token ids in the current + block. The current block is assumed to be full. + extra_keys: Extra keys for the block. + + Returns: + The hash value of the block and the token ids in the block. + The entire tuple is used as the hash key of the block. + """ + if not parent_block_hash: + parent_block_hash = NONE_HASH + + curr_block_token_ids_tuple = tuple(curr_block_token_ids) + return BlockHashType( + hash_function( + (parent_block_hash, curr_block_token_ids_tuple, extra_keys)), + curr_block_token_ids_tuple, extra_keys) + + +def hash_request_tokens(hash_function: Any, block_size: int, + request: Request) -> list[BlockHashType]: + """Computes hash values of a chain of blocks given a sequence of + token IDs. The hash value is used for prefix caching. + + Args: + block_size: The size of each block. + request: The request object. + + Returns: + The list of computed hash values. + """ + token_ids = request.all_token_ids + + req_need_extra_keys = need_extra_keys(request) + req_extra_keys = None + curr_mm_idx = 0 + + ret = [] + parent_block_hash_value = None + for start in range(0, len(token_ids), block_size): + end = start + block_size + block_token_ids = token_ids[start:end] + # Do not hash the block if it is not full. + if len(block_token_ids) < block_size: + break + + if req_need_extra_keys: + # MM and LoRA requests need extra keys for block-hash computation. + req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys( + request, start, end, curr_mm_idx) + + block_hash = hash_block_tokens(hash_function, parent_block_hash_value, + block_token_ids, req_extra_keys) + ret.append(block_hash) + parent_block_hash_value = block_hash.hash_value + return ret + + +def check_enough_kv_cache_memory(vllm_config: VllmConfig, + kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int): + """ + Checks whether `available_memory` is enough for the KV cache to hold at + least one request with the model's max_model_len. + + Args: + vllm_config: The global VllmConfig + kv_cache_spec: The kv cache spec of each attention layer in the model + available_memory: Memory available for KV cache in bytes. + + Raises: + ValueError: If there is not enough memory available for the KV cache. + """ + + if available_memory <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine.") + + max_model_len = vllm_config.model_config.max_model_len + needed_memory = 0 + for layer_spec in kv_cache_spec.values(): + needed_memory += layer_spec.max_memory_usage_bytes(vllm_config) + + if needed_memory > available_memory: + raise ValueError( + f"To serve at least one request with the models's max seq len " + f"({max_model_len}), ({needed_memory/1024/1024/1024:.2f} GiB KV " + f"cache is needed, which is larger than the available KV cache " + f"memory ({available_memory/1024/1024/1024:.2f} GiB). Try " + f"increasing `gpu_memory_utilization` or decreasing " + f"`max_model_len` when initializing the engine.") + + +def create_kv_cache_group_specs( + kv_cache_spec: dict[str, KVCacheSpec], + grouped_layer_names: list[list[str]]) -> list[KVCacheGroupSpec]: + """ + Create KVCacheGroupSpec object for each kv cache group layer. + The layers in the same group should share the same + KVCacheSpec. + + Args: + kv_cache_spec: + A mapping from each layer name to its corresponding KVCacheSpec. + grouped_layer_names: + A list of kv cache groups, where each element is a list of layer + names that belong to the same group and should share the same + KVCacheSpec. + Returns: + A list of KVCacheGroupSpec objects, one for each group. + """ + kv_cache_groups = [] + for layer_names_one_group in grouped_layer_names: + layer_spec = kv_cache_spec[layer_names_one_group[0]] + assert all( + kv_cache_spec[layer_name] == layer_spec + for layer_name in layer_names_one_group[1:]), ( + "All layers in the same KV cache group must share the same " + "KVCacheSpec.") + kv_cache_groups.append( + KVCacheGroupSpec(layer_names_one_group, layer_spec)) + return kv_cache_groups + + +def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: + """ + Whether all layers in the given KVCacheSpec have the same type of KV cache. + + Args: + kv_cache_spec: The kv cache spec of each attention layer in the model + + Returns: + True if all layers have the same type, False otherwise. + """ + + layer_keys = set(layer.type_id for layer in kv_cache_spec.values()) + return len(layer_keys) == 1 + + +def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, + kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int) -> KVCacheConfig: + """ + Generates the KV cache configuration for a model with one type of KV cache. + Divide the available memory equally among all layers. + + Args: + vllm_config: The global VllmConfig + kv_cache_spec: The kv cache spec of each attention layer in the model + available_memory: Memory available for KV cache in bytes. + + Returns: + The generated KVCacheConfig + """ + + page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()} + scale_page_sizes = {layer.scale_page_size_bytes for layer in kv_cache_spec.values()} + assert len(page_sizes) == 1 + page_size = page_sizes.pop() + scale_page_size = scale_page_sizes.pop() + + num_blocks = int(available_memory // (page_size + scale_page_size) // len(kv_cache_spec)) + num_blocks = max(num_blocks, 0) + + if vllm_config.cache_config.num_gpu_blocks_override is not None: + num_gpu_blocks_override = \ + vllm_config.cache_config.num_gpu_blocks_override + logger.info( + "Overriding num_gpu_blocks=%d with " + "num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override) + num_blocks = num_gpu_blocks_override + + num_tokens = num_blocks * vllm_config.cache_config.block_size + num_tokens_str = f"{num_tokens:,}" + logger.info("GPU KV cache size: %s tokens", num_tokens_str) + max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" + max_concurrency = num_tokens / vllm_config.model_config.max_model_len + logger.info("Maximum concurrency for %s tokens per request: %.2fx", + max_model_len_str, max_concurrency) + + per_layer_size = (page_size + scale_page_size) * num_blocks + # All layers have the same KV cache spec, so we create one kv cache group + # for all layers. + grouped_layer_names = [list(kv_cache_spec.keys())] + + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, + tensors={ + layer_name: KVCacheTensor(size=per_layer_size) + for layer_name in kv_cache_spec + }, + kv_cache_groups=create_kv_cache_group_specs(kv_cache_spec, + grouped_layer_names), + ) + return kv_cache_config + + +def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): + """ + Only models with one type of KV cache are supported yet. This function tries + to convert the KV cache specs to one type if the model is a hybrid model + with multiple type of KV cache. It will convert all SlidingWindowSpec to + FullAttentionSpec if both types are present. + + Args: + kv_cache_spec: The kv cache spec of each attention layer in the model + """ + + has_full_attention = any( + isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values()) + has_sliding_window = any( + isinstance(spec, SlidingWindowSpec) for spec in kv_cache_spec.values()) + if has_full_attention and has_sliding_window: + for layer_name, spec in kv_cache_spec.items(): + if isinstance(spec, SlidingWindowSpec): + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=spec.block_size, + num_kv_heads=spec.num_kv_heads, + head_size=spec.head_size, + dtype=spec.dtype, + use_mla=spec.use_mla, + ) + + +def get_kv_cache_config(vllm_config: VllmConfig, + kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int) -> KVCacheConfig: + """ + Generates the KV cache configuration for a model + TODO: support hybrid models with more than one type of KV cache. + + Args: + vllm_config: The global VllmConfig + kv_cache_spec: The kv cache spec of each attention layer in the model + available_memory: Memory available for KV cache in bytes. + + Returns: + The generated KVCacheConfigs + """ + check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory) + unify_hybrid_kv_cache_specs(kv_cache_spec) + if is_kv_cache_type_uniform(kv_cache_spec): + # KV cache of all layers are the same, which is true for + # most models. Allocate the same amount of memory for + # each layer. + return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec, + available_memory) + + raise NotImplementedError + + +def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]): + """ + Make the KV cache configurations for each worker consistent, so that all + workers can be controlled by the same KVCacheManager. + This function verifies that the layer group of each worker are the same, + and changes the num_blocks of each worker to the smallest among all workers. + + Args: + kv_cache_configs: The KV cache configurations for each worker. Will be + in-place modified to make them consistent. + """ + + # Sort the kv cache groups by the type_id of their KV cache spec. + # This can avoid the inconsistency caused by the order of groups. + for kv_cache_config in kv_cache_configs: + kv_cache_config.kv_cache_groups.sort( + key=lambda x: x.kv_cache_spec.type_id) + + # Verify that the groups of each rank are the same. + for kv_cache_config in kv_cache_configs[1:]: + for group_rank_0, group_rank_i in zip( + kv_cache_configs[0].kv_cache_groups, + kv_cache_config.kv_cache_groups): + assert group_rank_0.kv_cache_spec == group_rank_i.kv_cache_spec + + # Change the num_blocks of each rank to the smallest among all ranks. We + # do not need to shrink the tensor size because it is valid to only use the + # first `num_blocks` blocks of the tensor. + min_num_blocks = min(kv_cache_config.num_blocks + for kv_cache_config in kv_cache_configs) + for kv_cache_config in kv_cache_configs: + kv_cache_config.num_blocks = min_num_blocks + + return kv_cache_configs diff --git a/vllm/v1/core/sched/__init__.py b/vllm/v1/core/sched/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py new file mode 100644 index 0000000..bfed44f --- /dev/null +++ b/vllm/v1/core/sched/interface.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 +from abc import ABC, abstractmethod +from collections.abc import Iterable +from typing import TYPE_CHECKING, Optional, Union + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.engine import EngineCoreOutputs + from vllm.v1.metrics.stats import SchedulerStats + from vllm.v1.outputs import ModelRunnerOutput + from vllm.v1.request import Request, RequestStatus + + +class SchedulerInterface(ABC): + + @abstractmethod + def schedule(self) -> "SchedulerOutput": + """Schedule the requests to process in this scheduling step. + + The scheduling decision is made at the iteration level. Each scheduling + step corresponds to a single forward pass of the model. Therefore, this + method is called repeatedly by a busy loop in the engine. + + Essentially, the scheduler produces a dictionary of {req_id: num_tokens} + that specifies how many tokens to process for each request in this + scheduling step. For example, num_tokens can be as large as the number + of prompt tokens for new requests, or it can be 1 for the requests that + are auto-regressively generating new tokens one by one. Otherwise, it + can be somewhere in between in case of chunked prefills, prefix caching, + speculative decoding, etc. + + Additionally, the scheduler also returns useful data about each request + or the batch as a whole. The model runner will use this information in + preparing inputs to the model. + + Returns: + A SchedulerOutput object containing information about the scheduled + requests. + """ + raise NotImplementedError + + @abstractmethod + def update_from_output( + self, + scheduler_output: "SchedulerOutput", + model_runner_output: "ModelRunnerOutput", + ) -> "EngineCoreOutputs": + """Update the scheduler state based on the model runner output. + + This method is called after the model runner has processed the scheduled + requests. The model runner output includes generated token ids, draft + token ids for next step, etc. The scheduler uses this information to + update its states, checks the finished requests, and returns the output + for each request. + + Returns: + A EngineCoreOutputs object containing the outputs for each request. + """ + raise NotImplementedError + + @abstractmethod + def add_request(self, request: "Request") -> None: + """Add a new request to the scheduler's internal queue. + + Args: + request: The new request being added. + """ + raise NotImplementedError + + @abstractmethod + def finish_requests( + self, + request_ids: Union[str, Iterable[str]], + finished_status: "RequestStatus", + ) -> None: + """Finish the requests in the scheduler's internal queue. If the request + is not in the queue, this method will do nothing. + + This method is called in two cases: + 1. When the request is aborted by the client. + 2. When the frontend process detects a stop string of the request after + de-tokenizing its generated tokens. + + Args: + request_ids: A single or a list of request IDs. + finished_status: The finished status of the given requests. + """ + raise NotImplementedError + + @abstractmethod + def get_num_unfinished_requests(self) -> int: + """Number of unfinished requests in the scheduler's internal queue.""" + raise NotImplementedError + + def has_unfinished_requests(self) -> bool: + """Returns True if there are unfinished requests in the scheduler's + internal queue.""" + return self.get_num_unfinished_requests() > 0 + + @abstractmethod + def has_finished_requests(self) -> bool: + """Returns True if there are finished requests that need to be cleared. + NOTE: This is different from `not self.has_unfinished_requests()`. + + The scheduler maintains an internal list of the requests finished in the + previous step. This list is returned from the next call to schedule(), + to be sent to the model runner in the next step to clear cached states + for these finished requests. + + This method checks if this internal list of finished requests is + non-empty. This information is useful for DP attention. + """ + raise NotImplementedError + + def has_requests(self) -> bool: + """Returns True if there are unfinished requests, or finished requests + not yet returned in SchedulerOutputs.""" + return self.has_unfinished_requests() or self.has_finished_requests() + + @abstractmethod + def get_num_unscheduled_requests(self) -> int: + """Number of requests that are not being processed by the executor.""" + raise NotImplementedError + + @abstractmethod + def reset_prefix_cache(self) -> bool: + """Reset the prefix cache for KV cache. + + This is particularly required when the model weights are live-updated. + """ + raise NotImplementedError + + @abstractmethod + def make_stats(self) -> Optional["SchedulerStats"]: + """Make a SchedulerStats object for logging. + + The SchedulerStats object is created for every scheduling step. + """ + raise NotImplementedError diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py new file mode 100644 index 0000000..dc0d2d5 --- /dev/null +++ b/vllm/v1/core/sched/output.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + import numpy as np + import numpy.typing as npt + + from vllm.lora.request import LoRARequest + from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange + from vllm.sampling_params import SamplingParams + from vllm.v1.request import Request + + +@dataclass +class NewRequestData: + + req_id: str + prompt_token_ids: list[int] + prompt: Optional[str] + mm_inputs: list[MultiModalKwargs] + mm_hashes: list[str] + mm_positions: list[PlaceholderRange] + sampling_params: SamplingParams + block_ids: list[int] + num_computed_tokens: int + lora_request: Optional[LoRARequest] + + @classmethod + def from_request( + cls, + request: Request, + block_ids: list[int], + ) -> NewRequestData: + return cls( + req_id=request.request_id, + prompt_token_ids=request.prompt_token_ids, + prompt=request.prompt, + mm_inputs=request.mm_inputs, + mm_hashes=request.mm_hashes, + mm_positions=request.mm_positions, + sampling_params=request.sampling_params, + block_ids=block_ids, + num_computed_tokens=request.num_computed_tokens, + lora_request=request.lora_request, + ) + + +@dataclass +class CachedRequestData: + + req_id: str + # If resumed_from_preemption is False, new_block_ids will be appended to + # the request's block IDs. If True, new_block_ids will be used as the + # request's block IDs instead of appending to the existing block IDs. + resumed_from_preemption: bool + new_token_ids: list[int] + new_block_ids: list[int] + num_computed_tokens: int + + @classmethod + def from_request( + cls, + request: Request, + resumed_from_preemption: bool, + new_token_ids: list[int], + new_block_ids: list[int], + ) -> CachedRequestData: + return cls( + req_id=request.request_id, + resumed_from_preemption=resumed_from_preemption, + new_token_ids=new_token_ids, + new_block_ids=new_block_ids, + num_computed_tokens=request.num_computed_tokens, + ) + + +@dataclass +class SchedulerOutput: + + # list of the requests that are scheduled for the first time. + # We cache the request's data in each worker process, so that we don't + # need to re-send it every scheduling step. + scheduled_new_reqs: list[NewRequestData] + # list of the requests that have been scheduled before. + # Since the request's data is already cached in the worker processes, + # we only send the diff to minimize the communication cost. + scheduled_cached_reqs: list[CachedRequestData] + + # req_id -> num_scheduled_tokens + # Number of tokens scheduled for each request. + num_scheduled_tokens: dict[str, int] + # Total number of tokens scheduled for all requests. + # Equal to sum(num_scheduled_tokens.values()) + total_num_scheduled_tokens: int + # req_id -> spec_token_ids + # If a request does not have any spec decode tokens, it will not be + # included in the dictionary. + scheduled_spec_decode_tokens: dict[str, list[int]] + # req_id -> encoder input indices that need processing. + # E.g., if a request has [0, 1], it could mean the vision encoder needs + # to process that the request's 0-th and 1-th images in the current step. + scheduled_encoder_inputs: dict[str, list[int]] + # Number of common prefix blocks for all requests. + # This can be used for cascade attention. + num_common_prefix_blocks: int + + # Request IDs that are finished in between the previous and the current + # steps. This is used to notify the workers about the finished requests + # so that they can free the cached states for those requests. + finished_req_ids: set[str] + # list of (req_id, encoder_input_index) tuples. + # Used to free the encoder cache. + free_encoder_input_ids: list[tuple[str, int]] + + # Dict of request ids to their index within the batch + # for filling the next token bitmask + structured_output_request_ids: dict[str, int] + # the bitmask for the whole batch + grammar_bitmask: Optional[npt.NDArray[np.int32]] diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py new file mode 100644 index 0000000..81f8ad2 --- /dev/null +++ b/vllm/v1/core/sched/scheduler.py @@ -0,0 +1,759 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import time +from collections import deque +from collections.abc import Iterable +from typing import Optional, Union + +from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig +from vllm.logger import init_logger +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, + compute_encoder_budget) +from vllm.v1.core.kv_cache_manager import KVCacheManager +from vllm.v1.core.sched.interface import SchedulerInterface +from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, + SchedulerOutput) +from vllm.v1.core.sched.utils import check_stop +from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput, + EngineCoreOutputs) +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.metrics.stats import SchedulerStats +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.request import Request, RequestStatus +from vllm.v1.spec_decode.metrics import SpecDecodingStats +from vllm.v1.structured_output import StructuredOutputManager + +logger = init_logger(__name__) + + +class Scheduler(SchedulerInterface): + + def __init__( + self, + scheduler_config: SchedulerConfig, + model_config: ModelConfig, + cache_config: CacheConfig, + lora_config: Optional[LoRAConfig], + kv_cache_config: KVCacheConfig, + structured_output_manager: StructuredOutputManager, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + include_finished_set: bool = False, + log_stats: bool = False, + ) -> None: + self.scheduler_config = scheduler_config + self.cache_config = cache_config + self.lora_config = lora_config + self.kv_cache_config = kv_cache_config + self.log_stats = log_stats + self.structured_output_manager = structured_output_manager + + # include_finished_set controls whether a separate set of finished + # request ids should be included in the EngineCoreOutputs returned + # by update_from_outputs(). This is currently used in the multi-engine + # case to track request lifetimes efficiently. + self.include_finished_set = include_finished_set + + # Scheduling constraints. + self.max_num_running_reqs = self.scheduler_config.max_num_seqs + self.max_num_scheduled_tokens = \ + self.scheduler_config.max_num_batched_tokens + self.max_model_len = self.scheduler_config.max_model_len + + # Create the KV cache manager. + self.kv_cache_manager = KVCacheManager( + kv_cache_config=kv_cache_config, + max_model_len=self.max_model_len, + enable_caching=cache_config.enable_prefix_caching, + caching_hash_algo=self.cache_config.prefix_caching_hash_algo, + log_stats=self.log_stats) + self.block_size = self.cache_config.block_size + + # req_id -> Request + self.requests: dict[str, Request] = {} + # Priority queues for requests. + self.waiting: deque[Request] = deque() + self.running: list[Request] = [] + # The requests that have been scheduled and are being executed + # by the executor. + self.scheduled_req_ids: set[str] = set() + + # The request IDs that are finished in between the previous and the + # current steps. This is used to notify the workers about the finished + # requests so that they can free the cached states for those requests. + # This is flushed at the end of each scheduling step. + self.finished_req_ids: set[str] = set() + + # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating + # them at each scheduling step. + # Request id -> CachedRequestData + self._cached_reqs_data: dict[str, CachedRequestData] = {} + + # Encoder-related. + # Calculate encoder cache size if applicable + # NOTE: For now we use the same budget for both compute and space. + # This can be changed when we make encoder cache for embedding caching + # across requests. + encoder_compute_budget, encoder_cache_size = compute_encoder_budget( + model_config=model_config, + scheduler_config=scheduler_config, + mm_registry=mm_registry, + ) + + # NOTE(woosuk): Here, "encoder" includes the vision encoder (and + # projector if needed). Currently, we assume that the encoder also + # has the Transformer architecture (e.g., ViT). + self.max_num_encoder_input_tokens = encoder_compute_budget + # NOTE: For the models without encoder (e.g., text-only models), + # the encoder cache will not be initialized because cache size is 0 + # for these models. + self.encoder_cache_manager = EncoderCacheManager( + cache_size=encoder_cache_size) + + def schedule(self) -> SchedulerOutput: + # NOTE(woosuk) on the scheduling algorithm: + # There's no "decoding phase" nor "prefill phase" in the scheduler. + # Each request just has the num_computed_tokens and + # num_tokens_with_spec. num_tokens_with_spec = + # len(prompt_token_ids) + len(output_token_ids) + len(spec_token_ids). + # At each step, the scheduler tries to assign tokens to the requests + # so that each request's num_computed_tokens can catch up its + # num_tokens_with_spec. This is general enough to cover + # chunked prefills, prefix caching, speculative decoding, + # and the "jump decoding" optimization in the future. + + scheduled_new_reqs: list[Request] = [] + scheduled_resumed_reqs: list[Request] = [] + scheduled_running_reqs: list[Request] = [] + preempted_reqs: list[Request] = [] + + # NOTE: structured_output_request_ids maps + # a request's (request that uses structured output) + # request_id to the running request index. + # This will helps us determine to slice the grammar bitmask + # and only applies valid mask for requests that + # uses structured decoding. + structured_output_request_ids: dict[str, int] = {} + + req_to_new_block_ids: dict[str, list[int]] = {} + num_scheduled_tokens: dict[str, int] = {} + token_budget = self.max_num_scheduled_tokens + # Encoder-related. + scheduled_encoder_inputs: dict[str, list[int]] = {} + encoder_budget = self.max_num_encoder_input_tokens + # Spec decode-related. + scheduled_spec_decode_tokens: dict[str, list[int]] = {} + + # For logging. + scheduled_timestamp = time.monotonic() + + # First, schedule the RUNNING requests. + req_index = 0 + while req_index < len(self.running) and token_budget > 0: + request = self.running[req_index] + if request.request_id in self.scheduled_req_ids: + # This request has already been scheduled. + req_index += 1 + continue + + num_new_tokens = (request.num_tokens_with_spec - + request.num_computed_tokens) + if (0 < self.scheduler_config.long_prefill_token_threshold < + num_new_tokens): + num_new_tokens = ( + self.scheduler_config.long_prefill_token_threshold) + num_new_tokens = min(num_new_tokens, token_budget) + assert num_new_tokens > 0 + + # Schedule encoder inputs. + if request.has_encoder_inputs: + (encoder_inputs_to_schedule, num_new_tokens, + new_encoder_budget) = self._try_schedule_encoder_inputs( + request, request.num_computed_tokens, num_new_tokens, + encoder_budget) + if num_new_tokens == 0: + # The request cannot be scheduled because the encoder budget + # or the encoder cache is exhausted. + # NOTE(woosuk): By using `continue` instead of `break` here, + # we intentionally relax the strict FCFS scheduling policy + # to allow lower-priority requests to be scheduled when a + # higher-priority request is blocked by encoder constraints. + req_index += 1 + continue + else: + encoder_inputs_to_schedule = None + new_encoder_budget = encoder_budget + + while True: + new_blocks = self.kv_cache_manager.allocate_slots( + request, num_new_tokens) + if new_blocks is None: + # The request cannot be scheduled. + # Preempt the lowest-priority request. + preempted_req = self.running.pop() + self.kv_cache_manager.free(preempted_req) + preempted_req.status = RequestStatus.PREEMPTED + preempted_req.num_computed_tokens = 0 + if self.log_stats: + preempted_req.record_event( + EngineCoreEventType.PREEMPTED, scheduled_timestamp) + + self.waiting.appendleft(preempted_req) + preempted_reqs.append(preempted_req) + if preempted_req == request: + # No more request to preempt. + can_schedule = False + break + else: + # The request can be scheduled. + can_schedule = True + break + if not can_schedule: + break + assert new_blocks is not None + + # Schedule the request. + scheduled_running_reqs.append(request) + self.scheduled_req_ids.add(request.request_id) + if request.use_structured_output: + # PERF: in case of chunked prefill, + # request might not include any new tokens. + # Therefore, we might introduce some additional + # cycle to fill in the bitmask, which could be a big no-op. + structured_output_request_ids[request.request_id] = req_index + req_to_new_block_ids[request.request_id] = [ + b.block_id for b in new_blocks + ] + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + req_index += 1 + + # Speculative decode related. + if request.spec_token_ids: + num_scheduled_spec_tokens = (num_new_tokens + + request.num_computed_tokens - + request.num_tokens) + if num_scheduled_spec_tokens > 0: + # Trim spec_token_ids list to num_scheduled_spec_tokens. + del request.spec_token_ids[num_scheduled_spec_tokens:] + scheduled_spec_decode_tokens[request.request_id] = ( + request.spec_token_ids) + + # Encoder-related. + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_budget = new_encoder_budget + + # Record the LoRAs in scheduled_running_reqs + scheduled_loras: set[int] = set() + if self.lora_config: + scheduled_loras = set( + req.lora_request.lora_int_id for req in scheduled_running_reqs + if req.lora_request and req.lora_request.lora_int_id > 0) + assert len(scheduled_loras) <= self.lora_config.max_loras + + # Use a temporary deque to collect requests that need to be skipped + # and put back at the head of the waiting queue later + skipped_waiting_requests: deque[Request] = deque() + + # Next, schedule the WAITING requests. + if not preempted_reqs: + while self.waiting and token_budget > 0: + if len(self.running) == self.max_num_running_reqs: + break + + request = self.waiting[0] + + # Skip request if the structured output request is still waiting + # for FSM compilation. + if request.status == RequestStatus.WAITING_FOR_FSM: + structured_output_req = request.structured_output_request + if structured_output_req and structured_output_req.grammar: + request.status = RequestStatus.WAITING + else: + self.waiting.popleft() + skipped_waiting_requests.appendleft(request) + continue + + # Check that adding the request still respects the max_loras + # constraint. + if self.lora_config and request.lora_request and ( + len(scheduled_loras) == self.lora_config.max_loras + and request.lora_request.lora_int_id + not in scheduled_loras): + # Scheduling would exceed max_loras, skip. + self.waiting.popleft() + skipped_waiting_requests.appendleft(request) + continue + + # Get already-cached tokens. + computed_blocks, num_computed_tokens = \ + self.kv_cache_manager.get_computed_blocks(request) + # Number of tokens to be scheduled. + # We use `request.num_tokens` instead of + # `request.num_prompt_tokens` to consider the resumed requests, + # which have output tokens. + num_new_tokens = request.num_tokens - num_computed_tokens + if (0 < self.scheduler_config.long_prefill_token_threshold < + num_new_tokens): + num_new_tokens = ( + self.scheduler_config.long_prefill_token_threshold) + num_new_tokens = min(num_new_tokens, token_budget) + assert num_new_tokens > 0 + + # Schedule encoder inputs. + if request.has_encoder_inputs: + (encoder_inputs_to_schedule, num_new_tokens, + new_encoder_budget) = self._try_schedule_encoder_inputs( + request, num_computed_tokens, num_new_tokens, + encoder_budget) + if num_new_tokens == 0: + # The request cannot be scheduled. + break + else: + encoder_inputs_to_schedule = None + new_encoder_budget = encoder_budget + + new_blocks = self.kv_cache_manager.allocate_slots( + request, num_new_tokens, computed_blocks) + if new_blocks is None: + # The request cannot be scheduled. + break + + self.waiting.popleft() + if request.use_structured_output: + structured_output_request_ids[ + request.request_id] = req_index + req_index += 1 + self.running.append(request) + self.scheduled_req_ids.add(request.request_id) + if self.log_stats: + request.record_event(EngineCoreEventType.SCHEDULED, + scheduled_timestamp) + if request.status == RequestStatus.WAITING: + scheduled_new_reqs.append(request) + elif request.status == RequestStatus.PREEMPTED: + scheduled_resumed_reqs.append(request) + else: + raise RuntimeError( + f"Invalid request status: {request.status}") + + if self.lora_config and request.lora_request: + scheduled_loras.add(request.lora_request.lora_int_id) + req_to_new_block_ids[request.request_id] = [ + b.block_id for b in computed_blocks + new_blocks + ] + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + request.status = RequestStatus.RUNNING + request.num_computed_tokens = num_computed_tokens + + # Encoder-related. + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_budget = new_encoder_budget + + # Put back any skipped requests at the head of the waiting queue + if skipped_waiting_requests: + self.waiting.extendleft(skipped_waiting_requests) + + # Check if the scheduling constraints are satisfied. + total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) + assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens + assert token_budget >= 0 + assert len(self.running) <= self.max_num_running_reqs + # Since some requests in the RUNNING queue may not be scheduled in + # this step, the total number of scheduled requests can be smaller than + # len(self.running). + assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + + len(scheduled_running_reqs) <= len(self.running)) + + # Get the longest common prefix among all requests in the running queue. + # This can be potentially used for cascade attention. + num_common_prefix_blocks = 0 + if self.running: + any_request = self.running[0] + num_common_prefix_blocks = ( + self.kv_cache_manager.get_num_common_prefix_blocks( + any_request, len(self.running))) + + grammar_bitmask = self.structured_output_manager.grammar_bitmask( + self.requests, + structured_output_request_ids, + len(self.running), + ) + # Construct the scheduler output. + new_reqs_data = [ + NewRequestData.from_request(req, + req_to_new_block_ids[req.request_id]) + for req in scheduled_new_reqs + ] + resumed_reqs_data = [ + self._make_cached_request_data( + req, + num_scheduled_tokens[req.request_id], + len(scheduled_spec_decode_tokens.get(req.request_id, ())), + req_to_new_block_ids[req.request_id], + resumed_from_preemption=True, + ) for req in scheduled_resumed_reqs + ] + running_reqs_data = [ + self._make_cached_request_data( + req, + num_scheduled_tokens[req.request_id], + len(scheduled_spec_decode_tokens.get(req.request_id, ())), + req_to_new_block_ids[req.request_id], + resumed_from_preemption=False, + ) for req in scheduled_running_reqs + ] + scheduler_output = SchedulerOutput( + scheduled_new_reqs=new_reqs_data, + scheduled_cached_reqs=resumed_reqs_data + running_reqs_data, + num_scheduled_tokens=num_scheduled_tokens, + total_num_scheduled_tokens=total_num_scheduled_tokens, + scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, + scheduled_encoder_inputs=scheduled_encoder_inputs, + num_common_prefix_blocks=num_common_prefix_blocks, + # finished_req_ids is an existing state in the scheduler, + # instead of being newly scheduled in this step. + # It contains the request IDs that are finished in between + # the previous and the current steps. + finished_req_ids=self.finished_req_ids, + free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(), + structured_output_request_ids=structured_output_request_ids, + grammar_bitmask=grammar_bitmask, + ) + + # Advance the number of computed tokens for the request AFTER + # the request is scheduled. + # 1. The scheduler_output of the current step has to include the + # original number of scheduled tokens to determine input IDs. + # 2. Advance the number of computed tokens here allowing us to + # schedule the prefill request again immediately in the next + # scheduling step. + # 3. If some tokens (e.g. spec tokens) are rejected later, the number of + # computed tokens will be adjusted in update_from_output. + for req_id, num_scheduled_token in num_scheduled_tokens.items(): + self.requests[req_id].num_computed_tokens += num_scheduled_token + + self.finished_req_ids = set() + return scheduler_output + + def _make_cached_request_data( + self, + request: Request, + num_scheduled_tokens: int, + num_scheduled_spec_tokens: int, + new_block_ids: list[int], + resumed_from_preemption: bool, + ) -> CachedRequestData: + # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating + # them at each scheduling step. + num_computed_tokens = request.num_computed_tokens + num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens + new_token_ids = request.all_token_ids[ + num_computed_tokens:num_computed_tokens + num_regular_tokens] + req_data = self._cached_reqs_data.get(request.request_id) + if req_data is not None: + req_data.resumed_from_preemption = resumed_from_preemption + req_data.new_token_ids = new_token_ids + req_data.new_block_ids = new_block_ids + req_data.num_computed_tokens = num_computed_tokens + else: + req_data = CachedRequestData.from_request(request, + resumed_from_preemption, + new_token_ids, + new_block_ids) + self._cached_reqs_data[request.request_id] = req_data + return req_data + + def _try_schedule_encoder_inputs( + self, + request: Request, + num_computed_tokens: int, + num_new_tokens: int, + encoder_budget: int, + ) -> tuple[list[int], int, int]: + """ + Determine which encoder inputs need to be scheduled in the current step, + and update `num_new_tokens` and encoder token budget accordingly. + + An encoder input will be scheduled if: + - Its output tokens overlap with the range of tokens being computed + in this step, i.e., + [num_computed_tokens, num_computed_tokens + num_new_tokens). + - It is not already computed and stored in the encoder cache. + - There is sufficient encoder token budget to process it. + - The encoder cache has space to store it. + + If an encoder input cannot be scheduled due to cache or budget + limitations, the method adjusts `num_new_tokens` to schedule only the + decoder tokens up to just before the unschedulable encoder input. + """ + encoder_inputs_to_schedule: list[int] = [] + mm_positions = request.mm_positions + assert mm_positions is not None + assert len(mm_positions) > 0 + for i, pos_info in enumerate(mm_positions): + start_pos = pos_info["offset"] + num_encoder_tokens = pos_info["length"] + + # The encoder output is needed if the two ranges overlap: + # [num_computed_tokens, num_computed_tokens + num_new_tokens) and + # [start_pos, start_pos + num_encoder_tokens) + if start_pos >= num_computed_tokens + num_new_tokens: + # The encoder input is not needed in this step. + break + if start_pos + num_encoder_tokens <= num_computed_tokens: + # The encoder input is already computed and stored + # in the decoder's KV cache. + continue + + if self.encoder_cache_manager.has_cache(request, i): + # The encoder input is already computed and cached. + continue + if (not self.encoder_cache_manager.can_allocate(request, i) + or num_encoder_tokens > encoder_budget): + # The encoder cache is full or the encoder budget is exhausted. + # NOTE(woosuk): We assume that the encoder input tokens should + # be processed altogether, as the encoder usually uses + # bidirectional attention. + if num_computed_tokens < start_pos: + # We only schedule the decoder tokens just before the + # encoder input. + num_new_tokens = start_pos - num_computed_tokens + else: + # Because of prefix caching, num_computed_tokens is greater + # than start_pos even though its encoder input is not + # available. In this case, we can't schedule any token for + # the request in this step. + num_new_tokens = 0 + break + + encoder_budget -= num_encoder_tokens + encoder_inputs_to_schedule.append(i) + return encoder_inputs_to_schedule, num_new_tokens, encoder_budget + + def update_from_output( + self, + scheduler_output: SchedulerOutput, + model_runner_output: ModelRunnerOutput, + ) -> EngineCoreOutputs: + sampled_token_ids = model_runner_output.sampled_token_ids + spec_token_ids = model_runner_output.spec_token_ids + logprobs = model_runner_output.logprobs + prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict + num_scheduled_tokens = scheduler_output.num_scheduled_tokens + + new_running: list[Request] = [] + outputs: list[EngineCoreOutput] = [] + spec_decoding_stats: Optional[SpecDecodingStats] = None + + # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below + # loop can be a performance bottleneck. We should do our best to avoid + # expensive operations inside the loop. + for request in self.running: + req_id = request.request_id + num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) + if num_tokens_scheduled == 0: + # The request was not scheduled in this step. + new_running.append(request) + continue + + req_index = model_runner_output.req_id_to_index[req_id] + generated_token_ids = sampled_token_ids[req_index] + + scheduled_spec_token_ids = ( + scheduler_output.scheduled_spec_decode_tokens.get(req_id)) + if scheduled_spec_token_ids: + # num_computed_tokens represents the number of tokens + # processed in the current step, considering scheduled + # tokens and rejections. If some tokens are rejected, + # num_computed_tokens is decreased by the number of rejected + # tokens, where is given by: + # len(scheduled_spec_token_ids) + 1 - len(generated_token_ids). + num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 - + len(generated_token_ids)) + request.num_computed_tokens -= num_tokens_rejected + spec_decoding_stats = self.make_spec_decoding_stats( + spec_decoding_stats, + num_draft_tokens=len(scheduled_spec_token_ids), + num_accepted_tokens=len(generated_token_ids) - 1) + + cached_encoder_input_ids = ( + self.encoder_cache_manager.get_cached_input_ids(request)) + # OPTIMIZATION: Avoid list(set) if the set is empty. + if cached_encoder_input_ids: + for input_id in list(cached_encoder_input_ids): + mm_positions = request.mm_positions[input_id] + start_pos = mm_positions["offset"] + num_tokens = mm_positions["length"] + if start_pos + num_tokens <= request.num_computed_tokens: + # The encoder output is already processed and stored + # in the decoder's KV cache. + self.encoder_cache_manager.free_encoder_input( + request, input_id) + + # Add newly generated spec token ids to the request. + if spec_token_ids is not None: + request.spec_token_ids = spec_token_ids[req_index] + + stopped = False + new_logprobs = None + new_token_ids = generated_token_ids + + # Append generated tokens and check for stop. Note that if + # a request is still being prefilled, we expect the model runner + # to return empty token ids for the request. + for num_new, output_token_id in enumerate(new_token_ids, 1): + request.append_output_token_ids(output_token_id) + + # Check for stop and update request state. + # This must be called before we make the EngineCoreOutput. + stopped = check_stop(request, self.max_model_len) + if stopped: + self._free_request(request) + del new_token_ids[num_new:] # Trim new tokens if needed. + break + + # Extract sample logprobs if needed. + if request.sampling_params.logprobs is not None and logprobs: + # NOTE: once we support N tokens per step (spec decode), + # the outer lists can be of length > 1. + new_logprobs = logprobs.slice(req_index, req_index + 1) + + if new_token_ids and request.use_structured_output: + # NOTE: structured_output_request + # should not be None if use_structured_output, we have + # check above, so safe to ignore type warning + request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] + req_id, new_token_ids) + + # Get prompt logprobs for this request. + prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) + if new_token_ids: + # Add EngineCoreOutput for this Request. + outputs.append( + EngineCoreOutput( + request_id=req_id, + new_token_ids=new_token_ids, + finish_reason=request.get_finished_reason(), + new_logprobs=new_logprobs, + new_prompt_logprobs_tensors=prompt_logprobs_tensors, + stop_reason=request.stop_reason, + events=request.take_events())) + else: + # Invariant: EngineCore returns no partial prefill outputs. + assert not prompt_logprobs_tensors + + self.scheduled_req_ids.remove(req_id) + if not stopped: + new_running.append(request) + + self.running = new_running + engine_core_outputs = EngineCoreOutputs( + outputs=outputs, + scheduler_stats=self.make_stats(spec_decoding_stats), + ) + if self.include_finished_set: + #TODO currently sending duplicates here, improve this + engine_core_outputs.finished_requests = ( + scheduler_output.finished_req_ids | self.finished_req_ids) + + return engine_core_outputs + + def add_request(self, request: Request) -> None: + self.waiting.append(request) + self.requests[request.request_id] = request + if self.log_stats: + request.record_event(EngineCoreEventType.QUEUED) + + def finish_requests( + self, + request_ids: Union[str, Iterable[str]], + finished_status: RequestStatus, + ) -> None: + """Handles the finish signal from outside the scheduler. + + For example, the API server can abort a request when the client + disconnects. + """ + assert RequestStatus.is_finished(finished_status) + if isinstance(request_ids, str): + request_ids = (request_ids, ) + else: + request_ids = set(request_ids) + + for req_id in request_ids: + request = self.requests.get(req_id) + if request is None: + # Invalid request ID. + continue + + if request.status == RequestStatus.RUNNING: + self.running.remove(request) + self.scheduled_req_ids.discard(request.request_id) + else: + self.waiting.remove(request) + request.status = finished_status + self._free_request(request) + + def _free_request(self, request: Request) -> None: + assert request.is_finished() + self.kv_cache_manager.free(request) + self.kv_cache_manager.free_block_hashes(request) + self.encoder_cache_manager.free(request) + self._cached_reqs_data.pop(request.request_id, None) + del self.requests[request.request_id] + self.finished_req_ids.add(request.request_id) + + def get_num_unfinished_requests(self) -> int: + return len(self.waiting) + len(self.running) + + def has_finished_requests(self) -> bool: + return len(self.finished_req_ids) > 0 + + def get_num_unscheduled_requests(self) -> int: + """Number of requests that are not being processed by the executor.""" + return self.get_num_unfinished_requests() - len(self.scheduled_req_ids) + + def reset_prefix_cache(self) -> bool: + return self.kv_cache_manager.reset_prefix_cache() + + def make_stats( + self, + spec_decoding_stats: Optional[SpecDecodingStats] = None, + ) -> Optional[SchedulerStats]: + if not self.log_stats: + return None + return SchedulerStats( + num_running_reqs=len(self.running), + num_waiting_reqs=len(self.waiting), + gpu_cache_usage=self.kv_cache_manager.usage, + prefix_cache_stats=self.kv_cache_manager.make_prefix_cache_stats(), + spec_decoding_stats=spec_decoding_stats, + ) + + def make_spec_decoding_stats( + self, + spec_decoding_stats: Optional[SpecDecodingStats], + num_draft_tokens: int, + num_accepted_tokens: int, + ) -> Optional[SpecDecodingStats]: + if not self.log_stats: + return None + if spec_decoding_stats is None: + spec_decoding_stats = SpecDecodingStats() + spec_decoding_stats.observe(num_draft_tokens=num_draft_tokens, + num_accepted_tokens=num_accepted_tokens) + return spec_decoding_stats diff --git a/vllm/v1/core/sched/utils.py b/vllm/v1/core/sched/utils.py new file mode 100644 index 0000000..3a0028a --- /dev/null +++ b/vllm/v1/core/sched/utils.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 +from vllm.v1.request import Request, RequestStatus + + +def check_stop(request: Request, max_model_len: int) -> bool: + if (request.num_tokens >= max_model_len + or request.num_output_tokens >= request.max_tokens): + request.status = RequestStatus.FINISHED_LENGTH_CAPPED + return True + + sampling_params = request.sampling_params + last_token_id = request.output_token_ids[-1] + if (not sampling_params.ignore_eos + and last_token_id == request.eos_token_id): + request.status = RequestStatus.FINISHED_STOPPED + return True + + if last_token_id in (sampling_params.stop_token_ids or ()): + request.status = RequestStatus.FINISHED_STOPPED + request.stop_reason = last_token_id + return True + return False diff --git a/vllm/v1/core/specialized_manager.py b/vllm/v1/core/specialized_manager.py new file mode 100644 index 0000000..7a8a983 --- /dev/null +++ b/vllm/v1/core/specialized_manager.py @@ -0,0 +1,161 @@ +# SPDX-License-Identifier: Apache-2.0 +from abc import ABC, abstractmethod + +from vllm.utils import cdiv +from vllm.v1.core.block_pool import BlockPool +from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec, + SlidingWindowSpec) + + +class SpecializedManager(ABC): + """ + An abstract base class for specialized managers that handle the kv + cache management logic of different attention layers. + """ + + def __init__( + self, + kv_cache_spec: KVCacheSpec, + block_pool: BlockPool, + ) -> None: + """ + Initializes the SpecializedManager. + Args: + kv_cache_spec: The kv_cache_spec for this manager. + block_pool: The block pool. + """ + + self.block_size = kv_cache_spec.block_size + self.kv_cache_spec = kv_cache_spec + self.block_pool = block_pool + + @abstractmethod + def find_longest_cache_hit( + self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]: + """ + Get the longest cache hit prefix of the blocks. If no cache hit is + found, return an empty list. + + Args: + block_hashes: The block hashes of the request. + Returns: + A list of cached blocks with skipped blocks replaced by null block. + For example, sliding window manager should return a list like + [NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)] for block size 4 and + sliding window 8. + """ + + raise NotImplementedError + + @abstractmethod + def remove_skipped_blocks(self, blocks: list[KVCacheBlock], + num_computed_tokens: int) -> list[KVCacheBlock]: + """ + Remove the blocks that are no longer needed from `blocks`. The removed + blocks should be replaced by null_block. Return the removed blocks in + eviction order, where the first returned block should be evicted first. + Don't free the removed blocks in this function. + + Args: + blocks: The list of blocks to be updated. + num_computed_tokens: The number of tokens that have been computed. + Returns: + The removed blocks in eviction order. + """ + raise NotImplementedError + + +class FullAttentionManager(SpecializedManager): + + def find_longest_cache_hit( + self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]: + computed_blocks: list[KVCacheBlock] = [] + for block_hash in block_hashes: + # block_hashes is a chain of block hashes. If a block hash is not + # in the cached_block_hash_to_id, the following block hashes are + # not computed yet for sure. + if cached_block := self.block_pool.get_cached_block(block_hash): + computed_blocks.append(cached_block) + else: + break + return computed_blocks + + def remove_skipped_blocks(self, blocks: list[KVCacheBlock], + num_computed_tokens: int) -> list[KVCacheBlock]: + # No need to remove blocks for full attention. + return [] + + +class SlidingWindowManager(SpecializedManager): + + def __init__(self, kv_cache_spec: SlidingWindowSpec, + block_pool: BlockPool): + super().__init__(kv_cache_spec, block_pool) + self.sliding_window = kv_cache_spec.sliding_window + # The number of contiguous blocks needed for prefix cache hit. + # -1 since the input token itself is also included in the window + self.sliding_window_contiguous_blocks = cdiv( + (kv_cache_spec.sliding_window - 1), self.block_size) + self._null_block = block_pool.null_block + + def find_longest_cache_hit( + self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]: + # TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to + # optimize the time complexity from O(len(block_hashes)) to + # O(len(block_hashes) / sliding_window_contiguous_blocks + + # sliding_window_contiguous_blocks), + # which is good for low cache hit rate scenarios. + computed_blocks = [self._null_block] * len(block_hashes) + num_contiguous_blocks = 0 + + # Search from right to left and early stop when a match is found. + for i in range(len(block_hashes) - 1, -1, -1): + if cached_block := self.block_pool.get_cached_block( + block_hashes[i]): + computed_blocks[i] = cached_block + num_contiguous_blocks += 1 + if (num_contiguous_blocks + >= self.sliding_window_contiguous_blocks): + # Trim the trailing blocks. + # E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3] + # when sliding_window_contiguous_blocks=2. + del computed_blocks[i + num_contiguous_blocks:] + return computed_blocks + else: + num_contiguous_blocks = 0 + # The first `num_contiguous_blocks` is a cache hit even if + # `num_contiguous_blocks < sliding_window_contiguous_blocks`. + del computed_blocks[num_contiguous_blocks:] + return computed_blocks + + def remove_skipped_blocks(self, blocks: list[KVCacheBlock], + num_computed_tokens: int) -> list[KVCacheBlock]: + # Remove the blocks that are no longer be in the sliding window and + # skipped during the attention computation. + last_useful_token = num_computed_tokens - self.sliding_window + 1 + last_useful_block = last_useful_token // self.block_size + + removed_blocks: list[KVCacheBlock] = [] + for i in range(last_useful_block - 1, -1, -1): + if blocks[i] == self._null_block: + # If the block is already a null block, the blocks before it + # should also have been set to null blocks by the previous calls + # to this function. + break + removed_blocks.append(blocks[i]) + blocks[i] = self._null_block + return removed_blocks + + +spec_manager_map: dict[type[KVCacheSpec], type[SpecializedManager]] = { + FullAttentionSpec: FullAttentionManager, + SlidingWindowSpec: SlidingWindowManager, +} + + +def get_specialized_manager(kv_cache_spec: KVCacheSpec, + block_pool: BlockPool) -> SpecializedManager: + manager_class = spec_manager_map[type(kv_cache_spec)] + manager = manager_class(kv_cache_spec, block_pool) + return manager diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py new file mode 100644 index 0000000..0557d0c --- /dev/null +++ b/vllm/v1/engine/__init__.py @@ -0,0 +1,157 @@ +# SPDX-License-Identifier: Apache-2.0 + +import enum +import time +from typing import Any, Optional, Union + +import msgspec + +from vllm.lora.request import LoRARequest +from vllm.multimodal import MultiModalKwargs +from vllm.multimodal.inputs import PlaceholderRange +from vllm.sampling_params import SamplingParams +from vllm.v1.metrics.stats import SchedulerStats +from vllm.v1.outputs import LogprobsLists, LogprobsTensors + +# These are possible values of RequestOutput.finish_reason, +# so form part of the external API. +FINISH_REASON_STRINGS = ("stop", "length", "abort") + + +class FinishReason(enum.IntEnum): + """ + Reason a request finished - stop, length, or abort. + + Int rather than Str for more compact serialization. + + stop - a stop string was emitted + length - max_tokens was consumed, or max_model_len was reached + abort - aborted for another reason + + """ + STOP = 0 + LENGTH = 1 + ABORT = 2 + + def __str__(self): + return FINISH_REASON_STRINGS[self.value] + + +class EngineCoreRequest( + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + gc=False): # type: ignore[call-arg] + + # NOTE: prompt and prompt_token_ids should be DecoderOnlyInput, + # but this object is currently not playing well with msgspec + # due to circular imports and typing we have in data.py + + request_id: str + # NOTE(ywang96): original text prompt is needed when a request is added to + # Detokenizer, but set to None when it is added to EngineCoreClient. + prompt: Optional[str] + prompt_token_ids: list[int] + mm_inputs: Optional[list[MultiModalKwargs]] + mm_hashes: Optional[list[str]] + mm_placeholders: Optional[list[PlaceholderRange]] + sampling_params: SamplingParams + eos_token_id: Optional[int] + arrival_time: float + lora_request: Optional[LoRARequest] + + +class EngineCoreEventType(enum.IntEnum): + """The type of engine core request event.""" + QUEUED = 1 + SCHEDULED = 2 + PREEMPTED = 3 + + +class EngineCoreEvent(msgspec.Struct): + """A timestamped engine core event associated with a request. + + The timestamp is a monotonic timestamps and is used for by the engine + frontend to calculate intervals between engine core events. These + timestamps should not be compared with timestamps from other processes. + """ + type: EngineCoreEventType + timestamp: float + + @classmethod + def new_event(cls, + event_type: EngineCoreEventType, + timestamp: Optional[float] = None) -> "EngineCoreEvent": + timestamp = time.monotonic() if timestamp is None else timestamp + return cls(event_type, timestamp) + + +class EngineCoreOutput( + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + gc=False): # type: ignore[call-arg] + + request_id: str + new_token_ids: list[int] + + new_logprobs: Optional[LogprobsLists] = None + new_prompt_logprobs_tensors: Optional[LogprobsTensors] = None + + finish_reason: Optional[FinishReason] = None + stop_reason: Union[int, str, None] = None + events: Optional[list[EngineCoreEvent]] = None + + @property + def finished(self) -> bool: + return self.finish_reason is not None + + +class UtilityOutput( + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + gc=False): # type: ignore[call-arg] + + call_id: int + + # Non-None implies the call failed, result should be None. + failure_message: Optional[str] = None + result: Any = None + + +class EngineCoreOutputs( + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + gc=False): # type: ignore[call-arg] + + #NOTE(Nick): We could consider ways to make this more compact, + # e.g. columnwise layout + + engine_index: int = 0 + + # [num_reqs] + outputs: list[EngineCoreOutput] = [] + scheduler_stats: Optional[SchedulerStats] = None + timestamp: float = 0.0 + + utility_output: Optional[UtilityOutput] = None + finished_requests: Optional[set[str]] = None + + # In DP case, used to signal that the engine is paused. + engine_paused: bool = False + + def __post_init__(self): + if self.timestamp == 0.0: + self.timestamp = time.monotonic() + + +class EngineCoreRequestType(enum.Enum): + """ + Request types defined as hex byte strings, so it can be sent over sockets + without separate encoding step. + """ + ADD = b'\x00' + ABORT = b'\x01' + START_DP = b'\x02' + UTILITY = b'\x03' diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py new file mode 100644 index 0000000..b77a682 --- /dev/null +++ b/vllm/v1/engine/async_llm.py @@ -0,0 +1,463 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import logging +import os +from collections.abc import AsyncGenerator, Mapping +from copy import copy +from typing import Optional, Union + +import numpy as np + +import vllm.envs as envs +from vllm.config import ModelConfig, VllmConfig +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.protocol import EngineClient +from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE +from vllm.inputs import PromptType +from vllm.inputs.preprocess import InputPreprocessor +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.outputs import RequestOutput +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from vllm.usage.usage_lib import UsageContext +from vllm.utils import Device, cdiv, kill_process_tree +from vllm.v1.engine import EngineCoreRequest +from vllm.v1.engine.core_client import EngineCoreClient +from vllm.v1.engine.output_processor import (OutputProcessor, + RequestOutputCollector) +from vllm.v1.engine.parallel_sampling import ParentRequest +from vllm.v1.engine.processor import Processor +from vllm.v1.executor.abstract import Executor +from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger, + StatLoggerBase) +from vllm.v1.metrics.stats import IterationStats, SchedulerStats + +logger = init_logger(__name__) + + +class AsyncLLM(EngineClient): + + def __init__( + self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + use_cached_outputs: bool = False, + log_requests: bool = True, + start_engine_loop: bool = True, + ) -> None: + if not envs.VLLM_USE_V1: + raise ValueError( + "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. " + "This should not happen. As a workaround, try using " + "AsyncLLMEngine.from_vllm_config(...) or explicitly set " + "VLLM_USE_V1=0 or 1 and report this issue on Github.") + + assert start_engine_loop + + self.model_config = vllm_config.model_config + + self.log_requests = log_requests + self.log_stats = log_stats + + # Set up stat loggers; independent set for each DP rank. + self.stat_loggers: list[list[StatLoggerBase]] = [] + if self.log_stats: + for i in range(vllm_config.parallel_config.data_parallel_size): + loggers: list[StatLoggerBase] = [] + if logger.isEnabledFor(logging.INFO): + loggers.append(LoggingStatLogger(engine_index=i)) + loggers.append( + PrometheusStatLogger(vllm_config, engine_index=i)) + self.stat_loggers.append(loggers) + + # Tokenizer (+ ensure liveness if running in another process). + self.tokenizer = init_tokenizer_from_configs( + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + parallel_config=vllm_config.parallel_config, + lora_config=vllm_config.lora_config) + self.tokenizer.ping() + + # Processor (converts Inputs --> EngineCoreRequests). + self.processor = Processor( + vllm_config=vllm_config, + tokenizer=self.tokenizer, + mm_registry=mm_registry, + ) + + # OutputProcessor (converts EngineCoreOutputs --> RequestOutput). + self.output_processor = OutputProcessor(self.tokenizer, + log_stats=self.log_stats) + + # EngineCore (starts the engine in background process). + self.engine_core = EngineCoreClient.make_client( + multiprocess_mode=True, + asyncio_mode=True, + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=self.log_stats, + ) + + self.output_handler: Optional[asyncio.Task] = None + + @classmethod + def from_vllm_config( + cls, + vllm_config: VllmConfig, + start_engine_loop: bool = True, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[dict[str, StatLoggerBase]] = None, + disable_log_requests: bool = False, + disable_log_stats: bool = False, + ) -> "AsyncLLM": + if not envs.VLLM_USE_V1: + raise ValueError( + "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. " + "This should not happen. As a workaround, try using " + "AsyncLLMEngine.from_vllm_config(...) or explicitly set " + "VLLM_USE_V1=0 or 1 and report this issue on Github.") + + # FIXME(rob): refactor VllmConfig to include the StatLoggers + # include StatLogger in the Oracle decision. + if stat_loggers is not None: + raise ValueError("Custom StatLoggers are not yet supported on V1. " + "Explicitly set VLLM_USE_V1=0 to disable V1.") + + # Create the LLMEngine. + return cls( + vllm_config=vllm_config, + executor_class=Executor.get_class(vllm_config), + start_engine_loop=start_engine_loop, + log_requests=not disable_log_requests, + log_stats=not disable_log_stats, + usage_context=usage_context, + ) + + @classmethod + def from_engine_args( + cls, + engine_args: AsyncEngineArgs, + start_engine_loop: bool = True, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + ) -> "AsyncLLM": + """Create an AsyncLLM from the EngineArgs.""" + + # Create the engine configs. + vllm_config = engine_args.create_engine_config(usage_context) + executor_class = Executor.get_class(vllm_config) + + # Create the AsyncLLM. + return cls( + vllm_config=vllm_config, + executor_class=executor_class, + log_requests=not engine_args.disable_log_requests, + log_stats=not engine_args.disable_log_stats, + start_engine_loop=start_engine_loop, + usage_context=usage_context, + ) + + def shutdown(self): + """Shutdown, cleaning up the background proc and IPC.""" + + if engine_core := getattr(self, "engine_core", None): + engine_core.shutdown() + + if handler := getattr(self, "output_handler", None): + handler.cancel() + + async def add_request( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> RequestOutputCollector: + """Add new request to the AsyncLLM.""" + + assert isinstance(params, SamplingParams), \ + "Pooling is not supported in V1" + + # Create a new output collector for the request. + queue = RequestOutputCollector(output_kind=params.output_kind) + + # Convert Input --> Request. + request = self.processor.process_inputs(request_id, prompt, params, + arrival_time, lora_request, + trace_headers, + prompt_adapter_request, + priority) + + if params.n == 1: + await self._add_request(request, None, 0, queue) + return queue + + # Fan out child requests (for n>1). + parent_request = ParentRequest(request_id, params) + for idx in range(params.n): + request_id, params = parent_request.get_child_info(idx) + child_request = request if idx == params.n - 1 else copy(request) + child_request.request_id = request_id + child_request.sampling_params = params + await self._add_request(child_request, parent_request, idx, queue) + return queue + + async def _add_request(self, request: EngineCoreRequest, + parent_req: Optional[ParentRequest], index: int, + queue: RequestOutputCollector): + + # Add the request to OutputProcessor (this process). + self.output_processor.add_request(request, parent_req, index, queue) + + # Add the EngineCoreRequest to EngineCore (separate process). + await self.engine_core.add_request_async(request) + + if self.log_requests: + logger.info("Added request %s.", request.request_id) + + # TODO: we should support multiple prompts in one call, as you + # can do with LLM.generate. So that for multi-prompt completion + # requests we don't need to send multiple messages to core proc, + # and so we don't need multiple streams which then get + # re-multiplexed in the API server anyhow. + async def generate( + self, + prompt: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> AsyncGenerator[RequestOutput, None]: + """ + Main function called by the API server to kick off a request + * 1) Making an AsyncStream corresponding to the Request. + * 2) Processing the Input. + * 3) Adding the Request to the Detokenizer. + * 4) Adding the Request to the EngineCore (separate process). + + A separate output_handler loop runs in a background AsyncIO task, + pulling outputs from EngineCore and putting them into the + per-request AsyncStream. + + The caller of generate() iterates the returned AsyncGenerator, + returning the RequestOutput back to the caller. + """ + + try: + # We start the output_handler on the first call to generate() so + # we can call __init__ before the event loop, which enables us + # to handle startup failure gracefully in the OpenAI server. + if self.output_handler is None: + self.output_handler = asyncio.create_task( + self._run_output_handler()) + + q = await self.add_request( + request_id, + prompt, + sampling_params, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + ) + + # The output_handler task pushes items into the queue. + # This task pulls from the queue and yields to caller. + finished = False + while not finished: + # Note: drain queue without await if possible (avoids + # task switching under load which helps performance). + out = q.get_nowait() or await q.get() + + # Note: both OutputProcessor and EngineCore handle their + # own request cleanup based on finished. + finished = out.finished + yield out + + # If the request is disconnected by the client, the + # generate() task will be canceled. So, we abort the + # request if we end up here. + except asyncio.CancelledError: + await self.abort(request_id) + raise + + async def _run_output_handler(self): + """Background loop: pulls from EngineCore and pushes to AsyncStreams.""" + + try: + while True: + # 1) Pull EngineCoreOutputs from the EngineCore. + outputs = await self.engine_core.get_output_async() + num_outputs = len(outputs.outputs) + + iteration_stats = IterationStats() if ( + self.log_stats and num_outputs) else None + + # Split outputs into chunks of at most + # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the + # event loop for too long. + if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: + slices = (outputs.outputs, ) + else: + slices = np.array_split( + outputs.outputs, + cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE)) + + for i, outputs_slice in enumerate(slices): + # 2) Process EngineCoreOutputs. + processed_outputs = self.output_processor.process_outputs( + outputs_slice, outputs.timestamp, iteration_stats) + # NOTE: RequestOutputs are pushed to their queues. + assert not processed_outputs.request_outputs + + # Allow other asyncio tasks to run between chunks + if i + 1 < len(slices): + await asyncio.sleep(0) + + # 3) Abort any reqs that finished due to stop strings. + await self.engine_core.abort_requests_async( + processed_outputs.reqs_to_abort) + + # 4) Logging. + # TODO(rob): make into a coroutine and launch it in + # background thread once Prometheus overhead is non-trivial. + self._record_stats( + engine_index=outputs.engine_index, + scheduler_stats=outputs.scheduler_stats, + iteration_stats=iteration_stats, + ) + + except Exception as e: + logger.exception("EngineCore output handler hit an error: %s", e) + kill_process_tree(os.getpid()) + + async def abort(self, request_id: str) -> None: + """Abort RequestId in OutputProcessor and EngineCore.""" + + request_ids = self.output_processor.abort_requests((request_id, )) + await self.engine_core.abort_requests_async(request_ids) + + if self.log_requests: + logger.info("Aborted request %s.", request_id) + + def _record_stats( + self, + scheduler_stats: Optional[SchedulerStats], + iteration_stats: Optional[IterationStats], + engine_index: int = 0, + ): + if not self.log_stats: + return + + assert scheduler_stats is not None + for stat_logger in self.stat_loggers[engine_index]: + stat_logger.record(scheduler_stats=scheduler_stats, + iteration_stats=iteration_stats) + + def encode( + self, + prompt: PromptType, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, + ): + raise ValueError("Not Supported on V1 yet.") + + async def get_model_config(self) -> ModelConfig: + return self.model_config + + async def get_decoding_config(self): + raise ValueError("Not Supported on V1 yet.") + + async def get_input_preprocessor(self) -> InputPreprocessor: + return self.processor.input_preprocessor + + async def get_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + return self.tokenizer.get_lora_tokenizer(lora_request) + + async def is_tracing_enabled(self) -> bool: + return False + + async def do_log_stats( + self, + scheduler_outputs=None, + model_output=None, + ) -> None: + for loggers in self.stat_loggers: + for stat_logger in loggers: + stat_logger.log() + + async def check_health(self) -> None: + logger.debug("Called check_health.") + + async def start_profile(self) -> None: + await self.engine_core.profile_async(True) + + async def stop_profile(self) -> None: + await self.engine_core.profile_async(False) + + async def reset_prefix_cache(self, + device: Optional[Device] = None) -> None: + if device == Device.CPU: + raise ValueError("Not supported on CPU.") + await self.engine_core.reset_prefix_cache_async() + + async def sleep(self, level: int = 1) -> None: + await self.engine_core.sleep_async(level) + + async def wake_up(self, tags: Optional[list[str]] = None) -> None: + await self.engine_core.wake_up_async(tags) + + async def is_sleeping(self) -> bool: + return await self.engine_core.is_sleeping_async() + + async def add_lora(self, lora_request: LoRARequest) -> bool: + """Load a new LoRA adapter into the engine for future requests.""" + return await self.engine_core.add_lora_async(lora_request) + + async def remove_lora(self, lora_id: int) -> bool: + """Remove an already loaded LoRA adapter.""" + return await self.engine_core.remove_lora_async(lora_id) + + async def list_loras(self) -> set[int]: + """List all registered adapters.""" + return await self.engine_core.list_loras_async() + + async def pin_lora(self, lora_id: int) -> bool: + """Prevent an adapter from being evicted.""" + return await self.engine_core.pin_lora_async(lora_id) + + @property + def is_running(self) -> bool: + return True + + @property + def is_stopped(self) -> bool: + return False + + @property + def errored(self) -> bool: + return False + + @property + def dead_error(self) -> BaseException: + return Exception() # TODO: implement diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py new file mode 100644 index 0000000..39caca0 --- /dev/null +++ b/vllm/v1/engine/core.py @@ -0,0 +1,622 @@ +# SPDX-License-Identifier: Apache-2.0 +import os +import queue +import signal +import sys +import threading +import time +from concurrent.futures import Future +from inspect import isclass, signature +from logging import DEBUG +from typing import Any, Callable, Optional, TypeVar, Union + +import msgspec +import psutil +import zmq +import zmq.asyncio + +from vllm.config import ParallelConfig, VllmConfig +from vllm.distributed import stateless_destroy_torch_distributed_process_group +from vllm.executor.multiproc_worker_utils import _add_prefix +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.transformers_utils.config import ( + maybe_register_config_serialize_by_value) +from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname, + zmq_socket_ctx) +from vllm.v1.core.kv_cache_utils import (get_kv_cache_config, + unify_kv_cache_configs) +from vllm.v1.core.sched.interface import SchedulerInterface +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler +from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, + EngineCoreRequestType, UtilityOutput) +from vllm.v1.engine.mm_input_cache import MMInputCacheServer +from vllm.v1.executor.abstract import Executor +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.request import Request, RequestStatus +from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder +from vllm.v1.structured_output import StructuredOutputManager +from vllm.version import __version__ as VLLM_VERSION + +logger = init_logger(__name__) + +POLLING_TIMEOUT_S = 2.5 + +_R = TypeVar('_R') # Return type for collective_rpc + + +class EngineCore: + """Inner loop of vLLM's Engine.""" + + def __init__( + self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + ): + assert vllm_config.model_config.runner_type != "pooling" + + logger.info("Initializing a V1 LLM engine (v%s) with config: %s", + VLLM_VERSION, vllm_config) + + self.log_stats = log_stats + + # Setup Model. + self.model_executor = executor_class(vllm_config) + + # Setup KV Caches and update CacheConfig after profiling. + num_gpu_blocks, num_cpu_blocks, kv_cache_config = \ + self._initialize_kv_caches(vllm_config) + + vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks + vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks + + self.structured_output_manager = StructuredOutputManager(vllm_config) + + # Setup scheduler. + if isinstance(vllm_config.scheduler_config.scheduler_cls, str): + Scheduler = resolve_obj_by_qualname( + vllm_config.scheduler_config.scheduler_cls) + else: + Scheduler = vllm_config.scheduler_config.scheduler_cls + + # This warning can be removed once the V1 Scheduler interface is + # finalized and we can maintain support for scheduler classes that + # implement it + if Scheduler is not V1Scheduler: + logger.warning( + "Using configured V1 scheduler class %s. " + "This scheduler interface is not public and " + "compatibility may not be maintained.", + vllm_config.scheduler_config.scheduler_cls) + + self.scheduler: SchedulerInterface = Scheduler( + scheduler_config=vllm_config.scheduler_config, + model_config=vllm_config.model_config, + cache_config=vllm_config.cache_config, + lora_config=vllm_config.lora_config, + kv_cache_config=kv_cache_config, + structured_output_manager=self.structured_output_manager, + include_finished_set=vllm_config.parallel_config.data_parallel_size + > 1, + log_stats=self.log_stats, + ) + + # Setup MM Input Mapper. + self.mm_input_cache_server = MMInputCacheServer( + vllm_config.model_config) + + # Setup batch queue for pipeline parallelism. + # Batch queue for scheduled batches. This enables us to asynchronously + # schedule and execute batches, and is required by pipeline parallelism + # to eliminate pipeline bubbles. + self.batch_queue_size = self.model_executor.max_concurrent_batches + self.batch_queue: Optional[queue.Queue[tuple[Future[ModelRunnerOutput], + SchedulerOutput]]] = None + if self.batch_queue_size > 1: + logger.info("Batch queue is enabled with size %d", + self.batch_queue_size) + self.batch_queue = queue.Queue(self.batch_queue_size) + + def _initialize_kv_caches( + self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]: + start = time.time() + + # Get all kv cache needed by the model + kv_cache_specs = self.model_executor.get_kv_cache_specs() + + # Profiles the peak memory usage of the model to determine how much + # memory can be allocated for kv cache. + available_gpu_memory = self.model_executor.determine_available_memory() + + assert len(kv_cache_specs) == len(available_gpu_memory) + # Get the kv cache tensor size + kv_cache_configs = [ + get_kv_cache_config(vllm_config, kv_cache_spec_one_worker, + available_gpu_memory_one_worker) + for kv_cache_spec_one_worker, available_gpu_memory_one_worker in + zip(kv_cache_specs, available_gpu_memory) + ] + + # Since we use a shared centralized controller, we need the + # `kv_cache_config` to be consistent across all workers to make sure + # all the memory operators can be applied to all workers. + unify_kv_cache_configs(kv_cache_configs) + + # All workers have the same kv_cache_config except layer names, so use + # an arbitrary one to initialize the scheduler. + assert all([ + cfg.num_blocks == kv_cache_configs[0].num_blocks + for cfg in kv_cache_configs + ]) + num_gpu_blocks = kv_cache_configs[0].num_blocks + num_cpu_blocks = 0 + scheduler_kv_cache_config = kv_cache_configs[0] + + # Initialize kv cache and warmup the execution + self.model_executor.initialize_from_config(kv_cache_configs) + + elapsed = time.time() - start + logger.info(("init engine (profile, create kv cache, " + "warmup model) took %.2f seconds"), elapsed) + return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config + + def add_request(self, request: EngineCoreRequest): + """Add request to the scheduler.""" + + if request.mm_hashes is not None: + # Here, if hash exists for a multimodal input, then it will be + # fetched from the cache, else it will be added to the cache. + # Note that the cache here is mirrored with the client cache, so + # anything that has a hash must have a HIT cache entry here + # as well. + assert request.mm_inputs is not None + request.mm_inputs = self.mm_input_cache_server.get_and_update( + request.mm_inputs, request.mm_hashes) + + req = Request.from_engine_core_request(request) + if req.use_structured_output: + # Start grammar compilation asynchronously + self.structured_output_manager.grammar_init(req) + + self.scheduler.add_request(req) + + def abort_requests(self, request_ids: list[str]): + """Abort requests from the scheduler.""" + + # TODO: The scheduler doesn't really need to know the + # specific finish reason, TBD whether we propagate that + # (i.e. client-aborted vs stop criteria met). + self.scheduler.finish_requests(request_ids, + RequestStatus.FINISHED_ABORTED) + + def step(self) -> EngineCoreOutputs: + """Schedule, execute, and make output.""" + + # Check for any requests remaining in the scheduler - unfinished, + # or finished and not yet removed from the batch. + if not self.scheduler.has_requests(): + return EngineCoreOutputs( + outputs=[], + scheduler_stats=self.scheduler.make_stats(), + ) + scheduler_output = self.scheduler.schedule() + output = self.model_executor.execute_model(scheduler_output) + engine_core_outputs = self.scheduler.update_from_output( + scheduler_output, output) # type: ignore + + return engine_core_outputs + + def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]: + """Schedule and execute batches with the batch queue. + Note that if nothing to output in this step, None is returned. + + The execution flow is as follows: + 1. Try to schedule a new batch if there are unscheduled requests + and the job queue is not full. If a new batch is scheduled, directly + return an empty engine core output. In other words, we won't check + and return model outputs before the batch queue is full. + 2. If there is no new scheduled batch, meaning that the batch queue + is full or no other requests can be scheduled, we block until the first + batch in the job queue is finished. + 3. Update the scheduler from the output. + """ + assert self.batch_queue is not None + + engine_core_outputs = None + scheduler_output = None + # If there are unscheduled requests and the job queue + # is not full, schedule a new batch. Note that this is not blocking. + if (self.scheduler.get_num_unscheduled_requests() > 0 + and not self.batch_queue.full()): + scheduler_output = self.scheduler.schedule() + if scheduler_output.total_num_scheduled_tokens > 0: + future = self.model_executor.execute_model(scheduler_output) + self.batch_queue.put_nowait( + (future, scheduler_output)) # type: ignore + + scheduled_batch = (scheduler_output is not None + and scheduler_output.total_num_scheduled_tokens > 0) + + # If no more requests can be scheduled and the job queue is not empty, + # block until the first batch in the job queue is finished. + if not scheduled_batch and not self.batch_queue.empty(): + future, scheduler_output = self.batch_queue.get_nowait() + # Blocking until the first result is available. + model_output = future.result() + self.batch_queue.task_done() + engine_core_outputs = self.scheduler.update_from_output( + scheduler_output, model_output) + + return engine_core_outputs + + def shutdown(self): + self.model_executor.shutdown() + + def profile(self, is_start: bool = True): + self.model_executor.profile(is_start) + + def reset_prefix_cache(self): + self.scheduler.reset_prefix_cache() + + def sleep(self, level: int = 1): + self.model_executor.sleep(level) + + def wake_up(self, tags: Optional[list[str]] = None): + self.model_executor.wake_up(tags) + + def is_sleeping(self) -> bool: + return self.model_executor.is_sleeping + + def execute_dummy_batch(self): + self.model_executor.collective_rpc("execute_dummy_batch") + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_executor.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.model_executor.remove_lora(lora_id) + + def list_loras(self) -> set[int]: + return self.model_executor.list_loras() + + def pin_lora(self, lora_id: int) -> bool: + return self.model_executor.pin_lora(lora_id) + + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + self.model_executor.save_sharded_state(path=path, + pattern=pattern, + max_size=max_size) + + def collective_rpc(self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + return self.model_executor.collective_rpc(method, timeout, args, + kwargs) + + +class EngineCoreProc(EngineCore): + """ZMQ-wrapper for running EngineCore in background process.""" + + def __init__( + self, + input_path: str, + output_path: str, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + engine_index: int = 0, + ): + super().__init__(vllm_config, executor_class, log_stats) + + # Background Threads and Queues for IO. These enable us to + # overlap ZMQ socket IO with GPU since they release the GIL, + # and to overlap some serialization/deserialization with the + # model forward pass. + # Threads handle Socket <-> Queues and core_busy_loop uses Queue. + self.input_queue: queue.Queue[tuple[EngineCoreRequestType, + Any]] = queue.Queue() + self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue() + threading.Thread(target=self.process_input_socket, + args=(input_path, ), + daemon=True).start() + threading.Thread(target=self.process_output_socket, + args=(output_path, engine_index), + daemon=True).start() + + self.global_unfinished_reqs = False + + self.step_fn = (self.step if self.batch_queue is None else + self.step_with_batch_queue) + + @staticmethod + def run_engine_core(*args, + dp_rank: int = 0, + local_dp_rank: int = 0, + ready_pipe, + **kwargs): + """Launch EngineCore busy loop in background process.""" + + # Signal handler used for graceful termination. + # SystemExit exception is only raised once to allow this and worker + # processes to terminate without error + shutdown_requested = False + + # Ensure we can serialize transformer config after spawning + maybe_register_config_serialize_by_value() + + def signal_handler(signum, frame): + nonlocal shutdown_requested + if not shutdown_requested: + shutdown_requested = True + raise SystemExit() + + # Either SIGTERM or SIGINT will terminate the engine_core + signal.signal(signal.SIGTERM, signal_handler) + signal.signal(signal.SIGINT, signal_handler) + + parent_process = psutil.Process().parent() + engine_core: Optional[EngineCoreProc] = None + try: + parallel_config: ParallelConfig = kwargs[ + "vllm_config"].parallel_config + if parallel_config.data_parallel_size > 1: + # Set data parallel rank for this engine process. + parallel_config.data_parallel_rank = dp_rank + parallel_config.data_parallel_rank_local = local_dp_rank + engine_core = DPEngineCoreProc(*args, **kwargs) + else: + engine_core = EngineCoreProc(*args, **kwargs) + + # Send Readiness signal to EngineClient. + ready_pipe.send({"status": "READY"}) + + engine_core.run_busy_loop() + + except SystemExit: + logger.debug("EngineCore interrupted.") + + except Exception: + traceback = get_exception_traceback() + logger.error("EngineCore hit an exception: %s", traceback) + parent_process.send_signal(signal.SIGUSR1) + + finally: + if engine_core is not None: + engine_core.shutdown() + + def run_busy_loop(self): + """Core busy loop of the EngineCore.""" + + # Loop until process is sent a SIGINT or SIGTERM + while True: + # 1) Poll the input queue until there is work to do. + self._process_input_queue() + # 2) Step the engine core and return the outputs. + self._process_engine_step() + + def _process_input_queue(self): + """Exits when an engine step needs to be performed.""" + + waited = False + while not self.global_unfinished_reqs and not ( + self.scheduler.has_requests()): + if logger.isEnabledFor(DEBUG) and self.input_queue.empty(): + logger.debug("EngineCore waiting for work.") + waited = True + req = self.input_queue.get() + self._handle_client_request(*req) + + if waited: + logger.debug( + "EngineCore loop active - local unfinished: %s, finished: %s.", + self.scheduler.has_unfinished_requests(), + self.scheduler.has_finished_requests()) + + # Handle any more client requests. + while not self.input_queue.empty(): + req = self.input_queue.get_nowait() + self._handle_client_request(*req) + + def _process_engine_step(self): + """Called only when there are unfinished local requests.""" + + # Step the engine core. + outputs = self.step_fn() + # Put EngineCoreOutputs into the output queue. + if outputs is not None: + self.output_queue.put_nowait(outputs) + + def _handle_client_request(self, request_type: EngineCoreRequestType, + request: Any) -> None: + """Dispatch request from client.""" + + if request_type == EngineCoreRequestType.ADD: + self.add_request(request) + elif request_type == EngineCoreRequestType.ABORT: + self.abort_requests(request) + elif request_type == EngineCoreRequestType.START_DP: + if not self.global_unfinished_reqs: + logger.debug("EngineCore starting idle loop.") + self.global_unfinished_reqs = True + elif request_type == EngineCoreRequestType.UTILITY: + call_id, method_name, args = request + output = UtilityOutput(call_id) + try: + method = getattr(self, method_name) + output.result = method( + *self._convert_msgspec_args(method, args)) + except BaseException as e: + logger.exception("Invocation of %s method failed", method_name) + output.failure_message = (f"Call to {method_name} method" + f" failed: {str(e)}") + self.output_queue.put_nowait( + EngineCoreOutputs(utility_output=output)) + + @staticmethod + def _convert_msgspec_args(method, args): + """If a provided arg type doesn't match corresponding target method + arg type, try converting to msgspec object.""" + if not args: + return args + arg_types = signature(method).parameters.values() + assert len(args) <= len(arg_types) + return tuple( + msgspec.convert(v, type=p.annotation) if isclass(p.annotation) + and issubclass(p.annotation, msgspec.Struct) + and not isinstance(v, p.annotation) else v + for v, p in zip(args, arg_types)) + + def process_input_socket(self, input_path: str): + """Input socket IO thread.""" + + # Msgpack serialization decoding. + add_request_decoder = MsgpackDecoder(EngineCoreRequest) + generic_decoder = MsgpackDecoder() + + with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket: + while True: + # (RequestType, RequestData) + type_frame, data_frame = socket.recv_multipart(copy=False) + request_type = EngineCoreRequestType(bytes(type_frame.buffer)) + + # Deserialize the request data. + decoder = add_request_decoder if ( + request_type + == EngineCoreRequestType.ADD) else generic_decoder + request = decoder.decode(data_frame.buffer) + + # Push to input queue for core busy loop. + self.input_queue.put_nowait((request_type, request)) + + def process_output_socket(self, output_path: str, engine_index: int): + """Output socket IO thread.""" + + # Msgpack serialization encoding. + encoder = MsgpackEncoder() + # Reuse send buffer. + buffer = bytearray() + + with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket: + while True: + outputs = self.output_queue.get() + outputs.engine_index = engine_index + encoder.encode_into(outputs, buffer) + socket.send(buffer, copy=False) + + +ENGINE_PAUSED_OUTPUTS = EngineCoreOutputs(engine_paused=True) + + +class DPEngineCoreProc(EngineCoreProc): + """ZMQ-wrapper for running EngineCore in background process + in a data parallel context.""" + + def __init__( + self, + input_path: str, + output_path: str, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + ): + # Add process-specific prefix to stdout and stderr before + # we initialize the engine. + from multiprocessing import current_process + process_name = current_process().name + pid = os.getpid() + _add_prefix(sys.stdout, process_name, pid) + _add_prefix(sys.stderr, process_name, pid) + + dp_size = vllm_config.parallel_config.data_parallel_size + dp_rank = vllm_config.parallel_config.data_parallel_rank + local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local + + assert dp_size > 1 + assert 0 <= local_dp_rank <= dp_rank < dp_size + + from vllm.platforms import current_platform + if current_platform.is_cuda_alike(): + from vllm.platforms.cuda import device_id_to_physical_device_id + tp_size = vllm_config.parallel_config.tensor_parallel_size + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( + str(device_id_to_physical_device_id(i)) + for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) * + tp_size)) + + self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() + + # Initialize the engine after setting up environment. + super().__init__(input_path, output_path, vllm_config, executor_class, + log_stats, dp_rank) + + # Counts forward-passes of the model so that we can synchronize + # finished with DP peers every N steps. + self.counter = 0 + + def shutdown(self): + super().shutdown() + if dp_group := getattr(self, "dp_group", None): + stateless_destroy_torch_distributed_process_group(dp_group) + + def run_busy_loop(self): + """Core busy loop of the EngineCore for data parallel case.""" + + # Loop until process is sent a SIGINT or SIGTERM + while True: + # 1) Poll the input queue until there is work to do. + self._process_input_queue() + + local_unfinished_reqs = self.scheduler.has_unfinished_requests() + + if local_unfinished_reqs: + # 2) Step the engine core. + self._process_engine_step() + + # Check if we have now finished all requests. + local_unfinished_reqs = ( + self.scheduler.has_unfinished_requests()) + else: + if self.scheduler.has_finished_requests(): + # There are no unfinished requests, but there are some + # finished requests remaining to be removed from the + # batch state. This engine step won't perform a forward + # pass but will flush the finished requests to ensure + # up-to-date state is returned in the engine outputs. + self._process_engine_step() + + if not self.global_unfinished_reqs: + # All engines are idle. + continue + + # There must be unfinished requests in DP peers, run a + # dummy forward pass. + self.execute_dummy_batch() + + # 3) All-reduce operation to determine global unfinished reqs. + self.global_unfinished_reqs = self._has_global_unfinished_reqs( + local_unfinished_reqs) + + if not self.global_unfinished_reqs: + # Notify client that we are pausing the loop. + self.output_queue.put_nowait(ENGINE_PAUSED_OUTPUTS) + + def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: + + # Optimization - only perform finish-sync all-reduce every 16 steps. + self.counter += 1 + if self.counter != 16: + return True + self.counter = 0 + + return ParallelConfig.has_unfinished_dp(self.dp_group, + local_unfinished) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py new file mode 100644 index 0000000..e948e59 --- /dev/null +++ b/vllm/v1/engine/core_client.py @@ -0,0 +1,824 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import os +import queue +import signal +import threading +import uuid +import weakref +from abc import ABC, abstractmethod +from collections.abc import Awaitable, Sequence +from concurrent.futures import Future +from dataclasses import dataclass, field +from threading import Thread +from typing import Any, Callable, Optional, TypeVar, Union + +import zmq +import zmq.asyncio + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.utils import (get_open_zmq_inproc_path, get_open_zmq_ipc_path, + kill_process_tree, make_zmq_socket) +from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, + EngineCoreRequestType, UtilityOutput) +from vllm.v1.engine.core import EngineCore, EngineCoreProc +from vllm.v1.executor.abstract import Executor +from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder +from vllm.v1.utils import BackgroundProcHandle + +logger = init_logger(__name__) + +AnyFuture = Union[asyncio.Future[Any], Future[Any]] + +_R = TypeVar('_R') # Return type for collective_rpc + + +class EngineCoreClient(ABC): + """ + EngineCoreClient: subclasses handle different methods for pushing + and pulling from the EngineCore for asyncio / multiprocessing. + + Subclasses: + * InprocClient: In process EngineCore (for V0-style LLMEngine use) + * SyncMPClient: ZMQ + background proc EngineCore (for LLM) + * AsyncMPClient: ZMQ + background proc EngineCore w/ asyncio (for AsyncLLM) + """ + + @staticmethod + def make_client( + multiprocess_mode: bool, + asyncio_mode: bool, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + ) -> "EngineCoreClient": + + # TODO: support this for debugging purposes. + if asyncio_mode and not multiprocess_mode: + raise NotImplementedError( + "Running EngineCore in asyncio without multiprocessing " + "is not currently supported.") + + if multiprocess_mode and asyncio_mode: + if vllm_config.parallel_config.data_parallel_size > 1: + return DPAsyncMPClient(vllm_config, executor_class, log_stats) + + return AsyncMPClient(vllm_config, executor_class, log_stats) + + if multiprocess_mode and not asyncio_mode: + return SyncMPClient(vllm_config, executor_class, log_stats) + + return InprocClient(vllm_config, executor_class, log_stats) + + @abstractmethod + def shutdown(self): + ... + + def get_output(self) -> EngineCoreOutputs: + raise NotImplementedError + + def add_request(self, request: EngineCoreRequest) -> None: + raise NotImplementedError + + def profile(self, is_start: bool = True) -> None: + raise NotImplementedError + + def reset_prefix_cache(self) -> None: + raise NotImplementedError + + def sleep(self, level: int = 1) -> None: + raise NotImplementedError + + def wake_up(self, tags: Optional[list[str]] = None) -> None: + raise NotImplementedError + + def is_sleeping(self) -> bool: + raise NotImplementedError + + def execute_dummy_batch(self) -> None: + raise NotImplementedError + + async def execute_dummy_batch_async(self) -> None: + raise NotImplementedError + + def abort_requests(self, request_ids: list[str]) -> None: + raise NotImplementedError + + def add_lora(self, lora_request: LoRARequest) -> bool: + raise NotImplementedError + + def remove_lora(self, lora_id: int) -> bool: + raise NotImplementedError + + def list_loras(self) -> set[int]: + raise NotImplementedError + + def pin_lora(self, lora_id: int) -> bool: + raise NotImplementedError + + def save_sharded_state(self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None) -> None: + raise NotImplementedError + + def collective_rpc(self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + raise NotImplementedError + + async def get_output_async(self) -> EngineCoreOutputs: + raise NotImplementedError + + async def add_request_async(self, request: EngineCoreRequest) -> None: + raise NotImplementedError + + async def profile_async(self, is_start: bool = True) -> None: + raise NotImplementedError + + async def reset_prefix_cache_async(self) -> None: + raise NotImplementedError + + async def sleep_async(self, level: int = 1) -> None: + raise NotImplementedError + + async def wake_up_async(self, tags: Optional[list[str]] = None) -> None: + raise NotImplementedError + + async def is_sleeping_async(self) -> bool: + raise NotImplementedError + + async def abort_requests_async(self, request_ids: list[str]) -> None: + raise NotImplementedError + + async def add_lora_async(self, lora_request: LoRARequest) -> bool: + raise NotImplementedError + + async def remove_lora_async(self, lora_id: int) -> bool: + raise NotImplementedError + + async def list_loras_async(self) -> set[int]: + raise NotImplementedError + + async def pin_lora_async(self, lora_id: int) -> bool: + raise NotImplementedError + + async def save_sharded_state_async(self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None) -> None: + raise NotImplementedError + + async def collective_rpc_async( + self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + raise NotImplementedError + + +class InprocClient(EngineCoreClient): + """ + InprocClient: client for in-process EngineCore. Intended + for use in LLMEngine for V0-style add_request() and step() + EngineCore setup in this process (no busy loop). + + * pushes EngineCoreRequest directly into the EngineCore + * pulls EngineCoreOutputs by stepping the EngineCore + """ + + def __init__(self, *args, **kwargs): + self.engine_core = EngineCore(*args, **kwargs) + + def get_output(self) -> EngineCoreOutputs: + return self.engine_core.step() + + def add_request(self, request: EngineCoreRequest) -> None: + self.engine_core.add_request(request) + + def abort_requests(self, request_ids: list[str]) -> None: + if len(request_ids) > 0: + self.engine_core.abort_requests(request_ids) + + def shutdown(self) -> None: + self.engine_core.shutdown() + + def profile(self, is_start: bool = True) -> None: + self.engine_core.profile(is_start) + + def reset_prefix_cache(self) -> None: + self.engine_core.reset_prefix_cache() + + def sleep(self, level: int = 1) -> None: + self.engine_core.sleep(level) + + def wake_up(self, tags: Optional[list[str]] = None) -> None: + self.engine_core.wake_up(tags) + + def is_sleeping(self) -> bool: + return self.engine_core.is_sleeping() + + def execute_dummy_batch(self) -> None: + self.engine_core.execute_dummy_batch() + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.engine_core.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.engine_core.remove_lora(lora_id) + + def list_loras(self) -> set[int]: + return self.engine_core.list_loras() + + def pin_lora(self, lora_id: int) -> bool: + return self.engine_core.pin_lora(lora_id) + + def save_sharded_state(self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None) -> None: + self.engine_core.save_sharded_state(path, pattern, max_size) + + def collective_rpc(self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + return self.engine_core.collective_rpc(method, timeout, args, kwargs) + + +class CoreEngine: + """One per data parallel rank.""" + + def __init__( + self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + ctx: Union[zmq.Context, zmq.asyncio.Context], + output_path: str, + index: int = 0, + local_dp_rank: int = 0, + ): + # Paths and sockets for IPC. + input_path = get_open_zmq_ipc_path() + self.input_socket = make_zmq_socket(ctx, input_path, + zmq.constants.PUSH) + try: + # Start EngineCore in background process. + self.proc_handle = BackgroundProcHandle( + input_path=input_path, + output_path=output_path, + process_name=f"EngineCore_{index}", + target_fn=EngineCoreProc.run_engine_core, + process_kwargs={ + "vllm_config": vllm_config, + "dp_rank": index, + "local_dp_rank": local_dp_rank, + "executor_class": executor_class, + "log_stats": log_stats, + }) + + self.num_reqs_in_flight = 0 + finally: + if not hasattr(self, "num_reqs_in_flight"): + # Ensure socket is closed if process fails to start. + self.close() + + def send_multipart(self, msg_parts: Sequence): + return self.input_socket.send_multipart(msg_parts, copy=False) + + def close(self): + if proc_handle := getattr(self, "proc_handle", None): + proc_handle.shutdown() + if socket := getattr(self, "input_socket", None): + socket.close(linger=0) + + +@dataclass +class BackgroundResources: + """Used as a finalizer for clean shutdown, avoiding + circular reference back to the client object.""" + + ctx: Union[zmq.Context] + core_engines: list[CoreEngine] = field(default_factory=list) + output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None + shutdown_path: Optional[str] = None + + def __call__(self): + """Clean up background resources.""" + + for core_engine in self.core_engines: + core_engine.close() + + # ZMQ context termination can hang if the sockets + # aren't explicitly closed first. + if self.output_socket is not None: + self.output_socket.close(linger=0) + if self.shutdown_path is not None: + # We must ensure that the sync output socket is + # closed cleanly in its own thread. + with self.ctx.socket(zmq.PAIR) as shutdown_sender: + shutdown_sender.connect(self.shutdown_path) + # Send shutdown signal. + shutdown_sender.send(b'') + + +class MPClient(EngineCoreClient): + """ + MPClient: base client for multi-proc EngineCore. + EngineCore runs in a background process busy loop, getting + new EngineCoreRequests and returning EngineCoreOutputs + + * pushes EngineCoreRequests via input_socket + * pulls EngineCoreOutputs via output_socket + + * AsyncMPClient subclass for AsyncLLM usage + * SyncMPClient subclass for LLM usage + """ + + def __init__( + self, + asyncio_mode: bool, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + ): + # The child processes will send SIGUSR1 when unrecoverable + # errors happen. We kill the process tree here so that the + # stack trace is very evident. + # TODO(rob): rather than killing the main process, we should + # figure out how to raise an AsyncEngineDeadError and + # handle at the API server level so we can return a better + # error code to the clients calling vLLM. + def sigusr1_handler(signum, frame): + logger.fatal("Got fatal signal from worker processes, shutting " + "down. See stack trace above for root cause issue.") + kill_process_tree(os.getpid()) + + if threading.current_thread() == threading.main_thread(): + signal.signal(signal.SIGUSR1, sigusr1_handler) + else: + logger.warning("SIGUSR1 handler not installed because we are not " + "running in the main thread. In this case the " + "forked engine process may not be killed when " + "an exception is raised, and you need to handle " + "the engine process shutdown manually.") + + # Serialization setup. + self.encoder = MsgpackEncoder() + self.decoder = MsgpackDecoder(EngineCoreOutputs) + + # ZMQ setup. + sync_ctx = zmq.Context(io_threads=2) + self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx + + # This will ensure resources created so far are closed + # when the client is garbage collected, even if an + # exception is raised mid-construction. + self.resources = BackgroundResources(ctx=sync_ctx) + self._finalizer = weakref.finalize(self, self.resources) + + # Paths and sockets for IPC. + self.output_path = get_open_zmq_ipc_path() + + new_core_engine = lambda index, local_dp_rank=None: CoreEngine( + vllm_config, executor_class, log_stats, self.ctx, self.output_path, + index, local_dp_rank) + + # Start engine core process(es). + self._init_core_engines(vllm_config, new_core_engine, + self.resources.core_engines) + + # Wait for engine core process(es) to start. + for engine in self.resources.core_engines: + engine.proc_handle.wait_for_startup() + + self.utility_results: dict[int, AnyFuture] = {} + + def _init_core_engines( + self, + vllm_config: VllmConfig, + new_core_engine: Callable[[int, Optional[int]], CoreEngine], + core_engines: list[CoreEngine], + ) -> None: + + # Default case - single core engine. + dp_rank = vllm_config.parallel_config.data_parallel_rank + local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local + core_engine = new_core_engine( + dp_rank, local_dp_rank if local_dp_rank is not None else dp_rank) + core_engines.append(core_engine) + self.core_engine = core_engine + + def shutdown(self): + self._finalizer() + + +def _process_utility_output(output: UtilityOutput, + utility_results: dict[int, AnyFuture]): + """Set the result from a utility method in the waiting future""" + future = utility_results.pop(output.call_id) + if output.failure_message is not None: + future.set_exception(Exception(output.failure_message)) + else: + future.set_result(output.result) + + +class SyncMPClient(MPClient): + """Synchronous client for multi-proc EngineCore.""" + + def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], + log_stats: bool): + super().__init__( + asyncio_mode=False, + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=log_stats, + ) + + self.outputs_queue: queue.Queue[EngineCoreOutputs] = queue.Queue() + + # Ensure that the outputs socket processing thread does not have + # a ref to the client which prevents gc. + ctx = self.ctx + output_path = self.output_path + decoder = self.decoder + utility_results = self.utility_results + outputs_queue = self.outputs_queue + + shutdown_path = get_open_zmq_inproc_path() + self.resources.shutdown_path = shutdown_path + + def process_outputs_socket(): + shutdown_socket = ctx.socket(zmq.PAIR) + out_socket = make_zmq_socket(ctx, output_path, zmq.constants.PULL) + try: + shutdown_socket.bind(shutdown_path) + poller = zmq.Poller() + poller.register(shutdown_socket) + poller.register(out_socket) + while True: + socks = poller.poll() + if not socks: + continue + if len(socks) == 2 or socks[0][0] == shutdown_socket: + # shutdown signal, exit thread. + break + + frame = out_socket.recv(copy=False) + outputs = decoder.decode(frame.buffer) + if outputs.utility_output: + _process_utility_output(outputs.utility_output, + utility_results) + else: + outputs_queue.put_nowait(outputs) + finally: + # Close sockets. + shutdown_socket.close(linger=0) + out_socket.close(linger=0) + + # Process outputs from engine in separate thread. + self.output_queue_thread = Thread(target=process_outputs_socket, + name="EngineCoreOutputQueueThread", + daemon=True) + self.output_queue_thread.start() + + def get_output(self) -> EngineCoreOutputs: + return self.outputs_queue.get() + + def _send_input(self, request_type: EngineCoreRequestType, request: Any): + # (RequestType, SerializedRequest) + msg = (request_type.value, self.encoder.encode(request)) + self.core_engine.send_multipart(msg) + + def call_utility(self, method: str, *args) -> Any: + call_id = uuid.uuid1().int >> 64 + future: Future[Any] = Future() + self.utility_results[call_id] = future + self._send_input(EngineCoreRequestType.UTILITY, + (call_id, method, args)) + + return future.result() + + def add_request(self, request: EngineCoreRequest) -> None: + # NOTE: text prompt is not needed in the core engine as it has been + # tokenized. + request.prompt = None + self._send_input(EngineCoreRequestType.ADD, request) + + def abort_requests(self, request_ids: list[str]) -> None: + if len(request_ids) > 0: + self._send_input(EngineCoreRequestType.ABORT, request_ids) + + def profile(self, is_start: bool = True) -> None: + self.call_utility("profile", is_start) + + def reset_prefix_cache(self) -> None: + self.call_utility("reset_prefix_cache") + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.call_utility("add_lora", lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.call_utility("remove_lora", lora_id) + + def list_loras(self) -> set[int]: + return self.call_utility("list_loras") + + def pin_lora(self, lora_id: int) -> bool: + return self.call_utility("pin_lora", lora_id) + + def sleep(self, level: int = 1) -> None: + self.call_utility("sleep", level) + + def wake_up(self, tags: Optional[list[str]] = None) -> None: + self.call_utility("wake_up", tags) + + def is_sleeping(self) -> bool: + return self.call_utility("is_sleeping") + + def execute_dummy_batch(self) -> None: + self.call_utility("execute_dummy_batch") + + def collective_rpc(self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + return self.call_utility("collective_rpc", method, timeout, args, + kwargs) + + def save_sharded_state(self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None) -> None: + self.call_utility("save_sharded_state", path, pattern, max_size) + + +class AsyncMPClient(MPClient): + """Asyncio-compatible client for multi-proc EngineCore.""" + + def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], + log_stats: bool): + super().__init__( + asyncio_mode=True, + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=log_stats, + ) + + self.outputs_queue: Optional[asyncio.Queue[EngineCoreOutputs]] = None + self.queue_task: Optional[asyncio.Task] = None + + self.outputs_handler: Optional[Callable[ + [AsyncMPClient, EngineCoreOutputs], Awaitable[None]]] = None + + def _ensure_output_queue_task(self): + if self.outputs_queue is not None: + return + + # Perform IO in separate task to parallelize as much as possible. + # Avoid task having direct reference back to the client. + self.outputs_queue = asyncio.Queue() + decoder = self.decoder + utility_results = self.utility_results + outputs_queue = self.outputs_queue + output_handler = self.outputs_handler + _self_ref = weakref.ref(self) if output_handler else None + output_path = self.output_path + output_socket = make_zmq_socket(self.ctx, output_path, + zmq.constants.PULL) + self.resources.output_socket = output_socket + + async def process_outputs_socket(): + while True: + (frame, ) = await output_socket.recv_multipart(copy=False) + outputs: EngineCoreOutputs = decoder.decode(frame.buffer) + if outputs.utility_output: + _process_utility_output(outputs.utility_output, + utility_results) + continue + + if output_handler is not None: + assert _self_ref is not None + _self = _self_ref() + if not _self: + # Client has been garbage collected, abort. + return + await output_handler(_self, outputs) + + if outputs.outputs or outputs.scheduler_stats: + outputs_queue.put_nowait(outputs) + + self.queue_task = asyncio.create_task(process_outputs_socket(), + name="EngineCoreOutputQueueTask") + + async def get_output_async(self) -> EngineCoreOutputs: + self._ensure_output_queue_task() + assert self.outputs_queue is not None + return await self.outputs_queue.get() + + async def _send_input(self, request_type: EngineCoreRequestType, + request: Any) -> None: + await self.core_engine.send_multipart( + (request_type.value, self.encoder.encode(request))) + + self._ensure_output_queue_task() + + async def call_utility_async(self, method: str, *args) -> Any: + return await self._call_utility_async(method, + *args, + engine=self.core_engine) + + async def _call_utility_async( + self, + method: str, + *args, + engine: CoreEngine, + ) -> Any: + call_id = uuid.uuid1().int >> 64 + future = asyncio.get_running_loop().create_future() + self.utility_results[call_id] = future + message = (EngineCoreRequestType.UTILITY.value, + self.encoder.encode((call_id, method, args))) + await engine.send_multipart(message) + self._ensure_output_queue_task() + return await future + + async def add_request_async(self, request: EngineCoreRequest) -> None: + # NOTE: text prompt is not needed in the core engine as it has been + # tokenized. + request.prompt = None + await self._send_input(EngineCoreRequestType.ADD, request) + + async def abort_requests_async(self, request_ids: list[str]) -> None: + if len(request_ids) > 0: + await self._send_input(EngineCoreRequestType.ABORT, request_ids) + + async def profile_async(self, is_start: bool = True) -> None: + await self.call_utility_async("profile", is_start) + + async def reset_prefix_cache_async(self) -> None: + await self.call_utility_async("reset_prefix_cache") + + async def sleep_async(self, level: int = 1) -> None: + await self.call_utility_async("sleep", level) + + async def wake_up_async(self, tags: Optional[list[str]] = None) -> None: + await self.call_utility_async("wake_up", tags) + + async def is_sleeping_async(self) -> bool: + return await self.call_utility_async("is_sleeping") + + async def execute_dummy_batch_async(self) -> None: + await self.call_utility_async("execute_dummy_batch") + + async def add_lora_async(self, lora_request: LoRARequest) -> bool: + return await self.call_utility_async("add_lora", lora_request) + + async def remove_lora_async(self, lora_id: int) -> bool: + return await self.call_utility_async("remove_lora", lora_id) + + async def list_loras_async(self) -> set[int]: + return await self.call_utility_async("list_loras") + + async def pin_lora_async(self, lora_id: int) -> bool: + return await self.call_utility_async("pin_lora", lora_id) + + async def save_sharded_state_async(self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None) -> None: + await self.call_utility_async("save_sharded_state", path, pattern, + max_size) + + async def collective_rpc_async( + self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + return await self.call_utility_async("collective_rpc", method, timeout, + args, kwargs) + + +class DPAsyncMPClient(AsyncMPClient): + """Asyncio-compatible client for multi-proc, multi-engine (data parallel) + EngineCore.""" + + def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], + log_stats: bool): + super().__init__(vllm_config, executor_class, log_stats) + + assert len(self.core_engines) > 1 + + # Control message used for triggering dp idle mode loop. + self.start_dp_msg = (EngineCoreRequestType.START_DP.value, + self.encoder.encode(None)) + + self.num_engines_running = 0 + self.reqs_in_flight: dict[str, CoreEngine] = {} + + self.outputs_handler = DPAsyncMPClient.process_engine_outputs # type: ignore[assignment] + + def _init_core_engines( + self, + vllm_config: VllmConfig, + new_core_engine: Callable[[int, Optional[int]], CoreEngine], + core_engines: list[CoreEngine], + ) -> None: + + # Launch a core engine for each data parallel rank. + dp_size = vllm_config.parallel_config.data_parallel_size + for i in range(dp_size): + # Multi-node not yet supported so local_dp_rank == dp_rank. + core_engines.append(new_core_engine(i, i)) + + self.core_engines = core_engines + + async def call_utility_async(self, method: str, *args) -> Any: + # Only the result from the first engine is returned. + return (await asyncio.gather(*[ + self._call_utility_async(method, *args, engine=engine) + for engine in self.core_engines + ]))[0] + + async def add_request_async(self, request: EngineCoreRequest) -> None: + # NOTE: text prompt is not needed in the core engine as it has been + # tokenized. + request.prompt = None + + msg = (EngineCoreRequestType.ADD.value, self.encoder.encode(request)) + + chosen_engine = self.get_core_engine_for_request() + self.reqs_in_flight[request.request_id] = chosen_engine + chosen_engine.num_reqs_in_flight += 1 + if self.num_engines_running >= len(self.core_engines): + await chosen_engine.send_multipart(msg) + else: + # Send request to chosen engine and dp start loop + # control message to all other engines. + self.num_engines_running += len(self.core_engines) + await asyncio.gather(*[ + engine.send_multipart(msg if engine is + chosen_engine else self.start_dp_msg) + for engine in self.core_engines + ]) + + self._ensure_output_queue_task() + + def get_core_engine_for_request(self) -> CoreEngine: + return min(self.core_engines, key=lambda e: e.num_reqs_in_flight) + + @staticmethod + async def process_engine_outputs(self: "DPAsyncMPClient", + outputs: EngineCoreOutputs): + if self.reqs_in_flight: + for req_id in outputs.finished_requests or (): + if engine := self.reqs_in_flight.pop(req_id, None): + engine.num_reqs_in_flight -= 1 + + if outputs.engine_paused: + assert self.num_engines_running >= 1 + self.num_engines_running -= 1 + if not self.num_engines_running and self.reqs_in_flight: + # If there are requests in flight here, they must have + # been sent after the engines paused. We must make + # sure to start the other engines: + self.num_engines_running = len(self.core_engines) + coros = [ + engine.send_multipart(self.start_dp_msg) + for engine in self.core_engines + if not engine.num_reqs_in_flight + ] + if coros: + await asyncio.gather(*coros) + + async def abort_requests_async(self, request_ids: list[str]) -> None: + if not request_ids: + return + + if len(request_ids) == 1: + # Fast-path common case. + if engine := self.reqs_in_flight.get(request_ids[0]): + await self._abort_requests(request_ids, engine) + return + + by_engine: dict[CoreEngine, list[str]] = {} + for req_id in request_ids: + if engine := self.reqs_in_flight.get(req_id): + by_engine.setdefault(engine, []).append(req_id) + for engine, req_ids in by_engine.items(): + await self._abort_requests(req_ids, engine) + + async def _abort_requests(self, request_ids: list[str], + engine: CoreEngine) -> None: + await engine.send_multipart((EngineCoreRequestType.ABORT.value, + self.encoder.encode(request_ids))) diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py new file mode 100644 index 0000000..bf06a17 --- /dev/null +++ b/vllm/v1/engine/detokenizer.py @@ -0,0 +1,179 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass, field +from typing import Optional + +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.logger import init_logger +from vllm.transformers_utils.detokenizer_utils import ( + AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally) +from vllm.v1.engine import EngineCoreRequest + +logger = init_logger(__name__) + + +@dataclass +class IncrementalDetokenizer: + + # Generation data + token_ids: list[int] + output_text: str = "" + tokens: list[str] = field(default_factory=list) + prompt_len: int = 0 + + # Stop strings + stop: list[str] = field(default_factory=list) + include_stop_str_in_output: bool = False + + # Metadata for incremental detokenization + prefix_offset: int = 0 + read_offset: int = 0 + + # Parameters for detokenization + skip_special_tokens: bool = True + spaces_between_special_tokens: bool = True + + # Tokenizer for this request, + # None if detokenization is disabled. + tokenizer: Optional[AnyTokenizer] = None + + # Accounting for stop string buffering + stop_buffer_length: int = 0 + _last_output_text_offset: int = 0 + + @property + def output_token_ids(self) -> list[int]: + return self.token_ids if not self.prompt_len else ( + self.token_ids[self.prompt_len:]) + + @classmethod + def from_new_request( + cls, + tokenizer: Optional[AnyTokenizer], + request: EngineCoreRequest, + ) -> "IncrementalDetokenizer": + + if tokenizer is None: + return cls(token_ids=[]) + + tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens( + tokenizer=tokenizer, + prompt_ids=request.prompt_token_ids, + skip_special_tokens=request.sampling_params.skip_special_tokens, + ) + + stops = request.sampling_params.stop + # Number of chars to hold back when stop strings are to be excluded + # from streamed output. + if stops and not request.sampling_params.include_stop_str_in_output: + stop_buffer_length = max(len(s) for s in stops) - 1 + else: + stop_buffer_length = 0 + + return cls( + tokens=tokens, + # Detokenizer mutates this list, so need a unique copy. + # NOTE(Nick): could we take ownership of it though? + token_ids=request.prompt_token_ids.copy(), + stop=stops, + include_stop_str_in_output=request.sampling_params. + include_stop_str_in_output, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=request.sampling_params.skip_special_tokens, + spaces_between_special_tokens=request.sampling_params. + spaces_between_special_tokens, + prompt_len=len(request.prompt_token_ids), + tokenizer=tokenizer, + stop_buffer_length=stop_buffer_length, + ) + + def update(self, new_token_ids: list[int], + stop_terminated: bool) -> Optional[str]: + """ + Update RequestState for the request_id by: + 1) Detokenize the new token ids incrementally. + 2) Evaluate stop criteria. + + Return matched stop string or None. + """ + if not new_token_ids: + # Skip detokenization if no new token ids + return None + if self.tokenizer is None: + # Skip detokenization if no tokenizer + self.token_ids.extend(new_token_ids) + return None + + if stop_terminated and not self.include_stop_str_in_output: + # If stop-terminated, exclude last token from detokenization + # based on include_stop_str_in_output parameter. + skipped_stop_token_id = new_token_ids[-1] + new_token_ids = new_token_ids[:-1] + else: + skipped_stop_token_id = None + + # 1) Detokenize the new token ids incrementally. + # TODO(woosuk): This method becomes very inefficient when the number of + # new_token_ids is more than 1. We need to optimize this. + decoded_text = "" + for new_token_id in new_token_ids: + self.token_ids.append(new_token_id) + (new_tokens, new_decoded_token_text, prefix_offset, + read_offset) = detokenize_incrementally( + tokenizer=self.tokenizer, + all_input_ids=self.token_ids, + prev_tokens=self.tokens, + prefix_offset=self.prefix_offset, + read_offset=self.read_offset, + skip_special_tokens=self.skip_special_tokens, + spaces_between_special_tokens=self. + spaces_between_special_tokens, + ) + + self.tokens.extend(new_tokens) + self.prefix_offset = prefix_offset + self.read_offset = read_offset + + decoded_text += new_decoded_token_text + + self.output_text += decoded_text + + if stop_terminated: + if skipped_stop_token_id is not None: + # Cleanup after skipping detokenization + self.token_ids.append(skipped_stop_token_id) + # Stop token triggered; skip stop string check + return None + + # 2) Evaluate stop strings. + stop_string = None + if self.stop: + stop = StopChecker.check_stop_strings( + output_text=self.output_text, + new_char_count=len(decoded_text), + stop=self.stop, + include_in_output=self.include_stop_str_in_output, + ) + if stop is not None: + stop_string, truncate_to = stop + if truncate_to != -1: + self.output_text = self.output_text[:truncate_to] + + return stop_string + + def get_next_output_text(self, finished: bool, delta: bool) -> str: + """If delta is True, only new text since the last call to + this method is returned""" + + # We return the full output text if the sequence is finished. + buffer_length = 0 if finished else self.stop_buffer_length + if not delta: + return self.output_text[:-buffer_length] if buffer_length else ( + self.output_text) + length = len(self.output_text) - buffer_length + last_offset = self._last_output_text_offset + if last_offset < length: + self._last_output_text_offset = length + return self.output_text[last_offset:length] + return "" diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py new file mode 100644 index 0000000..4c67186 --- /dev/null +++ b/vllm/v1/engine/llm_engine.py @@ -0,0 +1,295 @@ +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Mapping +from copy import copy +from typing import Any, Callable, Optional, Union + +from typing_extensions import TypeVar + +import vllm.envs as envs +from vllm.config import ParallelConfig, VllmConfig +from vllm.distributed import stateless_destroy_torch_distributed_process_group +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.metrics_types import StatLoggerBase +from vllm.inputs import PromptType +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.outputs import RequestOutput +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer_group import ( + BaseTokenizerGroup, init_tokenizer_from_configs) +from vllm.usage.usage_lib import UsageContext +from vllm.utils import Device +from vllm.v1.engine.core_client import EngineCoreClient +from vllm.v1.engine.output_processor import OutputProcessor +from vllm.v1.engine.parallel_sampling import ParentRequest +from vllm.v1.engine.processor import Processor +from vllm.v1.executor.abstract import Executor + +logger = init_logger(__name__) + +_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) +_R = TypeVar("_R", default=Any) + + +class LLMEngine: + """Legacy LLMEngine for backwards compatibility.""" + + def __init__( + self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[dict[str, StatLoggerBase]] = None, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + use_cached_outputs: bool = False, + multiprocess_mode: bool = False, + ) -> None: + if not envs.VLLM_USE_V1: + raise ValueError( + "Using V1 LLMEngine, but envs.VLLM_USE_V1=False. " + "This should not happen. As a workaround, try using " + "LLMEngine.from_vllm_config(...) or explicitly set " + "VLLM_USE_V1=0 or 1 and report this issue on Github.") + + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + + # important: init dp group before init the engine_core + # In the decoupled engine case this is handled in EngineCoreProc. + parallel_config = vllm_config.parallel_config + if not multiprocess_mode and parallel_config.data_parallel_size > 1: + self.dp_group = parallel_config.stateless_init_dp_group() + else: + self.dp_group = None + self.should_execute_dummy_batch = False + + # Tokenizer (+ ensure liveness if running in another process). + self.tokenizer = init_tokenizer_from_configs( + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + parallel_config=vllm_config.parallel_config, + lora_config=vllm_config.lora_config) + self.tokenizer.ping() + + # Processor (convert Inputs --> EngineCoreRequests) + self.processor = Processor(vllm_config=vllm_config, + tokenizer=self.tokenizer, + mm_registry=mm_registry) + + # OutputProcessor (convert EngineCoreOutputs --> RequestOutput). + self.output_processor = OutputProcessor(self.tokenizer, + log_stats=False) + + # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs) + self.engine_core = EngineCoreClient.make_client( + multiprocess_mode=multiprocess_mode, + asyncio_mode=False, + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=False, # FIXME: implement + ) + + if not multiprocess_mode: + # for v0 compatibility + self.model_executor = self.engine_core.engine_core.model_executor # type: ignore + + @classmethod + def from_vllm_config( + cls, + vllm_config: VllmConfig, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[dict[str, StatLoggerBase]] = None, + disable_log_stats: bool = False, + ) -> "LLMEngine": + if stat_loggers is not None: + raise NotImplementedError( + "Passing StatLoggers to V1 is not yet supported. " + "Set VLLM_USE_V1=0 and file and issue on Github.") + + return cls(vllm_config=vllm_config, + executor_class=Executor.get_class(vllm_config), + log_stats=(not disable_log_stats), + usage_context=usage_context, + stat_loggers=stat_loggers, + multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING) + + @classmethod + def from_engine_args( + cls, + engine_args: EngineArgs, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[dict[str, StatLoggerBase]] = None, + enable_multiprocessing: bool = False, + ) -> "LLMEngine": + """Creates an LLM engine from the engine arguments.""" + + # Create the engine configs. + vllm_config = engine_args.create_engine_config(usage_context) + executor_class = Executor.get_class(vllm_config) + + if envs.VLLM_ENABLE_V1_MULTIPROCESSING: + logger.debug("Enabling multiprocessing for LLMEngine.") + enable_multiprocessing = True + + # Create the LLMEngine. + return cls(vllm_config=vllm_config, + executor_class=executor_class, + log_stats=not engine_args.disable_log_stats, + usage_context=usage_context, + stat_loggers=stat_loggers, + multiprocess_mode=enable_multiprocessing) + + def get_num_unfinished_requests(self) -> int: + return self.output_processor.get_num_unfinished_requests() + + def has_unfinished_requests(self) -> bool: + has_unfinished = self.output_processor.has_unfinished_requests() + if self.dp_group is None: + return has_unfinished + return self.has_unfinished_requests_dp(has_unfinished) + + def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool: + aggregated_has_unfinished = ParallelConfig.has_unfinished_dp( + self.dp_group, has_unfinished) + if not has_unfinished and aggregated_has_unfinished: + self.should_execute_dummy_batch = True + return aggregated_has_unfinished + + @classmethod + def validate_outputs(cls, outputs, output_type): + return outputs + + def abort_request(self, request_ids: list[str]) -> None: + """Remove request_ids from EngineCore and Detokenizer.""" + + request_ids = self.output_processor.abort_requests(request_ids) + self.engine_core.abort_requests(request_ids) + + def add_request( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> None: + # Process raw inputs into the request. + request = self.processor.process_inputs(request_id, prompt, params, + arrival_time, lora_request, + trace_headers, + prompt_adapter_request, + priority) + + n = params.n if isinstance(params, SamplingParams) else 1 + + if n == 1: + # Make a new RequestState and queue. + self.output_processor.add_request(request, None, 0) + # Add the request to EngineCore. + self.engine_core.add_request(request) + return + + # Fan out child requests (for n>1). + parent_req = ParentRequest(request_id, params) + for idx in range(n): + request_id, params = parent_req.get_child_info(idx) + child_request = request if idx == n - 1 else copy(request) + child_request.request_id = request_id + child_request.sampling_params = params + + # Make a new RequestState and queue. + self.output_processor.add_request(child_request, parent_req, idx) + # Add the request to EngineCore. + self.engine_core.add_request(child_request) + + def step(self) -> list[RequestOutput]: + + if self.should_execute_dummy_batch: + self.should_execute_dummy_batch = False + self.engine_core.execute_dummy_batch() + return [] + + # 1) Get EngineCoreOutput from the EngineCore. + outputs = self.engine_core.get_output() + + # 2) Process EngineCoreOutputs. + processed_outputs = self.output_processor.process_outputs( + outputs.outputs) + + # 3) Abort any reqs that finished due to stop strings. + self.engine_core.abort_requests(processed_outputs.reqs_to_abort) + + return processed_outputs.request_outputs + + def get_model_config(self): + return self.model_config + + def start_profile(self): + self.engine_core.profile(True) + + def stop_profile(self): + self.engine_core.profile(False) + + def reset_prefix_cache(self, device: Optional[Device] = None): + self.engine_core.reset_prefix_cache() + + def sleep(self, level: int = 1): + self.engine_core.sleep(level) + + def wake_up(self, tags: Optional[list[str]] = None): + self.engine_core.wake_up(tags) + + def is_sleeping(self) -> bool: + return self.engine_core.is_sleeping() + + def get_tokenizer_group( + self, + group_type: type[_G] = BaseTokenizerGroup, + ) -> _G: + tokenizer_group = self.tokenizer + + if tokenizer_group is None: + raise ValueError("Unable to get tokenizer because " + "skip_tokenizer_init is True") + if not isinstance(tokenizer_group, group_type): + raise TypeError("Invalid type of tokenizer group. " + f"Expected type: {group_type}, but " + f"found type: {type(tokenizer_group)}") + + return tokenizer_group + + def add_lora(self, lora_request: LoRARequest) -> bool: + """Load a new LoRA adapter into the engine for future requests.""" + return self.engine_core.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + """Remove an already loaded LoRA adapter.""" + return self.engine_core.remove_lora(lora_id) + + def list_loras(self) -> set[int]: + """List all registered adapters.""" + return self.engine_core.list_loras() + + def pin_lora(self, lora_id: int) -> bool: + """Prevent an adapter from being evicted.""" + return self.engine_core.pin_lora(lora_id) + + def collective_rpc(self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + return self.engine_core.collective_rpc(method, timeout, args, kwargs) + + def __del__(self): + if dp_group := getattr(self, "dp_group", None): + stateless_destroy_torch_distributed_process_group(dp_group) diff --git a/vllm/v1/engine/logprobs.py b/vllm/v1/engine/logprobs.py new file mode 100644 index 0000000..03d82b6 --- /dev/null +++ b/vllm/v1/engine/logprobs.py @@ -0,0 +1,198 @@ +# SPDX-License-Identifier: Apache-2.0 + +import itertools +from collections.abc import Iterable +from dataclasses import dataclass +from typing import Optional + +from vllm.logger import init_logger +from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs +from vllm.transformers_utils.detokenizer_utils import ( + AnyTokenizer, convert_ids_list_to_tokens) +from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest +from vllm.v1.outputs import LogprobsLists, LogprobsTensors + +logger = init_logger(__name__) + +NONES = itertools.repeat(None) + + +@dataclass +class LogprobsProcessor: + + # Tokenizer for this request, + # None if detokenization is disabled. + tokenizer: Optional[AnyTokenizer] + + # Logprobs for this request + logprobs: Optional[SampleLogprobs] + prompt_logprobs: Optional[PromptLogprobs] + cumulative_logprob: Optional[float] + num_logprobs: Optional[int] + num_prompt_logprobs: Optional[int] + + @classmethod + def from_new_request( + cls, + tokenizer: Optional[AnyTokenizer], + request: EngineCoreRequest, + ) -> "LogprobsProcessor": + num_logprobs = request.sampling_params.logprobs + num_prompt_logprobs = request.sampling_params.prompt_logprobs + return cls( + tokenizer=tokenizer, + cumulative_logprob=(None if num_logprobs is None else 0.), + logprobs=(None if num_logprobs is None else []), + # NOTE: logprob of first prompt token is None. + prompt_logprobs=(None if num_prompt_logprobs is None else [None]), + num_prompt_logprobs=num_prompt_logprobs, + num_logprobs=num_logprobs, + ) + + def _update_sample_logprobs(self, logprobs_lists: LogprobsLists) -> None: + """Update with sample logprobs from EngineCore. + + Outer lists are only of len > 1 if EngineCore made + >1 tokens in prior step (e.g. in spec decoding). + + Args: + logprobs_lists: the lists of logprob tokens, logprobs, and ranks. + + """ + + assert self.num_logprobs is not None + assert self.logprobs is not None + assert self.cumulative_logprob is not None + + token_ids_lst, logprobs_lst, ranks_lst = logprobs_lists + + for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst, + token_ids_lst): + + # Detokenize (non-incrementally). + decoded_tokens = NONES if self.tokenizer is None else ( + convert_ids_list_to_tokens(self.tokenizer, token_ids)) + + # Sampler puts the sampled logprob in first. + sampled_token_logprob = logprobs[0] + self.cumulative_logprob += sampled_token_logprob + + # Update with the Logprob dictionary for this pos. + self.logprobs.append( + self._make_logprob_dict( + logprobs, + token_ids, + decoded_tokens, + rank, + self.num_logprobs, + )) + + def _update_prompt_logprobs( + self, + prompt_logprobs_tensors: LogprobsTensors, + ) -> None: + """Update with prompt logprobs from EngineCore. + + Args: + prompt_logprobs_tensors: tuple containing the prompt logprobs + tensors. + + """ + + # Prompt logprobs are enabled. + assert self.num_prompt_logprobs is not None + assert self.prompt_logprobs is not None + + token_ids, logprobs, ranks = prompt_logprobs_tensors + + # Detokenize non-incrementally. + # Output is flat: [num_tok, num_lps] -> [num_tok * num_lps] + decoded_tokens = None if self.tokenizer is None else ( + convert_ids_list_to_tokens(self.tokenizer, + token_ids.flatten().tolist())) + + # Recover shapes. + num_prompt_tokens, num_logprobs = logprobs.shape + + # Pythonize the torch tensors. + prompt_token_ranks = ranks.tolist() + prompt_logprobs = logprobs.tolist() + token_ids = token_ids.tolist() + + # Make Logprob for each position. + for pos in range(num_prompt_tokens): + # Handle flattening. + offset = pos * num_logprobs + offset_end = offset + num_logprobs + decoded_tokens_for_pos = NONES \ + if decoded_tokens is None else decoded_tokens[offset:offset_end] + + # Update with the Logprob dictionary for this pos. + self.prompt_logprobs.append( + self._make_logprob_dict(prompt_logprobs[pos], token_ids[pos], + decoded_tokens_for_pos, + prompt_token_ranks[pos], + self.num_prompt_logprobs)) + + def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]: + """Pop and return all request prompt logprobs + + The logprobs processor aggregates prompt chunk logprobs + over one or more prefill chunks. This method returns + all prompt logprobs at once and then forgets them. + Ensures correct RequestOutputKind.DELTA semantics + wherein all prompt logprobs are returned at once at + the end of prefill. + + Returns: + None if prompt logprobs are disabled for this request. + List of all prompt logprobs, otherwise. + """ + plp = self.prompt_logprobs + if plp: + self.prompt_logprobs = [] + return plp + + @staticmethod + def _make_logprob_dict( + logprobs: list[float], + logprob_token_ids: list[int], + decoded_tokens: Iterable[Optional[str]], + rank: int, + num_logprobs: int, + ) -> dict[int, Logprob]: + """Make a Logprob dictionary for a position. + + Args: + logprobs: list of log probabilities + logprob_token_ids: list of top token ids + decoded_tokens: list of decoded top tokens + rank: rank of the sampled token + num_logprobs: number of logprobs requested + by the user (in addition to sampled logprob) + + Returns: + dict[token id, Logprob] + """ + + # We do not need a special case for the sampled token + # being in the topk, since inserting duplicated data + # into a dictionary twice is the same as doing it once. + topk_ranks = range(1, num_logprobs + 1) + ranks = itertools.chain((rank, ), topk_ranks) + + return { + token_id: Logprob( + logprob=logprob, + rank=rank, + decoded_token=token, + ) + for token_id, logprob, rank, token in zip( + logprob_token_ids, logprobs, ranks, decoded_tokens) + } + + def update_from_output(self, output: EngineCoreOutput) -> None: + if output.new_logprobs is not None: + self._update_sample_logprobs(output.new_logprobs) + if output.new_prompt_logprobs_tensors is not None: + self._update_prompt_logprobs(output.new_prompt_logprobs_tensors) diff --git a/vllm/v1/engine/mm_input_cache.py b/vllm/v1/engine/mm_input_cache.py new file mode 100644 index 0000000..61a55d2 --- /dev/null +++ b/vllm/v1/engine/mm_input_cache.py @@ -0,0 +1,55 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm.envs import VLLM_MM_INPUT_CACHE_GIB +from vllm.multimodal import MultiModalKwargs +from vllm.multimodal.processing import ProcessingCache + +# The idea of multimodal preprocessing caching is based on having a client and +# a server, where the client executes in the frontend process (=P0) and the +# server in the core process (=P1). +# +# -- Client: +# - BaseMultiModalProcessor to process MultiModalData into MultiModalKwargs +# with built-in caching functionality, with mm_hash as its identifier. +# +# -- Server: +# - MMInputCacheServer to perform caching of the received MultiModalKwargs. +# +# The caching for both client and server is mirrored, and this allows us +# to avoid the serialization of "mm_inputs" (like pixel values) between +# client (=P0) and server (=P1) processes if the mm_hash is found in the client +# cache. + +# Both Client and Server must use the same cache size +# (to perform mirrored caching). This cache size is set by the environment +# variable VLLM_MM_INPUT_CACHE_GIB. + + +class MMInputCacheServer: + + def __init__(self, model_config): + self.use_cache = not model_config.disable_mm_preprocessor_cache + self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB, + MultiModalKwargs) + + def get_and_update( + self, + mm_inputs: list[MultiModalKwargs], + mm_hashes: list[str], + ) -> list[MultiModalKwargs]: + assert len(mm_inputs) == len(mm_hashes) + + if not self.use_cache: + return mm_inputs + + full_mm_inputs = [] + for mm_input, mm_hash in zip(mm_inputs, mm_hashes): + assert mm_hash is not None + if mm_input is None: + mm_input = self.mm_cache[mm_hash] + else: + self.mm_cache[mm_hash] = mm_input + + full_mm_inputs.append(mm_input) + + return full_mm_inputs diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py new file mode 100644 index 0000000..70f072d --- /dev/null +++ b/vllm/v1/engine/output_processor.py @@ -0,0 +1,405 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +from collections.abc import Iterable +from dataclasses import dataclass +from typing import Optional, Union + +from vllm.outputs import CompletionOutput, RequestOutput +from vllm.sampling_params import RequestOutputKind +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup +from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason +from vllm.v1.engine.detokenizer import IncrementalDetokenizer +from vllm.v1.engine.logprobs import LogprobsProcessor +from vllm.v1.engine.parallel_sampling import ParentRequest +from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates, + RequestStateStats) + + +class RequestOutputCollector: + """ + Collects streamed RequestOutputs per individual request, + for hand-off to the consuming asyncio generate task. + + When streaming deltas, RequestOutputs are merged if the + producer gets ahead of the consumer. + """ + + def __init__(self, output_kind: RequestOutputKind): + self.aggregate = output_kind == RequestOutputKind.DELTA + self.output: Optional[RequestOutput] = None + self.ready = asyncio.Event() + + def put(self, output: RequestOutput) -> None: + if self.output is None: + self.output = output + self.ready.set() + elif self.aggregate: + # Coalesce the outputs in delta case. + self.output.add(output) + else: + # Just replace latest in non-delta case. + self.output = output + + async def get(self) -> RequestOutput: + while (output := self.output) is None: + await self.ready.wait() + self.output = None + self.ready.clear() + return output + + def get_nowait(self) -> Optional[RequestOutput]: + output = self.output + if output is not None: + self.output = None + self.ready.clear() + return output + + +@dataclass +class OutputProcessorOutput: + + request_outputs: list[RequestOutput] + reqs_to_abort: list[str] + + +class RequestState: + + def __init__( + self, + request_id: str, + parent_req: Optional[ParentRequest], + request_index: int, + lora_name: Optional[str], + output_kind: RequestOutputKind, + prompt: Optional[str], + prompt_token_ids: list[int], + logprobs_processor: LogprobsProcessor, + detokenizer: IncrementalDetokenizer, + max_tokens_param: Optional[int], + arrival_time: float, + queue: Optional[RequestOutputCollector], + log_stats: bool, + ): + self.request_id = request_id + self.parent_req = parent_req + self.request_index = request_index + self.lora_name = lora_name + self.output_kind = output_kind + self.prompt = prompt + self.prompt_token_ids = prompt_token_ids + self.prompt_len = len(prompt_token_ids) + self.logprobs_processor = logprobs_processor + self.detokenizer = detokenizer + self.max_tokens_param = max_tokens_param + self.is_prefilling = True + self.queue = queue + + self.stats = RequestStateStats( + arrival_time=arrival_time) if log_stats else None + + @classmethod + def from_new_request( + cls, + tokenizer: AnyTokenizer, + request: EngineCoreRequest, + parent_req: Optional[ParentRequest], + request_index: int, + queue: Optional[RequestOutputCollector], + log_stats: bool, + ) -> "RequestState": + if not request.sampling_params.detokenize: + tokenizer = None + return cls( + request_id=request.request_id, + parent_req=parent_req, + request_index=request_index, + lora_name=(request.lora_request.name + if request.lora_request is not None else None), + output_kind=request.sampling_params.output_kind, + prompt=request.prompt, + prompt_token_ids=request.prompt_token_ids, + logprobs_processor=LogprobsProcessor.from_new_request( + tokenizer=tokenizer, + request=request, + ), + detokenizer=IncrementalDetokenizer.from_new_request( + tokenizer=tokenizer, + request=request, + ), + max_tokens_param=(request.sampling_params.max_tokens if + request.sampling_params is not None else None), + arrival_time=request.arrival_time, + queue=queue, + log_stats=log_stats, + ) + + def make_request_output( + self, + new_token_ids: list[int], + finish_reason: Optional[FinishReason], + stop_reason: Union[int, str, None], + ) -> Optional[RequestOutput]: + + finished = finish_reason is not None + final_only = self.output_kind == RequestOutputKind.FINAL_ONLY + + if not finished and final_only: + # Only the final output is required in FINAL_ONLY mode. + return None + + completion_output = self._new_completion_output( + new_token_ids, finish_reason, stop_reason) + + request_id = self.request_id + if self.parent_req is None: + outputs = [completion_output] + else: + request_id, outputs, finished = self.parent_req.get_outputs( + request_id, completion_output) + if not outputs: + return None + + return self._new_request_output(request_id, outputs, finished) + + def _new_request_output( + self, + request_id: str, + outputs: list[CompletionOutput], + finished: bool, + ) -> RequestOutput: + + if self.output_kind == RequestOutputKind.DELTA: + # Side effect: logprobs processor forgets prompt logprobs + prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs() + else: + prompt_logprobs = self.logprobs_processor.prompt_logprobs + + return RequestOutput( + request_id=request_id, + prompt=self.prompt, + prompt_token_ids=self.prompt_token_ids, + prompt_logprobs=prompt_logprobs, + outputs=outputs, + finished=finished, + ) + + def _new_completion_output( + self, + token_ids: list[int], + finish_reason: Optional[FinishReason], + stop_reason: Union[int, str, None], + ) -> CompletionOutput: + + finished = finish_reason is not None + delta = self.output_kind == RequestOutputKind.DELTA + + # Prepare text and token_ids, based on delta mode + text = self.detokenizer.get_next_output_text(finished, delta) + if not delta: + token_ids = self.detokenizer.output_token_ids + + # Prepare logprobs, based on delta mode + logprobs = self.logprobs_processor.logprobs + if delta and logprobs: + logprobs = logprobs[-len(token_ids):] + + return CompletionOutput( + index=self.request_index, + text=text, + token_ids=token_ids, + logprobs=logprobs, + cumulative_logprob=self.logprobs_processor.cumulative_logprob, + finish_reason=str(finish_reason) if finished else None, + stop_reason=stop_reason if finished else None) + + +class OutputProcessor: + """Process EngineCoreOutputs into RequestOutputs.""" + + def __init__( + self, + tokenizer: BaseTokenizerGroup, + log_stats: bool, + ): + self.log_stats = log_stats + self.tokenizer = tokenizer + self.request_states: dict[str, RequestState] = {} + self.parent_requests: dict[str, ParentRequest] = {} + self.lora_states = LoRARequestStates() + + def get_num_unfinished_requests(self): + return len(self.request_states) + + def has_unfinished_requests(self) -> bool: + return len(self.request_states) > 0 + + def abort_requests( + self, + request_ids: Iterable[str], + ) -> list[str]: + request_ids_to_abort = [] + for request_id in request_ids: + req_state = self.request_states.pop(request_id, None) + if req_state is not None: + self.lora_states.abort_request(req_state) + request_ids_to_abort.append(request_id) + else: + parent = self.parent_requests.pop(request_id, None) + if parent and parent.child_requests: + self.abort_requests(parent.child_requests) + request_ids_to_abort.extend(parent.child_requests) + return request_ids_to_abort + + def add_request( + self, + request: EngineCoreRequest, + parent_req: Optional[ParentRequest] = None, + request_index: int = 0, + queue: Optional[RequestOutputCollector] = None, + ) -> None: + request_id = request.request_id + if request_id in self.request_states: + raise ValueError(f"Request id {request_id} already running.") + + req_state = RequestState.from_new_request( + tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request), + request=request, + parent_req=parent_req, + request_index=request_index, + queue=queue, + log_stats=self.log_stats) + self.request_states[request_id] = req_state + self.lora_states.add_request(req_state) + if parent_req: + self.parent_requests[parent_req.request_id] = parent_req + + def process_outputs( + self, + engine_core_outputs: list[EngineCoreOutput], + engine_core_timestamp: Optional[float] = None, + iteration_stats: Optional[IterationStats] = None, + ) -> OutputProcessorOutput: + """ + Process the EngineCoreOutputs: + 1) Compute stats for logging + 2) Detokenize + 3) Create and handle RequestOutput objects: + * If there is a queue (for usage with AsyncLLM), + put the RequestOutput objects into the queue for + handling by the per-request generate() tasks. + + * If there is no queue (for usage with LLMEngine), + return a list of RequestOutput objects. + + ****************** NOTE FOR DEVELOPERS ****************** + + vLLM V1 minimizes the number of python loops over the full + batch to ensure system overheads are minimized. This is the + only function that should loop over EngineCoreOutputs. + + If you need to touch every element of the batch, do it from + within the loop below. + + ********************************************************** + """ + + request_outputs: list[RequestOutput] = [] + reqs_to_abort: list[str] = [] + for engine_core_output in engine_core_outputs: + req_id = engine_core_output.request_id + req_state = self.request_states.get(req_id) + if req_state is None: + # Ignore output for already-aborted request. + continue + + # 1) Compute stats for this iteration. + self._update_stats_from_output(req_state, engine_core_output, + engine_core_timestamp, + iteration_stats) + + new_token_ids = engine_core_output.new_token_ids + finish_reason = engine_core_output.finish_reason + stop_reason = engine_core_output.stop_reason + + req_state.is_prefilling = False + + # 2) Detokenize the token ids into text and perform stop checks. + stop_string = req_state.detokenizer.update( + new_token_ids, finish_reason == FinishReason.STOP) + if stop_string: + finish_reason = FinishReason.STOP + stop_reason = stop_string + + # 3) Compute sample and prompt logprobs for request, if required. + req_state.logprobs_processor.update_from_output(engine_core_output) + + # 4) Create and handle RequestOutput objects. + if request_output := req_state.make_request_output( + new_token_ids, finish_reason, stop_reason): + if req_state.queue is not None: + # AsyncLLM: put into queue for handling by generate(). + req_state.queue.put(request_output) + else: + # LLMEngine: return list of RequestOutputs. + request_outputs.append(request_output) + + # Free completed requests. + if finish_reason is not None: + self.request_states.pop(req_id) + # Remove parent request if applicable. + parent_req = req_state.parent_req + if parent_req and not parent_req.child_requests: + self.parent_requests.pop(parent_req.request_id, None) + if not engine_core_output.finished: + # If req not finished in EngineCore, but Detokenizer + # detected stop string, abort needed in EngineCore. + reqs_to_abort.append(req_id) + + # Track per-request stats + self._update_stats_from_finished(req_state, finish_reason, + iteration_stats) + + self.lora_states.update_iteration_stats(iteration_stats) + + return OutputProcessorOutput( + request_outputs=request_outputs, + reqs_to_abort=reqs_to_abort, + ) + + def _update_stats_from_output(self, req_state: RequestState, + engine_core_output: EngineCoreOutput, + engine_core_timestamp: Optional[float], + iteration_stats: Optional[IterationStats]): + if iteration_stats is None: + return + + lora_stats = self.lora_states.get_stats(req_state) + + assert engine_core_timestamp is not None + assert req_state.stats is not None + iteration_stats.update_from_output(engine_core_output, + engine_core_timestamp, + req_state.is_prefilling, + req_state.prompt_len, + req_state.stats, lora_stats) + + def _update_stats_from_finished(self, req_state: RequestState, + finish_reason: Optional[FinishReason], + iteration_stats: Optional[IterationStats]): + if iteration_stats is None: + return + + assert finish_reason is not None + assert req_state.stats is not None + iteration_stats.update_from_finished_request( + finish_reason=finish_reason, + num_prompt_tokens=len(req_state.prompt_token_ids), + max_tokens_param=req_state.max_tokens_param, + req_stats=req_state.stats) + self.lora_states.finish_request(req_state) + + ParentRequest.observe_finished_request( + req_state.parent_req, iteration_stats, + req_state.stats.num_generation_tokens) diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py new file mode 100644 index 0000000..4df7ca5 --- /dev/null +++ b/vllm/v1/engine/parallel_sampling.py @@ -0,0 +1,132 @@ +# SPDX-License-Identifier: Apache-2.0 + +from copy import copy +from typing import Optional + +from vllm.outputs import CompletionOutput +from vllm.sampling_params import RequestOutputKind, SamplingParams +from vllm.v1.metrics.stats import IterationStats + + +class ParentRequest: + """Info, state & processing for parallel sampling request. + + Store parent request ID and sampling params. + Facilitate generating child request sampling params. + """ + + request_id: str + sampling_params: SamplingParams + + # To track the completion of child requests + child_requests: set[str] + + # To aggregate child completions when not streaming + output_aggregator: list[CompletionOutput] + + # To find the max number of generated tokens across all children + max_num_generation_tokens: int + + # To efficiently obtain child sampling params + cached_child_sampling_params: Optional[SamplingParams] + + def __init__(self, request_id: str, + sampling_params: SamplingParams) -> None: + self.request_id = request_id + self.sampling_params = sampling_params + + self.child_requests = set() + self.output_aggregator = [None] * sampling_params.n if ( + sampling_params.output_kind + == RequestOutputKind.FINAL_ONLY) else [] + self.max_num_generation_tokens = 0 + self.cached_child_sampling_params = None + + def _get_child_sampling_params( + self, + index: int, + ) -> SamplingParams: + """Efficiently obtain child `sampling_params` + + If `sampling_params.seed` is not `None` then + each child request requires a unique clone of + parent `sampling_params` with a unique seed. + + Args: + index: index within `n` child requests + + Returns: + Child `sampling_params` instance. + """ + seed = self.sampling_params.seed + if self.cached_child_sampling_params: + # Reuse child sampling_params data structure + return self.cached_child_sampling_params + # Build child sampling_params + child_sampling_params = copy(self.sampling_params) + child_sampling_params.n = 1 + if seed is None: + # Cache child sampling_params for later reuse + self.cached_child_sampling_params = child_sampling_params + else: + # Each child gets a clone with a unique seed + child_sampling_params.seed = seed + index + return child_sampling_params + + def get_child_info(self, index: int) -> tuple[str, SamplingParams]: + """Get child request ID and sampling params. + + Args: + index: index within `n` child requests. + + Returns: + (request ID, sampling_params) tuple + """ + child_req_id = f"{index}_{self.request_id}" + self.child_requests.add(child_req_id) + return child_req_id, self._get_child_sampling_params(index) + + @property + def n(self) -> int: + return self.sampling_params.n + + def get_outputs( + self, + child_request_id: str, + completion_output: CompletionOutput, + ) -> tuple[str, list[CompletionOutput], bool]: + if completion_output.finished(): + self.child_requests.remove(child_request_id) + + if self.sampling_params.output_kind != RequestOutputKind.FINAL_ONLY: + # If streaming, just return the current output. + outputs = [completion_output] + else: + # If not streaming, aggregate the n final outputs. + self.output_aggregator[completion_output.index] = completion_output + outputs = [] if self.child_requests else self.output_aggregator + + finished = not self.child_requests + return self.request_id, outputs, finished + + def observe_num_generation_tokens(self, num_generation_tokens: int): + self.max_num_generation_tokens = max(num_generation_tokens, + self.max_num_generation_tokens) + return self.max_num_generation_tokens + + @staticmethod + def observe_finished_request(parent_req: Optional['ParentRequest'], + iteration_stats: IterationStats, + num_generation_tokens: int): + + n_param = parent_req.n if parent_req is not None else 1 + + if parent_req is not None: + num_generation_tokens = parent_req.observe_num_generation_tokens( + num_generation_tokens) + + # Child requests finished, we can now record to iteration stats + if parent_req is None or not parent_req.child_requests: + iteration_stats.max_num_generation_tokens_iter.append( + num_generation_tokens) + iteration_stats.n_params_iter.append(n_param) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py new file mode 100644 index 0000000..0d28928 --- /dev/null +++ b/vllm/v1/engine/processor.py @@ -0,0 +1,326 @@ +# SPDX-License-Identifier: Apache-2.0 + +import time +from collections.abc import Mapping +from typing import Optional, Union + +from vllm.config import VllmConfig +from vllm.inputs import ProcessorInputs, PromptType +from vllm.inputs.parse import split_enc_dec_inputs +from vllm.inputs.preprocess import InputPreprocessor +from vllm.lora.request import LoRARequest +from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, + MultiModalRegistry) +from vllm.multimodal.inputs import PlaceholderRange +from vllm.multimodal.utils import merge_and_sort_multimodal_metadata +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup +from vllm.v1.engine import EngineCoreRequest +from vllm.v1.structured_output.backend_guidance import ( + validate_guidance_grammar) +from vllm.v1.structured_output.utils import ( + validate_structured_output_request_xgrammar) + + +class Processor: + + def __init__( + self, + vllm_config: VllmConfig, + tokenizer: BaseTokenizerGroup, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + ): + + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.decoding_config = vllm_config.decoding_config + self.tokenizer = tokenizer + + self.generation_config_fields = ( + self.model_config.try_get_generation_config()) + self.input_preprocessor = InputPreprocessor(self.model_config, + self.tokenizer, + mm_registry) + + # Multi-modal hasher (for images) + self.use_hash = ( + not self.model_config.disable_mm_preprocessor_cache) or \ + self.cache_config.enable_prefix_caching + + def _validate_logprobs( + self, + params: SamplingParams, + ) -> None: + max_logprobs = self.model_config.max_logprobs + # Validate sample logprobs. + if params.logprobs and params.logprobs > max_logprobs: + raise ValueError( + f"Requested sample logprobs of {params.logprobs}, " + f"which is greater than max allowed: {max_logprobs}") + + # Validate prompt logprobs. + if params.prompt_logprobs and params.prompt_logprobs > max_logprobs: + raise ValueError( + f"Requested prompt logprobs of {params.prompt_logprobs}, " + f"which is greater than max allowed: {max_logprobs}") + + def _validate_sampling_params( + self, + params: SamplingParams, + ) -> None: + self._validate_structured_output(params) + + if params.allowed_token_ids is None: + return + if not params.allowed_token_ids: + raise ValueError("allowed_token_ids is not None and empty!") + vocab_size = self.model_config.get_vocab_size() + if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids): + raise ValueError( + "allowed_token_ids contains out-of-vocab token id!") + + def _validate_supported_sampling_params( + self, + params: SamplingParams, + ) -> None: + # Best of not yet supported. + if params.best_of is not None and params.best_of > 1: + raise ValueError("vLLM V1 does not yet support best_of.") + # Logits processors not supported. + if params.logits_processors: + raise ValueError("vLLM V1 does not support per request " + "user provided logits processors.") + + def _validate_params( + self, + params: Union[SamplingParams, PoolingParams], + ): + """ + Validate supported SamplingParam. + Should raise ValueError if unsupported for API Server. + """ + + if not isinstance(params, SamplingParams): + raise ValueError("V1 does not yet support Pooling models.") + + self._validate_logprobs(params) + self._validate_sampling_params(params) + self._validate_supported_sampling_params(params) + + def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None: + if lora_request is not None and not self.lora_config: + raise ValueError(f"Got lora_request {lora_request} but LoRA is " + "not enabled!") + + def _validate_structured_output(self, params: SamplingParams) -> None: + if not params.guided_decoding or not self.decoding_config: + return + + supported_backends = [ + "xgrammar", "xgrammar:disable-any-whitespace", "guidance", + "guidance:disable-any-whitespace", "auto" + ] + engine_level_backend = self.decoding_config.guided_decoding_backend + if engine_level_backend not in supported_backends: + raise ValueError(f"Only {supported_backends} structured output is " + "supported in V1.") + if params.guided_decoding.backend: + if params.guided_decoding.backend != engine_level_backend: + raise ValueError("Request-level structured output backend " + "must match engine-level backend. " + f"{params.guided_decoding.backend}" + f" != {engine_level_backend}") + else: + params.guided_decoding.backend = engine_level_backend + import vllm.platforms + if vllm.platforms.current_platform.is_tpu(): + raise ValueError("Structured output is not supported on TPU.") + + # Request content validation + if engine_level_backend.startswith("xgrammar"): + # xgrammar with no fallback + validate_structured_output_request_xgrammar(params) + params.guided_decoding.backend = engine_level_backend + elif engine_level_backend == "auto": + # "auto" is an opt-in to opinionated behavior where we try to + # choose a backend based on request contents. This is not the + # default as it is less predictable and subject to change + # between releases as feature support changes. + try: + validate_structured_output_request_xgrammar(params) + params.guided_decoding.backend = "xgrammar" + except ValueError: + # The request includes some jsonschema feature(s) that + # are not supported in xgrammar. Fall back to guidance. + params.guided_decoding.backend = "guidance" + + if engine_level_backend.startswith("guidance"): + # TODO ideally we would have the LLTokenizer here as Lark syntax + # allows <|special_token|> and similar, see + # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens + # Without tokenizer these are disallowed in grammars. + validate_guidance_grammar(params, tokenizer=None) + params.guided_decoding.backend = engine_level_backend + + def process_inputs( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> EngineCoreRequest: + + # TODO(woosuk): Support pooling models. + # TODO(woosuk): Support encoder-decoder models. + + self._validate_lora(lora_request) + self._validate_params(params) + if priority != 0: + raise ValueError("V1 does not support priority yet.") + if trace_headers is not None: + raise ValueError("V1 does not support tracing yet.") + if prompt_adapter_request is not None: + raise ValueError("V1 does not support prompt_adapter_request.") + + if arrival_time is None: + arrival_time = time.time() + + # Process inputs, which includes: + # 1. Tokenize text prompt, with LoRA request if one exists. + # 2. For multimodal models with a merged preprocessor, preprocess + # multimodal data and expand prompt token ids accordingly. + # 3. Apply prompt adapter to prompt token ids if one exists. + processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess( + prompt, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + return_mm_hashes=self.use_hash, + ) + eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) + + self._validate_model_inputs(processed_inputs, lora_request) + + encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) + + # TODO: Impl encoder-decoder + if encoder_inputs is not None: + raise NotImplementedError + + assert isinstance(params, SamplingParams) + # TODO: can we avoid cloning here in multiproc case? + sampling_params = params.clone() + # If unset max tokens, then generate up to the max_model_len. + if sampling_params.max_tokens is None: + sampling_params.max_tokens = ( + self.model_config.max_model_len - + len(decoder_inputs["prompt_token_ids"])) + sampling_params.update_from_generation_config( + self.generation_config_fields, eos_token_id) + sampling_params.update_from_tokenizer( + self.tokenizer.get_lora_tokenizer(lora_request)) + + # Multimodal related. + sorted_mm_inputs: Optional[list[MultiModalKwargs]] = None + sorted_mm_positions: Optional[list[PlaceholderRange]] = None + sorted_mm_hashes: Optional[list[str]] = None + if decoder_inputs["type"] == "multimodal": + decoder_mm_inputs = decoder_inputs["mm_kwargs"] + + # Merge and flatten multimodal placeholders, hashes and inputs + # from dictionaries to lists, and sort them by each item's position + # in the input sequence. + ( + sorted_item_modalities, + sorted_mm_positions, + sorted_mm_hashes, + ) = merge_and_sort_multimodal_metadata( + decoder_inputs["mm_placeholders"], + decoder_inputs["mm_hashes"] if self.use_hash else None, + ) + + # The output of merged multi-modal processor (`decoder_mm_inputs`) + # is a single MultiModalKwargs for all items from all modalities. + # This code flattens kwargs for individual items in a list and + # sorts them by each item's position in the input sequence if there + # are multiple modalities. + unique_modalities = set(sorted_item_modalities) + if len(unique_modalities) > 1: + sorted_mm_inputs = [] + used_indices = {modality: 0 for modality in unique_modalities} + for modality in sorted_item_modalities: + items = decoder_mm_inputs.get_items(modality) + item = items[used_indices[modality]] + sorted_mm_inputs.append(MultiModalKwargs.from_items([item + ])) + used_indices[modality] += 1 + else: + sorted_mm_inputs = [ + MultiModalKwargs.from_items([item]) for item in + decoder_mm_inputs.get_items(sorted_item_modalities[0]) + ] + + return EngineCoreRequest( + request_id=request_id, + prompt=decoder_inputs.get("prompt"), + prompt_token_ids=decoder_inputs["prompt_token_ids"], + mm_inputs=sorted_mm_inputs, + mm_hashes=sorted_mm_hashes, + mm_placeholders=sorted_mm_positions, + sampling_params=sampling_params, + eos_token_id=eos_token_id, + arrival_time=arrival_time, + lora_request=lora_request, + ) + + def _validate_model_inputs(self, + inputs: ProcessorInputs, + lora_request: Optional[LoRARequest] = None): + encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs) + + # For encoder-decoder multimodal models, the max_prompt_len + # restricts the decoder prompt length + if self.model_config.is_multimodal_model: + prompt_inputs = decoder_inputs + else: + prompt_inputs = encoder_inputs or decoder_inputs + + prompt_ids = prompt_inputs["prompt_token_ids"] + + if prompt_ids is None or len(prompt_ids) == 0: + raise ValueError("Prompt cannot be empty") + + max_input_id = max(prompt_ids) + max_allowed = self.tokenizer.get_lora_tokenizer( + lora_request).max_token_id + if max_input_id > max_allowed: + raise ValueError( + "Token id {} is out of vocabulary".format(max_input_id)) + + if len(prompt_ids) >= self.model_config.max_model_len: + raise ValueError( + f"Prompt length of {len(prompt_ids)} is longer than the " + f"maximum model length of {self.model_config.max_model_len}.") + + if self.model_config.is_multimodal_model: + max_prompt_len = self.model_config.max_model_len + + if len(prompt_ids) > max_prompt_len: + raise ValueError( + f"The prompt (total length {len(prompt_ids)}) is too long " + f"to fit into the model (context length {max_prompt_len}). " + "Make sure that `max_model_len` is no smaller than the " + "number of text tokens plus multimodal tokens. For image " + "inputs, the number of image tokens depends on the number " + "of images, and possibly their aspect ratios as well.") + + # TODO: Find out how many placeholder tokens are there so we can + # check that chunked prefill does not truncate them + # max_batch_len = self.scheduler_config.max_num_batched_tokens diff --git a/vllm/v1/executor/__init__.py b/vllm/v1/executor/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py new file mode 100644 index 0000000..e3a4cd9 --- /dev/null +++ b/vllm/v1/executor/abstract.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 + +from concurrent.futures import Future +from typing import Union + +import torch +import torch.distributed as dist + +from vllm.config import VllmConfig +from vllm.executor.executor_base import ExecutorBase +from vllm.executor.uniproc_executor import ( # noqa + ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0) +from vllm.executor.uniproc_executor import ( # noqa + UniProcExecutor as UniProcExecutorV0) +from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec +from vllm.v1.outputs import ModelRunnerOutput + + +class Executor(ExecutorBase): + """ + Abstract class for v1 executors, mainly define some methods for v1. + For methods shared by v0 and v1, define them in ExecutorBase""" + + @staticmethod + def get_class(vllm_config: VllmConfig) -> type["Executor"]: + executor_class: type[Executor] + parallel_config = vllm_config.parallel_config + distributed_executor_backend = ( + parallel_config.distributed_executor_backend) + # distributed_executor_backend must be set in VllmConfig.__post_init__ + if isinstance(distributed_executor_backend, type): + if not issubclass(distributed_executor_backend, ExecutorBase): + raise TypeError( + "distributed_executor_backend must be a subclass of " + f"ExecutorBase. Got {distributed_executor_backend}.") + executor_class = distributed_executor_backend + elif distributed_executor_backend == "ray": + from vllm.v1.executor.ray_distributed_executor import ( # noqa + RayDistributedExecutor) + executor_class = RayDistributedExecutor + elif distributed_executor_backend == "mp": + from vllm.v1.executor.multiproc_executor import MultiprocExecutor + executor_class = MultiprocExecutor + elif distributed_executor_backend == "uni": + executor_class = UniProcExecutor + elif distributed_executor_backend == "external_launcher": + # TODO: make v1 scheduling deterministic + # to support external launcher + executor_class = ExecutorWithExternalLauncher + else: + raise ValueError("Unknown distributed executor backend: " + f"{distributed_executor_backend}") + return executor_class + + def initialize_from_config(self, + kv_cache_configs: list[KVCacheConfig]) -> None: + """ + Initialize the KV caches and begin the model execution loop of the + underlying workers. + """ + self.collective_rpc("initialize_from_config", + args=(kv_cache_configs, )) + self.collective_rpc("compile_or_warm_up_model") + + def determine_available_memory(self) -> list[int]: # in bytes + output = self.collective_rpc("determine_available_memory") + return output + + def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]: + output = self.collective_rpc("get_kv_cache_spec") + return output + + def execute_model( + self, + scheduler_output, + ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: + output = self.collective_rpc("execute_model", + args=(scheduler_output, )) + return output[0] + + @property + def max_concurrent_batches(self) -> int: + return 1 + + def profile(self, is_start: bool = True): + self.collective_rpc("profile", args=(is_start, )) + + +class UniProcExecutor(UniProcExecutorV0, Executor): + pass + + +class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor): + + def determine_available_memory(self) -> list[int]: # in bytes + # same as determine_num_available_blocks in v0, + # we need to get the min across all ranks. + memory = super().determine_available_memory() + from vllm.distributed.parallel_state import get_world_group + cpu_group = get_world_group().cpu_group + memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64) + dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN) + return [memory_tensor.item()] diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py new file mode 100644 index 0000000..1d5175e --- /dev/null +++ b/vllm/v1/executor/multiproc_executor.py @@ -0,0 +1,387 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +import pickle +import signal +import sys +import time +import traceback +import weakref +from dataclasses import dataclass +from enum import Enum, auto +from functools import partial +from multiprocessing.process import BaseProcess +from typing import Any, Callable, Optional, Union + +import cloudpickle +import psutil +import zmq + +from vllm.config import VllmConfig +from vllm.distributed import (destroy_distributed_environment, + destroy_model_parallel) +from vllm.distributed.device_communicators.shm_broadcast import (Handle, + MessageQueue) +from vllm.executor.multiproc_worker_utils import ( + _add_prefix, set_multiprocessing_worker_envs) +from vllm.logger import init_logger +from vllm.utils import (get_distributed_init_method, get_mp_context, + get_open_port, get_open_zmq_ipc_path, zmq_socket_ctx) +from vllm.v1.executor.abstract import Executor +from vllm.worker.worker_base import WorkerWrapperBase + +logger = init_logger(__name__) + +POLLING_TIMEOUT_MS = 5000 +POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000 + + +class MultiprocExecutor(Executor): + + def _init_executor(self) -> None: + # Call self.shutdown at exit to clean up + # and ensure workers will be terminated. + self._finalizer = weakref.finalize(self, self.shutdown) + + # The child processes will send SIGUSR1 when unrecoverable + # errors happen. + def sigusr1_handler(signum, frame): + logger.fatal( + "MulitprocExecutor got fatal signal from worker processes, " + "shutting down. See stack trace above for root cause issue.") + # Propagate error up to parent process. + parent_process = psutil.Process().parent() + parent_process.send_signal(signal.SIGUSR1) + self.shutdown() + + signal.signal(signal.SIGUSR1, sigusr1_handler) + + self.world_size = self.parallel_config.world_size + tensor_parallel_size = self.parallel_config.tensor_parallel_size + assert self.world_size == tensor_parallel_size, ( + f"world_size ({self.world_size}) must be equal to the " + f"tensor_parallel_size ({tensor_parallel_size}). " + f"Pipeline parallelism is not yet implemented in v1") + + # Set multiprocessing envs that are common to V0 and V1 + set_multiprocessing_worker_envs(self.parallel_config) + + # Multiprocessing-based executor does not support multi-node setting. + # Since it only works for single node, we can use the loopback address + # 127.0.0.1 for communication. + distributed_init_method = get_distributed_init_method( + "127.0.0.1", get_open_port()) + + # Initialize worker and set up message queues for SchedulerOutputs + # and ModelRunnerOutputs + self.rpc_broadcast_mq = MessageQueue(self.world_size, self.world_size) + scheduler_output_handle = self.rpc_broadcast_mq.export_handle() + + # Create workers + self.workers: list[WorkerProcHandle] = [] + for rank in range(self.world_size): + worker = WorkerProc.make_worker_process(self.vllm_config, rank, + rank, + distributed_init_method, + scheduler_output_handle) + self.workers.append(worker) + + # Ensure message queues are ready. Will deadlock if re-ordered + # Must be kept consistent with the WorkerProc + self.rpc_broadcast_mq.wait_until_ready() + for w in self.workers: + w.worker_response_mq.wait_until_ready() + + def collective_rpc(self, + method: Union[str, Callable], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None) -> list[Any]: + start_time = time.monotonic() + kwargs = kwargs or {} + + # NOTE: If the args are heterogeneous, then we pack them into a list, + # and unpack them in the method of every worker, because every worker + # knows their own rank. + try: + if isinstance(method, str): + send_method = method + else: + send_method = cloudpickle.dumps( + method, protocol=pickle.HIGHEST_PROTOCOL) + self.rpc_broadcast_mq.enqueue((send_method, args, kwargs)) + + responses = [None] * self.world_size + for w in self.workers: + dequeue_timeout = timeout - (time.monotonic() - start_time + ) if timeout is not None else None + status, result = w.worker_response_mq.dequeue( + timeout=dequeue_timeout) + + if status != WorkerProc.ResponseStatus.SUCCESS: + if isinstance(result, Exception): + raise result + else: + raise RuntimeError("Worker failed") + + responses[w.rank] = result + + return responses + except TimeoutError as e: + raise TimeoutError(f"RPC call to {method} timed out.") from e + except Exception as e: + # Re-raise any other exceptions + raise e + + def _ensure_worker_termination(self): + """Ensure that all worker processes are terminated. Assumes workers have + received termination requests. Waits for processing, then sends + termination and kill signals if needed.""" + + def wait_for_termination(procs, timeout): + if not time: + # If we are in late stage shutdown, the interpreter may replace + # `time` with `None`. + return all(not proc.is_alive() for proc in procs) + start_time = time.time() + while time.time() - start_time < timeout: + if all(not proc.is_alive() for proc in procs): + return True + time.sleep(0.1) + return False + + # Send SIGTERM if still running + active_procs = [w.proc for w in self.workers if w.proc.is_alive()] + for p in active_procs: + p.terminate() + if not wait_for_termination(active_procs, 4): + # Send SIGKILL if still running + active_procs = [p for p in active_procs if p.is_alive()] + for p in active_procs: + p.kill() + + self._cleanup_sockets() + + def _cleanup_sockets(self): + for w in self.workers: + # Remove the zmq ipc socket file + socket_path = w.ready_path.replace("ipc://", "") + if os and os.path.exists(socket_path): + os.remove(socket_path) + + def shutdown(self): + """Properly shut down the executor and its workers""" + if not getattr(self, 'shutting_down', False): + self.shutting_down = True + for w in self.workers: + w.worker_response_mq = None + self._ensure_worker_termination() + + self.rpc_broadcast_mq = None + + def check_health(self) -> None: + self.collective_rpc("check_health", timeout=10) + return + + +@dataclass +class WorkerProcHandle: + proc: BaseProcess + rank: int + ready_path: str + worker_response_mq: MessageQueue # The worker process writes to this MQ + + +class WorkerProc: + """Wrapper that runs one Worker in a separate process.""" + + READY_STR = "READY" + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + input_shm_handle: Handle, + ready_path: str, + ): + self.rank = rank + wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank) + # TODO: move `init_worker` to executor level as a collective rpc call + all_kwargs: list[dict] = [ + {} for _ in range(vllm_config.parallel_config.world_size) + ] + all_kwargs[rank] = { + "vllm_config": vllm_config, + "local_rank": local_rank, + "rank": rank, + "distributed_init_method": distributed_init_method, + "is_driver_worker": rank == 0, + } + wrapper.init_worker(all_kwargs) + self.worker = wrapper + + pid = os.getpid() + _add_prefix(sys.stdout, f"VllmWorker rank={rank}", pid) + _add_prefix(sys.stderr, f"VllmWorker rank={rank}", pid) + + # Initialize MessageQueue for receiving SchedulerOutput + self.rpc_broadcast_mq = MessageQueue.create_from_handle( + input_shm_handle, self.worker.rank) + + # Initializes a message queue for sending the model output + self.worker_response_mq = MessageQueue(1, 1) + worker_response_mq_handle = self.worker_response_mq.export_handle() + + # Send Readiness signal to EngineCore process. + # Set linger here because we want to ensure the message has + # been sent before the context is closed. + with zmq_socket_ctx(ready_path, zmq.constants.PUSH, + linger=10000) as ready_socket: + payload = pickle.dumps(worker_response_mq_handle, + protocol=pickle.HIGHEST_PROTOCOL) + ready_socket.send_string(WorkerProc.READY_STR) + ready_socket.send(payload) + + self.worker.init_device() + self.worker.load_model() + + @staticmethod + def make_worker_process( + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + input_shm_handle, # Receive SchedulerOutput + ) -> WorkerProcHandle: + context = get_mp_context() + + # ZMQ path for worker to send ready message and shm_broadcast handle + # back to core process. + ready_path = get_open_zmq_ipc_path() + + process_kwargs = { + "vllm_config": vllm_config, + "local_rank": local_rank, + "rank": rank, + "distributed_init_method": distributed_init_method, + "input_shm_handle": input_shm_handle, + "ready_path": ready_path, + } + # Run EngineCore busy loop in background process. + proc = context.Process(target=WorkerProc.worker_main, + kwargs=process_kwargs, + daemon=True) + + with zmq_socket_ctx(ready_path, zmq.constants.PULL) as ready_socket: + proc.start() + + # Wait for startup + worker_response_mq_handle = WorkerProc.wait_for_startup( + proc, ready_socket) + + worker_response_mq = MessageQueue.create_from_handle( + worker_response_mq_handle, 0) + + return WorkerProcHandle(proc, rank, ready_path, worker_response_mq) + + def shutdown(self): + self.rpc_broadcast_mq = None + self.worker_response_mq = None + destroy_model_parallel() + destroy_distributed_environment() + + @staticmethod + def worker_main(*args, **kwargs): + """ Worker initialization and execution loops. + This runs a background process """ + + # Signal handler used for graceful termination. + # SystemExit exception is only raised once to allow this and worker + # processes to terminate without error + shutdown_requested = False + + def signal_handler(signum, frame): + nonlocal shutdown_requested + if not shutdown_requested: + shutdown_requested = True + raise SystemExit() + + # Either SIGTERM or SIGINT will terminate the worker + signal.signal(signal.SIGTERM, signal_handler) + signal.signal(signal.SIGINT, signal_handler) + + worker = None + try: + worker = WorkerProc(*args, **kwargs) + + # Ensure message queues are ready. Will deadlock if re-ordered. + # Must be kept consistent with the Executor + worker.rpc_broadcast_mq.wait_until_ready() + worker.worker_response_mq.wait_until_ready() + + worker.worker_busy_loop() + + except SystemExit: + logger.debug("Worker interrupted.") + + except Exception: + # worker_busy_loop sends exceptions exceptons to Executor + # for shutdown, but if there is an error in startup or an + # error with IPC itself, we need to alert the parent. + psutil.Process().parent().send_signal(signal.SIGUSR1) + raise + + finally: + # Clean up once worker exits busy loop + if worker is not None: + worker.shutdown() + worker = None + + @staticmethod + def wait_for_startup( + proc: BaseProcess, + ready_socket: zmq.Socket, + ) -> Optional[Handle]: + """Wait until the Worker is ready.""" + + # Wait for Worker to send READY. + while ready_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: + logger.debug("Waiting for WorkerProc to startup.") + + if not proc.is_alive(): + raise RuntimeError("WorkerProc failed to start.") + + message = ready_socket.recv_string() + assert message == WorkerProc.READY_STR + handle_frame = ready_socket.recv(copy=False) + handle = pickle.loads(handle_frame.buffer) + return handle + + class ResponseStatus(Enum): + SUCCESS = auto() + FAILURE = auto() + + def worker_busy_loop(self): + """Main busy loop for Multiprocessing Workers""" + while True: + method, args, kwargs = self.rpc_broadcast_mq.dequeue() + + try: + if isinstance(method, str): + func = getattr(self.worker, method) + elif isinstance(method, bytes): + func = partial(cloudpickle.loads(method), self.worker) + output = func(*args, **kwargs) + except Exception as e: + # Notes have been introduced in python 3.11 + if hasattr(e, "add_note"): + e.add_note(traceback.format_exc()) + self.worker_response_mq.enqueue( + (WorkerProc.ResponseStatus.FAILURE, e)) + logger.exception("WorkerProc hit an exception: %s", exc_info=e) + continue + + self.worker_response_mq.enqueue( + (WorkerProc.ResponseStatus.SUCCESS, output)) diff --git a/vllm/v1/executor/ray_distributed_executor.py b/vllm/v1/executor/ray_distributed_executor.py new file mode 100644 index 0000000..320ebfd --- /dev/null +++ b/vllm/v1/executor/ray_distributed_executor.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 + +from concurrent.futures import Future +from typing import Union + +from vllm.executor.ray_distributed_executor import ( # noqa + RayDistributedExecutor as RayDistributedExecutorV0) +from vllm.v1.executor.abstract import Executor +from vllm.v1.outputs import ModelRunnerOutput + + +class FutureWrapper(Future): + """A wrapper around a Ray output reference to meet the interface + of .execute_model(). + """ + + def __init__(self, ref): + super().__init__() + self.ref = ref + + def result(self, timeout=None): + if timeout is not None: + raise NotImplementedError("timeout is not supported") + return self.ref.get() + + +class RayDistributedExecutor(RayDistributedExecutorV0, Executor): + """Ray distributed executor using Ray Compiled Graphs.""" + + @property + def max_concurrent_batches(self) -> int: + """Ray distributed executor supports pipeline parallelism, + meaning that it allows PP size batches to be executed concurrently. + """ + return self.parallel_config.pipeline_parallel_size + + def execute_model( + self, + scheduler_output, + ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: + """Execute the model on the Ray workers. + + Args: + scheduler_output: The scheduler output to execute. + + Returns: + The model runner output. + """ + # Build the compiled DAG for the first time. + if self.forward_dag is None: # type: ignore + self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) + + refs = self.forward_dag.execute(scheduler_output) # type: ignore + + # When PP is not used, we block here until the result is available. + if self.max_concurrent_batches == 1: + return refs[0].get() + + # When PP is used, we return a FutureWrapper immediately so that + # the scheduler can yield to the next batch. + return FutureWrapper(refs[0]) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py new file mode 100644 index 0000000..8185c21 --- /dev/null +++ b/vllm/v1/kv_cache_interface.py @@ -0,0 +1,178 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass + +import torch + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.utils import cdiv, get_dtype_size +import vllm.envs as envs +logger = init_logger(__name__) + + +@dataclass +class KVCacheSpec: + """ + A base class for specifying the KV cache format of one layer. + """ + + # number of tokens in a block + block_size: int + + @property + def type_id(self) -> str: + """ + The type identifier of this KV cache. + Return different strings for layers with different KV cache type (e.g., + different number of tokens like full attention vs sliding window + attention, different KV cache size per token like layers with different + number of heads) + + Returns: + The type identifier of this KV cache. + """ + raise NotImplementedError + + @property + def page_size_bytes(self) -> int: + """ + The size of a page with `block_size` tokens in bytes. + + Returns: + The page size + """ + raise NotImplementedError + + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + """ + The maximum possible memory usage of this KV cache in bytes. + + Returns: + The KV cache size in bytes + """ + raise NotImplementedError + + +@dataclass +class AttentionSpec(KVCacheSpec): + num_kv_heads: int + head_size: int + dtype: torch.dtype + use_mla: bool + + @property + def page_size_bytes(self) -> int: + # For MLA we only store a single latent vector + coef = 1 if self.use_mla else 2 + if envs.VLLM_USE_INT8_MLA: + self.dtype = torch.int8 + return coef * self.block_size * self.num_kv_heads * self.head_size \ + * get_dtype_size(self.dtype) + ## only for int8 mla + @property + def scale_page_size_bytes(self) -> int: + # For MLA we only store a single latent vector + coef = 1 if self.use_mla else 2 + if envs.VLLM_USE_INT8_MLA: + return coef * self.block_size * self.num_kv_heads * 2 \ + * get_dtype_size(torch.float32) + else: + return 0 + + +@dataclass +class FullAttentionSpec(AttentionSpec): + + @property + def type_id(self) -> str: + return f"full_attention_{self.block_size}_{self.page_size_bytes}" + + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + max_model_len = vllm_config.model_config.max_model_len + return cdiv(max_model_len, self.block_size) * (self.page_size_bytes + self.scale_page_size_bytes) + + +@dataclass +class SlidingWindowSpec(AttentionSpec): + sliding_window: int + + def __post_init__(self): + assert not self.use_mla, "MLA is not supported for sliding window" + + @property + def type_id(self) -> str: + return f"sliding_window_{self.sliding_window}_{self.block_size}_{self.page_size_bytes}" # noqa + + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + max_model_len = vllm_config.model_config.max_model_len + max_num_batched_tokens = ( + vllm_config.scheduler_config.max_num_batched_tokens) + + # During chunked prefill, we allocate KV cache for the last + # `self.sliding_window-1` computed tokens plus the newly scheduled + # tokens. And we won't allocate KV cache for more than `max_model_len` + # tokens. + num_tokens = min(self.sliding_window - 1 + max_num_batched_tokens, + max_model_len) + + # +1 here because the sliding window may not start from the beginning + # of the block. For example, if the block size is 4 and num_token + # is 4, we need two blocks [XXCD] [EF] to store the sliding + # window [CDEF] of 6 tokens. + return (cdiv(num_tokens, self.block_size) + 1) * (self.page_size_bytes + self.scale_page_size_bytes) + + +@dataclass +class KVCacheTensor: + """ + A dataclass for specifying how the workers should initialize the KV cache + for a layer. Only contains the size of KV cache for that layer for now. Will + be extended to support multiple layers sharing the same memory pool. + """ + size: int # The size of KV cache Tensor in bytes + + +@dataclass +class KVCacheGroupSpec: + """ + Represents a group of model layers that share the same KV cache block table. + These layers are regarded as one layer in the KV cache manager. + """ + # The names of model layers in this group + layer_names: list[str] + # The KV cache spec of this manager layer + kv_cache_spec: KVCacheSpec + + +@dataclass +class KVCacheConfig: + """ + The KV cache configuration of a model. + """ + """The number of KV cache blocks""" + num_blocks: int + """layer_name -> how to initialize KV cache for that layer""" + tensors: dict[str, KVCacheTensor] + """ + The kv cache groups of the model. + The layers in the models are repeated with some patterns, e.g., a model + with 10 full attention layers and 20 sliding window attention layers can be + regarded as repeating the pattern (1 * full, 2 * sw) 10 times. + The KVCacheManager allocates different block tables for each of the 3 layers + in the pattern, and repeats each of them 10 times to generate the + block_table for the 30 layers in the model. + Therefore, we can group the layers in the model into 3 groups, each of which + contains 10 layers in the model. + The KVCacheManager allocates the block_table for each group based on its + kv_cache spec, and the model runner applies the block table to each layer + in the group. + For example: + 1. A model only uses full attention. The pattern is + (num_hidden_layers * full), so there is only one group and the block table + is shared by all layers. + 2. (WIP) A model with 10 full attention layers and 20 sliding window + attention layers. There are 3 layers in the pattern (1 * full, 2 * sw), so + there are 3 groups, each of which represents 10 layers in the model. + """ + kv_cache_groups: list[KVCacheGroupSpec] diff --git a/vllm/v1/metrics/__init__.py b/vllm/v1/metrics/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py new file mode 100644 index 0000000..73883d9 --- /dev/null +++ b/vllm/v1/metrics/loggers.py @@ -0,0 +1,469 @@ +# SPDX-License-Identifier: Apache-2.0 + +import time +from abc import ABC, abstractmethod +from typing import Optional + +import numpy as np +import prometheus_client + +from vllm.config import SupportsMetricsInfo, VllmConfig +from vllm.logger import init_logger +from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics +from vllm.v1.engine import FinishReason +from vllm.v1.metrics.stats import IterationStats, SchedulerStats +from vllm.v1.spec_decode.metrics import SpecDecodingMetrics + +logger = init_logger(__name__) + +_LOCAL_LOGGING_INTERVAL_SEC = 5.0 + + +class StatLoggerBase(ABC): + + @abstractmethod + def record(self, scheduler_stats: SchedulerStats, + iteration_stats: Optional[IterationStats]): + ... + + def log(self): # noqa + pass + + +class LoggingStatLogger(StatLoggerBase): + + def __init__(self, engine_index: int = 0): + self.engine_index = engine_index + self._reset(time.monotonic()) + self.last_scheduler_stats = SchedulerStats() + # Prefix cache metrics. This cannot be reset. + # TODO: Make the interval configurable. + self.prefix_caching_metrics = PrefixCachingMetrics() + self.spec_decoding_metrics = SpecDecodingMetrics() + + def _reset(self, now): + self.last_log_time = now + + # Tracked stats over current local logging interval. + self.num_prompt_tokens: list[int] = [] + self.num_generation_tokens: list[int] = [] + + def _track_iteration_stats(self, iteration_stats: IterationStats): + # Save tracked stats for token counters. + self.num_prompt_tokens.append(iteration_stats.num_prompt_tokens) + self.num_generation_tokens.append( + iteration_stats.num_generation_tokens) + + def _get_throughput(self, tracked_stats: list[int], now: float) -> float: + # Compute summary metrics for tracked stats + return float(np.sum(tracked_stats) / (now - self.last_log_time)) + + def record(self, scheduler_stats: SchedulerStats, + iteration_stats: Optional[IterationStats]): + """Log Stats to standard output.""" + + if iteration_stats: + self._track_iteration_stats(iteration_stats) + + self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats) + + if scheduler_stats.spec_decoding_stats is not None: + self.spec_decoding_metrics.observe( + scheduler_stats.spec_decoding_stats) + + self.last_scheduler_stats = scheduler_stats + + def log(self): + now = time.monotonic() + prompt_throughput = self._get_throughput(self.num_prompt_tokens, now) + generation_throughput = self._get_throughput( + self.num_generation_tokens, now) + + self._reset(now) + + scheduler_stats = self.last_scheduler_stats + + # Format and print output. + logger.info( + "Engine %03d: " + "Avg prompt throughput: %.1f tokens/s, " + "Avg generation throughput: %.1f tokens/s, " + "Running: %d reqs, Waiting: %d reqs, " + "GPU KV cache usage: %.1f%%, " + "Prefix cache hit rate: %.1f%%", + self.engine_index, + prompt_throughput, + generation_throughput, + scheduler_stats.num_running_reqs, + scheduler_stats.num_waiting_reqs, + scheduler_stats.gpu_cache_usage * 100, + self.prefix_caching_metrics.hit_rate * 100, + ) + + if scheduler_stats.spec_decoding_stats is not None: + self.spec_decoding_metrics.log() + + +class PrometheusStatLogger(StatLoggerBase): + + def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): + self._unregister_vllm_metrics() + + # Use this flag to hide metrics that were deprecated in + # a previous release and which will be removed future + self.show_hidden_metrics = \ + vllm_config.observability_config.show_hidden_metrics + + labelnames = ["model_name", "engine"] + labelvalues = [ + vllm_config.model_config.served_model_name, + str(engine_index) + ] + + max_model_len = vllm_config.model_config.max_model_len + + # + # Scheduler state + # + self.gauge_scheduler_running = prometheus_client.Gauge( + name="vllm:num_requests_running", + documentation="Number of requests in model execution batches.", + labelnames=labelnames).labels(*labelvalues) + + self.gauge_scheduler_waiting = prometheus_client.Gauge( + name="vllm:num_requests_waiting", + documentation="Number of requests waiting to be processed.", + labelnames=labelnames).labels(*labelvalues) + + # + # GPU cache + # + self.gauge_gpu_cache_usage = prometheus_client.Gauge( + name="vllm:gpu_cache_usage_perc", + documentation="GPU KV-cache usage. 1 means 100 percent usage.", + labelnames=labelnames).labels(*labelvalues) + + self.counter_gpu_prefix_cache_queries = prometheus_client.Counter( + name="vllm:gpu_prefix_cache_queries", + documentation= + "GPU prefix cache queries, in terms of number of queried blocks.", + labelnames=labelnames).labels(*labelvalues) + + self.counter_gpu_prefix_cache_hits = prometheus_client.Counter( + name="vllm:gpu_prefix_cache_hits", + documentation= + "GPU prefix cache hits, in terms of number of cached blocks.", + labelnames=labelnames).labels(*labelvalues) + + # + # Counters + # + self.counter_num_preempted_reqs = prometheus_client.Counter( + name="vllm:num_preemptions_total", + documentation="Cumulative number of preemption from the engine.", + labelnames=labelnames).labels(*labelvalues) + + self.counter_prompt_tokens = prometheus_client.Counter( + name="vllm:prompt_tokens_total", + documentation="Number of prefill tokens processed.", + labelnames=labelnames).labels(*labelvalues) + + self.counter_generation_tokens = prometheus_client.Counter( + name="vllm:generation_tokens_total", + documentation="Number of generation tokens processed.", + labelnames=labelnames).labels(*labelvalues) + + self.counter_request_success: dict[FinishReason, + prometheus_client.Counter] = {} + counter_request_success_base = prometheus_client.Counter( + name="vllm:request_success_total", + documentation="Count of successfully processed requests.", + labelnames=labelnames + ["finished_reason"]) + for reason in FinishReason: + self.counter_request_success[ + reason] = counter_request_success_base.labels(*(labelvalues + + [str(reason)])) + + # + # Histograms of counts + # + self.histogram_num_prompt_tokens_request = \ + prometheus_client.Histogram( + name="vllm:request_prompt_tokens", + documentation="Number of prefill tokens processed.", + buckets=build_1_2_5_buckets(max_model_len), + labelnames=labelnames).labels(*labelvalues) + + self.histogram_num_generation_tokens_request = \ + prometheus_client.Histogram( + name="vllm:request_generation_tokens", + documentation="Number of generation tokens processed.", + buckets=build_1_2_5_buckets(max_model_len), + labelnames=labelnames).labels(*labelvalues) + + self.histogram_iteration_tokens = \ + prometheus_client.Histogram( + name="vllm:iteration_tokens_total", + documentation="Histogram of number of tokens per engine_step.", + buckets=build_cudagraph_buckets(vllm_config), + labelnames=labelnames).labels(*labelvalues) + + self.histogram_max_num_generation_tokens_request = \ + prometheus_client.Histogram( + name="vllm:request_max_num_generation_tokens", + documentation= + "Histogram of maximum number of requested generation tokens.", + buckets=build_1_2_5_buckets(max_model_len), + labelnames=labelnames).labels(*labelvalues) + + self.histogram_n_request = \ + prometheus_client.Histogram( + name="vllm:request_params_n", + documentation="Histogram of the n request parameter.", + buckets=[1, 2, 5, 10, 20], + labelnames=labelnames).labels(*labelvalues) + + self.histogram_max_tokens_request = \ + prometheus_client.Histogram( + name="vllm:request_params_max_tokens", + documentation="Histogram of the max_tokens request parameter.", + buckets=build_1_2_5_buckets(max_model_len), + labelnames=labelnames).labels(*labelvalues) + + # + # Histogram of timing intervals + # + self.histogram_time_to_first_token = \ + prometheus_client.Histogram( + name="vllm:time_to_first_token_seconds", + documentation="Histogram of time to first token in seconds.", + buckets=[ + 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, + 0.75, 1.0, 2.5, 5.0, 7.5, 10.0 + ], + labelnames=labelnames).labels(*labelvalues) + + self.histogram_time_per_output_token = \ + prometheus_client.Histogram( + name="vllm:time_per_output_token_seconds", + documentation="Histogram of time per output token in seconds.", + buckets=[ + 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, + 0.75, 1.0, 2.5 + ], + labelnames=labelnames).labels(*labelvalues) + + request_latency_buckets = [ + 0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, + 40.0, 50.0, 60.0 + ] + self.histogram_e2e_time_request = \ + prometheus_client.Histogram( + name="vllm:e2e_request_latency_seconds", + documentation="Histogram of e2e request latency in seconds.", + buckets=request_latency_buckets, + labelnames=labelnames).labels(*labelvalues) + self.histogram_queue_time_request = \ + prometheus_client.Histogram( + name="vllm:request_queue_time_seconds", + documentation= + "Histogram of time spent in WAITING phase for request.", + buckets=request_latency_buckets, + labelnames=labelnames).labels(*labelvalues) + self.histogram_inference_time_request = \ + prometheus_client.Histogram( + name="vllm:request_inference_time_seconds", + documentation= + "Histogram of time spent in RUNNING phase for request.", + buckets=request_latency_buckets, + labelnames=labelnames).labels(*labelvalues) + self.histogram_prefill_time_request = \ + prometheus_client.Histogram( + name="vllm:request_prefill_time_seconds", + documentation= + "Histogram of time spent in PREFILL phase for request.", + buckets=request_latency_buckets, + labelnames=labelnames).labels(*labelvalues) + self.histogram_decode_time_request = \ + prometheus_client.Histogram( + name="vllm:request_decode_time_seconds", + documentation= + "Histogram of time spent in DECODE phase for request.", + buckets=request_latency_buckets, + labelnames=labelnames).labels(*labelvalues) + + # + # LoRA metrics + # + self.gauge_lora_info: Optional[prometheus_client.Gauge] = None + if vllm_config.lora_config is not None: + self.labelname_max_lora = "max_lora" + self.labelname_waiting_lora_adapters = "waiting_lora_adapters" + self.labelname_running_lora_adapters = "running_lora_adapters" + self.max_lora = vllm_config.lora_config.max_loras + self.gauge_lora_info = \ + prometheus_client.Gauge( + name="vllm:lora_requests_info", + documentation="Running stats on lora requests.", + labelnames=[ + self.labelname_max_lora, + self.labelname_waiting_lora_adapters, + self.labelname_running_lora_adapters, + ]) + + # + # Speculative Decoding metrics + # The acceptance rate can be calculated using a PromQL query: + # + # rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) / + # rate(vllm:spec_decode_num_draft_tokens_total[$interval]) + # + self.counter_spec_decode_num_draft_tokens = \ + prometheus_client.Counter( + name="vllm:spec_decode_num_draft_tokens_total", + documentation="Number of draft tokens.", + labelnames=labelnames).labels(*labelvalues) + self.counter_spec_decode_num_accepted_tokens = \ + prometheus_client.Counter( + name="vllm:spec_decode_num_accepted_tokens_total", + documentation="Number of accepted tokens.", + labelnames=labelnames).labels(*labelvalues) + + # + # Cache config info metric + # + self.log_metrics_info("cache_config", vllm_config.cache_config) + + def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo): + metrics_info = config_obj.metrics_info() + + name, documentation = None, None + if type == "cache_config": + name = "vllm:cache_config_info" + documentation = "Information of the LLMEngine CacheConfig" + assert name is not None, f"Unknown metrics info type {type}" + + # Info type metrics are syntactic sugar for a gauge permanently set to 1 + # Since prometheus multiprocessing mode does not support Info, emulate + # info here with a gauge. + info_gauge = prometheus_client.Gauge( + name=name, + documentation=documentation, + labelnames=metrics_info.keys()).labels(**metrics_info) + info_gauge.set(1) + + def record(self, scheduler_stats: SchedulerStats, + iteration_stats: Optional[IterationStats]): + """Log to prometheus.""" + self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs) + self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs) + + self.gauge_gpu_cache_usage.set(scheduler_stats.gpu_cache_usage) + + self.counter_gpu_prefix_cache_queries.inc( + scheduler_stats.prefix_cache_stats.queries) + self.counter_gpu_prefix_cache_hits.inc( + scheduler_stats.prefix_cache_stats.hits) + + if scheduler_stats.spec_decoding_stats is not None: + self.counter_spec_decode_num_draft_tokens.inc( + scheduler_stats.spec_decoding_stats.num_draft_tokens) + self.counter_spec_decode_num_accepted_tokens.inc( + scheduler_stats.spec_decoding_stats.num_accepted_tokens) + + if iteration_stats is None: + return + + self.counter_num_preempted_reqs.inc(iteration_stats.num_preempted_reqs) + self.counter_prompt_tokens.inc(iteration_stats.num_prompt_tokens) + self.counter_generation_tokens.inc( + iteration_stats.num_generation_tokens) + self.histogram_iteration_tokens.observe( + iteration_stats.num_prompt_tokens + \ + iteration_stats.num_generation_tokens) + + for max_gen_tokens in iteration_stats.max_num_generation_tokens_iter: + self.histogram_max_num_generation_tokens_request.observe( + max_gen_tokens) + for n_param in iteration_stats.n_params_iter: + self.histogram_n_request.observe(n_param) + for ttft in iteration_stats.time_to_first_tokens_iter: + self.histogram_time_to_first_token.observe(ttft) + for tpot in iteration_stats.time_per_output_tokens_iter: + self.histogram_time_per_output_token.observe(tpot) + + for finished_request in iteration_stats.finished_requests: + self.counter_request_success[finished_request.finish_reason].inc() + self.histogram_e2e_time_request.observe( + finished_request.e2e_latency) + self.histogram_queue_time_request.observe( + finished_request.queued_time) + self.histogram_prefill_time_request.observe( + finished_request.prefill_time) + self.histogram_inference_time_request.observe( + finished_request.inference_time) + self.histogram_decode_time_request.observe( + finished_request.decode_time) + self.histogram_num_prompt_tokens_request.observe( + finished_request.num_prompt_tokens) + self.histogram_num_generation_tokens_request.observe( + finished_request.num_generation_tokens) + self.histogram_max_tokens_request.observe( + finished_request.max_tokens_param) + + if self.gauge_lora_info is not None: + running_lora_adapters = \ + ",".join(iteration_stats.running_lora_adapters.keys()) + waiting_lora_adapters = \ + ",".join(iteration_stats.waiting_lora_adapters.keys()) + lora_info_labels = { + self.labelname_running_lora_adapters: running_lora_adapters, + self.labelname_waiting_lora_adapters: waiting_lora_adapters, + self.labelname_max_lora: self.max_lora, + } + self.gauge_lora_info.labels(**lora_info_labels)\ + .set_to_current_time() + + @staticmethod + def _unregister_vllm_metrics(): + # Unregister any existing vLLM collectors (for CI/CD + for collector in list(prometheus_client.REGISTRY._collector_to_names): + if hasattr(collector, "_name") and "vllm" in collector._name: + prometheus_client.REGISTRY.unregister(collector) + + +def build_buckets(mantissa_lst: list[int], max_value: int) -> list[int]: + """ + Builds a list of buckets with increasing powers of 10 multiplied by + mantissa values until the value exceeds the specified maximum. + + """ + exponent = 0 + buckets: list[int] = [] + while True: + for m in mantissa_lst: + value = m * 10**exponent + if value <= max_value: + buckets.append(value) + else: + return buckets + exponent += 1 + + +def build_1_2_5_buckets(max_value: int) -> list[int]: + """ + Example: + >>> build_1_2_5_buckets(100) + [1, 2, 5, 10, 20, 50, 100] + """ + return build_buckets([1, 2, 5], max_value) + + +def build_cudagraph_buckets(vllm_config: VllmConfig) -> list[int]: + if not vllm_config.model_config.enforce_eager: + buckets = vllm_config.compilation_config.\ + cudagraph_capture_sizes.copy() + buckets.sort() + return buckets + else: + return [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096] diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py new file mode 100644 index 0000000..fd94926 --- /dev/null +++ b/vllm/v1/metrics/stats.py @@ -0,0 +1,238 @@ +# SPDX-License-Identifier: Apache-2.0 + +import time +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Optional + +from vllm.v1.spec_decode.metrics import SpecDecodingStats + +if TYPE_CHECKING: + from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason + from vllm.v1.engine.output_processor import RequestState + + +@dataclass +class PrefixCacheStats: + """Stores prefix cache hit statistics.""" + # Whether reset_prefix_cache was invoked. + reset: bool = False + # The number of requests in this update. + requests: int = 0 + # The number of queries in these requests. Note that "queries" here + # means the number of blocks that were queried from the cache. + queries: int = 0 + # The number of hits in these requests. + hits: int = 0 + + +@dataclass +class SchedulerStats: + """Stats associated with the scheduler.""" + + num_running_reqs: int = 0 + num_waiting_reqs: int = 0 + + gpu_cache_usage: float = 0.0 + + prefix_cache_stats: PrefixCacheStats = field( + default_factory=PrefixCacheStats) + + spec_decoding_stats: Optional[SpecDecodingStats] = None + + +@dataclass +class LoRAStats: + waiting_requests: set[str] = field(default_factory=set) + running_requests: set[str] = field(default_factory=set) + + +@dataclass +class RequestStateStats: + """Stats that need to be tracked across delta updates.""" + + num_generation_tokens: int = 0 + + # This is a engine frontend timestamp (wall-clock) + arrival_time: float = 0.0 + + # These are engine core timestamps (monotonic) + queued_ts: float = 0.0 + scheduled_ts: float = 0.0 + first_token_ts: float = 0.0 + last_token_ts: float = 0.0 + + +@dataclass +class FinishedRequestStats: + """Stats associated with a finished request.""" + + finish_reason: "FinishReason" + e2e_latency: float = 0.0 + num_prompt_tokens: int = 0 + num_generation_tokens: int = 0 + max_tokens_param: Optional[int] = None + queued_time: float = 0.0 + prefill_time: float = 0.0 + inference_time: float = 0.0 + decode_time: float = 0.0 + + +class IterationStats: + """Stats associated with a single set of EngineCoreOutputs.""" + + def __init__(self): + self.iteration_timestamp = time.time() + self.num_generation_tokens = 0 + self.num_prompt_tokens = 0 + self.num_preempted_reqs = 0 + self.finished_requests: list[FinishedRequestStats] = [] + self.max_num_generation_tokens_iter: list[int] = [] + self.n_params_iter: list[int] = [] + self.time_to_first_tokens_iter: list[float] = [] + self.time_per_output_tokens_iter: list[float] = [] + self.waiting_lora_adapters: dict[str, int] = {} + self.running_lora_adapters: dict[str, int] = {} + + def _time_since(self, start: float) -> float: + """Calculate an interval relative to this iteration's timestamp.""" + return self.iteration_timestamp - start + + def update_from_output(self, output: "EngineCoreOutput", + engine_core_timestamp: float, is_prefilling: bool, + prompt_len: int, req_stats: RequestStateStats, + lora_stats: Optional[LoRAStats]): + num_new_generation_tokens = len(output.new_token_ids) + + self.num_generation_tokens += num_new_generation_tokens + if is_prefilling: + assert num_new_generation_tokens > 0 + self.num_prompt_tokens += prompt_len + + first_token_latency = self._time_since(req_stats.arrival_time) + self.time_to_first_tokens_iter.append(first_token_latency) + + req_stats.num_generation_tokens += num_new_generation_tokens + + # Process request-level engine core events + if output.events is not None: + self.update_from_events(output.request_id, output.events, + is_prefilling, req_stats, lora_stats) + + # Process the batch-level "new tokens" engine core event + if is_prefilling: + req_stats.first_token_ts = engine_core_timestamp + else: + tpot = engine_core_timestamp - req_stats.last_token_ts + self.time_per_output_tokens_iter.append(tpot) + + req_stats.last_token_ts = engine_core_timestamp + + def update_from_events(self, req_id: str, events: list["EngineCoreEvent"], + is_prefilling: bool, req_stats: RequestStateStats, + lora_stats: Optional[LoRAStats]): + # Avoid circular dependency + from vllm.v1.engine import EngineCoreEventType + for event in events: + if event.type == EngineCoreEventType.QUEUED: + req_stats.queued_ts = event.timestamp + if lora_stats is not None: + lora_stats.waiting_requests.add(req_id) + elif event.type == EngineCoreEventType.SCHEDULED: + if req_stats.scheduled_ts == 0.0: # ignore preemptions + req_stats.scheduled_ts = event.timestamp + LoRARequestStates.scheduled_request(lora_stats, req_id) + elif event.type == EngineCoreEventType.PREEMPTED: + self.num_preempted_reqs += 1 + LoRARequestStates.preempted_request(lora_stats, req_id) + + def update_from_finished_request(self, finish_reason: "FinishReason", + num_prompt_tokens: int, + max_tokens_param: Optional[int], + req_stats: RequestStateStats): + e2e_latency = self._time_since(req_stats.arrival_time) + + # Queued interval is from first QUEUED event to first SCHEDULED + queued_time = req_stats.scheduled_ts - req_stats.queued_ts + + # Prefill interval is from first SCHEDULED to first NEW_TOKEN + # Any preemptions during prefill is included in the interval + prefill_time = req_stats.first_token_ts - req_stats.scheduled_ts + + # Decode interval is from first NEW_TOKEN to last NEW_TOKEN + # Any preemptions during decode are included + decode_time = req_stats.last_token_ts - req_stats.first_token_ts + + # Inference interval is from first SCHEDULED to last NEW_TOKEN + # Any preemptions during prefill or decode are included + inference_time = req_stats.last_token_ts - req_stats.scheduled_ts + + finished_req = \ + FinishedRequestStats(finish_reason=finish_reason, + e2e_latency=e2e_latency, + num_prompt_tokens=num_prompt_tokens, + num_generation_tokens=req_stats.num_generation_tokens, + max_tokens_param=max_tokens_param, + queued_time=queued_time, + prefill_time=prefill_time, + inference_time=inference_time, + decode_time=decode_time) + self.finished_requests.append(finished_req) + + +class LoRARequestStates: + """Per-LoRA request state stats.""" + + def __init__(self): + self.lora_name_to_stats: dict[str, LoRAStats] = {} + + def get_stats(self, req_state: 'RequestState') -> Optional[LoRAStats]: + if req_state.lora_name is None: + return None + if req_state.lora_name not in self.lora_name_to_stats: + self.lora_name_to_stats[req_state.lora_name] = LoRAStats() + return self.lora_name_to_stats[req_state.lora_name] + + def add_request(self, req_state: 'RequestState'): + if (lora_stats := self.get_stats(req_state)) is not None: + lora_stats.waiting_requests.add(req_state.request_id) + + def finish_request(self, req_state: 'RequestState'): + if req_state.lora_name is None: + return + lora_stats = self.lora_name_to_stats[req_state.lora_name] + lora_stats.running_requests.remove(req_state.request_id) + + def abort_request(self, req_state: 'RequestState'): + if req_state.lora_name is None: + return + lora_stats = self.lora_name_to_stats[req_state.lora_name] + lora_stats.waiting_requests.discard(req_state.request_id) + lora_stats.running_requests.discard(req_state.request_id) + + # Break the pattern for this lifecycle methods so we can + # call this from IterationStats.update_from_events() + @staticmethod + def scheduled_request(lora_stats: Optional[LoRAStats], request_id: str): + if lora_stats is None: + return + lora_stats.waiting_requests.remove(request_id) + lora_stats.running_requests.add(request_id) + + @staticmethod + def preempted_request(lora_stats: Optional[LoRAStats], request_id: str): + if lora_stats is None: + return + lora_stats.running_requests.remove(request_id) + lora_stats.waiting_requests.add(request_id) + + def update_iteration_stats(self, + iteration_stats: Optional[IterationStats]): + if iteration_stats is None: + return + for lora_name, stats in self.lora_name_to_stats.items(): + if stats.waiting_requests: + iteration_stats.waiting_lora_adapters[lora_name] = \ + len(stats.waiting_requests) + if stats.running_requests: + iteration_stats.running_lora_adapters[lora_name] = \ + len(stats.running_requests) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py new file mode 100644 index 0000000..2732b93 --- /dev/null +++ b/vllm/v1/outputs.py @@ -0,0 +1,111 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import NamedTuple, Optional + +import torch + + +class LogprobsLists(NamedTuple): + + # [num_reqs, max_num_logprobs + 1] + logprob_token_ids: list[list[int]] + # [num_reqs, max_num_logprobs + 1] + logprobs: list[list[float]] + # [num_reqs] + sampled_token_ranks: list[int] + + def slice(self, start: int, end: int): + return LogprobsLists( + self.logprob_token_ids[start:end], + self.logprobs[start:end], + self.sampled_token_ranks[start:end], + ) + + +class LogprobsTensors(NamedTuple): + + # [num_reqs, max_num_logprobs + 1] + logprob_token_ids: torch.Tensor + # [num_reqs, max_num_logprobs + 1] + logprobs: torch.Tensor + # [num_reqs] + selected_token_ranks: torch.Tensor + + def tolists(self): + return LogprobsLists( + self.logprob_token_ids.tolist(), + self.logprobs.tolist(), + self.selected_token_ranks.tolist(), + ) + + @staticmethod + def empty_cpu(num_positions: int, + num_tokens_per_position: int) -> "LogprobsTensors": + """Create empty LogprobsTensors on CPU.""" + + logprob_token_ids = torch.empty( + (num_positions, num_tokens_per_position), + dtype=torch.int32, + device="cpu") + logprobs = torch.empty_like(logprob_token_ids, dtype=torch.float32) + selected_token_ranks = torch.empty(num_positions, + dtype=torch.int32, + device="cpu") + return LogprobsTensors( + logprob_token_ids=logprob_token_ids, + logprobs=logprobs, + selected_token_ranks=selected_token_ranks, + ) + + +@dataclass +class SamplerOutput: + + # [num_reqs, max_num_generated_tokens] + # Different requests can have different number of generated tokens. + # All requests are padded to max_num_generated_tokens. + # PLACEHOLDER_TOKEN_ID (-1 by default) is used for padding. + sampled_token_ids: torch.Tensor + logprobs_tensors: Optional[LogprobsTensors] + + +# ModelRunnerOutput is serialized and sent to the scheduler process. +# This is expensive for torch.Tensor so prefer to use list instead. +@dataclass +class ModelRunnerOutput: + + # [num_reqs] + req_ids: list[str] + # req_id -> index + req_id_to_index: dict[str, int] + + # num_reqs x num_generated_tokens + # num_generated_tokens is the number of tokens + # generated in the current step. It can be different for + # each request due to speculative/jump decoding. + sampled_token_ids: list[list[int]] + + # num_reqs x num_spec_tokens + spec_token_ids: Optional[list[list[int]]] + + # [num_reqs, max_num_logprobs + 1] + # [num_reqs, max_num_logprobs + 1] + # [num_reqs] + logprobs: Optional[LogprobsLists] + + # req_id -> (token_ids, logprobs, ranks) + # [prompt_len, num_prompt_logprobs] + # [prompt_len, num_prompt_logprobs] + # [prompt_len] + prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] + + +EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput( + req_ids=[], + req_id_to_index={}, + sampled_token_ids=[], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, +) diff --git a/vllm/v1/request.py b/vllm/v1/request.py new file mode 100644 index 0000000..490fe4e --- /dev/null +++ b/vllm/v1/request.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: Apache-2.0 + +import enum +from typing import TYPE_CHECKING, Optional, Union + +from vllm.sampling_params import SamplingParams +from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType, + EngineCoreRequest, FinishReason) +from vllm.v1.structured_output.request import StructuredOutputRequest +from vllm.v1.utils import ConstantList + +if TYPE_CHECKING: + + from vllm.lora.request import LoRARequest + from vllm.multimodal import MultiModalKwargs + from vllm.multimodal.inputs import PlaceholderRange + + +class Request: + + def __init__( + self, + request_id: str, + prompt: Optional[str], + prompt_token_ids: list[int], + multi_modal_inputs: Optional[list["MultiModalKwargs"]], + multi_modal_hashes: Optional[list[str]], + multi_modal_placeholders: Optional[list["PlaceholderRange"]], + sampling_params: SamplingParams, + eos_token_id: Optional[int], + arrival_time: float, + lora_request: Optional["LoRARequest"] = None, + structured_output_request: Optional["StructuredOutputRequest"] = None, + ) -> None: + self.request_id = request_id + self.sampling_params = sampling_params + # Because of LoRA, the eos token id can be different for each request. + self.eos_token_id = eos_token_id + self.lora_request = lora_request + self.structured_output_request = structured_output_request + + self.status = (RequestStatus.WAITING_FOR_FSM + if sampling_params.guided_decoding is not None else + RequestStatus.WAITING) + self.events: list[EngineCoreEvent] = [] + self.stop_reason: Union[int, str, None] = None + assert sampling_params.max_tokens is not None + self.max_tokens = sampling_params.max_tokens + + self.prompt = prompt + self.prompt_token_ids = prompt_token_ids + self.num_prompt_tokens = len(self.prompt_token_ids) + self._output_token_ids: list[int] = [] + self._all_token_ids: list[int] = self.prompt_token_ids.copy() + self.spec_token_ids: list[int] = [] + self.num_computed_tokens = 0 + + # Multi-modal related + self.mm_positions = multi_modal_placeholders or [] + self.mm_inputs = multi_modal_inputs or [] + self.mm_hashes: list[str] = multi_modal_hashes or [] + self.num_encoder_inputs = len(self.mm_inputs) + self.has_encoder_inputs = self.num_encoder_inputs > 0 + + # Sanity check + assert len(self.mm_inputs) == len(self.mm_positions) + if self.mm_hashes: + assert len(self.mm_inputs) == len(self.mm_hashes) + + # Read-only views + # Prevent directly appending to the these lists since + # they should also be updated simultaneously. + self.output_token_ids = ConstantList(self._output_token_ids) + self.all_token_ids = ConstantList(self._all_token_ids) + + @classmethod + def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": + return cls( + request_id=request.request_id, + prompt=request.prompt, + prompt_token_ids=request.prompt_token_ids, + multi_modal_inputs=request.mm_inputs, + multi_modal_hashes=request.mm_hashes, + multi_modal_placeholders=request.mm_placeholders, + sampling_params=request.sampling_params, + eos_token_id=request.eos_token_id, + arrival_time=request.arrival_time, + lora_request=request.lora_request, + structured_output_request=StructuredOutputRequest( + sampling_params=request.sampling_params), + ) + + def append_output_token_ids( + self, + token_ids: Union[int, list[int]], + ) -> None: + if isinstance(token_ids, int): + self._output_token_ids.append(token_ids) + self._all_token_ids.append(token_ids) + else: + self._output_token_ids.extend(token_ids) + self._all_token_ids.extend(token_ids) + + @property + def num_tokens(self) -> int: + return len(self._all_token_ids) + + @property + def num_tokens_with_spec(self) -> int: + return len(self._all_token_ids) + len(self.spec_token_ids) + + @property + def num_output_tokens(self) -> int: + return len(self._output_token_ids) + + def is_finished(self) -> bool: + return RequestStatus.is_finished(self.status) + + def get_finished_reason(self) -> Union[FinishReason, None]: + return RequestStatus.get_finished_reason(self.status) + + def get_num_encoder_tokens(self, input_id: int) -> int: + assert input_id < len(self.mm_positions) + num_tokens = self.mm_positions[input_id]["length"] + return num_tokens + + @property + def use_structured_output(self) -> bool: + return self.sampling_params.guided_decoding is not None + + def record_event( + self, + event_type: EngineCoreEventType, + timestamp: Optional[float] = None, + ) -> None: + self.events.append(EngineCoreEvent.new_event(event_type, timestamp)) + + def take_events(self) -> Optional[list[EngineCoreEvent]]: + if not self.events: + return None + events, self.events = self.events, [] + return events + + +class RequestStatus(enum.IntEnum): + """Status of a request.""" + WAITING = enum.auto() + WAITING_FOR_FSM = enum.auto() + RUNNING = enum.auto() + PREEMPTED = enum.auto() + # Note: anything after PREEMPTED will be considered + # as a finished status. + FINISHED_STOPPED = enum.auto() + FINISHED_LENGTH_CAPPED = enum.auto() + FINISHED_ABORTED = enum.auto() + FINISHED_IGNORED = enum.auto() + + @staticmethod + def is_finished(status: "RequestStatus") -> bool: + return status > RequestStatus.PREEMPTED + + @staticmethod + def get_finished_reason( + status: "RequestStatus") -> Union[FinishReason, None]: + return _FINISHED_REASON_MAP.get(status) + + +# Mapping of finished statuses to their finish reasons. +# NOTE: The ignored requests are the requests whose prompt lengths +# are longer than the model's length cap. Therefore, the stop +# reason should also be "length" as in OpenAI API. +_FINISHED_REASON_MAP = { + RequestStatus.FINISHED_STOPPED: FinishReason.STOP, + RequestStatus.FINISHED_LENGTH_CAPPED: FinishReason.LENGTH, + RequestStatus.FINISHED_ABORTED: FinishReason.ABORT, + RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH, +} diff --git a/vllm/v1/sample/__init__.py b/vllm/v1/sample/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py new file mode 100644 index 0000000..e97e123 --- /dev/null +++ b/vllm/v1/sample/metadata.py @@ -0,0 +1,43 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import Optional + +import torch + + +@dataclass +class SamplingMetadata: + + temperature: Optional[torch.Tensor] + all_greedy: bool + all_random: bool + + top_p: Optional[torch.Tensor] + top_k: Optional[torch.Tensor] + min_p: Optional[torch.Tensor] + + generators: dict[int, torch.Generator] + + # None means no logprobs, 0 means sampled token logprobs only + max_num_logprobs: Optional[int] + + no_penalties: bool + prompt_token_ids: Optional[torch.Tensor] + frequency_penalties: torch.Tensor + presence_penalties: torch.Tensor + repetition_penalties: torch.Tensor + + output_token_ids: list[list[int]] + + # req_index -> (min_tokens, stop_token_ids) + min_tokens: dict[int, tuple[int, set[int]]] + + logit_bias: list[Optional[dict[int, float]]] + + # `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size, + # vocab size). + allowed_token_ids_mask: Optional[torch.Tensor] + + # req_index -> bad_words_token_ids + bad_words_token_ids: dict[int, list[list[int]]] diff --git a/vllm/v1/sample/ops/__init__.py b/vllm/v1/sample/ops/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/v1/sample/ops/bad_words.py b/vllm/v1/sample/ops/bad_words.py new file mode 100644 index 0000000..2984d4e --- /dev/null +++ b/vllm/v1/sample/ops/bad_words.py @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: Apache-2.0 + +import torch + +_SMALLEST_LOGIT = float("-inf") + + +def _apply_bad_words_single_batch( + logits: torch.Tensor, + bad_words_token_ids: list[list[int]], + past_tokens_ids: list[int], +) -> None: + for bad_word_ids in bad_words_token_ids: + if len(bad_word_ids) > len(past_tokens_ids) + 1: + continue + + prefix_length = len(bad_word_ids) - 1 + last_token_id = bad_word_ids[-1] + if prefix_length > 0: + actual_prefix = past_tokens_ids[-prefix_length:] + else: + actual_prefix = [] + expected_prefix = bad_word_ids[:prefix_length] + + assert len(actual_prefix) == len(expected_prefix) + + if actual_prefix == expected_prefix: + logits[last_token_id] = _SMALLEST_LOGIT + + +def apply_bad_words( + logits: torch.Tensor, + bad_words_token_ids: dict[int, list[list[int]]], + past_tokens_ids: list[list[int]], +) -> None: + for i, bad_words_ids in bad_words_token_ids.items(): + _apply_bad_words_single_batch(logits[i], bad_words_ids, + past_tokens_ids[i]) diff --git a/vllm/v1/sample/ops/penalties.py b/vllm/v1/sample/ops/penalties.py new file mode 100644 index 0000000..ed05e3f --- /dev/null +++ b/vllm/v1/sample/ops/penalties.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from vllm.model_executor.layers.utils import apply_penalties +from vllm.utils import is_pin_memory_available, make_tensor_with_pad + + +def apply_min_token_penalties( + logits: torch.Tensor, output_token_ids: list[list[int]], + min_tokens: dict[int, tuple[int, set[int]]]) -> None: + """ + Applies minimum token penalty by setting the logits of the stop tokens + to -inf. + """ + min_tokens_logits_to_penalize: list[tuple[int, int]] = [] + for index, (min_token, stop_token_ids) in min_tokens.items(): + if len(output_token_ids[index]) < min_token: + for stop_token_id in stop_token_ids: + min_tokens_logits_to_penalize.append((index, stop_token_id)) + if min_tokens_logits_to_penalize: + logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf") + + +def apply_all_penalties( + logits: torch.Tensor, + prompt_token_ids: torch.Tensor, + presence_penalties: torch.Tensor, + frequency_penalties: torch.Tensor, + repetition_penalties: torch.Tensor, + output_token_ids: list[list[int]], +) -> torch.Tensor: + """ + Applies presence, frequency and repetition penalties to the logits. + """ + _, vocab_size = logits.shape + output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size, + logits.device) + return apply_penalties(logits, prompt_token_ids, output_tokens_t, + presence_penalties, frequency_penalties, + repetition_penalties) + + +def _convert_to_tensors(output_token_ids: list[list[int]], vocab_size: int, + device: torch.device) -> torch.Tensor: + """ + Convert the different list data structures to tensors. + """ + output_tokens_tensor = make_tensor_with_pad( + output_token_ids, + # Use the value of vocab_size as a pad since we don't have a + # token_id of this value. + pad=vocab_size, + device="cpu", + dtype=torch.int64, + pin_memory=is_pin_memory_available(), + ) + return output_tokens_tensor.to(device, non_blocking=True) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py new file mode 100644 index 0000000..f69623e --- /dev/null +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -0,0 +1,312 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import torch +import torch.nn as nn + +from vllm import envs +from vllm.logger import init_logger +from vllm.platforms import current_platform + +logger = init_logger(__name__) + +try: + import flashinfer.sampling + is_flashinfer_available = True +except ImportError: + is_flashinfer_available = False + + +class TopKTopPSampler(nn.Module): + """ + Module that performs optional top-k and top-p filtering followed by + weighted random sampling of logits. + + Implementations may update the logits tensor in-place. + """ + + def __init__(self): + super().__init__() + if current_platform.is_cuda(): + if is_flashinfer_available: + flashinfer_version = flashinfer.__version__ + if flashinfer_version >= "0.2.3": + # FIXME(DefTruth): Currently, we have errors when using + # FlashInfer>=v0.2.3 for top-p & top-k sampling. As a + # workaround, we disable FlashInfer for top-p & top-k + # sampling by default while FlashInfer>=v0.2.3. + # The sampling API removes the success return value + # of all sampling API, which is not compatible with + # earlier design. + # https://github.com/flashinfer-ai/flashinfer/releases/ + # tag/v0.2.3 + logger.info( + "Currently, FlashInfer top-p & top-k sampling sampler " + "is disabled because FlashInfer>=v0.2.3 is not " + "backward compatible. Falling back to the PyTorch-" + "native implementation of top-p & top-k sampling.") + self.forward = self.forward_native + elif envs.VLLM_USE_FLASHINFER_SAMPLER is not False: + # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for + # sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by + # default it is unused). For backward compatibility, we set + # `VLLM_USE_FLASHINFER_SAMPLER` as None by default and + # interpret it differently in V0 and V1 samplers: In V0, + # None means False, while in V1, None means True. This is + # why we use the condition + # `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here. + logger.info("Using FlashInfer for top-p & top-k sampling.") + self.forward = self.forward_cuda + else: + logger.warning( + "FlashInfer is available, but it is not enabled. " + "Falling back to the PyTorch-native implementation of " + "top-p & top-k sampling. For the best performance, " + "please set VLLM_USE_FLASHINFER_SAMPLER=1.") + self.forward = self.forward_native + else: + logger.warning( + "FlashInfer is not available. Falling back to the PyTorch-" + "native implementation of top-p & top-k sampling. For the " + "best performance, please install FlashInfer.") + self.forward = self.forward_native + elif current_platform.is_tpu(): + if envs.VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: + logger.warning( + "TPU-specific optimization for top-k & top-p sampling are " + "disabled, falling back to PyTorch-native implementation " + "which could be very slow.") + self.forward = self.forward_native + else: + self.forward = self.forward_tpu + else: + self.forward = self.forward_native + + def forward_native( + self, + logits: torch.Tensor, + generators: dict[int, torch.Generator], + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], + ) -> torch.Tensor: + """ + PyTorch-native implementation of top-k and top-p sampling. + + The logits tensor may be updated in-place. + """ + logits = apply_top_k_top_p(logits, k, p) + probs = logits.softmax(dim=-1, dtype=torch.float32) + return random_sample(probs, generators) + + def forward_cuda( + self, + logits: torch.Tensor, + generators: dict[int, torch.Generator], + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], + ) -> torch.Tensor: + """More optimized implementation for top-k and top-p sampling.""" + probs = logits.softmax(dim=-1, dtype=torch.float32) + if k is None and p is None: + # We prefer `random_sample` over `flashinfer_sample` when sorting is + # not needed. This is because `random_sample` does not require + # CPU-GPU synchronization while `flashinfer_sample` does. + return random_sample(probs, generators) + return flashinfer_sample(probs, k, p, generators) + + def forward_tpu( + self, + logits: torch.Tensor, + generators: dict[int, torch.Generator], + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], + ) -> torch.Tensor: + logits = apply_top_k_top_p_tpu(logits, k, p) + probs = logits.softmax(dim=-1, dtype=torch.float32) + return random_sample(probs, generators) + + +def apply_top_k_top_p_tpu( + logits: torch.Tensor, + k: torch.Tensor, + p: torch.Tensor, +) -> torch.Tensor: + """ + Apply top-k and top-p optimized for TPU. + + This algorithm avoids using torch.scatter which is extremely slow on TPU. + This is achieved by finding a "cut-off" element in the original logit, and + after thresholding the logit using this cut-off, the remaining elements + shall constitute the top-p set. + + Note: in the case of tie (i.e. multipple cut-off elements present in the + logit), all tie elements are included in the top-p set. In other words, + this function does not break ties. Instead, these tie tokens have equal + chance of being chosen during final sampling, so we can consider the tie + being broken then. + """ + if k is not None: + logits = apply_top_k_only(logits, k) + + if p is not None: + probs = logits.softmax(dim=-1) + probs_sort, _ = probs.sort(dim=-1, descending=False) + cumprob = torch.cumsum(probs_sort, dim=-1) + top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1) + top_p_mask[:, -1] = False # at least one + + top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1) + top_p_cutoff = probs_sort.gather(-1, top_p_count) + elements_to_discard = probs < top_p_cutoff + logits.masked_fill_(elements_to_discard, -float("inf")) + + return logits + + +def apply_top_k_top_p( + logits: torch.Tensor, + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], +) -> torch.Tensor: + """Apply top-k and top-p masks to the logits. + + If a top-p is used, this function will sort the logits tensor, + which can be slow for large batches. + + The logits tensor may be updated in-place. + """ + if p is None: + if k is None: + return logits + + # Avoid sorting vocab for top-k only case. + return apply_top_k_only(logits, k) + + logits_sort, logits_idx = logits.sort(dim=-1, descending=False) + + if k is not None: + # Apply top-k. + top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B + # Get all the top_k values. + top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) + top_k_mask = logits_sort < top_k_mask + logits_sort.masked_fill_(top_k_mask, -float("inf")) + + if p is not None: + # Apply top-p. + probs_sort = logits_sort.softmax(dim=-1) + probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort) + top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) + # at least one + top_p_mask[:, -1] = False + logits_sort.masked_fill_(top_p_mask, -float("inf")) + + # Re-sort the probabilities. + logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) + return logits + + +def apply_top_k_only( + logits: torch.Tensor, + k: torch.Tensor, +) -> torch.Tensor: + """ + Apply top-k mask to the logits. + + This implementation doesn't involve sorting the entire vocab. + + The logits tensor may be updated in-place. + """ + no_top_k_mask = k == logits.shape[1] + # Set non-top-k rows to 1 so that we can gather. + k = k.masked_fill(no_top_k_mask, 1) + max_top_k = k.max() + # topk.values tensor has shape [batch_size, max_top_k]. + # Convert top k to 0-based index in range [0, max_top_k). + k_index = k.sub_(1).unsqueeze(1).expand(logits.shape[0], 1) + top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long()) + # Handle non-topk rows. + top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf")) + logits.masked_fill_(logits < top_k_mask, -float("inf")) + return logits + + +def random_sample( + probs: torch.Tensor, + generators: dict[int, torch.Generator], +) -> torch.Tensor: + """Randomly sample from the probabilities. + + We use this function instead of torch.multinomial because torch.multinomial + causes CPU-GPU synchronization. + """ + q = torch.empty_like(probs) + # NOTE(woosuk): To batch-process the requests without their own seeds, + # which is the common case, we first assume that every request does + # not have its own seed. Then, we overwrite the values for the requests + # that have their own seeds. + if len(generators) != probs.shape[0]: + q.exponential_() + if generators: + # TODO(woosuk): This can be slow because we handle each request + # one by one. Optimize this. + for i, generator in generators.items(): + q[i].exponential_(generator=generator) + return probs.div_(q).argmax(dim=-1).view(-1) + + +def flashinfer_sample( + probs: torch.Tensor, + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], + generators: dict[int, torch.Generator], +) -> torch.Tensor: + """Sample from the probabilities using FlashInfer. + + Statistically, this function is equivalent to the `random_sample` function. + However, this function is faster because it avoids sorting the logits tensor + via rejection sampling. + + NOTE: The outputs of this function do not necessarily match the outputs of + the `random_sample` function. It only guarantees that the outputs are + statistically equivalent. + + NOTE: This function includes CPU-GPU synchronization, while `random_sample` + does not. Call this function at the end of the forward pass to minimize + the synchronization overhead. + """ + assert not (k is None and p is None) + max_top_k_round = 32 + batch_size = probs.shape[0] + uniform_samples = torch.empty((max_top_k_round, batch_size), + device=probs.device) + if len(generators) != batch_size: + uniform_samples.uniform_() + if generators: + for i, generator in generators.items(): + uniform_samples[:, i].uniform_(generator=generator) + + if k is None: + # Top-p only. + next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs( + probs, uniform_samples, p, deterministic=True) + elif p is None: + # Top-k only. + next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs( + probs, uniform_samples, k, deterministic=True) + else: + # Both top-k and top-p. + next_token_ids, success = ( + flashinfer.sampling.top_k_top_p_sampling_from_probs( + probs, uniform_samples, k, p, deterministic=True)) + + # NOTE: CPU-GPU synchronization happens here. + if not success.all(): + if k is not None: + probs = flashinfer.sampling.top_k_renorm_prob(probs, k) + if p is not None: + probs = flashinfer.sampling.top_p_renorm_prob(probs, p) + next_token_ids = flashinfer.sampling.sampling_from_probs( + probs, uniform_samples[0], deterministic=True) + return next_token_ids.view(-1) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py new file mode 100644 index 0000000..3cf7fde --- /dev/null +++ b/vllm/v1/sample/rejection_sampler.py @@ -0,0 +1,631 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import torch +import torch.nn as nn +import triton +import triton.language as tl + +from vllm.logger import init_logger +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata + +logger = init_logger(__name__) + +PLACEHOLDER_TOKEN_ID: tl.constexpr = -1 +GREEDY_TEMPERATURE: tl.constexpr = -1 +# Maximum number of speculative draft tokens allowed per request in a single +# step. This value is chosen to be large enough to handle typical use cases. +MAX_SPEC_LEN = 32 + + +class RejectionSampler(nn.Module): + """ + The implementation strictly follows the algorithm described in + https://arxiv.org/abs/2211.17192. + However, we want to clarify the terminology used in the implementation: + accepted tokens: tokens that are accepted based on the relationship + between the "raw" draft and target probabilities. + recovered tokens: tokens that are sampled based on the adjusted probability + distribution, which is derived from both the draft and target + probabilities. + bonus tokens: + If all proposed tokens are accepted, the bonus token is added to the + end of the sequence. The bonus token is only sampled from the target + probabilities. We pass in the bonus tokens instead of sampling them + in the rejection sampler to allow for more flexibility in the + sampling process. For example, we can use top_p, top_k sampling for + bonus tokens, while spec decode does not support these sampling + strategies. + output tokens: + Tokens are finally generated with the rejection sampler. + output tokens = accepted tokens + recovered tokens + bonus tokens + """ + + def forward( + self, + metadata: SpecDecodeMetadata, + # [num_tokens, vocab_size] + draft_probs: Optional[torch.Tensor], + # [num_tokens, vocab_size] + target_logits: torch.Tensor, + # [batch_size, 1] + bonus_token_ids: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + ''' + Args: + metadata: + Metadata for spec decoding. + draft_probs (Optional[torch.Tensor]): + Probability distribution for the draft tokens. Shape is + [num_tokens, vocab_size]. Can be None if probabilities are + not provided, which is the case for ngram spec decode. + target_logits (torch.Tensor): + Target model's logits probability distribution. + Shape is [num_tokens, vocab_size]. Here, probabilities from + different requests are flattened into a single tensor because + this is the shape of the output logits. + NOTE: `target_logits` can be updated in place to save memory. + bonus_token_ids_tensor (torch.Tensor): + A tensor containing bonus tokens. Shape is [batch_size, 1]. + Bonus tokens are added to the end of the sequence if all + proposed tokens are accepted. We generate the bonus tokens + outside of the rejection sampler with the default sampling + strategy. It allows for more flexibility in the sampling + process such as top_p, top_k sampling. + sampling_metadata (SamplingMetadata): + Additional metadata needed for sampling, such as temperature, + top-k/top-p parameters, or other relevant information. + Returns: + output_token_ids (torch.Tensor): + A tensor containing the final output token IDs. + ''' + assert metadata.max_spec_len <= MAX_SPEC_LEN + # [num_tokens, vocab_size] + # NOTE(woosuk): `target_logits` can be updated in place inside the + # `compute_probs` function. + target_probs = compute_probs( + target_logits, + metadata.cu_num_draft_tokens, + sampling_metadata, + ) + + output_token_ids = rejection_sample( + metadata.draft_token_ids, + metadata.num_draft_tokens, + metadata.max_spec_len, + metadata.cu_num_draft_tokens, + draft_probs, + target_probs, + bonus_token_ids, + sampling_metadata, + ) + return output_token_ids + + @staticmethod + def parse_output( + output_token_ids: torch.Tensor, + vocab_size: int, + ) -> list[list[int]]: + """Parse the output of the rejection sampler. + + Args: + output_token_ids: The sampled token IDs in shape + [batch_size, max_spec_len + 1]. The rejected tokens are + replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler + and will be filtered out in this function. + vocab_size: The size of the vocabulary. + + Returns: + A list of lists of token IDs. + """ + output_token_ids_np = output_token_ids.cpu().numpy() + # Create mask for valid tokens. + valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) & + (output_token_ids_np < vocab_size)) + outputs = [ + row[valid_mask[i]].tolist() + for i, row in enumerate(output_token_ids_np) + ] + return outputs + + +def rejection_sample( + # [num_tokens] + draft_token_ids: torch.Tensor, + # [batch_size] + num_draft_tokens: list[int], + max_spec_len: int, + # [batch_size] + cu_num_draft_tokens: torch.Tensor, + # [num_tokens, vocab_size] + draft_probs: Optional[torch.Tensor], + # [num_tokens, vocab_size] + target_probs: torch.Tensor, + # [batch_size, 1] + bonus_token_ids: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> torch.Tensor: + assert draft_token_ids.ndim == 1 + assert draft_probs is None or draft_probs.ndim == 2 + assert cu_num_draft_tokens.ndim == 1 + assert target_probs.ndim == 2 + + batch_size = len(num_draft_tokens) + num_tokens = draft_token_ids.shape[0] + vocab_size = target_probs.shape[-1] + device = target_probs.device + assert draft_token_ids.is_contiguous() + assert draft_probs is None or draft_probs.is_contiguous() + assert target_probs.is_contiguous() + assert bonus_token_ids.is_contiguous() + assert target_probs.shape == (num_tokens, vocab_size) + + # Create output buffer. + output_token_ids = torch.empty( + (batch_size, max_spec_len + 1), + dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids. + device=device, + ) + output_token_ids.fill_(PLACEHOLDER_TOKEN_ID) + + if sampling_metadata.all_greedy: + is_greedy = None + else: + is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE + if not sampling_metadata.all_random: + # Rejection sampling for greedy sampling requests. + target_argmax = target_probs.argmax(dim=-1) + rejection_greedy_sample_kernel[(batch_size, )]( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + target_argmax, + bonus_token_ids, + is_greedy, + max_spec_len, + num_warps=1, + ) + if sampling_metadata.all_greedy: + return output_token_ids + + # Generate uniform probabilities for rejection sampling. + # [num_tokens] + uniform_probs = generate_uniform_probs( + num_tokens, + num_draft_tokens, + sampling_metadata.generators, + device, + ) + + # Sample recovered tokens for each position. + # [num_tokens] + recovered_token_ids = sample_recovered_tokens( + max_spec_len, + num_draft_tokens, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + sampling_metadata, + device, + ) + + # Rejection sampling for random sampling requests. + rejection_random_sample_kernel[(batch_size, )]( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + bonus_token_ids, + recovered_token_ids, + uniform_probs, + is_greedy, + max_spec_len, + vocab_size, + IS_NGRAM=draft_probs is None, + num_warps=1, + ) + return output_token_ids + + +def compute_probs( + logits: torch.Tensor, # [num_tokens, vocab_size] + cu_num_draft_tokens: torch.Tensor, # [batch_size] + sampling_metadata: SamplingMetadata, +) -> torch.Tensor: + """Compute probability distribution from logits based on sampling metadata. + + This function applies temperature scaling to the logits and converts + them to probabilities using softmax. For greedy decoding, it returns + the original logits. + + Args: + logits: Input logits tensor to be converted to probabilities. + cu_num_draft_tokens: Cumulative number of draft tokens. + sampling_metadata: Metadata containing sampling parameters such as + temperature and whether greedy sampling is used. + + Returns: + torch.Tensor: Probability distribution (softmax of scaled logits) + if non-greedy sampling is used, otherwise returns the + original logits. + """ + assert logits.ndim == 2 + assert cu_num_draft_tokens.ndim == 1 + if sampling_metadata.all_greedy: + return logits + + num_tokens = logits.shape[0] + temperature = expand_batch_to_tokens( + sampling_metadata.temperature, + cu_num_draft_tokens, + num_tokens, + replace_from=GREEDY_TEMPERATURE, + replace_to=1, + ) + # NOTE(woosuk): Update `logits` in place to avoid allocating a new tensor. + logits.div_(temperature.unsqueeze(-1)) + + # Get expanded top_k and top_p tensors. + top_k = None + if sampling_metadata.top_k is not None: + top_k = expand_batch_to_tokens( + sampling_metadata.top_k, + cu_num_draft_tokens, + num_tokens, + ) + top_p = None + if sampling_metadata.top_p is not None: + top_p = expand_batch_to_tokens( + sampling_metadata.top_p, + cu_num_draft_tokens, + num_tokens, + ) + + # NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask, + # which is slow for large vocab sizes. This may cause performance issues. + logits = apply_top_k_top_p(logits, top_k, top_p) + output_prob = logits.softmax(dim=-1, dtype=torch.float32) + return output_prob + + +def expand_batch_to_tokens( + x: torch.Tensor, # [batch_size] + cu_num_tokens: torch.Tensor, # [batch_size] + num_tokens: int, + replace_from: int = 0, + replace_to: int = 0, +) -> torch.Tensor: + """Expand [batch_size] tensor to [num_tokens] tensor based on the number of + tokens per batch in cu_num_tokens. + + For example, if x = [a, b, c] and cu_num_tokens = [2, 5, 6], then + num_tokens = 6, and expanded_x = [a, a, b, b, b, c]. + + Args: + x: [batch_size] tensor to expand. + cu_num_tokens: [batch_size] tensor containing the cumulative number of + tokens per batch. Each element represents the total number of + tokens up to and including that batch. + num_tokens: Total number of tokens. + replace_from: int = 0 + Value to be replaced if it is found in x. + replace_to: int = 0 + Value to replace with when replace_from is found. + Returns: + expanded_x: [num_tokens] tensor. + """ + batch_size = x.shape[0] + assert cu_num_tokens.shape[0] == batch_size + expanded_x = x.new_empty(num_tokens) + expand_kernel[(batch_size, )]( + expanded_x, + x, + cu_num_tokens, + replace_from, + replace_to, + MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation. + num_warps=1, + ) + return expanded_x + + +def generate_uniform_probs( + num_tokens: int, + num_draft_tokens: list[int], + generators: dict[int, torch.Generator], + device: torch.device, +) -> torch.Tensor: + """ + Generates a batch of uniform random samples, with optional seeding + if available. + + This method creates a tensor of shape `(num_tokens, )` filled + with uniform random values in the range [0, 1). If `generators` is provided, + the requests with their own seeds will use the provided `torch.Generator` + for reproducibility. The samples for the other requests will be generated + without a seed. + + Args: + num_tokens : int + Total number of tokens. + num_draft_tokens : List[List[int]] + Number of draft tokens per request. + generators : Optional[Dict[int, torch.Generator]] + A dictionary mapping indices in the batch to + `torch.Generator` objects. + device : torch.device + The device on which to allocate the tensor. + Returns: + uniform_rand : torch.Tensor + A tensor of shape `(num_tokens, )` containing uniform + random values in the range [0, 1). + """ + uniform_probs = torch.rand( + (num_tokens, ), + dtype=torch.float32, + device=device, + ) + start_idx = 0 + for req_idx, n in enumerate(num_draft_tokens): + # Do not generate random numbers for requests with no draft tokens. + # This can be important for reproducibility. + if n == 0: + continue + end_idx = start_idx + n + generator = generators.get(req_idx) + if generator is not None: + uniform_probs[start_idx:end_idx].uniform_(generator=generator) + start_idx = end_idx + return uniform_probs + + +def sample_recovered_tokens( + max_spec_len: int, + num_draft_tokens: list[int], + # [batch_size] + cu_num_draft_tokens: torch.Tensor, + # [num_tokens] + draft_token_ids: torch.Tensor, + # [num_tokens, vocab_size] + draft_probs: Optional[torch.Tensor], + # [num_tokens, vocab_size] + target_probs: torch.Tensor, + sampling_metadata: SamplingMetadata, + device: torch.device, +) -> torch.Tensor: + # NOTE(woosuk): Create only one distribution for each request. + batch_size = len(num_draft_tokens) + vocab_size = target_probs.shape[-1] + q = torch.empty( + (batch_size, vocab_size), + dtype=torch.float32, + device=device, + ) + q.exponential_() + for i, generator in sampling_metadata.generators.items(): + # Do not generate random numbers for requests with no draft tokens. + # This can be important for reproducibility. + if num_draft_tokens[i] > 0: + q[i].exponential_(generator=generator) + + recovered_token_ids = torch.empty_like(draft_token_ids) + sample_recovered_tokens_kernel[(batch_size, max_spec_len)]( + recovered_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + q, + vocab_size, + triton.next_power_of_2(vocab_size), + IS_NGRAM=draft_probs is None, + ) + return recovered_token_ids + + +# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation. +@triton.jit(do_not_specialize=["max_spec_len"]) +def rejection_greedy_sample_kernel( + output_token_ids_ptr, # [batch_size, max_spec_len + 1] + cu_num_draft_tokens_ptr, # [batch_size] + draft_token_ids_ptr, # [num_tokens] + target_argmax_ptr, # [num_tokens] + bonus_token_ids_ptr, # [batch_size] + is_greedy_ptr, # [batch_size] or None + max_spec_len, +): + req_idx = tl.program_id(0) + # FIXME(woosuk): Because is_greedy_ptr is not None at profiling run, + # re-compilation may happen during runtime when is_greedy_ptr is None. + if is_greedy_ptr is None: + is_greedy = True + else: + is_greedy = tl.load(is_greedy_ptr + req_idx) + if not is_greedy: + # Early exit for non-greedy sampling requests. + return + + if req_idx == 0: + start_idx = 0 + else: + start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1) + end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) + num_draft_tokens = end_idx - start_idx + + rejected = False + for pos in range(num_draft_tokens): + if not rejected: + draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) + target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos) + tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, + target_argmax_id) + if draft_token_id != target_argmax_id: + # Reject. + rejected = True + + if not rejected: + # If all tokens are accepted, append the bonus token. + bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + + num_draft_tokens, bonus_token_id) + + +# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation. +@triton.jit(do_not_specialize=["max_spec_len"]) +def rejection_random_sample_kernel( + output_token_ids_ptr, # [batch_size, max_spec_len + 1] + cu_num_draft_tokens_ptr, # [batch_size] + draft_token_ids_ptr, # [num_tokens] + draft_probs_ptr, # [num_tokens, vocab_size] or None + target_probs_ptr, # [num_tokens, vocab_size] + bonus_token_ids_ptr, # [batch_size] + recovered_token_ids_ptr, # [num_tokens] + uniform_probs_ptr, # [num_tokens] + is_greedy_ptr, # [batch_size] + max_spec_len, + vocab_size, + IS_NGRAM: tl.constexpr, +): + req_idx = tl.program_id(0) + is_greedy = tl.load(is_greedy_ptr + req_idx) + if is_greedy: + # Early exit for greedy sampling requests. + return + + if req_idx == 0: + start_idx = 0 + else: + start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1) + end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) + num_draft_tokens = end_idx - start_idx + + rejected = False + for pos in range(num_draft_tokens): + if not rejected: + draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) + if IS_NGRAM: + draft_prob = 1 + else: + draft_prob = tl.load(draft_probs_ptr + + (start_idx + pos) * vocab_size + + draft_token_id) + target_prob = tl.load(target_probs_ptr + + (start_idx + pos) * vocab_size + + draft_token_id) + uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) + # NOTE(woosuk): While the draft probability should never be 0, + # we check it to avoid NaNs. If it happens to be 0, we reject. + if draft_prob > 0 and target_prob / draft_prob >= uniform_prob: + # Accept. + token_id = draft_token_id + else: + # Reject. Use recovered token. + rejected = True + token_id = tl.load(recovered_token_ids_ptr + start_idx + pos) + tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, + token_id) + + if not rejected: + # If all tokens are accepted, append the bonus token. + bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + + num_draft_tokens, bonus_token_id) + + +# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation. +@triton.jit(do_not_specialize=["replace_from", "replace_to"]) +def expand_kernel( + output_ptr, # [num_tokens] + input_ptr, # [batch_size] + cu_num_tokens_ptr, # [batch_size] + replace_from, + replace_to, + MAX_NUM_TOKENS: tl.constexpr, +): + req_idx = tl.program_id(0) + if req_idx == 0: # noqa: SIM108 + start_idx = 0 + else: + start_idx = tl.load(cu_num_tokens_ptr + req_idx - 1) + end_idx = tl.load(cu_num_tokens_ptr + req_idx) + num_tokens = end_idx - start_idx + + src_val = tl.load(input_ptr + req_idx) + src_val = tl.where(src_val == replace_from, replace_to, src_val) + offset = tl.arange(0, MAX_NUM_TOKENS) + tl.store(output_ptr + start_idx + offset, + src_val, + mask=offset < num_tokens) + + +@triton.jit +def sample_recovered_tokens_kernel( + output_token_ids_ptr, # [num_tokens] + cu_num_draft_tokens_ptr, # [batch_size] + draft_token_ids_ptr, # [num_tokens] + draft_probs_ptr, # [num_tokens, vocab_size] or None + target_probs_ptr, # [num_tokens, vocab_size] + q_ptr, # [batch_size, vocab_size] + vocab_size, + PADDED_VOCAB_SIZE: tl.constexpr, + IS_NGRAM: tl.constexpr, +): + req_idx = tl.program_id(0) + if req_idx == 0: + start_idx = 0 + else: + start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1) + end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) + num_draft_tokens = end_idx - start_idx + + # Early exit for out-of-range positions. + pos = tl.program_id(1) + if pos >= num_draft_tokens: + return + + vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE) + if IS_NGRAM: + draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) + orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + + draft_token_id) + # Temporarily zero out the probability of the draft token. + # This is essentially the same as target_prob - draft_prob, except that + # n-gram does not have draft_prob. We regard it as 1. + tl.store( + target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, + 0) + prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + + vocab_offset, + mask=vocab_offset < vocab_size, + other=0) + else: + draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + + vocab_offset, + mask=vocab_offset < vocab_size, + other=0) + target_prob = tl.load(target_probs_ptr + + (start_idx + pos) * vocab_size + vocab_offset, + mask=vocab_offset < vocab_size, + other=0) + prob = tl.maximum(target_prob - draft_prob, 0) + # NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because + # `tl.argmax` will select the maximum value. + + q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset, + mask=vocab_offset < vocab_size, + other=float("-inf")) + recovered_id = tl.argmax(prob / q, axis=-1) + tl.store(output_token_ids_ptr + start_idx + pos, recovered_id) + + if IS_NGRAM: + # Restore the original probability. + tl.store( + target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, + orig_prob) diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py new file mode 100644 index 0000000..004f984 --- /dev/null +++ b/vllm/v1/sample/sampler.py @@ -0,0 +1,260 @@ +# SPDX-License-Identifier: Apache-2.0 +"""A layer that samples the next tokens from the model's outputs.""" + +import torch +import torch.nn as nn + +from vllm.v1.outputs import LogprobsTensors, SamplerOutput +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.ops.bad_words import apply_bad_words +from vllm.v1.sample.ops.penalties import (apply_all_penalties, + apply_min_token_penalties) +from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler + +_SAMPLING_EPS = 1e-5 + + +class Sampler(nn.Module): + + def __init__(self): + super().__init__() + self.topk_topp_sampler = TopKTopPSampler() + + def forward( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput: + # NOTE(woosuk): Use the original logits (before any penalties or + # temperature scaling) for the top-k logprobs. + # This is different from the V0 sampler, which uses the logits that + # is used for sampling (after penalties and temperature scaling). + # TODO(rob): provide option for logprobs post sampling. + # See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501 + num_logprobs = sampling_metadata.max_num_logprobs + if num_logprobs is not None: + raw_logprobs = self.compute_logprobs(logits) + + # Use float32 for the logits. + logits = logits.to(torch.float32) + # Apply allowed token ids. + logits = self.apply_allowed_token_ids(logits, sampling_metadata) + # Apply bad words exclusion. + logits = self.apply_bad_words(logits, sampling_metadata) + # Apply logits bias. + logits = self.apply_logits_bias(logits, sampling_metadata) + # Apply penalties (e.g., min_tokens, freq_penalties). + logits = self.apply_penalties(logits, sampling_metadata) + # Sample the next token. + sampled = self.sample(logits, sampling_metadata) + # Convert sampled token ids to int64 (long) type to ensure compatibility + # with subsequent operations that may use these values as indices. + # This conversion is necessary because FlashInfer sampling operations + # return int32 (while PyTorch argmax and topk return int64). + sampled = sampled.long() + + # Gather the logprobs of the topk and sampled token (if requested). + # Get logprobs and rank tensors (if requested) + logprobs_tensors = None if num_logprobs is None else \ + self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled) + + # Use int32 to reduce the tensor size. + sampled = sampled.to(torch.int32) + + # These are GPU tensors. + sampler_output = SamplerOutput( + # The sampled tokens are expanded to 2D tensor with shape + # [num_requests, 1], where each row represents one generated + # token per request. + sampled_token_ids=sampled.unsqueeze(-1), + logprobs_tensors=logprobs_tensors, + ) + return sampler_output + + def apply_temperature( + self, + logits: torch.Tensor, + temp: torch.Tensor, + ) -> torch.Tensor: + # Use in-place division to avoid creating a new tensor. + return logits.div_(temp.unsqueeze(dim=1)) + + def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor: + return logits.argmax(dim=-1).view(-1) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + """Sample logits based on sampling metadata. + + The various logits processing functions called in this method + may update the logits tensor in-place. + """ + + assert not (sampling_metadata.all_greedy + and sampling_metadata.all_random) + if sampling_metadata.all_random: + greedy_sampled = None + else: + greedy_sampled = self.greedy_sample(logits) + if sampling_metadata.all_greedy: + return greedy_sampled + + assert sampling_metadata.temperature is not None + + # Apply temperature. + logits = self.apply_temperature(logits, sampling_metadata.temperature) + + # Apply min_p. + if sampling_metadata.min_p is not None: + logits = self.apply_min_p(logits, sampling_metadata.min_p) + + # Apply top_k and/or top_p. + random_sampled = self.topk_topp_sampler( + logits, + sampling_metadata.generators, + sampling_metadata.top_k, + sampling_metadata.top_p, + ) + + if greedy_sampled is None: + return random_sampled + + sampled = torch.where( + sampling_metadata.temperature < _SAMPLING_EPS, + greedy_sampled, + random_sampled, + out=greedy_sampled, # Reuse tensor + ) + return sampled + + def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor: + return logits.log_softmax(dim=-1, dtype=torch.float32) + + def gather_logprobs( + self, + logprobs: torch.Tensor, + num_logprobs: int, + token_ids: torch.Tensor, + ) -> LogprobsTensors: + """ + Gather logprobs for topk and sampled/prompt token. + + Args: + logprobs: (num tokens) x (vocab) tensor + num_logprobs: minimum number of logprobs to + retain per token + token_ids: prompt tokens (if prompt logprobs) + or sampled tokens (if sampled + logprobs); 1D token ID tensor + with (num tokens) elements + Must be int64. + + Returns: + Top-k int indices tensor, (num tokens) x (num_logprobs + 1) + Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1) + Sampled token rank tensor, (num tokens) + """ + assert token_ids.dtype == torch.int64 + # Find the topK values. + topk_logprobs, topk_indices = torch.topk(logprobs, + num_logprobs, + dim=-1) + + # Get with the logprob of the prompt or sampled token. + token_ids = token_ids.unsqueeze(-1) + token_logprobs = logprobs.gather(-1, token_ids) + + # Compute the ranks of the actual token. + token_ranks = (logprobs >= token_logprobs).sum(-1) + + # Concatenate together with the topk. + indices = torch.cat((token_ids, topk_indices), dim=1) + logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1) + + # Use int32 to reduce the tensor size. + indices = indices.to(torch.int32) + + return LogprobsTensors(indices, logprobs, token_ranks) + + def apply_penalties( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + if sampling_metadata.min_tokens: + apply_min_token_penalties(logits, + sampling_metadata.output_token_ids, + sampling_metadata.min_tokens) + if not sampling_metadata.no_penalties: + assert sampling_metadata.prompt_token_ids is not None + logits = apply_all_penalties( + logits, + sampling_metadata.prompt_token_ids, + sampling_metadata.presence_penalties, + sampling_metadata.frequency_penalties, + sampling_metadata.repetition_penalties, + sampling_metadata.output_token_ids, + ) + return logits + + def apply_min_p( + self, + logits: torch.Tensor, + min_p: torch.Tensor, + ) -> torch.Tensor: + """ + Filters logits using adaptive probability thresholding. + """ + # Convert logits to probability distribution + probability_values = torch.nn.functional.softmax(logits, dim=-1) + # Calculate maximum probabilities per sequence + max_probabilities = torch.amax(probability_values, + dim=-1, + keepdim=True) + # Reshape min_p for broadcasting + adjusted_min_p = min_p.unsqueeze(1) * max_probabilities + # Identify valid tokens using threshold comparison + valid_token_mask = probability_values >= adjusted_min_p + # Apply mask using boolean indexing + logits[~valid_token_mask] = -float('inf') + return logits + + def apply_logits_bias( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + # TODO(houseroad): this implementation is extremely inefficient. + # One idea is implement this as a PyTorch C++ op, and we may + # even optimize the logit_bias layout. + for i, logit_bias in enumerate(sampling_metadata.logit_bias): + if logit_bias: + for token_id, bias in logit_bias.items(): + logits[i, token_id] += bias + return logits + + def apply_allowed_token_ids( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + if sampling_metadata.allowed_token_ids_mask is not None: + logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, + float("-inf")) + return logits + + def apply_bad_words( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + if sampling_metadata.bad_words_token_ids: + apply_bad_words( + logits, + sampling_metadata.bad_words_token_ids, + sampling_metadata.output_token_ids, + ) + return logits diff --git a/vllm/v1/sample/tpu/__init__.py b/vllm/v1/sample/tpu/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/v1/sample/tpu/metadata.py b/vllm/v1/sample/tpu/metadata.py new file mode 100644 index 0000000..89d3ddf --- /dev/null +++ b/vllm/v1/sample/tpu/metadata.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field +from typing import Optional + +import torch +import torch_xla.core.xla_model as xm + +from vllm.v1.worker.gpu_input_batch import InputBatch + +DEFAULT_SAMPLING_PARAMS = dict( + temperature=-1.0, + min_p=0.0, + # strictly disabled for now + # top_k=-1, + # top_p=0.0, + # frequency_penalties=0.0, + # presence_penalties=0.0, + # repetition_penalties=0.0, +) + + +@dataclass +class TPUSupportedSamplingMetadata: + # This class exposes a more xla-friendly interface than SamplingMetadata + # on TPU, in particular all arguments should be traceable and no optionals + # are allowed, to avoid graph recompilation on Nones. + temperature: torch.Tensor + + min_p: torch.Tensor + # Still too slow on forward_native! + top_k: torch.Tensor = None + top_p: torch.Tensor = None + + # Greedy sampling flag for compiling single xla graph. + all_greedy: torch.Tensor = None + + # Generator not supported by xla + generators: dict[int, + torch.Generator] = field(default_factory=lambda: dict()) + + # unsupported, you need to return an extra tensor of static size BxV + max_num_logprobs = None + + # TODO No penalties for now + no_penalties: bool = True + prompt_token_ids = None + frequency_penalties = None + presence_penalties = None + repetition_penalties = None + # should use tensor + output_token_ids: list[list[int]] = field(default_factory=lambda: list()) + + min_tokens = None # impl is not vectorized + + logit_bias: list[Optional[dict[int, float]]] = field( + default_factory=lambda: list()) + + allowed_token_ids_mask = None + bad_words_token_ids = None + indices_do_sample: torch.Tensor = None + + @classmethod + def from_input_batch( + cls, input_batch: InputBatch, + indices_do_sample: torch.Tensor) -> "TPUSupportedSamplingMetadata": + """ + Copy sampling tensors slices from `input_batch` to on device tensors. + + `InputBatch._make_sampling_metadata` causes recompilation on XLA as it + slices dynamic shapes on device tensors. This impl moves the dynamic + ops to CPU and produces tensors of fixed `padded_num_reqs` size. It + also reuses the on-device persistent tensors managed in `input_batch` + to reduce waste. + + `indices_do_sample` contains the indices to be fed to the Sampler, + normally one per request, here padded to the closest pre-compiled shape + We expect sampling params tensors to be padded to the same fixed shape. + + Eg. 3 requests, tensors padded to 4 + temperature: [0.7, 0.2, 0.9]=>[0.7, 0.2, 0.9, 0.0] + sample indices: [4, 10, 11]=>indices_do_sample: [4, 10, 11, 0] + """ + num_reqs = input_batch.num_reqs + padded_num_reqs = len(indices_do_sample) + + def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor, + fill_val) -> torch.Tensor: + # Copy slice from CPU to corresponding TPU pre-allocated tensor. + # Pad value is the default one. + cpu_tensor[num_reqs:padded_num_reqs] = fill_val + # Subtle compilation: len(tpu_tensor) must be >= `padded_num_reqs` + tpu_tensor[:padded_num_reqs] = cpu_tensor[:padded_num_reqs] + + # NOTE NickLucche The sync CPU-TPU graph we produce here must be + # consistent. We can't have flags to skip copies or we'll end up + # recompiling. + copy_slice(input_batch.temperature_cpu_tensor, input_batch.temperature, + DEFAULT_SAMPLING_PARAMS["temperature"]) + # TODO Temporarily disabled until sampling options are enabled + # copy_slice(input_batch.top_p_cpu_tensor, input_batch.top_p) + # copy_slice(input_batch.top_k_cpu_tensor, input_batch.top_k) + copy_slice(input_batch.min_p_cpu_tensor, input_batch.min_p, + DEFAULT_SAMPLING_PARAMS["min_p"]) + + xm.mark_step() + xm.wait_device_ops() + + # Slice persistent device tensors to a fixed pre-compiled padded shape. + return cls( + temperature=input_batch.temperature[:padded_num_reqs], + # Scalar tensor for xla-friendly tracing. + all_greedy=torch.tensor(input_batch.all_greedy, + dtype=torch.bool, + device=input_batch.device), + # TODO enable more and avoid returning None values + top_p=None, # input_batch.top_p[:padded_num_reqs], + top_k=None, # input_batch.top_k[:padded_num_reqs], + min_p=input_batch.min_p[:padded_num_reqs], + generators=input_batch.generators, + indices_do_sample=indices_do_sample) diff --git a/vllm/v1/sample/tpu/sampler.py b/vllm/v1/sample/tpu/sampler.py new file mode 100644 index 0000000..33526c0 --- /dev/null +++ b/vllm/v1/sample/tpu/sampler.py @@ -0,0 +1,154 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Sampler layer implementing TPU supported operations.""" + +import torch +import torch.nn as nn + +from vllm.v1.outputs import LogprobsTensors, SamplerOutput +from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler +from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata + +_SAMPLING_EPS = 1e-5 + + +class Sampler(nn.Module): + + def __init__(self): + super().__init__() + self.topk_topp_sampler = TopKTopPSampler() + + def forward( + self, + logits: torch.Tensor, + sampling_metadata: TPUSupportedSamplingMetadata, + ) -> SamplerOutput: + # NOTE(woosuk): Use the original logits (before any penalties or + # temperature scaling) for the top-k logprobs. + # This is different from the V0 sampler, which uses the logits that + # is used for sampling (after penalties and temperature scaling). + + # Use float32 for the logits. + logits = logits.to(torch.float32) + # Sample the next token. + sampled = self.sample(logits, sampling_metadata) + + # Use int32 to reduce the tensor size. + sampled = sampled.to(torch.int32) + + # These are GPU tensors. + sampler_output = SamplerOutput( + # The sampled tokens are expanded to 2D tensor with shape + # [num_requests, 1], where each row represents one generated + # token per request. + sampled_token_ids=sampled.unsqueeze(-1), + logprobs_tensors=None, + ) + return sampler_output + + def apply_temperature( + self, + logits: torch.Tensor, + temp: torch.Tensor, + ) -> torch.Tensor: + # Use in-place division to avoid creating a new tensor. + return logits.div_(temp.unsqueeze(dim=1)) + + def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor: + return logits.argmax(dim=-1).view(-1) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: TPUSupportedSamplingMetadata, + ) -> torch.Tensor: + greedy_sampled = self.greedy_sample(logits) + + assert sampling_metadata.temperature is not None + + # Apply temperature. + logits = self.apply_temperature(logits, sampling_metadata.temperature) + + # Apply min_p. + if sampling_metadata.min_p is not None: + logits = self.apply_min_p(logits, sampling_metadata.min_p) + + # Apply top_k and/or top_p. + random_sampled = self.topk_topp_sampler( + logits, + sampling_metadata.generators, + sampling_metadata.top_k, + sampling_metadata.top_p, + ) + + sampled = torch.where(sampling_metadata.temperature < _SAMPLING_EPS, + greedy_sampled, random_sampled) + return sampled + + def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor: + return logits.log_softmax(dim=-1, dtype=torch.float32) + + def gather_logprobs( + self, + logprobs: torch.Tensor, + num_logprobs: int, + token_ids: torch.Tensor, + ) -> LogprobsTensors: + """ + Gather logprobs for topk and sampled/prompt token. + + Args: + logits: (num tokens) x (vocab) tensor + num_logprobs: minimum number of logprobs to + retain per token + token_ids: prompt tokens (if prompt logprobs) + or sampled tokens (if sampled + logprobs); 1D token ID tensor + with (num tokens) elements + + Returns: + Top-k int indices tensor, (num tokens) x (num_logprobs + 1) + Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1) + Sampled token rank tensor, (num tokens) + """ + # Find the topK values. + topk_logprobs, topk_indices = torch.topk(logprobs, + num_logprobs, + dim=-1) + + # Get with the logprob of the prompt or sampled token. + token_ids = token_ids.unsqueeze(-1) + token_logprobs = logprobs.gather(-1, token_ids) + + # Compute the ranks of the actual token. + token_ranks = (logprobs >= token_logprobs).sum(-1) + + # Concatenate together with the topk. + indices = torch.cat((token_ids, topk_indices), dim=1) + logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1) + + # Use int32 to reduce the tensor size. + indices = indices.to(torch.int32) + + return LogprobsTensors(indices, logprobs, token_ranks) + + def apply_min_p( + self, + logits: torch.Tensor, + min_p: torch.Tensor, + ) -> torch.Tensor: + """ + Filters logits using adaptive probability thresholding. + """ + # Convert logits to probability distribution + probability_values = torch.nn.functional.softmax(logits, dim=-1) + # Calculate maximum probabilities per sequence + max_probabilities = torch.amax(probability_values, + dim=-1, + keepdim=True) + # Reshape min_p for broadcasting + adjusted_min_p = min_p.unsqueeze(1) * max_probabilities + # Identify valid tokens using threshold comparison + valid_token_mask = probability_values >= adjusted_min_p + # Apply mask using boolean indexing (xla friendly) + logits.masked_fill_(~valid_token_mask, -float("inf")) + return logits diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py new file mode 100644 index 0000000..146d7d7 --- /dev/null +++ b/vllm/v1/serial_utils.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pickle +from types import FunctionType +from typing import Any, Optional + +import cloudpickle +import torch +from msgspec import msgpack + +CUSTOM_TYPE_TENSOR = 1 +CUSTOM_TYPE_PICKLE = 2 +CUSTOM_TYPE_CLOUDPICKLE = 3 + + +class MsgpackEncoder: + """Encoder with custom torch tensor serialization.""" + + def __init__(self): + self.encoder = msgpack.Encoder(enc_hook=custom_enc_hook) + + def encode(self, obj: Any) -> bytes: + return self.encoder.encode(obj) + + def encode_into(self, obj: Any, buf: bytearray) -> None: + self.encoder.encode_into(obj, buf) + + +class MsgpackDecoder: + """Decoder with custom torch tensor serialization.""" + + def __init__(self, t: Optional[Any] = None): + args = () if t is None else (t, ) + self.decoder = msgpack.Decoder(*args, ext_hook=custom_ext_hook) + + def decode(self, obj: Any): + return self.decoder.decode(obj) + + +def custom_enc_hook(obj: Any) -> Any: + if isinstance(obj, torch.Tensor): + # NOTE(rob): it is fastest to use numpy + pickle + # when serializing torch tensors. + # https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501 + return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy())) + + if isinstance(obj, FunctionType): + return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj)) + + return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj)) + + +def custom_ext_hook(code: int, data: memoryview) -> Any: + if code == CUSTOM_TYPE_TENSOR: + return torch.from_numpy(pickle.loads(data)) + if code == CUSTOM_TYPE_PICKLE: + return pickle.loads(data) + if code == CUSTOM_TYPE_CLOUDPICKLE: + return cloudpickle.loads(data) + + raise NotImplementedError(f"Extension type code {code} is not supported") diff --git a/vllm/v1/spec_decode/__init__.py b/vllm/v1/spec_decode/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py new file mode 100644 index 0000000..3aaaf34 --- /dev/null +++ b/vllm/v1/spec_decode/eagle.py @@ -0,0 +1,261 @@ +# SPDX-License-Identifier: Apache-2.0 +import torch +import torch.nn as nn +import triton +import triton.language as tl + +from vllm.config import VllmConfig +from vllm.forward_context import set_forward_context +from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.v1.sample.metadata import SamplingMetadata + + +class EagleProposer: + + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + ): + self.vllm_config = vllm_config + self.num_speculative_tokens = ( + vllm_config.speculative_config.num_speculative_tokens) + self.block_size = vllm_config.cache_config.block_size + self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs, + device=device) + + def propose( + self, + # [num_tokens] + target_token_ids: torch.Tensor, + # [num_tokens] + target_positions: torch.Tensor, + # [num_tokens, hidden_size] + target_hidden_states: torch.Tensor, + # [num_tokens] + target_slot_mapping: torch.Tensor, + # [batch_size] + next_token_ids: torch.Tensor, + # [batch_size + 1] starting with 0 + cu_num_tokens: torch.Tensor, + # [batch_size, max_num_blocks_per_req] + block_table: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> tuple[torch.Tensor, torch.Tensor]: + num_tokens = target_token_ids.shape[0] + batch_size = next_token_ids.shape[0] + last_token_indices = cu_num_tokens[1:] - 1 + + input_ids = torch.empty_like(target_token_ids) + # Shift the input ids by one token. + # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] + input_ids[:-1] = target_token_ids[1:] + # Replace the last token with the next token. + # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] + input_ids[last_token_indices] = next_token_ids + + seq_lens = target_positions[last_token_indices] + 1 + # FIXME(woosuk): The below two ops cause synchronization. Optimize. + max_seq_len = seq_lens.max().item() + max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item() + attn_metadata = FlashAttentionMetadata( + num_actual_tokens=num_tokens, + max_query_len=max_num_tokens, + query_start_loc=cu_num_tokens, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + block_table=block_table, + slot_mapping=target_slot_mapping, + # TODO(woosuk): Support cascade attention. + use_cascade=False, + common_prefix_len=0, + cu_prefix_query_lens=None, + prefix_kv_lens=None, + suffix_kv_lens=None, + ) + + with set_forward_context(attn_metadata, self.vllm_config): + hidden_states = self.model( + input_ids=input_ids, + hidden_states=target_hidden_states, + positions=target_positions, + ) + sample_hidden_states = hidden_states[last_token_indices] + logits = self.model.compute_logits(sample_hidden_states, None) + draft_token_ids, draft_probs = compute_probs_and_sample_next_token( + logits, sampling_metadata) + + # Early exit if there is only one draft token to be generated. + if self.num_speculative_tokens == 1: + # [batch_size, 1] and [batch_size, 1, vocab_size] + return draft_token_ids.view(-1, 1), draft_probs.unsqueeze(dim=1) + + # Generate the remaining draft tokens. + draft_token_ids_list = [draft_token_ids] + draft_probs_list = [draft_probs] + + positions = target_positions[last_token_indices] + hidden_states = sample_hidden_states + attn_metadata.num_actual_tokens = batch_size + attn_metadata.max_query_len = 1 + attn_metadata.query_start_loc = self.arange[:batch_size] + for _ in range(self.num_speculative_tokens - 1): + # Update the inputs. + input_ids = draft_token_ids_list[-1] + positions += 1 + attn_metadata.max_seq_len += 1 + attn_metadata.seq_lens += 1 + # Compute the slot mapping. + block_numbers = positions // self.block_size + block_ids = block_table.gather(dim=1, + index=block_numbers.view(-1, 1)) + block_ids = block_ids.view(-1) + attn_metadata.slot_mapping = (block_ids * self.block_size + + positions % self.block_size) + + # Run the model. + with set_forward_context(attn_metadata, self.vllm_config): + hidden_states = self.model( + input_ids=input_ids, + hidden_states=hidden_states, + positions=positions, + ) + logits = self.model.compute_logits(hidden_states, None) + draft_token_ids, probs = compute_probs_and_sample_next_token( + logits, sampling_metadata) + draft_token_ids_list.append(draft_token_ids) + draft_probs_list.append(probs) + + # [batch_size, num_speculative_tokens] + draft_token_ids = torch.stack(draft_token_ids_list, dim=1) + # [batch_size, num_speculative_tokens, vocab_size] + draft_probs = torch.stack(draft_probs_list, dim=1) + return draft_token_ids, draft_probs + + @staticmethod + def prepare_inputs( + # [batch_size + 1] + cu_target_query_lens: torch.Tensor, + # [batch_size] + num_rejected_tokens: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + # cu_target_query_lens: [0, a, a + b, a + b + c] + # num_rejected_tokens: [n1, n2, n3] + # num_tokens_per_req: [a - n1, b - n2, c - n3] + # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] + # token_indices: [0, 1, ..., a - n1 - 1, + # a, a + 1, ..., a + b - n2 - 1, + # a + b, a + b + 1, ..., a + b + c - n3 - 1] + + # [0, a, a + b, a + b + c] -> [a, b, c] + query_len_per_req = (cu_target_query_lens[1:] - + cu_target_query_lens[:-1]) + # [a, b, c] -> [a - n1, b - n2, c - n3] + num_tokens_per_req = query_len_per_req - num_rejected_tokens + + cu_num_tokens = torch.empty_like(cu_target_query_lens) + torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) + cu_num_tokens[0] = 0 + + # FIXME(woosuk): Avoid synchronization. + num_tokens = cu_num_tokens[-1].item() + token_indices = torch.empty( + num_tokens, + dtype=torch.int32, + device=cu_num_tokens.device, + ) + + batch_size = num_rejected_tokens.shape[0] + BLOCK_SIZE = 1024 + prepare_input_kernel[(batch_size, )]( + token_indices, + cu_target_query_lens, + cu_num_tokens, + BLOCK_SIZE=BLOCK_SIZE, + ) + return cu_num_tokens, token_indices + + def load_model(self, target_model: nn.Module) -> None: + self.model = DummyEagleModel() + self.model.get_input_embeddings = target_model.get_input_embeddings + self.model.compute_logits = target_model.compute_logits + + +# FIXME(woosuk): This is a dummy model for testing. +# Remove this once we have a real model. +class DummyEagleModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward( + self, + input_ids: torch.Tensor, + hidden_states: torch.Tensor, + positions: torch.Tensor, + ) -> torch.Tensor: + input_embeddings = self.get_input_embeddings(input_ids) + return hidden_states + input_embeddings # Dummy return. + + +# FIXME(woosuk): The logic here is duplicated with the main sampling code. +# We should refactor this to reuse the same sampling implementation. +def compute_probs_and_sample_next_token( + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> tuple[torch.Tensor, torch.Tensor]: + if sampling_metadata.all_greedy: + # For greedy requests, draft_probs is not used in rejection sampling. + # Therefore, we can just return the logits. + probs = logits + next_token_ids = logits.argmax(dim=-1) + return next_token_ids, probs + + is_greedy = sampling_metadata.temperature == -1 + temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature) + logits.div_(temperature.view(-1, 1)) + probs = logits.softmax(dim=-1, dtype=torch.float32) + + # NOTE(woosuk): Currently, we ignore most of the sampling parameters in + # generating the draft tokens. We only use the temperature. While this + # could degrade the acceptance rate, it does not affect the distribution + # of the generated tokens after rejection sampling. + + # TODO(woosuk): Consider seeds. + q = torch.empty_like(probs) + q.exponential_() + next_token_ids = probs.div_(q).argmax(dim=-1).view(-1) + if not sampling_metadata.all_random: + greedy_token_ids = probs.argmax(dim=-1) + next_token_ids = torch.where( + is_greedy, + greedy_token_ids, + next_token_ids, + ) + return next_token_ids, probs + + +@triton.jit +def prepare_input_kernel( + out_ptr, + cu_query_lens_ptr, + cu_num_tokens_ptr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + + # [start_pos, end_pos) + start_pos = tl.load(cu_num_tokens_ptr + pid) + end_pos = tl.load(cu_num_tokens_ptr + pid + 1) + num_tokens = end_pos - start_pos + + index_start = tl.load(cu_query_lens_ptr + pid) + + num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE) + for i in tl.range(num_blocks): + offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + tl.store( + out_ptr + start_pos + offset, + index_start + offset, + mask=offset < num_tokens, + ) diff --git a/vllm/v1/spec_decode/metadata.py b/vllm/v1/spec_decode/metadata.py new file mode 100644 index 0000000..1cf650d --- /dev/null +++ b/vllm/v1/spec_decode/metadata.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass + +import numpy as np +import torch + + +@dataclass +class SpecDecodeMetadata: + + # [num_tokens] + draft_token_ids: torch.Tensor + # [batch_size] + num_draft_tokens: list[int] + # [batch_size] + cu_num_draft_tokens: torch.Tensor + # [num_tokens] + target_logits_indices: torch.Tensor + # [batch_size] + bonus_logits_indices: torch.Tensor + # [num_tokens + batch_size] + logits_indices: torch.Tensor + + def __post_init__(self): + self.max_spec_len = max(self.num_draft_tokens) + + @classmethod + def make_dummy( + cls, + draft_token_ids: list[list[int]], + device: torch.device, + ) -> "SpecDecodeMetadata": + batch_size = len(draft_token_ids) + num_draft_tokens = [len(ids) for ids in draft_token_ids] + flattened_draft_token_ids = sum(draft_token_ids, []) + num_tokens = len(flattened_draft_token_ids) + + draft_token_ids_tensor = torch.tensor(flattened_draft_token_ids, + dtype=torch.int32, + device=device) + cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32) + cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to( + device) + + target_logits_indices = torch.zeros(num_tokens, + dtype=torch.int32, + device=device) + bonus_logits_indices = torch.zeros(batch_size, + dtype=torch.int32, + device=device) + logits_indices = torch.zeros(num_tokens + batch_size, + dtype=torch.int32, + device=device) + return cls( + draft_token_ids=draft_token_ids_tensor, + num_draft_tokens=num_draft_tokens, + cu_num_draft_tokens=cu_num_draft_tokens_tensor, + target_logits_indices=target_logits_indices, + bonus_logits_indices=bonus_logits_indices, + logits_indices=logits_indices, + ) diff --git a/vllm/v1/spec_decode/metrics.py b/vllm/v1/spec_decode/metrics.py new file mode 100644 index 0000000..7bb3c20 --- /dev/null +++ b/vllm/v1/spec_decode/metrics.py @@ -0,0 +1,62 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass + +import numpy as np + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@dataclass +class SpecDecodingStats: + num_draft_tokens: int = 0 + num_accepted_tokens: int = 0 + + def take(self): + copied = SpecDecodingStats(self.num_draft_tokens, + self.num_accepted_tokens) + self.reset() + return copied + + def reset(self): + self.num_draft_tokens = 0 + self.num_accepted_tokens = 0 + + def observe(self, num_draft_tokens: int, num_accepted_tokens: int): + self.num_draft_tokens += num_draft_tokens + self.num_accepted_tokens += num_accepted_tokens + + +class SpecDecodingMetrics: + + def __init__(self): + self.reset() + + def reset(self): + self.num_draft_tokens: list[int] = [] + self.num_accepted_tokens: list[int] = [] + + def observe(self, spec_decoding_stats: SpecDecodingStats): + self.num_draft_tokens.append(spec_decoding_stats.num_draft_tokens) + self.num_accepted_tokens.append( + spec_decoding_stats.num_accepted_tokens) + + def log(self): + num_draft_tokens = np.sum(self.num_draft_tokens) + num_accepted_tokens = np.sum(self.num_accepted_tokens) + + draft_acceptance_rate = (num_accepted_tokens / num_draft_tokens * + 100 if num_draft_tokens > 0 else float("nan")) + + logger.info( + "SpecDecoding metrics: " + "Draft acceptance rate: %.1f%%, " + "Accepted: %d tokens, " + "Drafted: %d tokens", + draft_acceptance_rate, + num_accepted_tokens, + num_draft_tokens, + ) + self.reset() diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py new file mode 100644 index 0000000..7e548bb --- /dev/null +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import numpy as np +from numba import jit + +from vllm.config import VllmConfig + + +class NgramProposer: + + def __init__(self, vllm_config: VllmConfig): + # Minimum length of the n-gram to match. + self.min_n = vllm_config.speculative_config.prompt_lookup_min + # Maximum length of the n-gram to match. + self.max_n = vllm_config.speculative_config.prompt_lookup_max + # Number of tokens follow the match. If there are less than k + # tokens follow the match, we will return the maximum amount of + # tokens until the end. + self.k = vllm_config.speculative_config.num_speculative_tokens + # Trigger Numba JIT compilation for N-gram proposer. + # This usually takes less than 1 second. + self.propose(np.zeros(1024, dtype=np.int32)) + + def propose( + self, + context_token_ids: np.ndarray, + ) -> Optional[np.ndarray]: + """Proposes the next sequence of tokens based on n-gram pattern + matching in the context. The function finds matches of the last n + tokens in the previous context, and returns k tokens that followed + that match. + + Args: + context_token_ids: Numpy array of token IDs representing the + context sequence. + + Returns: + np.ndarray: The sequence of tokens that followed + the matched n-gram in the context. + None: If no matching n-gram pattern is found. + + Example: + If context_token_ids = [1,2,3,4,2,3], min_n = 2, max_n = 3, and + k = 4: + - The last 3 (= max_n) tokens [4,2,3] cannot find a match. + - The last 2 tokens [2,3] will be matched against the previous + 4 tokens [1,2,3,4]. + - Finding a match of [2,3] would return the tokens that + followed that pattern. Here we will return [4,2,3] because + we only have three tokens after the match. + """ + # TODO(woosuk): Optimize this. + for n in range(self.max_n, self.min_n - 1, -1): + result = _find_subarray_kmp(context_token_ids, n, self.k) + if result is not None: + return result + return None + + def load_model(self, *args, **kwargs): + # No model to load. + pass + + +@jit(nopython=True) +def _kmp_lps_array(pattern: np.ndarray) -> np.ndarray: + """ + Build the lps (longest proper prefix which is also suffix) + array for the pattern. + """ + lps = np.zeros(len(pattern), dtype=np.int32) + prev_lps = 0 # length of the previous longest prefix suffix + i = 1 + + while i < len(pattern): + if pattern[i] == pattern[prev_lps]: + prev_lps += 1 + lps[i] = prev_lps + i += 1 + else: + if prev_lps != 0: + prev_lps = lps[prev_lps - 1] + else: + lps[i] = 0 + i += 1 + return lps + + +@jit(nopython=True) +def _find_subarray_kmp( + context_token_ids: np.ndarray, + n: int, + k: int, +) -> Optional[np.ndarray]: + context_len = context_token_ids.shape[0] + assert n > 0 + + pattern = context_token_ids[-n:] + # Precompute lps array for Y + lps = _kmp_lps_array(pattern) + + i = 0 + j = 0 + # -n because the last n tokens are used as pattern + while i < context_len - n: + if context_token_ids[i] == pattern[j]: + i += 1 + j += 1 + + # If we have matched the entire Y + if j == n: + # Found pattern in context, gather the next K elements + return context_token_ids[i:i + k] + else: + # Mismatch + if j != 0: + # Use the lps array to avoid re-checking elements + j = lps[j - 1] + else: + i += 1 + + # Y not found + return None diff --git a/vllm/v1/spec_decode/utils.py b/vllm/v1/spec_decode/utils.py new file mode 100644 index 0000000..ce81a40 --- /dev/null +++ b/vllm/v1/spec_decode/utils.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: Apache-2.0 +from vllm.v1.worker.gpu_input_batch import InputBatch + + +def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool: + if req_id in input_batch.min_p_reqs: + # Spec decode doesn't support min_p sampling. + return False + elif (req_id in input_batch.frequency_penalties_reqs + or req_id in input_batch.presence_penalties_reqs + or req_id in input_batch.repetition_penalties_reqs): + # Spec decode doesn't support penalties. + return False + elif req_id in input_batch.num_logprobs: + # Spec decode doesn't support logprobs. + return False + + return True diff --git a/vllm/v1/stats/__init__.py b/vllm/v1/stats/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/v1/stats/common.py b/vllm/v1/stats/common.py new file mode 100644 index 0000000..4681897 --- /dev/null +++ b/vllm/v1/stats/common.py @@ -0,0 +1,453 @@ +# SPDX-License-Identifier: Apache-2.0 + +import time +from dataclasses import dataclass +from dataclasses import field as dataclass_field +from enum import IntEnum +from typing import ClassVar, Optional + +import msgspec +from msgspec import field as msgspec_field + +from vllm.sampling_params import SamplingParams + + +class RequestStatsUpdate( + msgspec.Struct, # type: ignore + array_like=True, + omit_defaults=True, + gc=False): + """ + An update to the request stats. + + This represents a stats update at a specific timestamp with metadata + associated with the update. + + NOTE: since there might be multiple processes generating updates at + different parts of the engine (e.g. input processor, scheduler, engine core, + etc.), we use the monotonic timestamp to record the update to compute any + intervals, and explicit wall-clock timestamp should be used for timestamps. + + WARNING: This assumes stats are generated in a single machine. If there are + potentially multiple machines, one should always generate the stats updates + on one single machine or use something else. + """ + + class Type(IntEnum): + """See `RequestStats` for the lifecycle of a request.""" + + # Request arrived at the engine frontend. + ARRIVED = 0 + # Input processed by the input processor. + INPUT_PROCESSED = 1 + # Queued on the engine core. + QUEUED = 2 + # Scheduled running prefill by the scheduler. + # A request could be running a new prefill on the prompt tokens or + # a resumed prefill on the original prefill tokens + generated output + # tokens before preemption. + PREFILLING = 3 + # Preempted by the scheduler. + PREEMPTED = 4 + # Output token is generated by the engine core. + DECODING = 5 + # Token detokenized by the detokenizer. + # We will record the timestamp for each output token, as well as the + # finish reason. + DETOKENIZED = 6 + # Request finishes (or aborts). + FINISHED = 7 + + """ + Valid state updates: + ARRIVED + │ + ├──────► INPUT_PROCESSED ──────► QUEUED ──────► PREFILLING ◄────┐ + │ │ │ │ │ + │ │ │ ▼ │ + │ │ │ -──► DECODING │ + │ │ │ | │ │ + │ │ │ | ▼ │ + │ │ │ └─ DETOKENIZED │ + │ │ │ │ │ + │ │ │ ▼ │ + │ ▼ ▼ PREEMPTED ◄──────┘ + │ │ │ │ + └──────────────┴───────────────────┴──────────────┴ + │ + ▼ + FINISHED (All could go to FINISHED) + """ + _VALID_TRANSITIONS: ClassVar[dict[Type, set[Type]]] = { + Type.ARRIVED: { + Type.INPUT_PROCESSED, + Type.FINISHED, + }, + Type.INPUT_PROCESSED: { + Type.QUEUED, + Type.FINISHED, + }, + Type.QUEUED: { + Type.PREFILLING, + Type.FINISHED, + }, + Type.PREFILLING: { + Type.DECODING, + Type.PREEMPTED, + Type.FINISHED, + }, + Type.DECODING: { + Type.DETOKENIZED, + Type.FINISHED, + }, + Type.DETOKENIZED: { + Type.DECODING, + Type.PREEMPTED, + Type.FINISHED, + }, + Type.PREEMPTED: {Type.PREFILLING, Type.FINISHED}, + Type.FINISHED: set(), + } + + request_id: str + + type: Type + + # Timestamp when the update is recorded. This is used to record time + # intervals between events rather than wall clock time. + monotonic_ts_s: float = msgspec_field( + default_factory=lambda: time.monotonic()) + + ############################################################ + # Metadata associated with the update. + ############################################################ + # For input_processed. Metadata needed for stats logging. + num_prompt_tokens: Optional[int] = None + sampling_params: Optional[SamplingParams] = None + + # For running. + # Number of tokens computed when scheduled to run. + num_computed_tokens: Optional[int] = None + # Number of cached tokens when scheduled to run. + num_cached_tokens: Optional[int] = None + + # For decoded. + # The number of new output tokens generated. + num_new_tokens: Optional[int] = None + + # For both detokenized and decoded. + # Finished reason. + finish_reason: Optional[str] = None + + # Non-optional fields for each update type. + _REQUIRED_FIELDS: ClassVar[dict[Type, list[str]]] = { + Type.INPUT_PROCESSED: ["num_prompt_tokens", "sampling_params"], + Type.PREFILLING: ["num_computed_tokens", "num_cached_tokens"], + Type.DETOKENIZED: ["num_new_tokens"], + Type.FINISHED: ["finish_reason"], + } + + def __post_init__(self): + required_fields = self._REQUIRED_FIELDS.get(self.type, []) + for field in required_fields: + if getattr(self, field) is None: + raise ValueError( + f"Field {field} is required for update type {self.type}.") + + @staticmethod + def check_valid_update( + update: "RequestStatsUpdate", + last_update_type: Optional[Type], + last_updated_ts_s: Optional[float], + ): + if last_update_type is None: + assert update.type == RequestStatsUpdate.Type.ARRIVED + else: + valid_cur_update_types = RequestStatsUpdate._VALID_TRANSITIONS[ + last_update_type] + assert update.type in valid_cur_update_types, ( + f"Invalid update type: {update.type} for last_update_type: " + f"{last_update_type}.") + + if last_updated_ts_s is not None: + assert update.monotonic_ts_s >= last_updated_ts_s, ( + "Update timestamp must be monotonically increasing, but " + f"last_updated_ts_s={last_updated_ts_s} and " + f"update.monotonic_ts_s={update.monotonic_ts_s}.") + + +@dataclass +class RequestStats: + """Stats associated with a request (`Request`).""" + + ############################################################ + # Metadata + ############################################################ + request_id: str + sampling_params: Optional[SamplingParams] = None + num_prompt_tokens: Optional[int] = None + + ############################################################ + # Metrics and Stats + ############################################################ + # Timestamp when the request was last updated. + last_updated_ts_s: Optional[float] = None + + # Last update stats type. + last_update_type: Optional[RequestStatsUpdate.Type] = None + + # Timestamp when the request arrived at the llm engine. + arrival_ts_s: Optional[float] = None + + # Number of tokens cached. When part of the request prefix is cached, + # this will be set. + num_cached_tokens: int = 0 + + # Number of tokens computed. + num_computed_tokens: int = 0 + + # The timestamp when the request become waiting in the queue. + queued_ts_s: Optional[float] = None + + # When the input processor is completed. + input_processor_end_ts_s: Optional[float] = None + + # A sorted list of timestamps when the request was scheduled to prefill. + # This could be when: + # 1. the request is newly scheduled, so it's a new prefill. + # 2. the request was preempted and resumed. It is equivalent to running + # a prefill of the original prefill tokens + generated output tokens + # before preemption. + prefill_start_ts_s_lst: list[float] = dataclass_field(default_factory=list) + + # A list of timestamps when a token is decoded by the engine core. + decoding_ts_s_lst: list[float] = dataclass_field(default_factory=list) + + # A sorted list of timestamps for each output token. + output_token_ts_s_lst: list[float] = dataclass_field(default_factory=list) + + # First token's timestamp. + first_token_ts_s: Optional[float] = None + + # TODO(rickyx): we need model runner to surface these. + model_forward_duration_s: float = 0.0 + # Includes model forward, block/sync across workers, cpu-gpu sync time + # and sampling time. + model_execute_duration_s: float = 0.0 + + # A sorted list of timestamps when the request was preempted at the + # scheduler. + # TODO(rickyx): right now, we don't actually have a good high-level + # metric to measure the impact of preemption other than observation of + # large P99 TPOT. Ideally we could quantify the impact of preemption by + # measuring the number of tokens re-computed due to preemption. + preempted_ts_s_lst: list[float] = dataclass_field(default_factory=list) + + # Timestamp when the request was finished at the engine core. + finished_ts_s: Optional[float] = None + + # Finish reason. + finish_reason: Optional[str] = None + + ############################################################ + # Derived properties. + ############################################################ + @property + def prefill_ts_s(self) -> Optional[float]: + """The timestamp when the request started prefilling. + Since a request could be preempted in decoding and later resumed + to prefill the decoded tokens, we use the first prefill start timestamp. + """ + return (self.prefill_start_ts_s_lst[0] + if self.prefill_start_ts_s_lst else None) + + @property + def e2e_latency_s(self) -> Optional[float]: + if self.finished_ts_s is None or self.arrival_ts_s is None: + return None + assert self.finished_ts_s >= self.arrival_ts_s + return self.finished_ts_s - self.arrival_ts_s + + @property + def queue_duration_s(self) -> Optional[float]: + """How long the request was waiting to run.""" + if self.queued_ts_s is None or self.prefill_ts_s is None: + # Either not queued or not running yet. + return None + assert self.queued_ts_s <= self.prefill_ts_s + return self.prefill_ts_s - self.queued_ts_s + + @property + def inference_latency_s(self) -> Optional[float]: + """How long the request was running inference + (prefill and decode).""" + if self.finished_ts_s is None or self.prefill_ts_s is None: + return None + assert self.finished_ts_s >= self.prefill_ts_s + return self.finished_ts_s - self.prefill_ts_s + + @property + def first_token_latency_s(self) -> Optional[float]: + if self.first_token_ts_s is None or self.arrival_ts_s is None: + return None + assert self.first_token_ts_s >= self.arrival_ts_s + return self.first_token_ts_s - self.arrival_ts_s + + @property + def prefill_latency_s(self) -> Optional[float]: + if self.first_token_ts_s is None or self.prefill_ts_s is None: + return None + assert self.first_token_ts_s >= self.prefill_ts_s + return self.first_token_ts_s - self.prefill_ts_s + + @property + def decode_latency_s(self) -> Optional[float]: + if self.e2e_latency_s is None or self.first_token_latency_s is None: + return None + assert self.e2e_latency_s >= self.first_token_latency_s + return self.e2e_latency_s - self.first_token_latency_s + + @property + def output_token_latency_s_lst(self) -> list[float]: + if len(self.output_token_ts_s_lst) == 0: + return [] + latency_s_lst = [] + for i in range(1, len(self.output_token_ts_s_lst)): + assert (self.output_token_ts_s_lst[i] + >= self.output_token_ts_s_lst[i - 1]) + latency_s = (self.output_token_ts_s_lst[i] - + self.output_token_ts_s_lst[i - 1]) + latency_s_lst.append(latency_s) + return latency_s_lst + + @property + def num_output_tokens(self) -> int: + return len(self.output_token_ts_s_lst) + + @property + def is_finished(self) -> bool: + return self.finished_ts_s is not None + + def update_from(self, update: "RequestStatsUpdate"): + RequestStatsUpdate.check_valid_update(update, self.last_update_type, + self.last_updated_ts_s) + ts = update.monotonic_ts_s + self.last_updated_ts_s = ts + self.last_update_type = update.type + if update.type == RequestStatsUpdate.Type.ARRIVED: + self.arrival_ts_s = ts + elif update.type == RequestStatsUpdate.Type.INPUT_PROCESSED: + self.input_processor_end_ts_s = ts + self.sampling_params = update.sampling_params + self.num_prompt_tokens = update.num_prompt_tokens + elif update.type == RequestStatsUpdate.Type.QUEUED: + self.queued_ts_s = ts + elif update.type == RequestStatsUpdate.Type.PREFILLING: + self.prefill_start_ts_s_lst.append(ts) + self.num_cached_tokens = update.num_cached_tokens or 0 + self.num_computed_tokens = update.num_computed_tokens or 0 + elif update.type == RequestStatsUpdate.Type.PREEMPTED: + self._reset_for_preemption(ts) + elif update.type == RequestStatsUpdate.Type.DECODING: + self.decoding_ts_s_lst.append(ts) + elif update.type == RequestStatsUpdate.Type.DETOKENIZED: + self._record_detokenized_output( + ts, + update.num_new_tokens or 0, + ) + elif update.type == RequestStatsUpdate.Type.FINISHED: + self.finished_ts_s = ts + self.finish_reason = update.finish_reason + else: + raise ValueError(f"Unknown update type: {update.type}") + + def _record_detokenized_output( + self, + ts_s: float, + num_new_tokens: int, + ): + # Update if first output token is generated. + if len(self.output_token_ts_s_lst) == 0: + self.first_token_ts_s = ts_s + assert ( + self.prefill_ts_s is not None + ), "Request must be running before generating output tokens." + + # Some X new tokens were generated at the ts. + self.output_token_ts_s_lst.extend([ts_s] * num_new_tokens) + + def _reset_for_preemption(self, ts_s: float): + self.preempted_ts_s_lst.append(ts_s) + # Reset the computed tokens since it might restart the prefill. + self.num_computed_tokens = 0 + # Cached token count might also change when resumed. + self.num_cached_tokens = 0 + # These stats don't change since they happen before request running. + # - arrival_ts_s + # - input_processor_end_ts_s + # - sampling_params + # - num_prompt_tokens + # - first_token_ts_s + # + # These stats are accumulated over preemptions: + # - output_token_ts_s_lst + # - prefill_start_ts_s_lst (after preemption, it will prefill the + # original prefill tokens and any output tokens generated before + # preemption.) + + +@dataclass +class KVCacheStats: + # KV Cache Usage in % + gpu_cache_usage_sys: float = 0.0 + gpu_prefix_cache_hit_rate: float = 0.0 + + +@dataclass +class SchedulerStats: + """Stats associated with the scheduler.""" + + # Number of requests currently running. + num_running_reqs: int = 0 + # Number of requests currently waiting. + num_waiting_reqs: int = 0 + + kv_cache_stats: KVCacheStats = dataclass_field( + default_factory=KVCacheStats) + + +@dataclass +class EngineCoreProcessStats: + """Stats associated with the engine core process.""" + + # Number of requests currently in the input queue. None if the engine core + # is not running in multiprocess mode. + input_queue_size: Optional[int] = None + # Number of outputs currently in the output queue. None if the engine core + # is not running in multiprocess mode. + output_queue_size: Optional[int] = None + + +class EngineCoreStatsSnapshot( + msgspec.Struct, # type: ignore + array_like=True, + omit_defaults=True, + gc=False): + """ + A snapshot of the EngineCore's current stats over a period of time. + """ + + # Snapshot of the scheduler stats. + scheduler_stats: SchedulerStats = msgspec_field( + default_factory=SchedulerStats) + + # Per request stats updates. + requests_stats_updates: list[RequestStatsUpdate] = msgspec_field( + default_factory=list) + + # Engine core's queue stats. + engine_core_process_stats: EngineCoreProcessStats = msgspec_field( + default_factory=EngineCoreProcessStats) + + # TODO(rickyx): Add other components' stats, + # e.g. model runner/worker and etc. diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py new file mode 100644 index 0000000..218af43 --- /dev/null +++ b/vllm/v1/structured_output/__init__.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import multiprocessing +from concurrent.futures import ThreadPoolExecutor +from typing import TYPE_CHECKING, Optional + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.v1.structured_output.backend_guidance import GuidanceBackend +from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, + StructuredOutputGrammar) + +if TYPE_CHECKING: + import numpy as np + import numpy.typing as npt + import torch + + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +class StructuredOutputManager: + """Engine-level manager for structured output requests.""" + + def __init__(self, vllm_config: VllmConfig): + self.backend: Optional[StructuredOutputBackend] = None + self.vllm_config = vllm_config + self._grammar_bitmask: Optional[torch.Tensor] = None + + # The default max_workers if not specified is the number of CPUs * 5, + # which is way too high since these tasks are CPU-bound, not I/O bound. + # We also know we would never dominate CPU usage with just grammar + # compilation, so we set it to half the number of CPUs. + max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) + self.executor = ThreadPoolExecutor(max_workers=max_workers) + + def grammar_init(self, request: Request) -> None: + if request.structured_output_request is None: + return + + # Initialize the backend the first time it is needed. + # + # NOTE: We only support a single backend. We do NOT support different + # backends on a per-request basis in V1 (for now, anyway...). + if self.backend is None: + backend_name = request.sampling_params.guided_decoding.backend_name + if backend_name == "xgrammar": + from vllm.v1.structured_output.backend_xgrammar import ( + XgrammarBackend) + + self.backend = XgrammarBackend(self.vllm_config) + elif backend_name == "guidance": + self.backend = GuidanceBackend(self.vllm_config) + else: + raise ValueError( + f"Unsupported structured output backend: {backend_name}") + + grammar = self.executor.submit(self._async_create_grammar, request) + request.structured_output_request.grammar = grammar # type: ignore[assignment] + + def _async_create_grammar( + self, + request: Request, + ) -> StructuredOutputGrammar: + key = request.structured_output_request.structured_output_key # type: ignore[union-attr] + + # Note that the request was validated in the engine core client, + # so at this point we know it is a supported type of request. + # + # TODO: we still need to handle xgrammar compilation failures, + # though it should be unlikely as we test that up front as well. + request_type, grammar_spec = key + + assert self.backend is not None + return self.backend.compile_grammar(request_type, grammar_spec) + + def grammar_bitmask( + self, + requests: dict[str, Request], + structured_output_request_ids: dict[str, int], + batch_len: int, + ) -> Optional[npt.NDArray[np.int32]]: + # Prepare the structured output bitmask for this batch. + if not structured_output_request_ids: + return None + + if self._grammar_bitmask is None: + assert self.backend is not None + self._grammar_bitmask = self.backend.allocate_token_bitmask( + self.vllm_config.scheduler_config.max_num_seqs) + + # Fill the bitmask using the index of each request equal to its + # position in the batch. Resize the bitmask down to the size of + # the batch. + bitmask_tensor = self._grammar_bitmask + for req_id, batch_index in structured_output_request_ids.items(): + request = requests[req_id].structured_output_request + assert request is not None and request.grammar is not None + if not request.grammar.is_terminated(): + request.grammar.fill_bitmask(bitmask_tensor, batch_index) + if batch_len < self._grammar_bitmask.shape[0]: + bitmask_tensor = self._grammar_bitmask[:batch_len] + + # After finishing with the xgrammar operations, we convert to + # np.ndarray, because that is much more efficient for serialization + # and deserialization when sending this to the GPU workers. + return bitmask_tensor.numpy() diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py new file mode 100644 index 0000000..a7ba710 --- /dev/null +++ b/vllm/v1/structured_output/backend_guidance.py @@ -0,0 +1,169 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional + +import torch + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from vllm.utils import LazyLoader +from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, + StructuredOutputGrammar, + StructuredOutputOptions) +from vllm.v1.structured_output.request import get_structured_output_key + +if TYPE_CHECKING: + import llguidance + import llguidance.hf as llguidance_hf + import llguidance.torch as llguidance_torch +else: + llguidance = LazyLoader("llguidance", globals(), "llguidance") + llguidance_hf = LazyLoader("llguidance.hf", globals(), "llguidance.hf") + llguidance_torch = LazyLoader("llguidance.torch", globals(), + "llguidance.torch") + +logger = init_logger(__name__) + + +class GuidanceBackend(StructuredOutputBackend): + + def __init__(self, vllm_config: VllmConfig): + self.vllm_config = vllm_config + tokenizer_group = init_tokenizer_from_configs( + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + parallel_config=vllm_config.parallel_config, + lora_config=vllm_config.lora_config) # type: ignore[arg-type] + tokenizer_group.ping() + self.vllm_config = vllm_config + self.vocab_size = vllm_config.model_config.get_vocab_size() + self.disable_any_whitespace = ( + "disable-any-whitespace" + in vllm_config.decoding_config.guided_decoding_backend) + + tokenizer = tokenizer_group.get_lora_tokenizer(None) + self.ll_tokenizer = llguidance_hf.from_tokenizer(tokenizer, None) + + def compile_grammar(self, request_type: StructuredOutputOptions, + grammar_spec: str) -> StructuredOutputGrammar: + self.serialized_grammar = serialize_guidance_grammar( + request_type, grammar_spec, self.disable_any_whitespace) + + ll_matcher = llguidance.LLMatcher( + self.ll_tokenizer, + self.serialized_grammar, + log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")), + ) + + r = GuidanceGrammar( + ll_matcher=ll_matcher, + ll_tokenizer=self.ll_tokenizer, + vocab_size=self.vocab_size, + ) + + r.check_error() + return r + + def allocate_token_bitmask(self, max_num_seqs: int): + return llguidance_torch.allocate_token_bitmask( + max_num_seqs, self.ll_tokenizer.vocab_size) + + +@dataclass +class GuidanceGrammar(StructuredOutputGrammar): + ll_matcher: llguidance.LLMatcher + ll_tokenizer: llguidance.LLTokenizer + vocab_size: int + printed_error: bool = False + terminated: bool = False + + def check_error(self): + if not self.printed_error: + err = self.ll_matcher.get_error() + if err: + self.printed_error = True + logger.warning("LLMatcher error: %s", err) + + def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: + """Accepts a list of tokens and advances the parser. + + Returns True if the parser was advanced successfully. + Returns False if the parser failed to advance. + """ + + if self.ll_tokenizer.eos_token in tokens: + self.terminated = True + + if self.ll_matcher.is_stopped(): + return True + + # TODO - Add jump decoding support in the future: + # self.ll_matcher.compute_ff_bytes() - this should always work + # self.ll_matcher.compute_ff_tokens() - this only works for + # "canonical" tokenizers + # For conversion between the two, see + # https://github.com/guidance-ai/llguidance/blob/main/docs/fast_forward.md + + r = self.ll_matcher.consume_tokens(tokens) + + self.check_error() + + return r + + def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: + # this will automatically return [EOS] mask if the matcher is stopped + # or otherwise in an error state + llguidance_torch.fill_next_token_bitmask(self.ll_matcher, bitmask, idx) + self.check_error() + + def is_terminated(self) -> bool: + return self.terminated + + def reset(self): + # This method may be not needed anymore? TODO + self.ll_matcher.reset() + + +def serialize_guidance_grammar(request_type: StructuredOutputOptions, + grammar_spec: str, + disable_any_whitespace: bool = False) -> str: + if request_type == StructuredOutputOptions.JSON: + return llguidance.LLMatcher.grammar_from_json_schema( + grammar_spec, + defaults={ + "whitespace_flexible": not disable_any_whitespace, + }) + elif request_type == StructuredOutputOptions.JSON_OBJECT: + return llguidance.LLMatcher.grammar_from_json_schema( + '{"type": "object"}', + defaults={ + "whitespace_flexible": not disable_any_whitespace, + }) + else: + if request_type == StructuredOutputOptions.REGEX: + tp = "regex" + elif request_type == StructuredOutputOptions.GRAMMAR: + tp = "grammar" + elif request_type == StructuredOutputOptions.CHOICE: + tp = "choice" + else: + logger.error("Validation should have already occurred. " + "Please file an issue.") + raise ValueError("grammar is not of valid supported types. " + f"({request_type!s})") + return llguidance.grammar_from(tp, grammar_spec) + + +def validate_guidance_grammar( + sampling_params: SamplingParams, + tokenizer: Optional[llguidance.LLTokenizer] = None) -> None: + tp, grm = get_structured_output_key(sampling_params) + guidance_grm = serialize_guidance_grammar(tp, grm) + err = llguidance.LLMatcher.validate_grammar(guidance_grm, + tokenizer=tokenizer) + if err: + raise ValueError(f"Grammar error: {err}") diff --git a/vllm/v1/structured_output/backend_types.py b/vllm/v1/structured_output/backend_types.py new file mode 100644 index 0000000..6dc2a92 --- /dev/null +++ b/vllm/v1/structured_output/backend_types.py @@ -0,0 +1,89 @@ +# SPDX-License-Identifier: Apache-2.0 + +import enum +from abc import ABC, abstractmethod + +import torch + + +class StructuredOutputOptions(enum.Enum): + JSON = enum.auto() + JSON_OBJECT = enum.auto() + REGEX = enum.auto() + GRAMMAR = enum.auto() + CHOICE = enum.auto() + + +StructuredOutputKey = tuple[StructuredOutputOptions, str] + + +class StructuredOutputGrammar(ABC): + """Request-level backend for structured output requests.""" + + @abstractmethod + def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: + """ + Determines whether the provided tokens are accepted for the + given request. + + Args: + request_id (str): The unique identifier for the request. + tokens (list[int]): A list of token IDs to evaluate. + + Returns: + bool: True if the tokens are accepted, False otherwise. + """ + + @abstractmethod + def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None: + """ + Fills the bitmask for a specific batch index. + + Args: + bitmask (torch.Tensor): The bitmask to fill + batch_index (int): The index in the bitmask to fill + """ + + @abstractmethod + def is_terminated(self) -> bool: + """ + Checks whether the structured output process has terminated. + + Returns: + bool: True if the process is terminated, False otherwise. + """ + + @abstractmethod + def reset(self): + """ + Resets the state of the structured output grammar. + """ + + +class StructuredOutputBackend(ABC): + """Engine-level backend for structured output requests.""" + + @abstractmethod + def compile_grammar(self, request_type: StructuredOutputOptions, + grammar_spec: str) -> StructuredOutputGrammar: + """ + Compiles a grammar specification into a structured output grammar. + + Args: + request_type (StructuredOutputOptions): The type of structured + output request. + grammar_spec (str): The grammar specification to compile. + + Returns: + StructuredOutputGrammar: The compiled structured output grammar. + """ + + @abstractmethod + def allocate_token_bitmask(self, max_num_seqs: int): + """ + Allocates a token bitmask for the specified maximum number of sequences. + + Args: + max_num_seqs (int): The maximum number of sequences for which + to allocate the bitmask. + """ diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py new file mode 100644 index 0000000..783a334 --- /dev/null +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -0,0 +1,152 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +import torch + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer +from vllm.utils import LazyLoader +from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, + StructuredOutputGrammar, + StructuredOutputOptions) + +if TYPE_CHECKING: + import xgrammar as xgr +else: + xgr = LazyLoader("xgr", globals(), "xgrammar") + +logger = init_logger(__name__) + + +class XgrammarBackend(StructuredOutputBackend): + + def __init__(self, vllm_config: VllmConfig): + self.vllm_config = vllm_config + self.disable_any_whitespace = ( + "disable-any-whitespace" + in vllm_config.decoding_config.guided_decoding_backend) + tokenizer_group = init_tokenizer_from_configs( + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + parallel_config=vllm_config.parallel_config, + lora_config=vllm_config.lora_config) # type: ignore[arg-type] + tokenizer_group.ping() + + tokenizer = tokenizer_group.get_lora_tokenizer(None) + self.vocab_size = vllm_config.model_config.get_vocab_size() + if isinstance(tokenizer, MistralTokenizer): + # NOTE: ideally, xgrammar should handle this accordingly. + # refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 + try: + if tokenizer.is_tekken: + encoded_vocab = tokenizer._vocab + else: + encoded_vocab = [ + token for token, _ in sorted( + tokenizer.get_vocab().items(), + key=lambda x: x[1], + ) + ] + stop_token_ids = None + if hasattr( + tokenizer, + "eos_token_id", + ) and tokenizer.eos_token_id is not None: + stop_token_ids = [tokenizer.eos_token_id] + except AttributeError as e: + raise ValueError( + f"Cannot get the vocabulary of the tokenizer " + f"{type(tokenizer)}. The tokenizer should have a " + "get_vocab method.") from e + tokenizer_info = xgr.TokenizerInfo( # type: ignore + encoded_vocab=encoded_vocab, + # NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501 + vocab_type=xgr.VocabType.RAW + if tokenizer.is_tekken else xgr.VocabType.BYTE_FALLBACK, + vocab_size=self.vocab_size, + stop_token_ids=stop_token_ids, + add_prefix_space=True, + ) + else: + tokenizer_info = xgr.TokenizerInfo.from_huggingface( + tokenizer, + vocab_size=self.vocab_size, + ) + self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) + + def compile_grammar(self, request_type: StructuredOutputOptions, + grammar_spec: str) -> StructuredOutputGrammar: + if request_type == StructuredOutputOptions.JSON: + ctx = self.compiler.compile_json_schema( + grammar_spec, any_whitespace=not self.disable_any_whitespace) + elif request_type == StructuredOutputOptions.JSON_OBJECT: + ctx = self.compiler.compile_json_schema( + '{"type": "object"}', + any_whitespace=not self.disable_any_whitespace) + elif request_type == StructuredOutputOptions.GRAMMAR: + ctx = self.compiler.compile_grammar(grammar_spec) + elif request_type == StructuredOutputOptions.REGEX: + ctx = self.compiler.compile_regex(grammar_spec) + else: + logger.error( + "Validation should have already occurred. Please file an issue." + ) + raise ValueError( + f"grammar is not of valid supported types. ({request_type!s})") + + return XgrammarGrammar( + matcher=xgr.GrammarMatcher(ctx), + vocab_size=self.vocab_size, + ctx=ctx, + ) + + def allocate_token_bitmask(self, max_num_seqs: int): + return xgr.allocate_token_bitmask(max_num_seqs, self.vocab_size) + + +@dataclass +class XgrammarGrammar(StructuredOutputGrammar): + # NOTE: This would be a generic-enough class for + # supporting different backends, in the future. + # For now, just xgrammar. + # + # TODO: support max_rollback_tokens + # https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string + # for jump-forward decoding + + vocab_size: int + matcher: xgr.GrammarMatcher = field(hash=False) + ctx: xgr.CompiledGrammar = field(hash=False) + num_processed_tokens: int = field(default_factory=lambda: 0, + repr=False, + hash=False, + init=False) + + def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: + """Accepts a list of tokens and advances the FSM. + + Returns True if the FSM was advanced successfully. + Returns False if the FSM failed to advance. + """ + for token in tokens: + if not self.matcher.accept_token(token): + logger.error( + "Failed to advance FSM for request %s " + "for tokens %s. Please file an issue.", request_id, token) + return False + self.num_processed_tokens += 1 + return True + + def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: + self.matcher.fill_next_token_bitmask(bitmask, idx) + + def is_terminated(self) -> bool: + return self.matcher.is_terminated() + + def reset(self): + self.num_processed_tokens = 0 + self.matcher.reset() diff --git a/vllm/v1/structured_output/request.py b/vllm/v1/structured_output/request.py new file mode 100644 index 0000000..9e54b8b --- /dev/null +++ b/vllm/v1/structured_output/request.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import dataclasses +import functools +import json +from concurrent.futures import Future +from concurrent.futures._base import TimeoutError +from typing import Optional, Union, cast + +from vllm.sampling_params import SamplingParams +from vllm.v1.structured_output.backend_types import (StructuredOutputGrammar, + StructuredOutputKey, + StructuredOutputOptions) + + +@dataclasses.dataclass +class StructuredOutputRequest: + + sampling_params: SamplingParams + _grammar: Optional[Union[Future[StructuredOutputGrammar], + StructuredOutputGrammar]] = None + + def _check_grammar_completion(self) -> bool: + # NOTE: We have to lazy import to gate circular imports + from vllm.v1.request import RequestStatus + + if isinstance(self._grammar, Future): + try: + # We will check whether the future is ready within 100 us + self._grammar = self._grammar.result(timeout=0.0001) + self.status = RequestStatus.WAITING + except TimeoutError: + return False + return True + + @property + def is_grammar_ready(self) -> bool: + return self._check_grammar_completion() + + @property + def grammar(self) -> Optional[StructuredOutputGrammar]: + completed = self._check_grammar_completion() + return cast(Optional[StructuredOutputGrammar], + self._grammar) if completed else None + + @grammar.setter + def grammar( + self, grammar: Union[StructuredOutputGrammar, + Future[StructuredOutputGrammar]] + ) -> None: + self._grammar = grammar + + @functools.cached_property + def structured_output_key(self) -> StructuredOutputKey: + return get_structured_output_key(self.sampling_params) + + +def get_structured_output_key( + sampling_params: SamplingParams) -> StructuredOutputKey: + params = sampling_params.guided_decoding + assert params is not None, "params can't be None." + if params.json is not None: + if not isinstance(params.json, str): + json_str = json.dumps(params.json) + else: + json_str = params.json + return (StructuredOutputOptions.JSON, json_str) + elif params.json_object: + return (StructuredOutputOptions.JSON_OBJECT, "") + elif params.regex is not None: + return (StructuredOutputOptions.REGEX, params.regex) + elif params.choice is not None: + if not isinstance(params.choice, str): + json_str = json.dumps(params.choice) + else: + json_str = params.choice + return (StructuredOutputOptions.CHOICE, json_str) + elif params.grammar is not None: + return (StructuredOutputOptions.GRAMMAR, params.grammar) + else: + raise ValueError("No valid structured output parameter found") diff --git a/vllm/v1/structured_output/utils.py b/vllm/v1/structured_output/utils.py new file mode 100644 index 0000000..a771256 --- /dev/null +++ b/vllm/v1/structured_output/utils.py @@ -0,0 +1,295 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import json +import re +from typing import TYPE_CHECKING, Any + +from vllm.sampling_params import SamplingParams +from vllm.utils import LazyLoader + +if TYPE_CHECKING: + import xgrammar as xgr +else: + xgr = LazyLoader("xgr", globals(), "xgrammar") + + +def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool: + """Check if JSON schema contains features unsupported by xgrammar.""" + + def check_object(obj: dict[str, Any]) -> bool: + if not isinstance(obj, dict): + return False + + # Check for pattern restrictions + if "pattern" in obj: + return True + + # Check for numeric ranges + if obj.get("type") in ("integer", "number") and any( + key in obj + for key in ("minimum", "maximum", "exclusiveMinimum", + "exclusiveMaximum", "multipleOf")): + return True + + # Check for array unsupported keywords + if obj.get("type") == "array" and any( + key in obj + for key in ("uniqueItems", "contains", "minContains", + "maxContains", "minItems", "maxItems")): + return True + + # Unsupported keywords for strings + if obj.get("type") == "string" and any( + key in obj for key in ("minLength", "maxLength", "format")): + return True + + # Unsupported keywords for objects + if obj.get("type") == "object" and any( + key in obj for key in ("minProperties", "maxProperties", + "propertyNames", "patternProperties")): + return True + + # Recursively check all nested objects and arrays + for value in obj.values(): + if isinstance(value, dict): + if check_object(value): + return True + elif isinstance(value, list): + for item in value: + if isinstance(item, dict) and check_object(item): + return True + + return False + + return check_object(schema) + + +def grammar_is_likely_lark(grammar_str: str) -> bool: + """ + Check if grammar appears to use Lark syntax. + + Args: + grammar_str: Input grammar string + + Returns: + bool: True if grammar appears to be in Lark format, False otherwise + + Examples: + >>> grammar_is_likely_lark("rule: 'abc'") + True + >>> grammar_is_likely_lark("rule ::= 'abc'") + False + """ + if not grammar_str or not isinstance(grammar_str, str): + return False + + for line in grammar_str.split('\n'): + # Remove both comment styles + line = re.sub(r'(#|//).*$', '', line).strip() + if not line: + continue + + # Look for EBNF rule definition + if '::=' in line: + return False + + return True + + +def convert_lark_to_ebnf(grammar_str: str) -> str: + """ + Convert a Lark grammar string to EBNF format. + + EBNF reference: + https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md + Lark grammar reference: + https://lark-parser.readthedocs.io/en/latest/grammar.html + + Args: + grammar_str: Input grammar in Lark format + + Returns: + str: Converted grammar in EBNF format + + Examples: + >>> print(convert_lark_to_ebnf("rule: 'hello'")) + root ::= rule + rule ::= "hello" + """ + if not isinstance(grammar_str, str): + raise ValueError(f"Grammar must be a string, got {type(grammar_str)}") + if not grammar_str.strip(): + raise ValueError("Grammar string cannot be empty") + + defined_rules = set() + referenced_rules = set() + output_lines = [] + + def clean_line(line: str) -> str: + """Remove comments and whitespace from line.""" + return re.sub(r'(#|//).*$', '', line).strip() + + def check_quotes(text: str, rule_name: str, line_num: int) -> None: + """Validate quote matching in text.""" + if text.count("'") % 2 != 0 or text.count('"') % 2 != 0: + raise ValueError( + f"Mismatched quotes in {rule_name} on line {line_num}") + + def extract_references(text: str) -> set: + """Extract rule references from text.""" + # Remove quoted strings and special characters + text = re.sub(r'"[^"]*"', '', text) + text = re.sub(r'[+*?()|\[\]{}]', ' ', text) + return set(re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', text)) + + # First pass: Find root rule and validate rule definitions + lines = [clean_line(line) for line in grammar_str.split('\n')] + first_rule = None + + for line_num, line in enumerate(lines, 1): + if not line or line.startswith('|'): + continue + + if ':' in line: + try: + name = line.split(':', 1)[0].strip().strip('?') + defined_rules.add(name) + if first_rule is None: + first_rule = name + if name == 'start': + first_rule = 'start' + except IndexError as e: + raise ValueError(f"Invalid rule format on line {line_num}. " + "Expected 'rule_name: definition'") from e + + if not defined_rules: + raise ValueError("No valid rules found in grammar") + + # Add root rule + output_lines.append(f"root ::= {first_rule}") + + # Second pass: Process rule definitions and alternatives + current_rule = None + current_definition = [] + + for line_num, line in enumerate(lines, 1): + if not line: + continue + + try: + if ':' in line and not line.startswith('|'): + # Save previous rule if exists + if current_rule: + output_lines.append( + f"{current_rule} ::= {' | '.join(current_definition)}") + + # Process new rule + name, definition = line.split(':', 1) + current_rule = name.strip().strip('?') + + check_quotes(definition, f"rule '{current_rule}'", line_num) + definition = re.sub(r"'([^']*)'", r'"\1"', definition) + referenced_rules.update(extract_references(definition)) + current_definition = [definition.strip()] + + elif line.startswith('|'): + if not current_rule: + raise ValueError(f"Alternative '|' on line {line_num} " + "without a preceding rule definition") + + alt_def = line[1:].strip() + check_quotes(alt_def, f"alternative for rule '{current_rule}'", + line_num) + alt_def = re.sub(r"'([^']*)'", r'"\1"', alt_def) + referenced_rules.update(extract_references(alt_def)) + current_definition.append(alt_def) + + except ValueError as e: + raise ValueError(f"Error on line {line_num}: {str(e)}") from e + + # Add final rule if exists + if current_rule: + output_lines.append( + f"{current_rule} ::= {' | '.join(current_definition)}") + + # Validate all rules are defined + undefined_rules = referenced_rules - defined_rules - {'root'} + if undefined_rules: + raise ValueError("Referenced rules are not defined: " + f"{', '.join(sorted(undefined_rules))}") + + return '\n'.join(output_lines) + + +def choice_as_grammar(choice: list[str]) -> str: + + def escape_ebnf_string(s: str) -> str: + """Escape special characters in a EBNF string.""" + # Escape double quotes and backslashes + return re.sub(r'(["\\])', r'\\\1', s) + + escaped_choices = (escape_ebnf_string(c) for c in choice) + grammar = ('root ::= ' + ' | '.join(f'"{c}"' for c in escaped_choices)) + return grammar + + +def validate_structured_output_request_xgrammar( + sampling_params: SamplingParams) -> None: + """Validate that the request is supported by structured output. + + Raises ValueError if the request is not supported. + """ + if sampling_params.guided_decoding is None: + return + + gd_params = sampling_params.guided_decoding + + if gd_params.regex: + try: + xgr.Grammar.from_regex(gd_params.regex) + except Exception as err: + raise ValueError("Failed to transform regex into a grammar: " + f"{err}") from err + + if gd_params.choice: + choice_grammar = choice_as_grammar(gd_params.choice) + try: + xgr.Grammar.from_ebnf(choice_grammar) + except Exception as err: + raise ValueError("Failed to transform choices into a grammar: " + "{err}") from err + gd_params.choice = None + gd_params.grammar = choice_grammar + return + + if gd_params.json: + if isinstance(gd_params.json, str): + try: + schema = json.loads(gd_params.json) + except json.JSONDecodeError as e: + raise ValueError("Invalid JSON grammar specification.") from e + else: + schema = gd_params.json + + if has_xgrammar_unsupported_json_features(schema): + raise ValueError("The provided JSON schema contains features not " + "supported by xgrammar.") + return + + if gd_params.grammar: + if grammar_is_likely_lark(gd_params.grammar): + # xgrammar supports EBNF grammars only + try: + gd_params.grammar = convert_lark_to_ebnf(gd_params.grammar) + except ValueError as e: + raise ValueError( + "Failed to convert the grammar from Lark to EBNF. ") from e + + # Test parsing EBNF grammar, possibly already converted from Lark + try: + # parse the grammar, but we aren't compiling it. + xgr.Grammar.from_ebnf(gd_params.grammar) + except Exception as e: + raise ValueError("Invalid grammar specification.") from e diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py new file mode 100644 index 0000000..9c5bf8e --- /dev/null +++ b/vllm/v1/utils.py @@ -0,0 +1,253 @@ +# SPDX-License-Identifier: Apache-2.0 + +import multiprocessing +import os +import weakref +from collections import defaultdict +from collections.abc import Sequence +from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, + Union, overload) + +import torch + +from vllm.logger import init_logger +from vllm.model_executor.models.utils import extract_layer_index +from vllm.utils import get_mp_context, kill_process_tree + +if TYPE_CHECKING: + from vllm.attention.layer import Attention + +logger = init_logger(__name__) + +T = TypeVar("T") + + +class ConstantList(Generic[T], Sequence): + + def __init__(self, x: list[T]) -> None: + self._x = x + + def append(self, item): + raise Exception("Cannot append to a constant list") + + def extend(self, item): + raise Exception("Cannot extend a constant list") + + def insert(self, item): + raise Exception("Cannot insert into a constant list") + + def pop(self, item): + raise Exception("Cannot pop from a constant list") + + def remove(self, item): + raise Exception("Cannot remove from a constant list") + + def clear(self): + raise Exception("Cannot clear a constant list") + + def index(self, + item: T, + start: int = 0, + stop: Optional[int] = None) -> int: + return self._x.index(item, start, + stop if stop is not None else len(self._x)) + + @overload + def __getitem__(self, item: int) -> T: + ... + + @overload + def __getitem__(self, s: slice, /) -> list[T]: + ... + + def __getitem__(self, item: Union[int, slice]) -> Union[T, list[T]]: + return self._x[item] + + @overload + def __setitem__(self, item: int, value: T): + ... + + @overload + def __setitem__(self, s: slice, value: T, /): + ... + + def __setitem__(self, item: Union[int, slice], value: Union[T, list[T]]): + raise Exception("Cannot set item in a constant list") + + def __delitem__(self, item): + raise Exception("Cannot delete item from a constant list") + + def __iter__(self): + return iter(self._x) + + def __contains__(self, item): + return item in self._x + + def __len__(self): + return len(self._x) + + def __repr__(self): + return f"ConstantList({self._x})" + + +class BackgroundProcHandle: + """ + Utility class to handle creation, readiness, and shutdown + of background processes used by the AsyncLLM and LLMEngine. + """ + + def __init__( + self, + input_path: str, + output_path: str, + process_name: str, + target_fn: Callable, + process_kwargs: dict[Any, Any], + ): + context = get_mp_context() + self.reader, writer = context.Pipe(duplex=False) + + assert ("ready_pipe" not in process_kwargs + and "input_path" not in process_kwargs + and "output_path" not in process_kwargs) + process_kwargs["ready_pipe"] = writer + process_kwargs["input_path"] = input_path + process_kwargs["output_path"] = output_path + + # Run busy loop in background process. + self.proc = context.Process(target=target_fn, + kwargs=process_kwargs, + name=process_name) + self._finalizer = weakref.finalize(self, shutdown, self.proc, + input_path, output_path) + self.proc.start() + + def wait_for_startup(self): + # Wait for startup. + if self.reader.recv()["status"] != "READY": + raise RuntimeError(f"{self.proc.name} initialization failed. " + "See root cause above.") + + def shutdown(self): + self._finalizer() + + +# Note(rob): shutdown function cannot be a bound method, +# else the gc cannot collect the object. +def shutdown(proc: multiprocessing.Process, input_path: str, output_path: str): + # Shutdown the process. + if proc.is_alive(): + proc.terminate() + proc.join(5) + + if proc.is_alive(): + kill_process_tree(proc.pid) + + # Remove zmq ipc socket files. + ipc_sockets = [output_path, input_path] + for ipc_socket in ipc_sockets: + socket_file = ipc_socket.replace("ipc://", "") + if os and os.path.exists(socket_file): + os.remove(socket_file) + + +def bind_kv_cache( + kv_caches: dict[str, torch.Tensor], + forward_context: dict[str, "Attention"], + runner_kv_caches: list[torch.Tensor], +) -> None: + """ + Bind the allocated KV cache to both ModelRunner and forward context so + that the KV cache can be used in the forward pass. + + This function: + 1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with + kv_caches. + 2) Associates each attention layer in the `forward_context` with its + corresponding KV cache in kv_caches. + + Args: + kv_caches: The allocated kv_caches with layer names as keys. + forward_context: The global forward context containing all Attention + layers with layer names as keys. + runner_kv_caches: The kv_cache declared by ModelRunner. + """ + # Bind kv_caches to ModelRunner + assert len(runner_kv_caches) == 0 + + # Convert kv_caches dict to a list of tensors in the order of layer_index. + index2name = defaultdict(list) + for layer_name in kv_caches: + index2name[extract_layer_index(layer_name)].append(layer_name) + + for layer_index in sorted(index2name.keys()): + layer_names = index2name[layer_index] + if len(layer_names) > 1: + # One typical case is encoder-decoder model, e.g., bart. + # The cross attention and self attention in the same decoder layer + # has different layer_name but the same layer_index. + raise NotImplementedError + layer_name = layer_names[0] + runner_kv_caches.append(kv_caches[layer_name]) + + # Bind kv_caches to forward context + for layer_name, kv_cache in kv_caches.items(): + # NOTE: Use list because of v0 PP virtual engine. + forward_context[layer_name].kv_cache = [kv_cache] + +def bind_kv_cache_scale( + kv_caches_scale: dict[str, torch.Tensor], + forward_context: dict[str, "Attention"], + runner_kv_caches_scale: list[torch.Tensor], +) -> None: + """ + Bind the allocated KV cache to both ModelRunner and forward context so + that the KV cache can be used in the forward pass. + + This function: + 1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with + kv_caches. + 2) Associates each attention layer in the `forward_context` with its + corresponding KV cache in kv_caches. + + Args: + kv_caches: The allocated kv_caches with layer names as keys. + forward_context: The global forward context containing all Attention + layers with layer names as keys. + runner_kv_caches: The kv_cache declared by ModelRunner. + """ + # Bind kv_caches to ModelRunner + assert len(runner_kv_caches_scale) == 0 + + # Convert kv_caches dict to a list of tensors in the order of layer_index. + index2name = defaultdict(list) + for layer_name in kv_caches_scale: + index2name[extract_layer_index(layer_name)].append(layer_name) + + for layer_index in sorted(index2name.keys()): + layer_names = index2name[layer_index] + if len(layer_names) > 1: + # One typical case is encoder-decoder model, e.g., bart. + # The cross attention and self attention in the same decoder layer + # has different layer_name but the same layer_index. + raise NotImplementedError + layer_name = layer_names[0] + runner_kv_caches_scale.append(kv_caches_scale[layer_name]) + + # Bind kv_caches to forward context + for layer_name, kv_cache_scale in kv_caches_scale.items(): + # NOTE: Use list because of v0 PP virtual engine. + forward_context[layer_name].kv_cache_scale = [kv_cache_scale] + + +def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor, + length: int) -> torch.Tensor: + """ + Copy the first length elements of a tensor into another tensor in a + non-blocking manner. + + Used to copy pinned CPU tensor data to pre-allocated GPU tensors. + + Returns the sliced target tensor. + """ + return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True) diff --git a/vllm/v1/worker/__init__.py b/vllm/v1/worker/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py new file mode 100644 index 0000000..7d4082b --- /dev/null +++ b/vllm/v1/worker/block_table.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import torch + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class BlockTable: + + def __init__( + self, + max_num_reqs: int, + max_num_blocks_per_req: int, + pin_memory: bool, + device: torch.device, + ): + self.max_num_reqs = max_num_reqs + self.max_num_blocks_per_req = max_num_blocks_per_req + self.pin_memory = pin_memory + self.device = device + + self.block_table = torch.zeros( + (max_num_reqs, max_num_blocks_per_req), + device=self.device, + dtype=torch.int32, + ) + self.block_table_cpu = torch.zeros( + (max_num_reqs, max_num_blocks_per_req), + device="cpu", + dtype=torch.int32, + pin_memory=pin_memory, + ) + self.block_table_np = self.block_table_cpu.numpy() + self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) + + def append_row( + self, + block_ids: list[int], + row_idx: int, + ) -> None: + if not block_ids: + return + num_blocks = len(block_ids) + start = self.num_blocks_per_row[row_idx] + self.num_blocks_per_row[row_idx] += num_blocks + self.block_table_np[row_idx, start:start + num_blocks] = block_ids + + def add_row(self, block_ids: list[int], row_idx: int) -> None: + self.num_blocks_per_row[row_idx] = 0 + self.append_row(block_ids, row_idx) + + def move_row(self, src: int, tgt: int) -> None: + num_blocks = self.num_blocks_per_row[src] + self.block_table_np[tgt, :num_blocks] = self.block_table_np[ + src, :num_blocks] + self.num_blocks_per_row[tgt] = num_blocks + + def swap_row(self, src: int, tgt: int) -> None: + num_blocks_src = self.num_blocks_per_row[src] + num_blocks_tgt = self.num_blocks_per_row[tgt] + self.num_blocks_per_row[src] = num_blocks_tgt + self.num_blocks_per_row[tgt] = num_blocks_src + + self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]] + + def commit(self, num_reqs: int) -> None: + self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs], + non_blocking=True) + + def clear(self) -> None: + self.block_table.fill_(0) + self.block_table_cpu.fill_(0) + + def get_device_tensor(self) -> torch.Tensor: + """Ruturns the device tensor of the block table.""" + return self.block_table + + def get_cpu_tensor(self) -> torch.Tensor: + """Returns the CPU tensor of the block table.""" + return self.block_table_cpu + + def get_numpy_array(self) -> np.ndarray: + """Returns the numpy array of the block table.""" + return self.block_table_np diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py new file mode 100644 index 0000000..a64cb97 --- /dev/null +++ b/vllm/v1/worker/gpu_input_batch.py @@ -0,0 +1,678 @@ +# SPDX-License-Identifier: Apache-2.0 +# Datastructures defining an input batch + +from dataclasses import dataclass +from typing import Optional, cast + +import numpy as np +import torch + +from vllm.lora.request import LoRARequest +from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange +from vllm.sampling_params import SamplingParams, SamplingType +from vllm.utils import swap_dict_values +from vllm.v1.outputs import LogprobsTensors +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.utils import copy_slice +from vllm.v1.worker.block_table import BlockTable + +_SAMPLING_EPS = 1e-5 + + +@dataclass +class CachedRequestState: + + req_id: str + prompt_token_ids: list[int] + prompt: Optional[str] + mm_inputs: list[MultiModalKwargs] + mm_positions: list[PlaceholderRange] + sampling_params: SamplingParams + generator: Optional[torch.Generator] + + block_ids: list[int] + num_computed_tokens: int + output_token_ids: list[int] + + mrope_positions: Optional[torch.Tensor] = None + mrope_position_delta: Optional[int] = None + + lora_request: Optional[LoRARequest] = None + + def __post_init__(self): + self.num_prompt_tokens = len(self.prompt_token_ids) + + @property + def num_tokens(self) -> int: + return self.num_prompt_tokens + len(self.output_token_ids) + + def get_token_id(self, idx: int) -> int: + if idx < self.num_prompt_tokens: + return self.prompt_token_ids[idx] + else: + return self.output_token_ids[idx - self.num_prompt_tokens] + + +class InputBatch: + + def __init__( + self, + max_num_reqs: int, + max_model_len: int, + max_num_blocks_per_req: int, + device: torch.device, + pin_memory: bool, + vocab_size: int, + ): + self.max_num_reqs = max_num_reqs + self.max_model_len = max_model_len + self.max_num_blocks_per_req = max_num_blocks_per_req + self.device = device + self.pin_memory = pin_memory + self.vocab_size = vocab_size + + self._req_ids: list[Optional[str]] = [] + self.req_id_to_index: dict[str, int] = {} + + # TODO(woosuk): This buffer could be too large if max_model_len is big. + # Find a way to reduce the CPU memory usage. + # This buffer is not directly transferred to the GPU, so it does not + # need to be pinned. + self.token_ids_cpu_tensor = torch.zeros( + (max_num_reqs, max_model_len), + device="cpu", + dtype=torch.int32, + pin_memory=False, + ) + self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() + self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32) + self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) + self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) + self.num_computed_tokens_cpu_tensor = torch.zeros( + (max_num_reqs, ), + device="cpu", + dtype=torch.int32, + pin_memory=pin_memory, + ) + self.num_computed_tokens_cpu = \ + self.num_computed_tokens_cpu_tensor.numpy() + + # Block table. + self.block_table = BlockTable( + max_num_reqs=max_num_reqs, + max_num_blocks_per_req=max_num_blocks_per_req, + pin_memory=pin_memory, + device=device, + ) + + # Sampling-related. + self.temperature = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.temperature_cpu = self.temperature_cpu_tensor.numpy() + self.greedy_reqs: set[str] = set() + self.random_reqs: set[str] = set() + + self.top_p = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.top_p_cpu = self.top_p_cpu_tensor.numpy() + self.top_p_reqs: set[str] = set() + + self.top_k = torch.empty((max_num_reqs, ), + dtype=torch.int32, + device=device) + self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.int32, + device="cpu", + pin_memory=pin_memory) + self.top_k_cpu = self.top_k_cpu_tensor.numpy() + self.top_k_reqs: set[str] = set() + + self.min_p = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.min_p_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.min_p_cpu = self.min_p_cpu_tensor.numpy() + self.min_p_reqs: set[str] = set() + + # Frequency penalty related data structures + self.frequency_penalties = torch.empty((max_num_reqs, ), + dtype=torch.float, + device=device) + self.frequency_penalties_cpu_tensor = torch.empty( + (max_num_reqs, ), + dtype=torch.float, + device="cpu", + pin_memory=pin_memory) + self.frequency_penalties_cpu = \ + self.frequency_penalties_cpu_tensor.numpy() + self.frequency_penalties_reqs: set[str] = set() + + # Presence penalty related data structures + self.presence_penalties = torch.empty((max_num_reqs, ), + dtype=torch.float, + device=device) + self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float, + device="cpu", + pin_memory=pin_memory) + self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy( + ) + self.presence_penalties_reqs: set[str] = set() + + # Repetition penalty related data structures + self.repetition_penalties = torch.empty((max_num_reqs, ), + dtype=torch.float, + device=device) + self.repetition_penalties_cpu_tensor = torch.empty( + (max_num_reqs, ), + dtype=torch.float, + device="cpu", + pin_memory=pin_memory) + self.repetition_penalties_cpu = \ + self.repetition_penalties_cpu_tensor.numpy() + self.repetition_penalties_reqs: set[str] = set() + + # req_index -> (min_tokens, stop_token_ids) + self.min_tokens: dict[int, tuple[int, set[int]]] = {} + + # lora related + self.request_lora_mapping = np.zeros((self.max_num_reqs, ), + dtype=np.int32) + self.lora_id_to_request_ids: dict[int, set[str]] = {} + self.lora_id_to_lora_request: dict[int, LoRARequest] = {} + + # req_index -> generator + # NOTE(woosuk): The indices of the requests that do not have their own + # generator should not be included in the dictionary. + self.generators: dict[int, torch.Generator] = {} + + self.num_logprobs: dict[str, int] = {} + # NOTE(rob): num_prompt_logprobs only includes reqs + # that are currently in the prefill phase. + self.num_prompt_logprobs: dict[str, int] = {} + + # To accumulate prompt logprobs tensor chunks across prefill steps. + self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {} + + self.logit_bias: list[Optional[dict[int, + float]]] = [None] * max_num_reqs + self.has_allowed_token_ids: set[str] = set() + # NOTE(lufang): In the mask tensor, if the corresponding token allowed, + # the value is False. Since we use masked_fill_ to set -inf. + self.allowed_token_ids_mask: Optional[torch.Tensor] = None + self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None + + # req_index -> bad_words_token_ids + self.bad_words_token_ids: dict[int, list[list[int]]] = {} + + self.req_output_token_ids: list[Optional[list[int]]] = [] + + # This is updated each time the batch constituents change. + self.sampling_metadata = self._make_sampling_metadata() + + @property + def req_ids(self) -> list[str]: + # None elements should only be present transiently + # while performing state updates to the batch. + return cast(list[str], self._req_ids) + + def add_request( + self, + request: "CachedRequestState", + req_index: Optional[int] = None, + ) -> None: + if req_index is None: + req_index = self.num_reqs + assert req_index < self.max_num_reqs + + req_id = request.req_id + if req_index == len(self._req_ids): + self._req_ids.append(req_id) + self.req_output_token_ids.append(request.output_token_ids) + else: + self._req_ids[req_index] = req_id + self.req_output_token_ids[req_index] = request.output_token_ids + + self.req_id_to_index[req_id] = req_index + + # Copy the prompt token ids and output token ids. + num_prompt_tokens = len(request.prompt_token_ids) + self.num_prompt_tokens[req_index] = num_prompt_tokens + self.token_ids_cpu[ + req_index, :num_prompt_tokens] = request.prompt_token_ids + start_idx = num_prompt_tokens + end_idx = start_idx + len(request.output_token_ids) + self.token_ids_cpu[req_index, + start_idx:end_idx] = request.output_token_ids + # Number of token ids in token_ids_cpu. + # NOTE(woosuk): This may include spec decode tokens. + self.num_tokens[req_index] = request.num_tokens + # Number of tokens without spec decode tokens. + self.num_tokens_no_spec[req_index] = request.num_tokens + + self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens + self.block_table.add_row(request.block_ids, req_index) + + sampling_params = request.sampling_params + if sampling_params.sampling_type == SamplingType.GREEDY: + # Avoid later division by zero. + self.temperature_cpu[req_index] = -1.0 + self.greedy_reqs.add(req_id) + else: + self.temperature_cpu[req_index] = sampling_params.temperature + self.random_reqs.add(req_id) + + self.top_p_cpu[req_index] = sampling_params.top_p + if sampling_params.top_p < 1: + self.top_p_reqs.add(req_id) + top_k = sampling_params.top_k + if 0 < top_k < self.vocab_size: + self.top_k_reqs.add(req_id) + else: + top_k = self.vocab_size + self.top_k_cpu[req_index] = top_k + self.min_p_cpu[req_index] = sampling_params.min_p + self.frequency_penalties_cpu[ + req_index] = sampling_params.frequency_penalty + if sampling_params.min_p > _SAMPLING_EPS: + self.min_p_reqs.add(req_id) + if sampling_params.frequency_penalty != 0.0: + self.frequency_penalties_reqs.add(req_id) + self.presence_penalties_cpu[ + req_index] = sampling_params.presence_penalty + if sampling_params.presence_penalty != 0.0: + self.presence_penalties_reqs.add(req_id) + self.repetition_penalties_cpu[ + req_index] = sampling_params.repetition_penalty + if sampling_params.repetition_penalty != 1.0: + self.repetition_penalties_reqs.add(req_id) + if sampling_params.min_tokens: + self.min_tokens[req_index] = (sampling_params.min_tokens, + sampling_params.all_stop_token_ids) + + # NOTE(woosuk): self.generators should not include the requests that + # do not have their own generator. + if request.generator is not None: + self.generators[req_index] = request.generator + + if sampling_params.logprobs is not None: + self.num_logprobs[req_id] = sampling_params.logprobs + if sampling_params.prompt_logprobs is not None: + self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs + if sampling_params.logit_bias is not None: + self.logit_bias[req_index] = sampling_params.logit_bias + + if sampling_params.allowed_token_ids: + self.has_allowed_token_ids.add(req_id) + if self.allowed_token_ids_mask_cpu_tensor is None: + # Lazy allocation for this tensor, which can be large. + # False means we don't fill with -inf. + self.allowed_token_ids_mask = torch.zeros(self.max_num_reqs, + self.vocab_size, + dtype=torch.bool, + device=self.device) + self.allowed_token_ids_mask_cpu_tensor = torch.zeros( + self.max_num_reqs, + self.vocab_size, + dtype=torch.bool, + device="cpu") + self.allowed_token_ids_mask_cpu_tensor[req_index] = True + # False means we don't fill with -inf. + self.allowed_token_ids_mask_cpu_tensor[req_index][ + sampling_params.allowed_token_ids] = False + + if sampling_params.bad_words_token_ids: + self.bad_words_token_ids[ + req_index] = sampling_params.bad_words_token_ids + + # Add request lora ID + if request.lora_request: + lora_id = request.lora_request.lora_int_id + if lora_id not in self.lora_id_to_request_ids: + self.lora_id_to_request_ids[lora_id] = set() + + self.request_lora_mapping[req_index] = lora_id + self.lora_id_to_request_ids[lora_id].add(request.req_id) + self.lora_id_to_lora_request[lora_id] = request.lora_request + else: + # No LoRA + self.request_lora_mapping[req_index] = 0 + + def remove_request(self, req_id: str) -> Optional[int]: + """This method must always be followed by a call to condense().""" + + req_index = self.req_id_to_index.pop(req_id, None) + if req_index is None: + return None + self._req_ids[req_index] = None + self.req_output_token_ids[req_index] = None + + self.greedy_reqs.discard(req_id) + self.random_reqs.discard(req_id) + self.top_p_reqs.discard(req_id) + self.top_k_reqs.discard(req_id) + self.min_p_reqs.discard(req_id) + self.min_tokens.pop(req_index, None) + self.frequency_penalties_reqs.discard(req_id) + self.presence_penalties_reqs.discard(req_id) + self.repetition_penalties_reqs.discard(req_id) + self.generators.pop(req_index, None) + self.num_logprobs.pop(req_id, None) + self.num_prompt_logprobs.pop(req_id, None) + self.in_progress_prompt_logprobs_cpu.pop(req_id, None) + + # LoRA + lora_id = self.request_lora_mapping[req_index] + if lora_id != 0: + self.lora_id_to_request_ids[lora_id].discard(req_id) + if len(self.lora_id_to_request_ids[lora_id]) == 0: + self.lora_id_to_request_ids.pop(lora_id) + self.lora_id_to_lora_request.pop(lora_id) + self.request_lora_mapping[req_index] = 0 + + self.logit_bias[req_index] = None + self.has_allowed_token_ids.discard(req_id) + if self.allowed_token_ids_mask_cpu_tensor is not None: + # False means we don't fill with -inf. + self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False) + self.bad_words_token_ids.pop(req_index, None) + return req_index + + def swap_states(self, i1: int, i2: int) -> None: + old_id_i1 = self._req_ids[i1] + old_id_i2 = self._req_ids[i2] + self._req_ids[i1], self._req_ids[i2] =\ + self._req_ids[i2], self._req_ids[i1] # noqa + self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\ + self.req_output_token_ids[i2], self.req_output_token_ids[i1] + assert old_id_i1 is not None and old_id_i2 is not None + self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\ + self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1] + self.num_tokens[i1], self.num_tokens[i2] =\ + self.num_tokens[i2], self.num_tokens[i1] + self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\ + self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1] + self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\ + self.num_prompt_tokens[i2], self.num_prompt_tokens[i1] + self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\ + self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1] + self.temperature_cpu[i1], self.temperature_cpu[i2] =\ + self.temperature_cpu[i2], self.temperature_cpu[i1] + self.top_p_cpu[i1], self.top_p_cpu[i2] =\ + self.top_p_cpu[i2], self.top_p_cpu[i1] + self.top_k_cpu[i1], self.top_k_cpu[i2] =\ + self.top_k_cpu[i2], self.top_k_cpu[i1] + self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\ + self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1] + self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\ + self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1] + self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\ + self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1] + self.min_p_cpu[i1], self.min_p_cpu[i2] =\ + self.min_p_cpu[i2], self.min_p_cpu[i1] + + # NOTE: the following is unsafe + # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\ + # self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...] + # instead, we need to temporiarily copy the data for one of the indices + # TODO(lucas): optimize this by only copying valid indices + tmp = self.token_ids_cpu[i1, ...].copy() + self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...] + self.token_ids_cpu[i2, ...] = tmp + + swap_dict_values(self.generators, i1, i2) + swap_dict_values(self.min_tokens, i1, i2) + swap_dict_values(self.bad_words_token_ids, i1, i2) + + self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\ + self.request_lora_mapping[i2], self.request_lora_mapping[i1] + self.logit_bias[i1], self.logit_bias[i2] =\ + self.logit_bias[i2], self.logit_bias[i1] + + if self.allowed_token_ids_mask_cpu_tensor is not None: + self.allowed_token_ids_mask_cpu_tensor[i1], \ + self.allowed_token_ids_mask_cpu_tensor[i2] =\ + self.allowed_token_ids_mask_cpu_tensor[i2], \ + self.allowed_token_ids_mask_cpu_tensor[i1] + self.block_table.swap_row(i1, i2) + + def condense(self, empty_req_indices: list[int]) -> None: + num_reqs = self.num_reqs + if num_reqs == 0: + # The batched states are empty. + self._req_ids.clear() + self.req_output_token_ids.clear() + return + + # NOTE(woosuk): This function assumes that the empty_req_indices + # is sorted in descending order. + last_req_index = num_reqs + len(empty_req_indices) - 1 + while empty_req_indices: + # Find the largest non-empty index. + while last_req_index in empty_req_indices: + last_req_index -= 1 + + # Find the smallest empty index. + empty_index = empty_req_indices.pop() + if empty_index >= last_req_index: + break + + # Swap the states. + req_id = self._req_ids[last_req_index] + output_token_ids = self.req_output_token_ids[last_req_index] + assert req_id is not None + self._req_ids[empty_index] = req_id + self._req_ids[last_req_index] = None + self.req_output_token_ids[empty_index] = output_token_ids + self.req_output_token_ids[last_req_index] = None + self.req_id_to_index[req_id] = empty_index + + num_tokens = self.num_tokens[last_req_index] + self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ + last_req_index, :num_tokens] + self.num_tokens[empty_index] = num_tokens + self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[ + last_req_index] + self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[ + last_req_index] + self.num_computed_tokens_cpu[ + empty_index] = self.num_computed_tokens_cpu[last_req_index] + self.block_table.move_row(last_req_index, empty_index) + self.temperature_cpu[empty_index] = self.temperature_cpu[ + last_req_index] + self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] + self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] + self.frequency_penalties_cpu[ + empty_index] = self.frequency_penalties_cpu[last_req_index] + self.presence_penalties_cpu[ + empty_index] = self.presence_penalties_cpu[last_req_index] + self.repetition_penalties_cpu[ + empty_index] = self.repetition_penalties_cpu[last_req_index] + self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index] + generator = self.generators.pop(last_req_index, None) + if generator is not None: + self.generators[empty_index] = generator + + min_token = self.min_tokens.pop(last_req_index, None) + if min_token is not None: + self.min_tokens[empty_index] = min_token + + self.request_lora_mapping[empty_index] = self.request_lora_mapping[ + last_req_index] + + self.logit_bias[empty_index] = self.logit_bias[last_req_index] + + if self.allowed_token_ids_mask_cpu_tensor is not None: + self.allowed_token_ids_mask_cpu_tensor[ + empty_index] = self.allowed_token_ids_mask_cpu_tensor[ + last_req_index] + + bad_words_token_ids = self.bad_words_token_ids.pop( + last_req_index, None) + if bad_words_token_ids is not None: + self.bad_words_token_ids[empty_index] = bad_words_token_ids + # Decrement last_req_index since it is now empty. + last_req_index -= 1 + + # Trim lists to the batch size. + del self._req_ids[self.num_reqs:] + del self.req_output_token_ids[self.num_reqs:] + + def refresh_sampling_metadata(self): + self.sampling_metadata = self._make_sampling_metadata() + + def _make_sampling_metadata(self) -> SamplingMetadata: + num_reqs = self.num_reqs + if not self.all_greedy: + temperature = copy_slice(self.temperature_cpu_tensor, + self.temperature, num_reqs) + else: + temperature = None + if not self.no_top_p: + copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs) + if not self.no_top_k: + copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs) + if not self.no_min_p: + copy_slice(self.min_p_cpu_tensor, self.min_p, num_reqs) + + if not self.no_penalties: + # Since syncing these tensors is expensive only copy them + # if necessary i.e. if there are requests which require + # penalties to be applied during sampling. + copy_slice(self.frequency_penalties_cpu_tensor, + self.frequency_penalties, num_reqs) + copy_slice(self.presence_penalties_cpu_tensor, + self.presence_penalties, num_reqs) + copy_slice(self.repetition_penalties_cpu_tensor, + self.repetition_penalties, num_reqs) + + # The prompt tokens are used only for applying penalties during + # the sampling process. Hence copy these tensors only when + # there are requests which need penalties to be applied. + prompt_token_ids = self._make_prompt_token_ids_tensor() + else: + prompt_token_ids = None + + allowed_token_ids_mask: Optional[torch.Tensor] = None + if not self.no_allowed_token_ids: + assert self.allowed_token_ids_mask is not None + copy_slice(self.allowed_token_ids_mask_cpu_tensor, + self.allowed_token_ids_mask, num_reqs) + allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs] + + return SamplingMetadata( + temperature=temperature, + all_greedy=self.all_greedy, + all_random=self.all_random, + top_p=None if self.no_top_p else self.top_p[:num_reqs], + top_k=None if self.no_top_k else self.top_k[:num_reqs], + min_p=None if self.no_min_p else self.min_p[:num_reqs], + generators=self.generators, + max_num_logprobs=self.max_num_logprobs, + prompt_token_ids=prompt_token_ids, + frequency_penalties=self.frequency_penalties[:num_reqs], + presence_penalties=self.presence_penalties[:num_reqs], + repetition_penalties=self.repetition_penalties[:num_reqs], + output_token_ids=cast(list[list[int]], self.req_output_token_ids), + min_tokens=self.min_tokens, + no_penalties=self.no_penalties, + logit_bias=self.logit_bias[:num_reqs], + allowed_token_ids_mask=allowed_token_ids_mask, + bad_words_token_ids=self.bad_words_token_ids, + ) + + def _make_prompt_token_ids_tensor(self) -> torch.Tensor: + max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max() + prompt_token_ids_cpu_tensor = torch.empty( + (self.num_reqs, max_prompt_len), + device="cpu", + dtype=torch.int64, + pin_memory=self.pin_memory, + ) + prompt_token_ids = prompt_token_ids_cpu_tensor.numpy() + prompt_token_ids[:] = self.token_ids_cpu[:self. + num_reqs, :max_prompt_len] + # Use the value of vocab_size as a pad since we don't have a + # token_id of this value. + for i in range(self.num_reqs): + prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size + return prompt_token_ids_cpu_tensor.to(device=self.device, + non_blocking=True) + + def make_lora_inputs( + self, num_scheduled_tokens: np.ndarray + ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]: + """ + Given the num_scheduled_tokens for each request in the batch, return + datastructures used to activate the current LoRAs. + Returns: + 1. prompt_lora_mapping: A tuple of size self.num_reqs where, + prompt_lora_mapping[i] is the LoRA id to use for the ith prompt. + 2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens) + where, token_lora_mapping[i] is the LoRA id to use for ith token. + 3. lora_requests: Set of relevant LoRA requests. + """ + + req_lora_mapping = self.request_lora_mapping[:self.num_reqs] + prompt_lora_mapping = tuple(req_lora_mapping) + token_lora_mapping = tuple( + req_lora_mapping.repeat(num_scheduled_tokens)) + active_lora_requests: set[LoRARequest] = set( + self.lora_id_to_lora_request.values()) + + return prompt_lora_mapping, token_lora_mapping, active_lora_requests + + @property + def num_reqs(self) -> int: + return len(self.req_id_to_index) + + @property + def all_greedy(self) -> bool: + return len(self.random_reqs) == 0 + + @property + def all_random(self) -> bool: + return len(self.greedy_reqs) == 0 + + @property + def no_top_p(self) -> bool: + return len(self.top_p_reqs) == 0 + + @property + def no_top_k(self) -> bool: + return len(self.top_k_reqs) == 0 + + @property + def no_min_p(self) -> bool: + return len(self.min_p_reqs) == 0 + + @property + def no_penalties(self) -> bool: + return (len(self.presence_penalties_reqs) == 0 + and len(self.frequency_penalties_reqs) == 0 + and len(self.repetition_penalties_reqs) == 0) + + @property + def max_num_logprobs(self) -> Optional[int]: + return max(self.num_logprobs.values()) if self.num_logprobs else None + + @property + def no_prompt_logprob(self) -> bool: + return not self.num_prompt_logprobs + + @property + def no_allowed_token_ids(self) -> bool: + return len(self.has_allowed_token_ids) == 0 diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py new file mode 100644 index 0000000..358bfb1 --- /dev/null +++ b/vllm/v1/worker/gpu_model_runner.py @@ -0,0 +1,1717 @@ +# SPDX-License-Identifier: Apache-2.0 + +import gc +import time +import weakref +from typing import TYPE_CHECKING, Optional, Union + +import numpy as np +import torch +import torch.distributed +import torch.nn as nn + +from vllm.attention import AttentionType, get_attn_backend +from vllm.attention.layer import Attention +from vllm.config import CompilationLevel, VllmConfig +from vllm.distributed.parallel_state import get_pp_group, graph_capture +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding +from vllm.model_executor.model_loader import get_model +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal.utils import group_mm_inputs_by_modality +from vllm.sampling_params import SamplingType +from vllm.sequence import IntermediateTensors +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, + GiB_bytes, LayerBlockType, LazyLoader, cdiv, + check_use_alibi, is_pin_memory_available) +from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.v1.core.encoder_cache_manager import compute_encoder_budget +from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, + KVCacheConfig, KVCacheSpec, + SlidingWindowSpec) +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, + ModelRunnerOutput) +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.rejection_sampler import RejectionSampler +from vllm.v1.spec_decode.eagle import EagleProposer +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.spec_decode.utils import is_spec_decode_supported +from vllm.v1.utils import bind_kv_cache, bind_kv_cache_scale +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin +import vllm.envs as envs +from .utils import sanity_check_mm_encoder_outputs + +if TYPE_CHECKING: + import xgrammar as xgr + + from vllm.v1.core.sched.output import SchedulerOutput +else: + xgr = LazyLoader("xgr", globals(), "xgrammar") + +logger = init_logger(__name__) + + +class GPUModelRunner(LoRAModelRunnerMixin): + + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + ): + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + + from vllm.model_executor.models.utils import set_cpu_offload_max_bytes + set_cpu_offload_max_bytes( + int(self.cache_config.cpu_offload_gb * 1024**3)) + + model_config = self.model_config + cache_config = self.cache_config + scheduler_config = self.scheduler_config + parallel_config = self.parallel_config + self.device = device + self.pin_memory = is_pin_memory_available() + self.dtype = self.model_config.dtype + if cache_config.cache_dtype == "auto": + self.kv_cache_dtype = self.dtype + else: + self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ + cache_config.cache_dtype] + + # NOTE(woosuk): sliding_window is None for models with interleaved + # attention. Use interleaved_sliding_window instead. + self.sliding_window = model_config.get_sliding_window() + self.interleaved_sliding_window = getattr( + model_config.hf_text_config, "interleaved_sliding_window", None) + self.window_size = (self.sliding_window + or self.interleaved_sliding_window) + + self.is_multimodal_model = model_config.is_multimodal_model + self.block_size = cache_config.block_size + self.max_model_len = model_config.max_model_len + self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) + self.max_num_tokens = scheduler_config.max_num_batched_tokens + self.max_num_reqs = scheduler_config.max_num_seqs + + # Model-related. + self.num_attn_layers = model_config.get_num_layers_by_block_type( + parallel_config, LayerBlockType.attention) + self.num_query_heads = model_config.get_num_attention_heads( + parallel_config) + self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) + self.head_size = model_config.get_head_size() + self.hidden_size = model_config.get_hidden_size() + self.attention_chunk_size = model_config.attention_chunk_size + + self.attn_backend = get_attn_backend( + self.head_size, + self.dtype, + self.kv_cache_dtype, + self.block_size, + self.model_config.is_attention_free, + use_mla=self.model_config.use_mla, + ) + if self.attn_backend is None: + error_msg = ( + f"Error with get_att_backend: {self.head_size=}, " + f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, " + f"{self.model_config.is_attention_free=}, " + f"{self.model_config.use_mla=}") + logger.error(error_msg) + raise NotImplementedError( + "Non-Attention backend is not supported by V1 GPUModelRunner.") + + self.attn_metadata_builder = self.attn_backend.get_builder_cls()( + weakref.proxy(self)) + self.cascade_attn_enabled = not self.model_config.disable_cascade_attn + + # Multi-modal data support + self.mm_registry = MULTIMODAL_REGISTRY + self.uses_mrope = model_config.uses_mrope + + encoder_compute_budget, encoder_cache_size = compute_encoder_budget( + model_config=model_config, + scheduler_config=scheduler_config, + mm_registry=self.mm_registry, + ) + self.max_num_encoder_input_tokens = encoder_compute_budget + self.encoder_cache_size = encoder_cache_size + + # Lazy initialization + # self.model: nn.Module # Set after load_model + self.kv_caches: list[torch.Tensor] = [] + self.kv_caches_scale: list[torch.Tensor] = [] + + # req_id -> (input_id -> encoder_output) + self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} + + # Set up speculative decoding. + self.use_spec_decode = False + if self.speculative_config: + self.use_spec_decode = True + if get_pp_group().is_last_rank: + if self.speculative_config.method == "ngram": + self.drafter = NgramProposer(self.vllm_config) + elif self.speculative_config.method == "eagle": + self.drafter = EagleProposer(self.vllm_config, + self.device) # type: ignore + else: + raise ValueError("Unknown speculative decoding method: " + f"{self.speculative_config.method}") + self.rejection_sampler = RejectionSampler() + + # Request states. + self.requests: dict[str, CachedRequestState] = {} + # Persistent batch. + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.max_model_len, + max_num_blocks_per_req=self.max_num_blocks_per_req, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=model_config.get_vocab_size(), + ) + + self.use_cuda_graph = (self.vllm_config.compilation_config.level + == CompilationLevel.PIECEWISE + and not self.model_config.enforce_eager) + # TODO(woosuk): Provide an option to tune the max cudagraph batch size. + # The convention is different. + # self.cudagraph_batch_sizes sorts in ascending order. + # The batch sizes in the config are in descending order. + self.cudagraph_batch_sizes = list( + reversed( + self.vllm_config.compilation_config.cudagraph_capture_sizes)) + + # Cache the device properties. + self.device_properties = torch.cuda.get_device_properties(self.device) + self.num_sms = self.device_properties.multi_processor_count + + # Persistent buffers for CUDA graphs. + self.input_ids = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device=self.device) + self.positions = torch.zeros(self.max_num_tokens, + dtype=torch.int64, + device=self.device) + # None in the first PP rank. The rest are set after load_model. + self.intermediate_tensors: Optional[IntermediateTensors] = None + + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.uses_mrope: + # NOTE: `mrope_positions` is implemented with one additional dummy + # position on purpose to make it non-contiguous so that it can work + # with torch compile. + # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923 + + # NOTE: When M-RoPE is enabled, position ids are 3D regardless of + # the modality of inputs. For text-only inputs, each dimension has + # identical position IDs, making M-RoPE functionally equivalent to + # 1D-RoPE. + # See page 5 of https://arxiv.org/abs/2409.12191 + self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1), + dtype=torch.int64, + device=self.device) + self.mrope_positions_cpu = torch.zeros( + (3, self.max_num_tokens + 1), + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory) + + # Only relevant for models using ALiBi (e.g, MPT) + self.use_alibi = check_use_alibi(model_config) + + self.inputs_embeds = torch.zeros( + (self.max_num_tokens, self.hidden_size), + dtype=self.dtype, + device=self.device) + + # OPTIMIZATION: Cache the tensors rather than creating them every step. + self.arange_np = np.arange(max(self.max_num_reqs + 1, + self.max_model_len, + self.max_num_tokens), + dtype=np.int32) + # NOTE(woosuk): These tensors are "stateless", i.e., they are literally + # a faster version of creating a new tensor every time. Thus, we should + # not make any assumptions about the values in these tensors. + self.input_ids_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.input_ids_np = self.input_ids_cpu.numpy() + self.positions_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory) + self.positions_np = self.positions_cpu.numpy() + self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.slot_mapping_np = self.slot_mapping_cpu.numpy() + self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.query_start_loc_np = self.query_start_loc_cpu.numpy() + self.seq_lens_cpu = torch.zeros(self.max_num_reqs, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.seq_lens_np = self.seq_lens_cpu.numpy() + + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: + """Update the cached states and the persistent batch with the scheduler + output. + + The updated states are used by the `_prepare_inputs` function to create + the input GPU tensors for the model. + + The SamplingMetadata is updated and copied to the GPU if there is a + new/resumed/paused/finished request in the batch. + """ + # Remove finished requests from the cached states. + for req_id in scheduler_output.finished_req_ids: + self.requests.pop(req_id, None) + self.encoder_cache.pop(req_id, None) + # Remove the finished requests from the persistent batch. + # NOTE(woosuk): There could be an edge case where finished_req_ids and + # scheduled_req_ids overlap. This happens when a request is aborted and + # then resubmitted with the same ID. In this case, we treat them as two + # distinct requests - clearing the cached states for the first request + # and handling the second as a new request. + removed_req_indices: list[int] = [] + for req_id in scheduler_output.finished_req_ids: + req_index = self.input_batch.remove_request(req_id) + if req_index is not None: + removed_req_indices.append(req_index) + + # Free the cached encoder outputs. + for req_id, input_id in scheduler_output.free_encoder_input_ids: + encoder_outputs = self.encoder_cache.get(req_id) + if encoder_outputs is not None: + encoder_outputs.pop(input_id, None) + if not encoder_outputs: + self.encoder_cache.pop(req_id, None) + + # Remove the unscheduled requests from the persistent batch. + # NOTE(woosuk): The unscheduled requests are either preempted requests + # or running requests that are not scheduled in this step. We remove + # them from the persistent batch but keep their cached states since + # they will be scheduled again sometime in the future. + scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys() + cached_req_ids = self.input_batch.req_id_to_index.keys() + unscheduled_req_ids = cached_req_ids - scheduled_req_ids + # NOTE(woosuk): The persistent batch optimization assumes that + # consecutive batches contain mostly the same requests. If batches + # have low request overlap (e.g., alternating between two distinct + # sets of requests), this optimization becomes very inefficient. + for req_id in unscheduled_req_ids: + req_index = self.input_batch.remove_request(req_id) + assert req_index is not None + removed_req_indices.append(req_index) + + req_ids_to_add: list[str] = [] + # Add new requests to the cached states. + for new_req_data in scheduler_output.scheduled_new_reqs: + req_id = new_req_data.req_id + sampling_params = new_req_data.sampling_params + if sampling_params.sampling_type == SamplingType.RANDOM_SEED: + generator = torch.Generator(device=self.device) + generator.manual_seed(sampling_params.seed) + else: + generator = None + + self.requests[req_id] = CachedRequestState( + req_id=req_id, + prompt_token_ids=new_req_data.prompt_token_ids, + prompt=new_req_data.prompt, + mm_inputs=new_req_data.mm_inputs, + mm_positions=new_req_data.mm_positions, + sampling_params=sampling_params, + generator=generator, + block_ids=new_req_data.block_ids, + num_computed_tokens=new_req_data.num_computed_tokens, + output_token_ids=[], + lora_request=new_req_data.lora_request, + ) + + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.uses_mrope: + image_grid_thw = [] + video_grid_thw = [] + second_per_grid_ts = [] + for mm_input in self.requests[req_id].mm_inputs: + if mm_input.get("image_grid_thw") is not None: + image_grid_thw.extend( + mm_input["image_grid_thw"].tolist()) + if mm_input.get("video_grid_thw") is not None: + video_grid_thw.extend( + mm_input["video_grid_thw"].tolist()) + if mm_input.get("second_per_grid_ts") is not None: + second_per_grid_ts.extend( + mm_input["second_per_grid_ts"]) + + hf_config = self.model_config.hf_config + + self.requests[req_id].mrope_positions, \ + self.requests[req_id].mrope_position_delta = \ + MRotaryEmbedding.get_input_positions_tensor( + self.requests[req_id].prompt_token_ids, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + ) + + req_ids_to_add.append(req_id) + + # Update the states of the running/resumed requests. + for req_data in scheduler_output.scheduled_cached_reqs: + req_id = req_data.req_id + req_state = self.requests[req_id] + + # Update the cached states. + num_computed_tokens = req_data.num_computed_tokens + req_state.num_computed_tokens = num_computed_tokens + # Add the sampled token(s) from the previous step (if any). + # This doesn't include "unverified" tokens like spec decode tokens. + num_new_tokens = (num_computed_tokens + + len(req_data.new_token_ids) - + req_state.num_tokens) + if num_new_tokens == 1: + # Avoid slicing list in most common case. + req_state.output_token_ids.append(req_data.new_token_ids[-1]) + elif num_new_tokens > 0: + req_state.output_token_ids.extend( + req_data.new_token_ids[-num_new_tokens:]) + # Update the block IDs. + if not req_data.resumed_from_preemption: + # Append the new blocks to the existing block IDs. + req_state.block_ids.extend(req_data.new_block_ids) + else: + # The request is resumed from preemption. + # Replace the existing block IDs with the new ones. + req_state.block_ids = req_data.new_block_ids + + req_index = self.input_batch.req_id_to_index.get(req_id) + if req_index is None: + # The request is not in the persistent batch. + # The request was either preempted and resumed later, or was not + # scheduled in the previous step and needs to be added again. + req_ids_to_add.append(req_id) + continue + + # Update the persistent batch. + self.input_batch.num_computed_tokens_cpu[req_index] = ( + num_computed_tokens) + self.input_batch.block_table.append_row(req_data.new_block_ids, + req_index) + # Add new_token_ids to token_ids_cpu. + start_token_index = num_computed_tokens + end_token_index = num_computed_tokens + len(req_data.new_token_ids) + self.input_batch.token_ids_cpu[ + req_index, + start_token_index:end_token_index] = req_data.new_token_ids + self.input_batch.num_tokens_no_spec[req_index] = end_token_index + # Add spec_token_ids to token_ids_cpu. + spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( + req_id, ()) + if spec_token_ids: + start_index = end_token_index + end_token_index += len(spec_token_ids) + self.input_batch.token_ids_cpu[ + req_index, start_index:end_token_index] = spec_token_ids + # NOTE(woosuk): `num_tokens` here may include spec decode tokens. + self.input_batch.num_tokens[req_index] = end_token_index + + # Check if the batch has changed. If not, we can skip copying the + # sampling metadata from CPU to GPU. + batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0 + + # Add the new or resumed requests to the persistent batch. + # The smaller empty indices are filled first. + removed_req_indices = sorted(removed_req_indices, reverse=True) + for req_id in req_ids_to_add: + req_state = self.requests[req_id] + if removed_req_indices: + # Fill the empty index. + req_index = removed_req_indices.pop() + else: + # Append to the end. + req_index = None + self.input_batch.add_request(req_state, req_index) + + # Condense the batched states if there are empty indices. + if removed_req_indices: + self.input_batch.condense(removed_req_indices) + + if batch_changed: + self.input_batch.refresh_sampling_metadata() + + def _prepare_inputs( + self, + scheduler_output: "SchedulerOutput", + ) -> tuple[FlashAttentionMetadata, torch.Tensor, + Optional[SpecDecodeMetadata]]: + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + num_reqs = self.input_batch.num_reqs + assert num_reqs > 0 + + # Some attention backends (namely MLA) may want to separate requests + # based on if the attention computation will be compute-bound or + # memory-bound. This gives them a hook to do that. + modified_batch = self.attn_metadata_builder.reorder_batch( + self.input_batch, scheduler_output) + if modified_batch: + self.input_batch.refresh_sampling_metadata() + + # OPTIMIZATION: Start copying the block table first. + # This way, we can overlap the copy with the following CPU operations. + self.input_batch.block_table.commit(num_reqs) + + # Get the number of scheduled tokens for each request. + # TODO: The Python loop can be slow. Optimize. + num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32) + max_num_scheduled_tokens = 0 + for i, req_id in enumerate(self.input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_scheduled_tokens[i] = num_tokens + max_num_scheduled_tokens = max(max_num_scheduled_tokens, + num_tokens) + + # Get request indices. + # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + req_indices = np.repeat(self.arange_np[:num_reqs], + num_scheduled_tokens) + + # Get batched arange. + # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # Equivalent to but faster than: + # np.concatenate([np.arange(n) for n in num_scheduled_tokens]) + # Step 1. [2, 5, 3] -> [2, 7, 10] + cu_num_tokens = np.cumsum(num_scheduled_tokens) + # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7] + cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens, + num_scheduled_tokens) + # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets + + # Get positions. + positions_np = self.positions_np[:total_num_scheduled_tokens] + np.add(self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np) + + # Calculate M-RoPE positions. + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.uses_mrope: + self._calc_mrope_positions(scheduler_output) + + # Get token indices. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] + # where M is the max_model_len. + token_indices = (positions_np + + req_indices * self.input_batch.token_ids_cpu.shape[1]) + + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. + torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + torch.from_numpy(token_indices), + out=self.input_ids_cpu[:total_num_scheduled_tokens]) + + # Calculate the slot mapping. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] + # where K is the max_num_blocks_per_req and the block size is 2. + # NOTE(woosuk): We can't simply use `token_indices // block_size` here + # because M (max_model_len) is not necessarily divisible by block_size. + block_table_indices = (req_indices * self.max_num_blocks_per_req + + positions_np // self.block_size) + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. + block_table_cpu = self.input_batch.block_table.get_cpu_tensor() + block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() + block_offsets = positions_np % self.block_size + np.add(block_numbers * self.block_size, + block_offsets, + out=self.slot_mapping_np[:total_num_scheduled_tokens]) + + # Prepare the attention metadata. + self.query_start_loc_np[0] = 0 + self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens + + self.seq_lens_np[:num_reqs] = ( + self.input_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens) + + # Copy the tensors to the GPU. + self.input_ids[:total_num_scheduled_tokens].copy_( + self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) + if self.uses_mrope: + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + self.mrope_positions[:, :total_num_scheduled_tokens].copy_( + self.mrope_positions_cpu[:, :total_num_scheduled_tokens], + non_blocking=True) + else: + # Common case (1D positions) + self.positions[:total_num_scheduled_tokens].copy_( + self.positions_cpu[:total_num_scheduled_tokens], + non_blocking=True) + + # Prepare for cascade attention if enabled & beneficial. + common_prefix_len = 0 + if self.cascade_attn_enabled: + common_prefix_len = self._compute_cascade_attn_prefix_len( + num_scheduled_tokens, + scheduler_output.num_common_prefix_blocks, + ) + + attn_metadata = self.attn_metadata_builder.build( + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + common_prefix_len=common_prefix_len, + ) + + use_spec_decode = len( + scheduler_output.scheduled_spec_decode_tokens) > 0 + if not use_spec_decode: + # NOTE(woosuk): Due to chunked prefills, the batch may contain + # partial requests. While we should not sample any token + # from these partial requests, we do so for simplicity. + # We will ignore the sampled tokens from the partial requests. + # TODO: Support prompt logprobs. + logits_indices = attn_metadata.query_start_loc[1:] - 1 + spec_decode_metadata = None + else: + # Get the number of draft tokens for each request. + # Iterate over the dictionary rather than all requests since not all + # requests have draft tokens. + num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) + for req_id, draft_token_ids in ( + scheduler_output.scheduled_spec_decode_tokens.items()): + req_idx = self.input_batch.req_id_to_index[req_id] + num_draft_tokens[req_idx] = len(draft_token_ids) + + spec_decode_metadata = self._calc_spec_decode_metadata( + num_draft_tokens, cu_num_tokens) + logits_indices = spec_decode_metadata.logits_indices + + # Hot-Swap lora model + if self.lora_config: + self.set_active_loras(self.input_batch, num_scheduled_tokens) + + return attn_metadata, logits_indices, spec_decode_metadata + + def _compute_cascade_attn_prefix_len( + self, + num_scheduled_tokens: np.ndarray, + num_common_prefix_blocks: int, + ) -> int: + """Compute the length of the common prefix for cascade attention. + + NOTE(woosuk): The common prefix length returned by this function + represents the length used specifically for cascade attention, not the + actual number of tokens shared between requests. When cascade attention + is disabled (use_cascade=False), this function returns 0 even if + requests share common tokens. Additionally, the common prefix length is + truncated to a multiple of the block size and may be further truncated + due to implementation details explained below. + + Args: + num_scheduled_tokens: Number of tokens scheduled per request. + num_common_prefix_blocks: Number of shared KV cache blocks. + + Returns: + int: Length of common prefix in tokens. + """ + common_prefix_len = num_common_prefix_blocks * self.block_size + if common_prefix_len == 0: + # Common case. + return 0 + + # NOTE(woosuk): Cascade attention uses two attention kernels: one + # for the common prefix and the other for the rest. For the first + # kernel, we concatenate all the query tokens (possibly from + # different requests) and treat them as if they are from the same + # request. Then, we use bi-directional attention to process the + # common prefix in the KV cache. Importantly, this means that the + # first kernel does not do any masking. + + # Consider the following example: + # Request 1's input query: [D, E, X] + # Request 1's kv cache: [A, B, C, D, E, X] + # Request 1's num_computed_tokens: 3 (i.e., [A, B, C]) + # Request 2's input query: [E, Y] + # Request 2's kv cache: [A, B, C, D, E, Y] + # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D]) + + # If we use [A, B, C, D, E] as the common prefix, then the + # first kernel will compute the bi-directional attention between + # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E]. + # However, this is wrong because D in Request 1 should not attend to + # E in the common prefix (i.e., we need masking). + # To avoid this, [A, B, C, D] should be the common prefix. + # That is, the common prefix should be capped by the minimum + # num_computed_tokens among the requests, and plus one to include + # the first token of the query. + + # In practice, we use [A, B, C] as the common prefix, instead of + # [A, B, C, D] (i.e., the common prefix is capped by the minimum + # num_computed_tokens, without plus one). + # This is because of an implementation detail: We want to always + # use two kernels for cascade attention. Let's imagine: + # Request 3's input query: [D] + # Request 3's kv cache: [A, B, C, D] + # Request 3's num_computed_tokens: 3 (i.e., [A, B, C]) + # If we use [A, B, C, D] as the common prefix for Request 1-3, + # then Request 3 will be processed only by the first kernel, + # and the second kernel will get an empty input. While this is not + # a fundamental problem, our current implementation does not support + # this case. + num_reqs = len(num_scheduled_tokens) + common_prefix_len = min( + common_prefix_len, + self.input_batch.num_computed_tokens_cpu[:num_reqs].min()) + # common_prefix_len should be a multiple of the block size. + common_prefix_len = (common_prefix_len // self.block_size * + self.block_size) + use_cascade = self.attn_backend.use_cascade_attention( + common_prefix_len=common_prefix_len, + query_lens=num_scheduled_tokens, + num_query_heads=self.num_query_heads, + num_kv_heads=self.num_kv_heads, + use_alibi=self.use_alibi, + use_sliding_window=self.window_size is not None, + num_sms=self.num_sms, + ) + return common_prefix_len if use_cascade else 0 + + def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): + mrope_pos_ptr = 0 + for index, req_id in enumerate(self.input_batch.req_ids): + req = self.requests[req_id] + assert req.mrope_positions is not None + + num_computed_tokens = \ + self.input_batch.num_computed_tokens_cpu[index] + num_scheduled_tokens = \ + scheduler_output.num_scheduled_tokens[req_id] + num_prompt_tokens = len(req.prompt_token_ids) + + if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens: + prompt_part_len = max(0, + num_prompt_tokens - num_computed_tokens) + completion_part_len = max( + 0, num_scheduled_tokens - prompt_part_len) + else: + prompt_part_len = num_scheduled_tokens + completion_part_len = 0 + + assert num_scheduled_tokens == prompt_part_len + completion_part_len + + if prompt_part_len > 0: + # prompt's mrope_positions are pre-computed + dst_start = mrope_pos_ptr + dst_end = mrope_pos_ptr + prompt_part_len + src_start = num_computed_tokens + src_end = num_computed_tokens + prompt_part_len + + self.mrope_positions_cpu[:, dst_start:dst_end] = \ + req.mrope_positions[:,src_start:src_end] + + mrope_pos_ptr += prompt_part_len + + if completion_part_len > 0: + # compute completion's mrope_positions on-the-fly + dst_start = mrope_pos_ptr + dst_end = mrope_pos_ptr + completion_part_len + + self.mrope_positions_cpu[:, dst_start:dst_end] = \ + MRotaryEmbedding.get_next_input_positions_tensor( + req.mrope_position_delta, + context_len=num_computed_tokens + + prompt_part_len, + seq_len=num_computed_tokens + + prompt_part_len + + completion_part_len, + ) + + mrope_pos_ptr += completion_part_len + + def _calc_spec_decode_metadata( + self, + num_draft_tokens: np.ndarray, + cu_num_scheduled_tokens: np.ndarray, + ) -> SpecDecodeMetadata: + # Inputs: + # cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209] + # num_draft_tokens: [ 3, 0, 2, 0, 1] + # Outputs: + # cu_num_draft_tokens: [ 3, 3, 5, 5, 6] + # logits_indices: [ 0, 1, 2, 3, 103, 104, 105, 106, + # 206, 207, 208] + # target_logits_indices: [ 0, 1, 2, 5, 6, 9] + # bonus_logits_indices: [ 3, 4, 7, 8, 10] + + # Compute the logits indices. + # [4, 1, 3, 1, 2] + num_sampled_tokens = num_draft_tokens + 1 + # Step 1. [4, 5, 8, 9, 11] + cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32) + total_num_sampled_tokens = cu_num_sampled_tokens[-1] + # Step 2. [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] + cumsums_offsets = np.repeat(cu_num_sampled_tokens - num_sampled_tokens, + num_sampled_tokens) + # Step 3. [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] + arange = self.arange_np[:total_num_sampled_tokens] - cumsums_offsets + # Step 4. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] + logits_indices = np.repeat( + cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens) + # Step 5. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] + logits_indices += arange + + # Compute the bonus logits indices. + bonus_logits_indices = cu_num_sampled_tokens - 1 + + # Compute the draft logits indices. + # [3, 3, 5, 5, 6] + cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32) + total_num_draft_tokens = cu_num_draft_tokens[-1] + # [0, 0, 0, 3, 3, 5] + cumsums_offsets = np.repeat(cu_num_draft_tokens - num_draft_tokens, + num_draft_tokens) + # [0, 1, 2, 0, 1, 0] + arange = self.arange_np[:total_num_draft_tokens] - cumsums_offsets + # [0, 0, 0, 5, 5, 9] + target_logits_indices = np.repeat( + cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens) + # [0, 1, 2, 5, 6, 9] + target_logits_indices += arange + + # TODO: Optimize the CPU -> GPU copy. + cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to( + self.device, non_blocking=True) + logits_indices = torch.from_numpy(logits_indices).to(self.device, + non_blocking=True) + target_logits_indices = torch.from_numpy(target_logits_indices).to( + self.device, non_blocking=True) + bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to( + self.device, non_blocking=True) + + # Compute the draft token ids. + # draft_token_indices: [ 1, 2, 3, 105, 106, 208] + draft_token_ids = self.input_ids[logits_indices] + draft_token_ids = draft_token_ids[target_logits_indices + 1] + + metadata = SpecDecodeMetadata( + draft_token_ids=draft_token_ids, + num_draft_tokens=num_draft_tokens.tolist(), + cu_num_draft_tokens=cu_num_draft_tokens, + target_logits_indices=target_logits_indices, + bonus_logits_indices=bonus_logits_indices, + logits_indices=logits_indices, + ) + return metadata + + def _execute_encoder(self, scheduler_output: "SchedulerOutput"): + scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs + if not scheduled_encoder_inputs: + return + + # Batch the multi-modal inputs. + mm_inputs: list[MultiModalKwargs] = [] + req_input_ids: list[tuple[str, int]] = [] + for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): + req_state = self.requests[req_id] + for input_id in encoder_input_ids: + mm_inputs.append(req_state.mm_inputs[input_id]) + req_input_ids.append((req_id, input_id)) + + # Batch mm inputs as much as we can: if a request in the batch has + # multiple modalities or a different modality than the previous one, + # we process it separately to preserve item order. + # FIXME(ywang96): This is a hacky way to deal with multiple modalities + # in the same batch while still being able to benefit from batching + # multimodal inputs. The proper solution should be reordering the + # encoder outputs. + grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs) + + encoder_outputs = [] + for grouped_mm_inputs in grouped_mm_inputs_list: + batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs) + batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs, + device=self.device) + + # Run the encoder. + # `curr_group_outputs` is either of the following: + # 1. A tensor of shape (num_items, feature_size, hidden_size) + # in case feature_size is fixed across all multimodal items. + # 2. A list or tuple (length: num_items) of tensors, each of shape + # (feature_size, hidden_size) in case the feature size is dynamic + # depending on the input multimodal items. + curr_group_outputs = self.model.get_multimodal_embeddings( + **batched_mm_inputs) + + sanity_check_mm_encoder_outputs( + curr_group_outputs, + expected_num_items=len(grouped_mm_inputs), + ) + + for output in curr_group_outputs: + encoder_outputs.append(output) + + # Cache the encoder outputs. + for (req_id, input_id), output in zip(req_input_ids, encoder_outputs): + if req_id not in self.encoder_cache: + self.encoder_cache[req_id] = {} + self.encoder_cache[req_id][input_id] = output + + def _gather_encoder_outputs( + self, + scheduler_output: "SchedulerOutput", + ) -> list[torch.Tensor]: + encoder_outputs: list[torch.Tensor] = [] + for req_id in self.input_batch.req_ids: + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ + req_id] + req_state = self.requests[req_id] + num_computed_tokens = req_state.num_computed_tokens + mm_positions = req_state.mm_positions + for i, pos_info in enumerate(mm_positions): + start_pos = pos_info["offset"] + num_encoder_tokens = pos_info["length"] + + # The encoder output is needed if the two ranges overlap: + # [num_computed_tokens, + # num_computed_tokens + num_scheduled_tokens) and + # [start_pos, start_pos + num_encoder_tokens) + if start_pos >= num_computed_tokens + num_scheduled_tokens: + # The encoder output is not needed in this step. + break + if start_pos + num_encoder_tokens <= num_computed_tokens: + # The encoder output is already processed and stored + # in the decoder's KV cache. + continue + + start_idx = max(num_computed_tokens - start_pos, 0) + end_idx = min( + num_computed_tokens - start_pos + num_scheduled_tokens, + num_encoder_tokens) + assert start_idx < end_idx + assert req_id in self.encoder_cache + assert i in self.encoder_cache[req_id] + encoder_output = self.encoder_cache[req_id][i] + encoder_outputs.append(encoder_output[start_idx:end_idx]) + return encoder_outputs + + def get_model(self) -> nn.Module: + return self.model + + def apply_grammar_bitmask( + self, + scheduler_output: "SchedulerOutput", + logits: torch.Tensor, + ): + # Serialization of np.ndarray is much more efficient than a tensor, + # so we receive it in that format. + grammar_bitmask = scheduler_output.grammar_bitmask + if grammar_bitmask is None: + return + + # We receive the structured output bitmask from the scheduler, but the + # indices of the requests in the batch may not match the indices of + # the bitmask since the scheduler doesn't know how the gpu runner is + # ordering the requests in the batch. We need to sort the bitmask to + # match the order of the requests used here. + struct_out_req_batch_indices: dict[str, int] = {} + indices_match = True + for req_id in self.input_batch.req_ids: + mask_index = scheduler_output.structured_output_request_ids.get( + req_id) + if mask_index is None: + # not a structured output request + continue + batch_index = self.input_batch.req_id_to_index[req_id] + if batch_index != mask_index: + indices_match = False + struct_out_req_batch_indices[req_id] = batch_index + + if not indices_match: + # Sort the bitmask to match the order of the requests + sorted_bitmask = np.zeros_like(grammar_bitmask) + for req_id, batch_index in struct_out_req_batch_indices.items(): + orig_index = scheduler_output.structured_output_request_ids[ + req_id] + sorted_bitmask[batch_index] = grammar_bitmask[orig_index] + grammar_bitmask = sorted_bitmask + + grammar_bitmask = torch.from_numpy(grammar_bitmask) + + # TODO: compatibility with spec decode + xgr.apply_token_bitmask_inplace( + logits, + grammar_bitmask.to(self.device, non_blocking=True), + indices=list(struct_out_req_batch_indices.values()), + ) + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[ModelRunnerOutput, torch.Tensor]: + self._update_states(scheduler_output) + if not scheduler_output.total_num_scheduled_tokens: + # Return empty ModelRunnerOuptut if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + + if self.is_multimodal_model: + # Run the multimodal encoder if any. + self._execute_encoder(scheduler_output) + encoder_outputs = self._gather_encoder_outputs(scheduler_output) + else: + encoder_outputs = [] + + # Prepare the decoder inputs. + attn_metadata, logits_indices, spec_decode_metadata = ( + self._prepare_inputs(scheduler_output)) + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + if (self.use_cuda_graph + and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): + # Use piecewise CUDA graphs. + # Add padding to the batch size. + num_input_tokens = self.vllm_config.pad_for_cudagraph( + num_scheduled_tokens) + else: + # Eager mode. + num_input_tokens = num_scheduled_tokens + attn_metadata.num_input_tokens = num_input_tokens + + if self.is_multimodal_model: + # NOTE(woosuk): To unify token ids and soft tokens (vision + # embeddings), we always use embeddings (rather than token ids) + # as input to the multimodal model, even when the input is text. + input_ids = self.input_ids[:num_scheduled_tokens] + if encoder_outputs: + inputs_embeds = self.model.get_input_embeddings( + input_ids, encoder_outputs) + else: + inputs_embeds = self.model.get_input_embeddings(input_ids) + # TODO(woosuk): Avoid the copy. Optimize. + self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds) + inputs_embeds = self.inputs_embeds[:num_input_tokens] + input_ids = None + else: + # For text-only models, we use token ids as input. + # While it is possible to use embeddings as input just like the + # multimodal models, it is not desirable for performance since + # then the embedding layer is not included in the CUDA graph. + input_ids = self.input_ids[:num_input_tokens] + inputs_embeds = None + if self.uses_mrope: + positions = self.mrope_positions[:, :num_input_tokens] + else: + positions = self.positions[:num_input_tokens] + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + assert intermediate_tensors is not None + assert self.intermediate_tensors is not None + for k, v in intermediate_tensors.items(): + self.intermediate_tensors[k][:num_input_tokens].copy_( + v[:num_input_tokens], non_blocking=True) + intermediate_tensors = IntermediateTensors({ + k: v[:num_input_tokens] + for k, v in self.intermediate_tensors.items() + }) + + # Run the decoder. + # Use persistent buffers for CUDA graphs. + with set_forward_context(attn_metadata, self.vllm_config): + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + if not get_pp_group().is_last_rank: + # For mid-pipeline stages, return the hidden states. + return hidden_states + + hidden_states = hidden_states[:num_scheduled_tokens] + sample_hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(sample_hidden_states, None) + + # Apply structured output bitmasks if present + if scheduler_output.grammar_bitmask is not None: + self.apply_grammar_bitmask(scheduler_output, logits) + + # Sample the next token and get logprobs if needed. + sampling_metadata = self.input_batch.sampling_metadata + if spec_decode_metadata is None: + sampler_output = self.model.sample( + logits=logits, + sampling_metadata=sampling_metadata, + ) + else: + # When indexing with a tensor (bonus_logits_indices), PyTorch + # creates a new tensor with separate storage from the original + # logits tensor. This means any in-place operations on bonus_logits + # won't affect the original logits tensor. + bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] + sampler_output = self.model.sample( + logits=bonus_logits, + sampling_metadata=sampling_metadata, + ) + bonus_token_ids = sampler_output.sampled_token_ids + + # Just like `bonus_logits`, `target_logits` is a new tensor with + # separate storage from the original `logits` tensor. Therefore, + # it is safe to update `target_logits` in place. + target_logits = logits[spec_decode_metadata.target_logits_indices] + output_token_ids = self.rejection_sampler( + spec_decode_metadata, + None, # draft_probs + target_logits, + bonus_token_ids, + sampling_metadata, + ) + sampler_output.sampled_token_ids = output_token_ids + + # TODO(woosuk): The following loop can be slow since it iterates over + # the requests one by one. Optimize. + discard_sampled_tokens_req_indices = [] + for i, req_id in enumerate(self.input_batch.req_ids): + req_state = self.requests[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + if seq_len < req_state.num_tokens: + # Ignore the sampled token for partial prefills. + # Rewind the generator state as if the token was not sampled. + # This relies on cuda-specific torch-internal impl details + generator = self.input_batch.generators.get(i) + if generator is not None: + generator.set_offset(generator.get_offset() - 4) + # Record the index of the request that should not be sampled, + # so that we could clear the sampled tokens before returning. + discard_sampled_tokens_req_indices.append(i) + + # NOTE: GPU -> CPU Sync happens here. + # Move as many CPU operations as possible before this sync point. + logprobs_tensors = sampler_output.logprobs_tensors + logprobs_lists = logprobs_tensors.tolists() \ + if logprobs_tensors is not None else None + + # Compute prompt logprobs if needed. + prompt_logprobs_dict = self._get_prompt_logprobs_dict( + hidden_states, + scheduler_output, + ) + + # Get the valid generated tokens. + sampled_token_ids = sampler_output.sampled_token_ids + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + # No spec decode tokens. + valid_sampled_token_ids = sampled_token_ids.tolist() + else: + # Includes spec decode tokens. + valid_sampled_token_ids = self.rejection_sampler.parse_output( + sampled_token_ids, + self.input_batch.vocab_size, + ) + # Mask out the sampled tokens that should not be sampled. + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[i].clear() + + if not self.use_spec_decode: + # Speculative decoding is not enabled. + spec_token_ids = None + elif self.speculative_config.method == "ngram": + assert isinstance(self.drafter, NgramProposer) + spec_token_ids = self.generate_draft_token_ids( + valid_sampled_token_ids, sampling_metadata) + elif self.speculative_config.method == "eagle": + assert isinstance(self.drafter, EagleProposer) + # TODO(woosuk): Refactor the loop. + next_token_ids: list[int] = [] + for i, token_ids in enumerate(valid_sampled_token_ids): + if token_ids: + # Common case. + next_token_id = token_ids[-1] + else: + # Partial prefill (rare case). + # Get the next token id from the request state. + req_id = self.input_batch.req_ids[i] + req_state = self.requests[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + next_token_id = req_state.get_token_id(seq_len) + next_token_ids.append(next_token_id) + next_token_ids = torch.tensor(next_token_ids, + dtype=torch.int32, + device=self.device) + + if spec_decode_metadata is None: + # input_ids can be None for multimodal models. + target_token_ids = self.input_ids[:num_scheduled_tokens] + target_positions = positions + target_hidden_states = hidden_states + target_slot_mapping = attn_metadata.slot_mapping + cu_num_tokens = attn_metadata.query_start_loc + else: + # TODO(woosuk): Refactor this. + num_draft_tokens = spec_decode_metadata.num_draft_tokens + num_rejected_tokens = [ + n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 + for i, n in enumerate(num_draft_tokens) + ] + num_rejected_tokens = torch.tensor( + num_rejected_tokens, + dtype=torch.int32, + device=self.device, + ) + cu_num_tokens, token_indices = self.drafter.prepare_inputs( + attn_metadata.query_start_loc, + num_rejected_tokens, + ) + target_token_ids = self.input_ids[token_indices] + target_positions = positions[token_indices] + target_hidden_states = hidden_states[token_indices] + target_slot_mapping = attn_metadata.slot_mapping[token_indices] + + draft_token_ids, draft_probs = self.drafter.propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + target_slot_mapping=target_slot_mapping, + next_token_ids=next_token_ids, + cu_num_tokens=cu_num_tokens, + block_table=attn_metadata.block_table, + sampling_metadata=sampling_metadata, + ) + spec_token_ids = draft_token_ids.tolist() + # TODO(woosuk): Cache draft_probs and use it for rejection sampling + # in the next step. + del draft_probs + + return ModelRunnerOutput( + req_ids=self.input_batch.req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=valid_sampled_token_ids, + spec_token_ids=spec_token_ids, + logprobs=logprobs_lists, + prompt_logprobs_dict=prompt_logprobs_dict, + ) + + def generate_draft_token_ids( + self, + sampled_token_ids: list[list[int]], + sampling_metadata: SamplingMetadata, + ) -> list[list[int]]: + # TODO(woosuk): Optimize. + draft_token_ids: list[list[int]] = [] + for i, sampled_ids in enumerate(sampled_token_ids): + num_sampled_ids = len(sampled_ids) + if not num_sampled_ids: + # Skip speculative decoding. + draft_token_ids.append([]) + continue + + # Skip requests that require top-p, top-k, etc. + req_id = self.input_batch.req_ids[i] + if not is_spec_decode_supported(req_id, self.input_batch): + draft_token_ids.append([]) + continue + + # Add sampled_token_ids to token_ids_cpu. + start_idx = self.input_batch.num_tokens_no_spec[i] + end_idx = start_idx + num_sampled_ids + self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids + drafter_output = self.drafter.propose( + self.input_batch.token_ids_cpu[i, :end_idx]) + if drafter_output is None or len(drafter_output) == 0: + draft_token_ids.append([]) + else: + draft_token_ids.append(drafter_output.tolist()) + return draft_token_ids + + def load_model(self) -> None: + logger.info("Starting to load model %s...", self.model_config.model) + with DeviceMemoryProfiler() as m: # noqa: SIM117 + time_before_load = time.perf_counter() + self.model = get_model(vllm_config=self.vllm_config) + if self.lora_config: + self.model = self.load_lora_model(self.model, + self.model_config, + self.scheduler_config, + self.lora_config, + self.device) + if hasattr(self, "drafter"): + logger.info("Loading drafter model...") + self.drafter.load_model(self.model) + time_after_load = time.perf_counter() + self.model_memory_usage = m.consumed_memory + logger.info("Model loading took %.4f GiB and %.6f seconds", + self.model_memory_usage / GiB_bytes, + time_after_load - time_before_load) + + def _get_prompt_logprobs_dict( + self, + hidden_states: torch.Tensor, + scheduler_output: "SchedulerOutput", + ) -> dict[str, Optional[LogprobsTensors]]: + num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs + if not num_prompt_logprobs_dict: + return {} + + in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu + prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {} + + # Since prompt logprobs are a rare feature, prioritize simple, + # maintainable loop over optimal performance. + completed_prefill_reqs = [] + for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items(): + + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + + # Get metadata for this request. + request = self.requests[req_id] + num_prompt_tokens = len(request.prompt_token_ids) + prompt_token_ids = torch.tensor(request.prompt_token_ids).to( + self.device, non_blocking=True) + + # Set up target LogprobsTensors object. + logprobs_tensors = in_progress_dict.get(req_id) + if not logprobs_tensors: + # Create empty logprobs CPU tensors for the entire prompt. + # If chunked, we'll copy in slice by slice. + logprobs_tensors = LogprobsTensors.empty_cpu( + num_prompt_tokens - 1, num_prompt_logprobs + 1) + in_progress_dict[req_id] = logprobs_tensors + + # Determine number of logits to retrieve. + start_idx = request.num_computed_tokens + start_tok = start_idx + 1 + num_remaining_tokens = num_prompt_tokens - start_tok + if num_tokens <= num_remaining_tokens: + # This is a chunk, more tokens remain. + # In the == case, there are no more prompt logprobs to produce + # but we want to defer returning them to the next step where we + # have new generated tokens to return. + num_logits = num_tokens + else: + # This is the last chunk of prompt tokens to return. + num_logits = num_remaining_tokens + completed_prefill_reqs.append(req_id) + prompt_logprobs_dict[req_id] = logprobs_tensors + + if num_logits <= 0: + # This can happen for the final chunk if we prefilled exactly + # (num_prompt_tokens - 1) tokens for this request in the prior + # step. There are no more prompt logprobs to produce. + continue + + # Get the logits corresponding to this req's prompt tokens. + # If this is a partial request (i.e. chunked prefill), + # then there is prompt logprob generated for each index. + req_idx = self.input_batch.req_id_to_index[req_id] + offset = self.query_start_loc_np[req_idx].item() + prompt_hidden_states = hidden_states[offset:offset + num_logits] + logits = self.model.compute_logits(prompt_hidden_states, None) + + # Get the "target" tokens for each index. For prompt at index i, + # the token at prompt index i+1 is the "sampled" token we want + # to gather the logprob for. + tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits] + + # Compute prompt logprobs. + logprobs = self.model.sampler.compute_logprobs(logits) + token_ids, logprobs, ranks = self.model.sampler.gather_logprobs( + logprobs, num_prompt_logprobs, tgt_token_ids) + + # Transfer GPU->CPU async. + chunk_slice = slice(start_idx, start_idx + num_logits) + logprobs_tensors.logprob_token_ids[chunk_slice].copy_( + token_ids, non_blocking=True) + logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, + non_blocking=True) + logprobs_tensors.selected_token_ranks[chunk_slice].copy_( + ranks, non_blocking=True) + + # Remove requests that have completed prefill from the batch + # num_prompt_logprobs_dict. + for req_id in completed_prefill_reqs: + del num_prompt_logprobs_dict[req_id] + del in_progress_dict[req_id] + + # Must synchronize the non-blocking GPU->CPU transfers. + if prompt_logprobs_dict: + torch.cuda.synchronize() + + return prompt_logprobs_dict + + @torch.inference_mode() + def _dummy_run( + self, + num_tokens: int, + ) -> torch.Tensor: + + # Set num_scheduled_tokens based on num_tokens and max_num_seqs + # for dummy run with LoRA so that the num_reqs collectively + # has num_tokens in total. + assert num_tokens <= self.scheduler_config.max_num_batched_tokens + max_num_reqs = self.scheduler_config.max_num_seqs + num_reqs = max_num_reqs if num_tokens >= max_num_reqs else num_tokens + min_tokens_per_req = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + assert sum(num_scheduled_tokens_list) == num_tokens + assert len(num_scheduled_tokens_list) == num_reqs + num_scheduled_tokens = np.array(num_scheduled_tokens_list, + dtype=np.int32) + + with self.maybe_dummy_run_with_lora(self.lora_config, + num_scheduled_tokens): + model = self.model + if self.is_multimodal_model: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + else: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = None + if self.uses_mrope: + positions = self.mrope_positions[:, :num_tokens] + else: + positions = self.positions[:num_tokens] + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + if self.intermediate_tensors is None: + self.intermediate_tensors = ( + self.model.make_empty_intermediate_tensors( + batch_size=self.max_num_tokens, + dtype=self.model_config.dtype, + device=self.device)) + intermediate_tensors = IntermediateTensors({ + k: v[:num_tokens] + for k, v in self.intermediate_tensors.items() + }) + + with set_forward_context(None, + self.vllm_config, + num_tokens=num_tokens): + hidden_states = model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + logit_indices = np.cumsum(num_scheduled_tokens) - 1 + return hidden_states[logit_indices] + + @torch.inference_mode() + def _dummy_sampler_run( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + + logits = self.model.compute_logits(hidden_states, None) + num_reqs = logits.size(0) + + dummy_tensors = lambda v: torch.full( + (num_reqs, ), v, device=self.device) + + dummy_metadata = SamplingMetadata( + temperature=dummy_tensors(0.5), + all_greedy=False, + all_random=False, + top_p=dummy_tensors(0.9), + top_k=dummy_tensors(logits.size(1) - 1), + min_p=None, + generators={}, + max_num_logprobs=None, + no_penalties=True, + prompt_token_ids=None, + frequency_penalties=dummy_tensors(0.1), + presence_penalties=dummy_tensors(0.1), + repetition_penalties=dummy_tensors(0.1), + output_token_ids=[[] for _ in range(num_reqs)], + min_tokens={}, + logit_bias=[None for _ in range(num_reqs)], + allowed_token_ids_mask=None, + bad_words_token_ids={}, + ) + try: + sampler_output = self.model.sample( + logits=logits, sampling_metadata=dummy_metadata) + except RuntimeError as e: + if 'out of memory' in str(e): + raise RuntimeError( + "CUDA out of memory occurred when warming up sampler with " + f"{num_reqs} dummy requests. Please try lowering " + "`max_num_seqs` or `gpu_memory_utilization` when " + "initializing the engine.") from e + else: + raise e + if self.use_spec_decode: + draft_token_ids = [[0] for _ in range(num_reqs)] + dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy( + draft_token_ids, self.device) + + num_tokens = sum(len(ids) for ids in draft_token_ids) + # draft_probs = torch.randn( + # num_tokens, logits.shape[-1], device=self.device, + # dtype=logits.dtype) + draft_probs = None + target_logits = torch.randn(num_tokens, + logits.shape[-1], + device=self.device, + dtype=logits.dtype) + # NOTE(woosuk): Here, we should use int32 because the sampler uses + # int32 for bonus_token_ids. If the dtype mismatches, re-compilation + # will occur at runtime. + bonus_token_ids = torch.zeros(num_reqs, + device=self.device, + dtype=torch.int32) + self.rejection_sampler( + dummy_spec_decode_metadata, + draft_probs, + target_logits, + bonus_token_ids, + dummy_metadata, + ) + return sampler_output + + def profile_run(self) -> None: + # Profile with multimodal encoder & encoder cache. + # TODO: handle encoder-decoder models once we support them. + if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0 + and self.encoder_cache_size > 0): + + # NOTE: Currently model is profiled with a single non-text + # modality with the max possible input tokens even when + # it supports multiple. + max_tokens_by_modality_dict = self.mm_registry \ + .get_max_tokens_per_item_by_nonzero_modality(self.model_config) + dummy_data_modality, max_tokens_per_mm_item = max( + max_tokens_by_modality_dict.items(), key=lambda item: item[1]) + + # Check how many items of this modality can be supported by + # the encoder budget. + encoder_budget = min(self.max_num_encoder_input_tokens, + self.encoder_cache_size) + + max_num_mm_items_encoder_budget = cdiv(encoder_budget, + max_tokens_per_mm_item) + + # Check how many items of this modality can be supported by + # the decoder budget. + max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt( + self.model_config)[dummy_data_modality] + + # NOTE: We do not consider max_num_batched_tokens on purpose + # because the multimodal embeddings can be generated in advance + # and chunked prefilled. + max_num_mm_items_decoder_budget = self.max_num_reqs * \ + max_mm_items_per_req + + max_num_mm_items = min(max_num_mm_items_encoder_budget, + max_num_mm_items_decoder_budget) + + logger.info( + "Encoder cache will be initialized with a budget of %s tokens," + " and profiled with %s %s items of the maximum feature size.", + encoder_budget, max_num_mm_items, dummy_data_modality) + + # Create dummy batch of multimodal inputs. + dummy_mm_kwargs = self.mm_registry.get_decoder_dummy_data( + model_config=self.model_config, + seq_len=self.max_num_tokens, + mm_counts={ + dummy_data_modality: 1 + }, + ).multi_modal_data + + batched_dummy_mm_inputs = MultiModalKwargs.batch( + [dummy_mm_kwargs] * max_num_mm_items) + batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs( + batched_dummy_mm_inputs, device=self.device) + + # Run multimodal encoder. + dummy_encoder_outputs = self.model.get_multimodal_embeddings( + **batched_dummy_mm_inputs) + + sanity_check_mm_encoder_outputs( + dummy_encoder_outputs, + expected_num_items=max_num_mm_items, + ) + + # Cache the dummy encoder outputs. + self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) + + hidden_states = self._dummy_run(self.max_num_tokens) + if get_pp_group().is_last_rank: + sampler_output = self._dummy_sampler_run(hidden_states) + else: + sampler_output = None + torch.cuda.synchronize() + del hidden_states, sampler_output + self.encoder_cache.clear() + gc.collect() + + def capture_model(self) -> None: + if not self.use_cuda_graph: + logger.warning( + "Skipping CUDA graph capture. Please add " + "-O %s to use CUDA graphs.", CompilationLevel.PIECEWISE) + return + + start_time = time.perf_counter() + start_free_gpu_memory = torch.cuda.mem_get_info()[0] + + # Trigger CUDA graph capture for specific shapes. + # Capture the large shapes first so that the smaller shapes + # can reuse the memory pool allocated for the large shapes. + with graph_capture(device=self.device): + for num_tokens in reversed(self.cudagraph_batch_sizes): + for _ in range(self.vllm_config.compilation_config. + cudagraph_num_of_warmups): + self._dummy_run(num_tokens) + self._dummy_run(num_tokens) + + end_time = time.perf_counter() + end_free_gpu_memory = torch.cuda.mem_get_info()[0] + elapsed_time = end_time - start_time + cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory + # This usually takes 5~20 seconds. + logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", + elapsed_time, cuda_graph_size / (1 << 30)) + + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: + """ + Initialize KV cache based on `kv_cache_config`. + Args: + kv_cache_config: Configuration for the KV cache, including the KV + cache size of each layer + """ + if len(kv_cache_config.kv_cache_groups) > 1: + raise NotImplementedError( + "Hybrid models with more than one KV cache type are not " + "supported yet.") + + kv_caches: dict[str, torch.Tensor] = {} + kv_caches_scale: dict[str, torch.Tensor] = {} + + for kv_cache_group in kv_cache_config.kv_cache_groups: + kv_cache_spec = kv_cache_group.kv_cache_spec + for layer_name in kv_cache_group.layer_names: + tensor_config = kv_cache_config.tensors[layer_name] + assert tensor_config.size % (kv_cache_spec.page_size_bytes + kv_cache_spec.scale_page_size_bytes) == 0 + num_blocks = tensor_config.size // (kv_cache_spec.page_size_bytes + kv_cache_spec.scale_page_size_bytes) + # `num_blocks` is the number of blocks the model runner can use. + # `kv_cache_config.num_blocks` is the number of blocks that + # KVCacheManager may allocate. + # Since different GPUs may have different number of layers and + # different memory capacities, `num_blocks` can be different on + # different GPUs, and `kv_cache_config.num_blocks` is set to + # the min of all `num_blocks`. Verify it here. + assert num_blocks >= kv_cache_config.num_blocks + if isinstance(kv_cache_spec, AttentionSpec): + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + dtype = kv_cache_spec.dtype + kv_caches[layer_name] = torch.zeros(kv_cache_shape, + dtype=dtype, + device=self.device) + if envs.VLLM_USE_INT8_MLA: + kv_caches_scale_shape = kv_cache_shape[:-1]+(2,) + kv_caches_scale[layer_name] = torch.zeros(kv_caches_scale_shape, + dtype=torch.float32, + device=self.device) + + else: + # TODO: add new branches when introducing more types of + # KV cache specs. + raise ValueError("Unknown KV cache spec type.") + + bind_kv_cache( + kv_caches, + self.vllm_config.compilation_config.static_forward_context, + self.kv_caches) + if envs.VLLM_USE_INT8_MLA: + bind_kv_cache_scale( + kv_caches_scale, + self.vllm_config.compilation_config.static_forward_context, + self.kv_caches_scale) + + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: + """ + Generates the KVCacheSpec by parsing the kv cache format from each + Attention module in the static forward context. + Returns: + KVCacheSpec: A dictionary mapping layer names to their KV cache + format. Layers that do not need KV cache are not included. + """ + + forward_ctx = self.vllm_config.compilation_config.static_forward_context + block_size = self.vllm_config.cache_config.block_size + use_mla = self.vllm_config.model_config.use_mla + kv_cache_spec: dict[str, KVCacheSpec] = {} + for layer_name, attn_module in forward_ctx.items(): + if isinstance(attn_module, FusedMoE): + continue + + # TODO: Support other attention modules, e.g., sliding window, + # cross-attention + assert isinstance(attn_module, Attention) + if attn_module.attn_type == AttentionType.DECODER: + if attn_module.sliding_window is not None: + kv_cache_spec[layer_name] = SlidingWindowSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + sliding_window=attn_module.sliding_window, + use_mla=use_mla) + else: + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + use_mla=use_mla) + elif attn_module.attn_type in (AttentionType.ENCODER, + AttentionType.ENCODER_ONLY): + # encoder-only attention does not need KV cache. + continue + elif attn_module.attn_type == AttentionType.ENCODER_DECODER: + raise NotImplementedError + else: + raise ValueError( + f"Unknown attention type: {attn_module.attn_type}") + + return kv_cache_spec diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py new file mode 100644 index 0000000..2972e0f --- /dev/null +++ b/vllm/v1/worker/gpu_worker.py @@ -0,0 +1,320 @@ +# SPDX-License-Identifier: Apache-2.0 +"""A GPU worker class.""" +import gc +import os +from typing import TYPE_CHECKING, Optional + +import torch +import torch.distributed +import torch.nn as nn + +import vllm.envs as envs +from vllm.config import ParallelConfig, VllmConfig +from vllm.device_allocator.cumem import CuMemAllocator +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment, + set_custom_all_reduce) +from vllm.distributed.parallel_state import get_pp_group +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor import set_random_seed +from vllm.platforms import current_platform +from vllm.utils import GiB_bytes +from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.worker.gpu_model_runner import GPUModelRunner +from vllm.v1.worker.worker_base import WorkerBase + +logger = init_logger(__name__) + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + + +class Worker(WorkerBase): + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False, + ): + + super().__init__(vllm_config=vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + is_driver_worker=is_driver_worker) + + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + # Torch profiler. Enabled and configured through env vars: + # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace + if envs.VLLM_TORCH_PROFILER_DIR: + torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + logger.info("Profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir) + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + with_stack=True, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + torch_profiler_trace_dir, use_gzip=True)) + else: + self.profiler = None + + def sleep(self, level: int = 1) -> None: + free_bytes_before_sleep = torch.cuda.mem_get_info()[0] + allocator = CuMemAllocator.get_instance() + allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple()) + free_bytes_after_sleep, total = torch.cuda.mem_get_info() + freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep + used_bytes = total - free_bytes_after_sleep + assert freed_bytes >= 0, "Memory usage increased after sleeping." + logger.info( + "Sleep mode freed %.2f GiB memory, " + "%.2f GiB memory is still in use.", freed_bytes / GiB_bytes, + used_bytes / GiB_bytes) + + def wake_up(self, tags: Optional[list[str]] = None) -> None: + allocator = CuMemAllocator.get_instance() + allocator.wake_up(tags) + + def init_device(self): + if self.device_config.device.type == "cuda": + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # This env var set by Ray causes exceptions with graph building. + os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) + self.device = torch.device(f"cuda:{self.local_rank}") + torch.cuda.set_device(self.device) + + _check_if_gpu_supports_dtype(self.model_config.dtype) + gc.collect() + torch.cuda.empty_cache() + self.init_gpu_memory = torch.cuda.mem_get_info()[0] + else: + raise RuntimeError( + f"Not support device type: {self.device_config.device}") + # Initialize the distributed environment. + init_worker_distributed_environment(self.parallel_config, self.rank, + self.distributed_init_method, + self.local_rank) + # Set random seed. + set_random_seed(self.model_config.seed) + + # Construct the model runner + self.model_runner: GPUModelRunner = GPUModelRunner( + self.vllm_config, self.device) + + # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool + # to hijack tensor allocation. + def load_model(self) -> None: + if self.vllm_config.model_config.enable_sleep_mode: + allocator = CuMemAllocator.get_instance() + assert allocator.get_current_usage() == 0, ( + "Sleep mode can only be " + "used for one instance per process.") + context = allocator.use_memory_pool(tag="weights") + else: + from contextlib import nullcontext + context = nullcontext() + with context: + self.model_runner.load_model() + + @torch.inference_mode() + def determine_available_memory(self) -> int: + """Profiles the peak memory usage of the model to determine how much + memory can be used for KV cache without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the free memory that can be used for KV cache in + bytes. + + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + _, total_gpu_memory = torch.cuda.mem_get_info() + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + self.model_runner.profile_run() + + free_gpu_memory, _ = torch.cuda.mem_get_info() + # NOTE(woosuk): Here we assume that the other processes using the same + # GPU did not change their memory usage during the profiling. + assert self.init_gpu_memory > free_gpu_memory, ( + "Error in memory profiling. " + f"Initial free memory {self.init_gpu_memory}, current free memory" + f" {free_gpu_memory}. This happens when the GPU memory was " + "not properly cleaned up before initializing the vLLM instance.") + + # Get the peak memory allocation recorded by torch + peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"] + + # Check for any memory left around that may have been allocated on the + # gpu outside of `torch`. NCCL operations, for example, can use a few + # GB during a forward pass + torch.cuda.empty_cache() + torch_allocated_bytes = torch.cuda.memory_stats( + )["allocated_bytes.all.current"] + total_allocated_bytes = torch.cuda.mem_get_info( + )[1] - torch.cuda.mem_get_info()[0] + non_torch_allocations = total_allocated_bytes - torch_allocated_bytes + if non_torch_allocations > 0: + peak_memory += non_torch_allocations + available_kv_cache_memory = ( + total_gpu_memory * self.cache_config.gpu_memory_utilization - + peak_memory) + + return int(available_kv_cache_memory) + + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: + return self.model_runner.get_kv_cache_spec() + + def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: + """Allocate GPU KV cache with the specified kv_cache_config.""" + if self.vllm_config.model_config.enable_sleep_mode: + allocator = CuMemAllocator.get_instance() + context = allocator.use_memory_pool(tag="kv_cache") + else: + from contextlib import nullcontext + context = nullcontext() + with context: + self.model_runner.initialize_kv_cache(kv_cache_config) + + def compile_or_warm_up_model(self) -> None: + # warm up sizes that are not in cudagraph capture sizes, + # but users still want to compile for better performance, + # e.g. for the max-num-batched token size in chunked prefill. + warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy() + if not self.model_config.enforce_eager: + warmup_sizes = [ + x for x in warmup_sizes if x not in + self.vllm_config.compilation_config.cudagraph_capture_sizes + ] + for size in sorted(warmup_sizes, reverse=True): + logger.info("Compile and warming up model for size %d", size) + self.model_runner._dummy_run(size) + if not self.model_config.enforce_eager: + self.model_runner.capture_model() + + # Warm up sampler and preallocate memory buffer for logits and other + # sampling related tensors of max possible shape to avoid memory + # fragmentation issue. + # NOTE: This is called after `capture_model` on purpose to prevent + # memory buffers from being cleared by `torch.cuda.empty_cache`. + if get_pp_group().is_last_rank: + max_num_reqs = min(self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens) + self.model_runner._dummy_sampler_run( + hidden_states=self.model_runner._dummy_run( + num_tokens=max_num_reqs)) + + # Reset the seed to ensure that the random state is not affected by + # the model initialization and profiling. + set_random_seed(self.model_config.seed) + + def get_model(self) -> nn.Module: + return self.model_runner.get_model() + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> Optional[ModelRunnerOutput]: + output = self.model_runner.execute_model(scheduler_output) + return output if self.is_driver_worker else None + + def profile(self, is_start: bool = True): + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + if is_start: + self.profiler.start() + else: + self.profiler.stop() + + def execute_dummy_batch(self) -> None: + self.model_runner._dummy_run(1) + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_runner.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.model_runner.remove_lora(lora_id) + + def list_loras(self) -> set[int]: + return self.model_runner.list_loras() + + def pin_lora(self, lora_id: int) -> bool: + return self.model_runner.pin_lora(lora_id) + + def check_health(self) -> None: + # worker will always be healthy as long as it's running. + return + + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + from vllm.model_executor.model_loader.loader import ShardedStateLoader + ShardedStateLoader.save_model( + self.model_runner.model, + path, + pattern=pattern, + max_size=max_size, + ) + + +def init_worker_distributed_environment( + parallel_config: ParallelConfig, + rank: int, + distributed_init_method: Optional[str] = None, + local_rank: int = -1, +) -> None: + """Initialize the distributed environment.""" + set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) + + init_distributed_environment(parallel_config.world_size, rank, + distributed_init_method, local_rank) + + ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) + + +def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): + # Check if the GPU supports the dtype. + if torch_dtype == torch.bfloat16: # noqa: SIM102 + if not current_platform.has_device_capability(80): + capability = current_platform.get_device_capability() + gpu_name = current_platform.get_device_name() + + if capability is None: + compute_str = "does not have a compute capability" + else: + version_str = capability.as_version_str() + compute_str = f"has compute capability {version_str}" + + raise ValueError( + "Bfloat16 is only supported on GPUs with compute capability " + f"of at least 8.0. Your {gpu_name} GPU {compute_str}. " + "You can use float16 instead by explicitly setting the " + "`dtype` flag in CLI, for example: --dtype=half.") diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py new file mode 100644 index 0000000..a8a19e0 --- /dev/null +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -0,0 +1,149 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Define LoRA functionality mixin for model runners. +""" + +from contextlib import contextmanager + +import numpy as np +import torch.nn as nn + +from vllm.config import LoRAConfig, ModelConfig, SchedulerConfig +from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager +from vllm.model_executor.models import supports_lora, supports_multimodal +from vllm.v1.worker.gpu_input_batch import InputBatch + +logger = init_logger(__name__) + + +# Defined as a mixin for GPUModelRunner +class LoRAModelRunnerMixin: + + LORA_WARMUP_RANK = 8 + + def load_lora_model(self, model: nn.Module, model_config: ModelConfig, + scheduler_config: SchedulerConfig, + lora_config: LoRAConfig, device: str) -> nn.Module: + + assert supports_lora( + model), f"{model.__class__.__name__} does not support LoRA yet." + + if supports_multimodal(model): + logger.warning("Regarding multimodal models, vLLM currently " + "only supports adding LoRA to language model.") + + # It's necessary to distinguish between the max_position_embeddings + # of VLMs and LLMs. + if hasattr(model.config, "max_position_embeddings"): + max_pos_embeddings = model.config.max_position_embeddings + else: + max_pos_embeddings = ( + model.config.text_config.max_position_embeddings) + + # Add LoRA Manager to the Model Runner + self.lora_manager = LRUCacheWorkerLoRAManager( + scheduler_config.max_num_seqs, + scheduler_config.max_num_batched_tokens, + model_config.get_vocab_size(), + lora_config, + device, + model.embedding_modules, + model.embedding_padding_modules, + max_position_embeddings=max_pos_embeddings, + ) + return self.lora_manager.create_lora_manager(model) + + def _set_active_loras(self, prompt_lora_mapping: tuple[int, ...], + token_lora_mapping: tuple[int, ...], + lora_requests: set[LoRARequest]) -> None: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + + # Set is_prefill to True, so we always use the SGMV kernels on + # non-cuda platforms. + # On cuda platforms we use the same kernels for prefill and + # decode and this flag is generally ignored. + lora_mapping = LoRAMapping(token_lora_mapping, + prompt_lora_mapping, + is_prefill=True) + self.lora_manager.set_active_adapters(lora_requests, lora_mapping) + + def set_active_loras(self, input_batch: InputBatch, + num_scheduled_tokens: np.ndarray) -> None: + + prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs + token_lora_mapping: tuple[int, + ...] # of size np.sum(num_scheduled_tokens) + lora_requests: set[LoRARequest] + prompt_lora_mapping, token_lora_mapping, lora_requests = \ + input_batch.make_lora_inputs(num_scheduled_tokens) + return self._set_active_loras(prompt_lora_mapping, token_lora_mapping, + lora_requests) + + @contextmanager + def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig, + num_scheduled_tokens: np.ndarray): + if lora_config is None: + yield + else: + # __enter__ code + assert self.lora_manager is not None, "LoRA is not enabled" + + num_reqs = len(num_scheduled_tokens) + num_loras = lora_config.max_loras + + # Make prompt lora mapping + # Assign LoRA IDs cyclically to simulate a worst-case scenario. + prompt_lora_mapping = (np.arange(num_reqs, dtype=np.int32) % + num_loras) + 1 + + # Make token lora mapping + token_lora_mapping = np.repeat(prompt_lora_mapping, + num_scheduled_tokens) + + # Make dummy lora requests + lora_requests: set[LoRARequest] = { + LoRARequest(lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_path="/not/a/real/path") + for lora_id in range(1, num_loras + 1) + } + + with self.lora_manager.dummy_lora_cache(): + # Add the dummy LoRAs here so _set_active_loras doesn't try to + # load from disk. + for lr in lora_requests: + self.lora_manager.add_dummy_lora( + lr, rank=self.LORA_WARMUP_RANK) + + self._set_active_loras(tuple(prompt_lora_mapping), + tuple(token_lora_mapping), + lora_requests) + + yield + + # __exit__ code + self.lora_manager.remove_all_adapters() + + def add_lora(self, lora_request: LoRARequest) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.add_adapter(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.remove_adapter(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.pin_adapter(lora_id) + + def list_loras(self) -> set[int]: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.list_adapters() diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py new file mode 100644 index 0000000..0668e71 --- /dev/null +++ b/vllm/v1/worker/tpu_model_runner.py @@ -0,0 +1,1037 @@ +# SPDX-License-Identifier: Apache-2.0 +import bisect +import time +from typing import TYPE_CHECKING, Optional, cast +from unittest.mock import patch + +import numpy as np +import torch +import torch.distributed +import torch.nn as nn +# TPU XLA related +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr + +import vllm.envs as envs +from vllm.attention.backends.abstract import AttentionType +from vllm.attention.layer import Attention +from vllm.config import VllmConfig +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal.utils import group_mm_inputs_by_modality +from vllm.sampling_params import SamplingType +from vllm.sequence import IntermediateTensors +from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available +from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, + PallasMetadata) +from vllm.v1.core.encoder_cache_manager import compute_encoder_budget +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheSpec, SlidingWindowSpec) +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, + ModelRunnerOutput, SamplerOutput) +from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata +from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler +from vllm.v1.utils import bind_kv_cache +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch + +from .utils import sanity_check_mm_encoder_outputs + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + +logger = init_logger(__name__) + +# Here we utilize the behavior that out-of-bound index is ignored. +# FIXME(woosuk): Find a more reliable way to prevent possible bugs. +_PAD_SLOT_ID = 1_000_000_000 +INVALID_TOKEN_ID = -1 +# Smallest output size +MIN_NUM_SEQS = 8 + + +class TPUModelRunner: + + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + ): + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + self.device_config = vllm_config.device_config + + model_config = self.model_config + cache_config = self.cache_config + scheduler_config = self.scheduler_config + parallel_config = self.parallel_config + self.device = device + self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION + + self.enforce_eager = model_config.enforce_eager + + self.num_xla_graphs = 0 + self._update_num_xla_graphs("init") + + self.pin_memory = is_pin_memory_available() + self.dtype = self.model_config.dtype + self._hidden_states_dtype = self.dtype + + self.is_multimodal_model = model_config.is_multimodal_model + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + self.max_model_len = model_config.max_model_len + self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) + self.max_num_tokens = scheduler_config.max_num_batched_tokens + # InputBatch needs to work with sampling tensors greater than padding + # to avoid dynamic shapes. Also, avoid suboptimal alignment. + self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS) + + # Model-related. + self.num_attn_layers = model_config.get_num_layers_by_block_type( + parallel_config, LayerBlockType.attention) + self.num_query_heads = model_config.get_num_attention_heads( + parallel_config) + self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) + self.head_size = model_config.get_head_size() + self.hidden_size = model_config.get_hidden_size() + + # Multi-modal data support + self.mm_registry = MULTIMODAL_REGISTRY + self.uses_mrope = model_config.uses_mrope + # TODO: Support M-RoPE (e.g, Qwen2-VL) + assert not self.uses_mrope, "TPU does not support M-RoPE yet." + + encoder_compute_budget, encoder_cache_size = compute_encoder_budget( + model_config=model_config, + scheduler_config=scheduler_config, + mm_registry=self.mm_registry, + ) + self.max_num_encoder_input_tokens = encoder_compute_budget + self.encoder_cache_size = encoder_cache_size + + # Lazy initialization + # self.model: nn.Module # Set after load_model + self.kv_caches: list[torch.Tensor] = [] + # req_id -> (input_id -> encoder_output) + self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} + + # Request states. + self.requests: dict[str, CachedRequestState] = {} + # Persistent batch. + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.max_model_len, + max_num_blocks_per_req=self.max_num_blocks_per_req, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=model_config.get_vocab_size(), + ) + + # Cached torch/numpy tensor + # The pytorch tensor and numpy array share the same buffer. + # Sometimes the numpy op is faster so we create both. + self.input_ids_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device="cpu") + self.input_ids_np = self.input_ids_cpu.numpy() + + self.positions_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device="cpu") + self.positions_np = self.positions_cpu.numpy() + + self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int64, + device="cpu") + self.slot_mapping_np = self.slot_mapping_cpu.numpy() + self.block_table_cpu = torch.zeros( + (self.max_num_tokens, self.max_num_blocks_per_req), + dtype=self.input_batch.block_table.get_cpu_tensor().dtype, + device="cpu") + + self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.query_start_loc_np = self.query_start_loc_cpu.numpy() + + self.seq_lens_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.seq_lens_np = self.seq_lens_cpu.numpy() + + # Range tensor with values [0 .. self.max_num_tokens - 1]. + # Used to initialize positions / context_lens / seq_lens + self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32) + self.num_tokens_paddings = _get_paddings( + min_token_size=16, + max_token_size=self.max_num_tokens, + padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP) + + def _update_num_xla_graphs(self, case_str): + check_comp = self.check_recompilation and not self.enforce_eager + if not check_comp: + return + + total_cached_graphs = xr.get_num_cached_compilation_graph() + new_compiled_graphs = total_cached_graphs - self.num_xla_graphs + if new_compiled_graphs == 0: + return + + logger.info("Add new %d compiled XLA graphs due to %s", + new_compiled_graphs, case_str) + self.num_xla_graphs += new_compiled_graphs + + def _verify_num_xla_graphs(self, case_str): + check_comp = self.check_recompilation and not self.enforce_eager + if not check_comp: + return + + curr_cached_graph = xr.get_num_cached_compilation_graph() + assert self.num_xla_graphs == curr_cached_graph, ( + "Recompilation after warm up is detected during {}." + " num_xla_graphs = {} curr_cached_graph = {}".format( + case_str, self.num_xla_graphs, curr_cached_graph)) + + def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: + """Update the cached states and the persistent batch with the scheduler + output. + + The updated states are used by the `_prepare_inputs` function to create + the input GPU tensors for the model. + + Returns: + True if there is a new/resumed/paused/finished request. + If False, we can skip copying SamplingMetadata to the GPU. + """ + # Remove finished requests from the cached states. + for req_id in scheduler_output.finished_req_ids: + self.requests.pop(req_id, None) + self.encoder_cache.pop(req_id, None) + + # Remove the finished requests from the persistent batch. + # NOTE(woosuk): There could be an edge case where finished_req_ids and + # scheduled_req_ids overlap. This happens when a request is aborted and + # then resubmitted with the same ID. In this case, we treat them as two + # distinct requests - clearing the cached states for the first request + # and handling the second as a new request. + removed_req_indices: list[int] = [] + for req_id in scheduler_output.finished_req_ids: + req_index = self.input_batch.remove_request(req_id) + if req_index is not None: + removed_req_indices.append(req_index) + + # Free the cached encoder outputs. + for req_id, input_id in scheduler_output.free_encoder_input_ids: + encoder_outputs = self.encoder_cache.get(req_id) + if encoder_outputs is not None: + encoder_outputs.pop(input_id, None) + if not encoder_outputs: + self.encoder_cache.pop(req_id, None) + + # Remove the unscheduled requests from the persistent batch. + # NOTE(woosuk): The unscheduled requests are either preempted requests + # or running requests that are not scheduled in this step. We remove + # them from the persistent batch but keep their cached states since + # they will be scheduled again sometime in the future. + scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys() + cached_req_ids = self.input_batch.req_id_to_index.keys() + unscheduled_req_ids = cached_req_ids - scheduled_req_ids + # NOTE(woosuk): The persistent batch optimization assumes that + # consecutive batches contain mostly the same requests. If batches + # have low request overlap (e.g., alternating between two distinct + # sets of requests), this optimization becomes very inefficient. + for req_id in unscheduled_req_ids: + req_index = self.input_batch.remove_request(req_id) + assert req_index is not None + removed_req_indices.append(req_index) + + req_ids_to_add: list[str] = [] + # Add new requests to the cached states. + for new_req_data in scheduler_output.scheduled_new_reqs: + req_id = new_req_data.req_id + sampling_params = new_req_data.sampling_params + if sampling_params.sampling_type == SamplingType.RANDOM_SEED: + generator = torch.Generator(device=self.device) + generator.manual_seed(sampling_params.seed) + else: + generator = None + + self.requests[req_id] = CachedRequestState( + req_id=req_id, + prompt_token_ids=new_req_data.prompt_token_ids, + prompt=new_req_data.prompt, + mm_inputs=new_req_data.mm_inputs, + mm_positions=new_req_data.mm_positions, + sampling_params=sampling_params, + generator=generator, + block_ids=new_req_data.block_ids, + num_computed_tokens=new_req_data.num_computed_tokens, + output_token_ids=[], + lora_request=new_req_data.lora_request, + ) + + req_ids_to_add.append(req_id) + + # Update the states of the running/resumed requests. + for req_data in scheduler_output.scheduled_cached_reqs: + req_id = req_data.req_id + req_state = self.requests[req_id] + + # Update the cached states. + req_state.num_computed_tokens = req_data.num_computed_tokens + if not req_data.resumed_from_preemption: + # Append the new blocks to the existing block IDs. + req_state.block_ids.extend(req_data.new_block_ids) + else: + # The request is resumed from preemption. + # Replace the existing block IDs with the new ones. + req_state.block_ids = req_data.new_block_ids + + req_index = self.input_batch.req_id_to_index.get(req_id) + if req_index is None: + # The request is not in the persistent batch. + # The request was either preempted and resumed later, or was not + # scheduled in the previous step and needs to be added again. + req_ids_to_add.append(req_id) + continue + + # Update the persistent batch. + self.input_batch.num_computed_tokens_cpu[req_index] = ( + req_data.num_computed_tokens) + self.input_batch.block_table.append_row(req_data.new_block_ids, + req_index) + + # Add the new or resumed requests to the persistent batch. + # The smaller empty indices are filled first. + removed_req_indices = sorted(removed_req_indices, reverse=True) + for req_id in req_ids_to_add: + req_state = self.requests[req_id] + if removed_req_indices: + # Fill the empty index. + req_index = removed_req_indices.pop() + else: + # Append to the end. + req_index = None + self.input_batch.add_request(req_state, req_index) + + # Condense the batched states if there are empty indices. + if removed_req_indices: + self.input_batch.condense(removed_req_indices) + + return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0 + + def get_model(self) -> nn.Module: + assert self.model is not None + return self.model + + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: + """ + Generates the KVCacheSpec by parsing the kv cache format from each + Attention module in the static forward context. + Returns: + KVCacheSpec: A dictionary mapping layer names to their KV cache + format. Layers that do not need KV cache are not included. + """ + + forward_ctx = self.vllm_config.compilation_config.static_forward_context + block_size = self.vllm_config.cache_config.block_size + kv_cache_spec: dict[str, KVCacheSpec] = {} + for layer_name, attn_module in forward_ctx.items(): + assert isinstance(attn_module, Attention) + if attn_module.attn_type == AttentionType.DECODER: + if attn_module.sliding_window is not None: + kv_cache_spec[layer_name] = SlidingWindowSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=attn_module.dtype, + sliding_window=attn_module.sliding_window, + use_mla=False, + ) + else: + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=attn_module.dtype, + use_mla=False, + ) + elif attn_module.attn_type in (AttentionType.ENCODER, + AttentionType.ENCODER_ONLY): + # encoder-only attention does not need KV cache. + continue + elif attn_module.attn_type == AttentionType.ENCODER_DECODER: + raise NotImplementedError + else: + raise ValueError( + f"Unknown attention type: {attn_module.attn_type}") + + return kv_cache_spec + + def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + num_reqs = self.input_batch.num_reqs + assert num_reqs > 0 + + # Get the number of scheduled tokens for each request. + num_scheduled_tokens_per_req = [] + max_num_scheduled_tokens_all_reqs = 0 + for req_id in self.input_batch.req_ids[:num_reqs]: + assert req_id is not None + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_scheduled_tokens_per_req.append(num_tokens) + max_num_scheduled_tokens_all_reqs = max( + max_num_scheduled_tokens_all_reqs, num_tokens) + num_scheduled_tokens_per_req = np.array(num_scheduled_tokens_per_req, + dtype=np.int32) + assert max_num_scheduled_tokens_all_reqs > 0 + + # Get request indices. + # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + # For each scheduled token, what are the corresponding req index. + req_indices = np.repeat(self.arange_np[:num_reqs], + num_scheduled_tokens_per_req) + + # Get batched arange. + # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # For each scheduled token, what is its position in corresponding req. + arange = np.concatenate( + [self.arange_np[:n] for n in num_scheduled_tokens_per_req]) + + # Get positions. + positions_np = self.positions_np[:total_num_scheduled_tokens] + np.add(self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np) + + # Get token indices. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] + # where M is the max_model_len. + token_indices = (positions_np + + req_indices * self.input_batch.token_ids_cpu.shape[1]) + + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. + torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + torch.from_numpy(token_indices), + out=self.input_ids_cpu[:total_num_scheduled_tokens]) + + # Calculate the slot mapping. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] + # where K is the max_num_blocks_per_req and the block size is 2. + # NOTE(woosuk): We can't simply use `token_indices // block_size` here + # because M (max_model_len) is not necessarily divisible by block_size. + # req_indices: # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + block_table_indices = (req_indices * self.max_num_blocks_per_req + + positions_np // self.block_size) + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. + block_table_cpu = self.input_batch.block_table.get_cpu_tensor() + block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() + block_offsets = positions_np % self.block_size + np.add(block_numbers * self.block_size, + block_offsets, + out=self.slot_mapping_np[:total_num_scheduled_tokens]) + + # Prepare the attention metadata. + self.query_start_loc_np[0] = 0 + np.cumsum(num_scheduled_tokens_per_req, + out=self.query_start_loc_np[1:num_reqs + 1]) + self.query_start_loc_np[num_reqs + 1:] = 1 + + self.seq_lens_np[:num_reqs] = ( + self.input_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens_per_req) + + # Do the padding and copy the tensors to the TPU. + padded_total_num_scheduled_tokens = _get_padded_token_len( + self.num_tokens_paddings, total_num_scheduled_tokens) + # Zero out to avoid spurious values from prev iteration (last cp chunk) + self.input_ids_cpu[ + total_num_scheduled_tokens:padded_total_num_scheduled_tokens] = 0 + self.input_ids = self.input_ids_cpu[: + padded_total_num_scheduled_tokens].to( + self.device) + self.position_ids = self.positions_cpu[: + padded_total_num_scheduled_tokens].to( + self.device) + self.slot_mapping_cpu[total_num_scheduled_tokens:] = _PAD_SLOT_ID + slot_mapping = self.slot_mapping_cpu[: + padded_total_num_scheduled_tokens].to( + self.device) + block_tables = self.block_table_cpu[:self.max_num_reqs] + block_tables[:num_reqs, :self.max_num_blocks_per_req] = ( + self.input_batch.block_table.get_cpu_tensor()[:num_reqs]) + block_tables = block_tables.to(self.device) + query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1].to( + self.device) + seq_lens = self.seq_lens_cpu[:self.max_num_reqs].to(self.device) + + attn_metadata = PallasMetadata( + slot_mapping=slot_mapping, + block_tables=block_tables, + context_lens=seq_lens, + query_start_loc=query_start_loc, + num_seqs=torch.tensor([num_reqs], + dtype=torch.int32, + device=self.device), + ) + # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial + # request in the batch. While we should not sample any token from this + # partial request, we do so for simplicity. We will ignore the sampled + # token from the partial request. + # TODO: Support prompt logprobs. + padded_num_reqs = _get_padded_num_reqs_with_upper_limit( + num_reqs, self.max_num_reqs) + # Indices at which we sample (positions of last token in the sequence). + # Padded to avoid recompiling when `num_reqs` varies. + logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1 + logits_indices = logits_indices.to(self.device) + return attn_metadata, logits_indices + + def _execute_encoder(self, scheduler_output: "SchedulerOutput"): + scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs + if not scheduled_encoder_inputs: + return + + # Batch the multi-modal inputs. + mm_inputs: list[MultiModalKwargs] = [] + req_input_ids: list[tuple[str, int]] = [] + for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): + req_state = self.requests[req_id] + for input_id in encoder_input_ids: + mm_inputs.append(req_state.mm_inputs[input_id]) + req_input_ids.append((req_id, input_id)) + + # Batch mm inputs as much as we can: if a request in the batch has + # multiple modalities or a different modality than the previous one, + # we process it separately to preserve item order. + # FIXME(ywang96): This is a hacky way to deal with multiple modalities + # in the same batch while still being able to benefit from batching + # multimodal inputs. The proper solution should be reordering the + # encoder outputs. + grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs) + + encoder_outputs = [] + for grouped_mm_inputs in grouped_mm_inputs_list: + batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs) + batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs, + device=self.device) + + # Run the encoder. + # `curr_group_outputs` is either of the following: + # 1. A tensor of shape (num_items, feature_size, hidden_size) + # in case feature_size is fixed across all multimodal items. + # 2. A list or tuple (length: num_items) of tensors, each of shape + # (feature_size, hidden_size) in case the feature size is dynamic + # depending on the input multimodal items. + curr_group_outputs = self.model.get_multimodal_embeddings( + **batched_mm_inputs) + + sanity_check_mm_encoder_outputs( + curr_group_outputs, + expected_num_items=len(grouped_mm_inputs), + ) + + for output in curr_group_outputs: + encoder_outputs.append(output) + + # Cache the encoder outputs. + for (req_id, input_id), output in zip(req_input_ids, encoder_outputs): + if req_id not in self.encoder_cache: + self.encoder_cache[req_id] = {} + self.encoder_cache[req_id][input_id] = output + + def _gather_encoder_outputs( + self, + scheduler_output: "SchedulerOutput", + ) -> list[torch.Tensor]: + encoder_outputs: list[torch.Tensor] = [] + for req_id in self.input_batch.req_ids: + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ + req_id] + req_state = self.requests[req_id] + num_computed_tokens = req_state.num_computed_tokens + mm_positions = req_state.mm_positions + for i, pos_info in enumerate(mm_positions): + start_pos = pos_info["offset"] + num_encoder_tokens = pos_info["length"] + + # The encoder output is needed if the two ranges overlap: + # [num_computed_tokens, + # num_computed_tokens + num_scheduled_tokens) and + # [start_pos, start_pos + num_encoder_tokens) + if start_pos >= num_computed_tokens + num_scheduled_tokens: + # The encoder output is not needed in this step. + break + if start_pos + num_encoder_tokens <= num_computed_tokens: + # The encoder output is already processed and stored + # in the decoder's KV cache. + continue + + start_idx = max(num_computed_tokens - start_pos, 0) + end_idx = min( + num_computed_tokens - start_pos + num_scheduled_tokens, + num_encoder_tokens) + assert start_idx < end_idx + assert req_id in self.encoder_cache + assert i in self.encoder_cache[req_id] + encoder_output = self.encoder_cache[req_id][i] + encoder_outputs.append(encoder_output[start_idx:end_idx]) + return encoder_outputs + + @torch.no_grad() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> ModelRunnerOutput: + # Update cached state + self._update_states(scheduler_output) + if not scheduler_output.total_num_scheduled_tokens: + # Return empty ModelRunnerOuptut if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + + if self.is_multimodal_model: + # Run the multimodal encoder if any. + self._execute_encoder(scheduler_output) + encoder_outputs = self._gather_encoder_outputs(scheduler_output) + else: + encoder_outputs = [] + + # Prepare inputs + attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) + if self.is_multimodal_model: + # NOTE(woosuk): To unify token ids and soft tokens (vision + # embeddings), we always use embeddings (rather than token ids) + # as input to the multimodal model, even when the input is text. + if encoder_outputs: + inputs_embeds = self.model.get_input_embeddings( + self.input_ids, encoder_outputs) + else: + inputs_embeds = self.model.get_input_embeddings(self.input_ids) + input_ids = None + else: + # For text-only models, we use token ids as input. + # While it is possible to use embeddings as input just like the + # multimodal models, it is not desirable for performance since + # then the embedding layer is not included in the CUDA graph. + input_ids = self.input_ids + inputs_embeds = None + num_reqs = self.input_batch.num_reqs + # NOTE (NickLucche) here we sync with TPU: sampling params tensors + # are copied to device in chunks of pre-compiled padded shape to + # avoid recompilations. + tpu_sampling_metadata = TPUSupportedSamplingMetadata.\ + from_input_batch(self.input_batch, logits_indices) + # Run the decoder + with set_forward_context(attn_metadata, self.vllm_config): + hidden_states = self.model( + input_ids=input_ids, + positions=self.position_ids, + kv_caches=self.kv_caches, + inputs_embeds=inputs_embeds, + ) + selected_token_ids = self.model.sample_from_hidden( + hidden_states, tpu_sampling_metadata) + # Remove padding on cpu and keep dynamic op outside of xla graph. + selected_token_ids = selected_token_ids.cpu()[:num_reqs] + + # Update the cache state concurrently. Code above will not block until + # we use `selected_token_ids`. Add mark_step if post-processing changes + request_seq_lens: list[tuple[int, CachedRequestState, int]] = [] + discard_sampled_tokens_req_indices = [] + for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): + assert req_id is not None + req_state = self.requests[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + if seq_len >= req_state.num_tokens: + request_seq_lens.append((i, req_state, seq_len)) + else: + # Ignore the sampled token from the partial request. + # Rewind the generator state as if the token was not sampled. + generator = self.input_batch.generators.get(i) + if generator is not None: + # This relies on cuda-specific torch-internal impl details + generator.set_offset(generator.get_offset() - 4) + + # Record the index of the request that should not be sampled, + # so that we could clear the sampled tokens before returning. + discard_sampled_tokens_req_indices.append(i) + + assert all( + req_id is not None for req_id in + self.input_batch.req_ids[:num_reqs]), "req_ids contains None" + req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs]) + + prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {} + for req_id in self.input_batch.req_ids[:num_reqs]: + prompt_logprobs_dict[req_id] = None + + max_gen_len = selected_token_ids.shape[-1] + if max_gen_len == 1: + valid_sampled_token_ids = selected_token_ids.tolist() + + # Mask out the sampled tokens that should not be sampled. + # TODO: Keep in sync with gpu_model_runner.py, in particular + # the "else" case here + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[i].clear() + + # Append sampled tokens + for i, req_state, seq_len in request_seq_lens: + token_id = valid_sampled_token_ids[i][0] + self.input_batch.token_ids_cpu[i, seq_len] = token_id + req_state.output_token_ids.append(token_id) + self.input_batch.num_tokens[i] += 1 + + else: + valid_mask = selected_token_ids != INVALID_TOKEN_ID + gen_lens = valid_mask.sum(dim=1).tolist() + valid_sampled_token_ids = [ + seq.tolist() + for seq in selected_token_ids[valid_mask].split(gen_lens) + ] + self.input_batch.num_tokens[:num_reqs] += gen_lens + for i, req_state, seq_len in request_seq_lens: + target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1) + self.input_batch.token_ids_cpu[ + i, target_slice] = valid_sampled_token_ids[i] + req_state.output_token_ids.extend(valid_sampled_token_ids[i]) + + model_runner_output = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=valid_sampled_token_ids, + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict=prompt_logprobs_dict, + ) + + # Check there are no new graphs compiled - all the graphs should be + # captured and compiled during warm up. + self._verify_num_xla_graphs("execute_model") + + return model_runner_output + + def load_model(self) -> None: + self.device = self.device_config.device + + # NOTE(woosuk): While the executor assigns the TP ranks to the worker + # process, the ranks can be different from the ranks internally assigned + # by the xm runtime. Therefore, there is a mismatch in the rank + # assignment between the gloo (cpu) runtime and the xm (tpu) runtime. + # This is not a problem in linear layers because all-reduce is + # rank-agnostic. However, it matters for all-gather as the ranks + # determine the order of concatenating the output tensors. + # As a workaround, we use the xm's rank assignment only when loading + # the embedding weights. + xm_tp_rank = xr.global_ordinal() + with patch( + "vllm.model_executor.layers.vocab_parallel_embedding." + "get_tensor_model_parallel_rank", + return_value=xm_tp_rank): + model = get_model(vllm_config=self.vllm_config) + model = model.eval() + xm.mark_step() + xm.wait_device_ops() + model = ModelWrapperV1(model) + self.model = torch.compile(model, + backend="openxla", + fullgraph=True, + dynamic=False) + + @torch.no_grad() + def _dummy_run(self, kv_caches, num_tokens: int) -> None: + if self.is_multimodal_model: + input_ids = None + inputs_embeds = torch.zeros((num_tokens, self.hidden_size), + dtype=self.dtype, + device=self.device) + else: + input_ids = torch.zeros((num_tokens), + dtype=torch.int32, + device=self.device) + inputs_embeds = None + actual_num_reqs = min(num_tokens, self.max_num_reqs) + position_ids = torch.zeros(num_tokens, + dtype=torch.int32, + device=self.device) + slot_mapping = torch.zeros(num_tokens, + dtype=torch.int64, + device=self.device) + block_tables = torch.zeros( + (self.max_num_reqs, self.block_table_cpu.shape[1]), + dtype=torch.int32, + device=self.device) + query_lens = [1] * self.max_num_reqs + query_start_loc = torch.cumsum(torch.tensor([0] + query_lens, + dtype=torch.int32), + dim=0, + dtype=torch.int32).to(self.device) + context_lens = torch.ones((self.max_num_reqs, ), + dtype=torch.int32, + device=self.device) + num_seqs = torch.tensor([actual_num_reqs], + dtype=torch.int32, + device=self.device) + attn_metadata = PallasMetadata( + slot_mapping=slot_mapping, + block_tables=block_tables, + context_lens=context_lens, + query_start_loc=query_start_loc, + num_seqs=num_seqs, + ) + + if self.is_multimodal_model: + torch._dynamo.mark_dynamic(inputs_embeds, 0) + else: + torch._dynamo.mark_dynamic(input_ids, 0) + torch._dynamo.mark_dynamic(position_ids, 0) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) + + with set_forward_context(attn_metadata, self.vllm_config, 0): + out = self.model(input_ids=input_ids, + positions=position_ids, + kv_caches=kv_caches, + inputs_embeds=inputs_embeds) + self._hidden_states_dtype = out.dtype + + def capture_model(self) -> None: + """Compile the model.""" + + logger.info("Compiling the model with different input shapes.") + + start = time.perf_counter() + for num_tokens in self.num_tokens_paddings: + logger.info(" -- num_tokens: %d", num_tokens) + self._dummy_run(self.kv_caches, num_tokens) + xm.mark_step() + xm.wait_device_ops() + end = time.perf_counter() + + logger.info("Compilation finished in in %.2f [secs].", end - start) + self._update_num_xla_graphs("model") + + logger.info("Compiling sampling with different input shapes.") + start = time.perf_counter() + hsize = self.model_config.get_hidden_size() + device = self.device + # Compile sampling step for different model+sampler outputs in bucketed + # n_tokens x max_num_reqs. Graph is really small so this is fine. + for num_tokens in self.num_tokens_paddings: + num_reqs_to_sample = MIN_NUM_SEQS + dummy_hidden = torch.randn((num_tokens, hsize), + device=device, + dtype=self._hidden_states_dtype) + # Compile for [8, 16, .., 128,.., `self.max_num_reqs`] + while True: + indices = torch.zeros( + num_reqs_to_sample, + dtype=torch.int32, + device=device, + ) + xm.mark_step() + sampling_meta = TPUSupportedSamplingMetadata.\ + from_input_batch(self.input_batch, indices) + logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens, + num_reqs_to_sample) + out = self.model.sample_from_hidden(dummy_hidden, + sampling_meta) + out = out.cpu() + # Requests can't be more than tokens. But do compile for the + # next bigger value in case num_tokens uses bucketed padding. + if num_reqs_to_sample >= min(num_tokens, self.max_num_reqs): + break + # Make sure to compile the `max_num_reqs` upper-limit case + num_reqs_to_sample = _get_padded_num_reqs_with_upper_limit( + num_reqs_to_sample + 1, self.max_num_reqs) + xm.wait_device_ops() + end = time.perf_counter() + + logger.info("Compilation finished in in %.2f [secs].", end - start) + self._update_num_xla_graphs("sampling") + + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: + """ + Initialize KV cache based on `kv_cache_config`. + Args: + kv_cache_config: Configuration for the KV cache, including the KV + cache size of each layer + """ + if len(kv_cache_config.kv_cache_groups) > 1: + raise NotImplementedError( + "Hybrid models with more than one KV cache type are not " + "supported yet.") + + kv_caches: dict[str, torch.Tensor] = {} + + for kv_cache_group in kv_cache_config.kv_cache_groups: + kv_cache_spec = kv_cache_group.kv_cache_spec + for layer_name in kv_cache_group.layer_names: + tensor_config = kv_cache_config.tensors[layer_name] + assert tensor_config.size % kv_cache_spec.page_size_bytes == 0 + num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes + if isinstance(kv_cache_spec, FullAttentionSpec): + kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + dtype = kv_cache_spec.dtype + + tpu_kv_cache = torch.zeros(kv_cache_shape, + dtype=dtype, + device=self.device) + + kv_caches[layer_name] = tpu_kv_cache + else: + raise NotImplementedError + + bind_kv_cache( + kv_caches, + self.vllm_config.compilation_config.static_forward_context, + self.kv_caches) + + +class ModelWrapperV1(nn.Module): + + def __init__(self, model: nn.Module): + super().__init__() + self.model = model + self.sampler = TPUSampler() + + def sample( + self, logits: torch.Tensor, + sampling_metadata: TPUSupportedSamplingMetadata) -> SamplerOutput: + sampler_out = self.sampler(logits, sampling_metadata) + return sampler_out + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: list[torch.Tensor], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Executes the forward pass of the model. + + Args: + input_ids: The input token IDs of shape [num_tokens]. + positions: The input position IDs of shape [num_tokens]. + kv_caches: The key and value caches. They can be None during the + memory profiling at initialization. + inputs_embeds: The input embeddings of shape [num_tokens, + hidden_size]. It is used for multimodal models. + """ + + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + ) + + return hidden_states + + def sample_from_hidden( + self, + hidden_states: torch.Tensor, + sampling_metadata: TPUSupportedSamplingMetadata, + ) -> torch.Tensor: + """ + Sample with xla-friendly function. This function is to be traced + separately from `forward` for lighter compilation overhead. + """ + # Tensor `sample_hidden_states` is of fixed pre-compiled size. + sample_hidden_states = \ + hidden_states[sampling_metadata.indices_do_sample] + logits = self.compute_logits(sample_hidden_states) + # Optimized greedy sampling branch, tracing both paths in a single pass + # NOTE all_greedy is a scalar, this is just an optimized if/else. + out_tokens = torch.where(sampling_metadata.all_greedy, + torch.argmax(logits, dim=-1, keepdim=True), + self.sample(logits, sampling_metadata)\ + .sampled_token_ids) + return out_tokens + + def compute_logits(self, + hidden_states: torch.Tensor) -> Optional[torch.Tensor]: + # SamplingMetadata here for pruning output in LogitsProcessor, disabled + logits = self.model.compute_logits(hidden_states, None) + return logits + + def get_multimodal_embeddings(self, *args, **kwargs): + return self.model.get_multimodal_embeddings(*args, **kwargs) + + def get_input_embeddings(self, *args, **kwargs): + return self.model.get_input_embeddings(*args, **kwargs) + + +def _get_padded_number(n: int, multiple: int) -> int: + return ((n + multiple - 1) // multiple) * multiple + + +def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int: + res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length() + return min(res, upper_limit) + + +def _get_paddings(min_token_size: int, max_token_size: int, + padding_gap: int) -> list[int]: + """Generate a list of padding size, starting from min_token_size, + ending with a number that can cover max_token_size + + If padding_gap == 0 then: + increase 2X each time (exponential) + else: + first increase the size to twice, + then increase the padding size by padding_gap. + """ + paddings = [] + num = min_token_size + + if padding_gap == 0: + logger.info("Using exponential paddings:") + while num <= max_token_size: + logger.info(" %d", num) + paddings.append(num) + num *= 2 + + else: + logger.info("Using incremental paddings:") + while num <= padding_gap: + logger.info(" %d", num) + paddings.append(num) + num *= 2 + num //= 2 + while num < max_token_size: + num += padding_gap + logger.info(" %d", num) + paddings.append(num) + + return paddings + + +def _get_padded_token_len(paddings: list[int], x: int) -> int: + """Return the first element in paddings list greater or equal to x. + """ + index = bisect.bisect_left(paddings, x) + assert index < len(paddings) + return paddings[index] diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py new file mode 100644 index 0000000..67902b4 --- /dev/null +++ b/vllm/v1/worker/tpu_worker.py @@ -0,0 +1,250 @@ +# SPDX-License-Identifier: Apache-2.0 +"""A TPU worker class.""" +import os +from typing import Optional + +import torch +import torch.distributed +import torch.nn as nn +import torch_xla.core.xla_model as xm +import torch_xla.debug.profiler as xp +import torch_xla.runtime as xr + +import vllm.envs as envs +from vllm.config import ParallelConfig, VllmConfig +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.logger import init_logger +from vllm.model_executor import set_random_seed +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig, + KVCacheSpec) +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.utils import bind_kv_cache +from vllm.v1.worker.tpu_model_runner import TPUModelRunner + +logger = init_logger(__name__) + + +class TPUWorker: + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False, + ): + self.is_driver_worker = is_driver_worker + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + + self.parallel_config.rank = rank + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + + if self.cache_config.cache_dtype == "auto": + self.cache_dtype = self.model_config.dtype + else: + self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ + self.cache_config.cache_dtype] + + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + # Delay profiler initialization to the start of the profiling. + # This is because in vLLM V1, MP runtime is initialized before the + # TPU Worker is initialized. The profiler server needs to start after + # MP runtime is initialized. + self.profiler = None + self.profile_dir = None + if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1: + # For TPU, we can only have 1 active profiler session for 1 profiler + # server. So we only profile on rank0. + self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR + logger.info("Profiling enabled. Traces will be saved to: %s", + self.profile_dir) + + if self.model_config.seed is None: + self.model_config.seed = 0 + + def init_device(self): + os.environ["PJRT_DEVICE"] = "TPU" + # Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D + # ring, the xla tpu compiler flag + # `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to + # fix this. It will be removed after the bug in XLA compiler is fixed. + os.environ["LIBTPU_INIT_ARGS"] = ( + "--xla_tpu_force_1d_allreduce_at_chunk_count=1") + torch.set_grad_enabled(False) + torch.set_default_dtype(self.model_config.dtype) + + # Initialize the distributed environment. + init_tpu_worker_distributed_environment(self.parallel_config, + self.rank, + self.distributed_init_method, + self.local_rank) + + # Device initialization should happen after initializing + # the distributed runtime. + self.device = xm.xla_device() + self.device_config.device = self.device + + # Set random seed. + set_random_seed(self.model_config.seed) + if self.model_config.seed is not None: + xm.set_rng_state(self.model_config.seed, self.device) + + # Increase the cache size limit, which is the maximum number of + # dynamo graphs that can be compiled. + # TODO (NickLucche) On gsm we compile 80+ graphs. + # Re-evaluate limit, with MM we may get close to this limit. + torch._dynamo.config.cache_size_limit = 128 + # Use persistent cache to avoid XLA recompilation. + # NOTE(woosuk): Set per-rank cache path since different ranks + # can have slightly different XLA graphs. + world_size = self.parallel_config.world_size + rank = xr.global_ordinal() + # The PyTorch/XLA compilation cache uses the Torch IR to generate keys. + # Consequently, changes in optimization flags, which affect compilation + # results, don't change the cache key. This can result in the wrong + # compilation being used. To prevent this, disabling the XLA compilation + # cache during development is recommended.We can disable it by + # `export VLLM_XLA_CACHE_PATH=` + if envs.VLLM_XLA_CACHE_PATH: + per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, + f"tp{world_size}_rank{rank}") + xr.initialize_cache(per_rank_path, readonly=False) + + # Init ModelRunner here, so that we have access to self.device. + self.model_runner = TPUModelRunner(self.vllm_config, self.device) + + def determine_available_memory(self) -> int: + kv_caches: dict[str, torch.Tensor] = {} + kv_cache_spec = self.model_runner.get_kv_cache_spec() + for layer_name, layer_spec in kv_cache_spec.items(): + if isinstance(layer_spec, AttentionSpec): + dtype = layer_spec.dtype + + # Use an empty tensor instead of `None`` to force Dynamo to pass + # it by reference, rather by specializing on the value ``None``. + tpu_kv_cache = torch.tensor([], + dtype=dtype, + device=self.device) + kv_caches[layer_name] = tpu_kv_cache + else: + raise NotImplementedError( + f"Unsupported KV cache spec '{type(layer_spec)}'") + + runner_kv_caches: list[torch.Tensor] = [] + bind_kv_cache( + kv_caches, + self.vllm_config.compilation_config.static_forward_context, + runner_kv_caches) + + self.model_runner._dummy_run( + runner_kv_caches, + num_tokens=self.scheduler_config.max_num_batched_tokens, + ) + + # Synchronize before measuring the memory usage. + xm.wait_device_ops() + + # Get the maximum amount of memory used by the model weights and + # intermediate activations. + m = xm.get_memory_info(self.device) + total_memory_size = m["bytes_limit"] + current_mem = m["bytes_used"] + # Ideally we would use profiled = m["peak_bytes_used"] to + # get weights + activations. But there is memory used during + # compilation / weight loading that impacts the peak and + # there is no way to reset peak memory in XLA, So we + # use the heuristic of 2% of weights. + profiled = current_mem * 1.02 + + # Calculate the TPU KV cache size based on profiling. + usable_memory_size = int(total_memory_size * + self.cache_config.gpu_memory_utilization) + tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0) + + return int(tpu_kv_cache_bytes) + + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> Optional[ModelRunnerOutput]: + output = self.model_runner.execute_model(scheduler_output) + return output if self.is_driver_worker else None + + def profile(self, is_start: bool = True): + if self.rank < 1: + if self.profile_dir is None: + raise RuntimeError("Profiler is not enabled.") + if is_start: + if self.profiler is None: + self.profiler = xp.start_server(9012) + xp.start_trace(self.profile_dir) + else: + xp.stop_trace() + + def load_model(self) -> None: + self.model_runner.load_model() + + def compile_or_warm_up_model(self) -> None: + if not self.model_config.enforce_eager: + self.model_runner.capture_model() + + # Reset the seed to ensure that the random state is not affected by + # the model initialization and profiling. + set_random_seed(self.model_config.seed) + + def get_model(self) -> nn.Module: + return self.model_runner.get_model() + + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: + return self.model_runner.get_kv_cache_spec() + + def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: + """Allocate GPU KV cache with the specified kv_cache_config.""" + self.model_runner.initialize_kv_cache(kv_cache_config) + + def check_health(self) -> None: + # worker will always be healthy as long as it's running. + return + + +def init_tpu_worker_distributed_environment( + parallel_config: ParallelConfig, + rank: int, + distributed_init_method: Optional[str] = None, + local_rank: int = -1, +) -> None: + """Initialize the distributed environment.""" + + # NOTE(woosuk): This is just to initialize the TP group and broadcast + # the input objects on CPU. The all-reduce and all-gather ops on TPU + # are invoked by `xm.all_reduce` and `xm.all_gather` which use their + # own context. + init_distributed_environment( + world_size=parallel_config.world_size, + rank=rank, + local_rank=local_rank, + distributed_init_method=distributed_init_method, + backend="gloo", + ) + ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py new file mode 100644 index 0000000..b1d3aa7 --- /dev/null +++ b/vllm/v1/worker/utils.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 +import torch + + +def sanity_check_mm_encoder_outputs( + mm_embeddings: object, + expected_num_items: int, +) -> None: + """ + Perform sanity checks for the result of + :meth:`vllm.model_executor.models.SupportsMultiModal.get_multimodal_embeddings`. + """ + assert isinstance(mm_embeddings, (list, tuple, torch.Tensor)), ( + "Expected multimodal embeddings to be a list/tuple of 2D tensors, " + f"or a single 3D tensor, but got {type(mm_embeddings)} " + "instead. This is most likely due to incorrect implementation " + "of the model's `get_multimodal_embeddings` method.") + + assert len(mm_embeddings) == expected_num_items, ( + "Expected number of multimodal embeddings to match number of " + f"input items: {expected_num_items}, but got {len(mm_embeddings)=} " + "instead. This is most likely due to incorrect implementation " + "of the model's `get_multimodal_embeddings` method.") + + assert all(e.ndim == 2 for e in mm_embeddings), ( + "Expected multimodal embeddings to be a sequence of 2D tensors, " + f"but got tensors with shapes {[e.shape for e in mm_embeddings]} " + "instead. This is most likely due to incorrect implementation " + "of the model's `get_multimodal_embeddings` method.") diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py new file mode 100644 index 0000000..487a49b --- /dev/null +++ b/vllm/v1/worker/worker_base.py @@ -0,0 +1,64 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import torch +import torch.nn as nn + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.v1.kv_cache_interface import KVCacheSpec +from vllm.worker.worker_base import WorkerBase as WorkerBaseV0 + +logger = init_logger(__name__) + + +class WorkerBase(WorkerBaseV0): + """ + Abstract class for v1 worker, mainly define some methods for v1. + For methods shared by v0 and v1, define them in v0 WorkerBase + """ + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False, + ): + """ + Initialize common worker components. + + Args: + vllm_config: Complete vLLM configuration + local_rank: Local device index + rank: Global rank in distributed setup + distributed_init_method: Distributed initialization method + is_driver_worker: Whether this worker handles driver + responsibilities + """ + # Configuration storage + super().__init__(vllm_config=vllm_config) + + self.parallel_config.rank = rank + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.is_driver_worker = is_driver_worker + + # Device and model state + self.device: Optional[torch.device] = None + self.model_runner: Optional[nn.Module] = None + + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: + """Get specifications for KV cache implementation.""" + raise NotImplementedError + + def compile_or_warm_up_model(self) -> None: + """Prepare model for execution through compilation/warmup.""" + raise NotImplementedError + + def check_health(self) -> None: + """Basic health check (override for device-specific checks).""" + return diff --git a/vllm/version.py b/vllm/version.py new file mode 100644 index 0000000..c68dea1 --- /dev/null +++ b/vllm/version.py @@ -0,0 +1,2 @@ +__version__ = "0.8.3" +__version_tuple__ = (0, 8, 3) diff --git a/vllm/vllm_flash_attn/.gitkeep b/vllm/vllm_flash_attn/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/vllm/vllm_flash_attn/fa_utils.py b/vllm/vllm_flash_attn/fa_utils.py new file mode 100644 index 0000000..ca88549 --- /dev/null +++ b/vllm/vllm_flash_attn/fa_utils.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +from vllm import envs +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]: + # import here to avoid circular dependencies + from vllm.platforms import current_platform + try: + from vllm.vllm_flash_attn.flash_attn_interface import ( + fa_version_unsupported_reason, is_fa_version_supported) + device_capability = current_platform.get_device_capability() + + assert device_capability is not None + + # 1. default version depending on platform + fa_version = 3 if (device_capability.major == 9 + and is_fa_version_supported(3)) else 2 + + # 2. override if passed by environment + if envs.VLLM_FLASH_ATTN_VERSION is not None: + assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3] + fa_version = envs.VLLM_FLASH_ATTN_VERSION + + # 3. fallback for unsupported combinations + if device_capability.major == 10 and fa_version == 3: + logger.warning_once( + "Cannot use FA version 3 on Blackwell platform " + "defaulting to FA version 2.") + fa_version = 2 + + if requires_alibi and fa_version == 3: + logger.warning_once("Cannot use FA version 3 with ALiBi, " + "defaulting to FA version 2.") + fa_version = 2 + + if not is_fa_version_supported(fa_version): + logger.error("Cannot use FA version %d is not supported due to %s", + fa_version, fa_version_unsupported_reason(fa_version)) + + assert is_fa_version_supported(fa_version) + return fa_version + except (ImportError, AssertionError): + return None + + +def flash_attn_supports_fp8() -> bool: + from vllm.platforms import current_platform + return get_flash_attn_version() == 3 and \ + current_platform.get_device_capability().major == 9 diff --git a/vllm/worker/__init__.py b/vllm/worker/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py new file mode 100644 index 0000000..e2fe281 --- /dev/null +++ b/vllm/worker/cache_engine.py @@ -0,0 +1,156 @@ +# SPDX-License-Identifier: Apache-2.0 +"""CacheEngine class for managing the KV cache.""" +from typing import List + +import torch + +from vllm.attention import get_attn_backend +from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig +from vllm.logger import init_logger +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, + get_dtype_size, is_pin_memory_available) +import vllm.envs as envs + +logger = init_logger(__name__) + + +class CacheEngine: + """Manages the KV cache. + + This class is responsible for initializing and managing the GPU and CPU KV + caches. It also provides methods for performing KV cache operations, such + as swapping and copying. + """ + + def __init__( + self, + cache_config: CacheConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, + device_config: DeviceConfig, + ) -> None: + self.cache_config = cache_config + self.model_config = model_config + self.parallel_config = parallel_config + self.device_config = device_config + + self.head_size = model_config.get_head_size() + # Models like Jamba, have mixed typed layers, E.g Mamba + self.num_attention_layers = model_config.get_num_layers_by_block_type( + parallel_config, LayerBlockType.attention) + self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) + + self.block_size = cache_config.block_size + self.num_gpu_blocks = cache_config.num_gpu_blocks + if self.num_gpu_blocks: + self.num_gpu_blocks //= parallel_config.pipeline_parallel_size + self.num_cpu_blocks = cache_config.num_cpu_blocks + if self.num_cpu_blocks: + self.num_cpu_blocks //= parallel_config.pipeline_parallel_size + + if cache_config.cache_dtype == "auto": + self.dtype = model_config.dtype + else: + self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + + if envs.VLLM_USE_INT8_MLA: + self.dtype = torch.int8 + # Get attention backend. + self.attn_backend = get_attn_backend(self.head_size, + model_config.dtype, + cache_config.cache_dtype, + self.block_size, + model_config.is_attention_free, + use_mla=model_config.use_mla) + + # Initialize the cache. + if envs.VLLM_USE_INT8_MLA: + self.gpu_cache, self.gpu_cache_scale = self._allocate_kv_cache( + self.num_gpu_blocks, self.device_config.device_type) + else: + self.gpu_cache = self._allocate_kv_cache( + self.num_gpu_blocks, self.device_config.device_type) + self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu") + + def _allocate_kv_cache( + self, + num_blocks: int, + device: str, + ) -> List[torch.Tensor]: + """Allocates KV cache on the specified device.""" + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, self.block_size, self.num_kv_heads, self.head_size) + pin_memory = is_pin_memory_available() if device == "cpu" else False + kv_cache: List[torch.Tensor] = [] + kv_cache_scale: List[torch.Tensor] = [] + + for _ in range(self.num_attention_layers): + # null block in CpuGpuBlockAllocator requires at least that + # block to be zeroed-out. + # We zero-out everything for simplicity. + layer_kv_cache = torch.zeros(kv_cache_shape, + dtype=self.dtype, + pin_memory=pin_memory, + device=device) + + # view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases + # when entry_shape is higher than 1D + kv_cache.append(layer_kv_cache) + if envs.VLLM_USE_INT8_MLA: + kv_cache_scale_shape = kv_cache_shape[:-1]+(2,) + layer_kv_cache_scale = torch.zeros(kv_cache_scale_shape, + dtype=torch.float32, + pin_memory=pin_memory, + device=device) + kv_cache_scale.append(layer_kv_cache_scale) + if envs.VLLM_USE_INT8_MLA: + return kv_cache, kv_cache_scale + else: + return kv_cache + + def swap_in(self, src_to_dst: torch.Tensor) -> None: + for i in range(self.num_attention_layers): + self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i], + src_to_dst) + + def swap_out(self, src_to_dst: torch.Tensor) -> None: + for i in range(self.num_attention_layers): + self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i], + src_to_dst) + + def copy(self, src_to_dsts: torch.Tensor) -> None: + self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts) + + @staticmethod + def get_cache_block_size( + cache_config: CacheConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, + ) -> int: + head_size = model_config.get_head_size() + num_heads = model_config.get_num_kv_heads(parallel_config) + num_attention_layers = model_config.get_num_layers_by_block_type( + parallel_config, LayerBlockType.attention) + + if cache_config.cache_dtype == "auto": + dtype = model_config.dtype + else: + dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + + key_cache_entry = num_heads * head_size + + # For MLA there is no value cache, since the latent vector + # is joint keys and values. + value_cache_entry = key_cache_entry if not model_config.use_mla else 0 + total = num_attention_layers * cache_config.block_size * \ + (key_cache_entry + value_cache_entry) + + if envs.VLLM_USE_INT8_MLA: + dtype = torch.int8 + dtype_quant_scale = get_dtype_size(torch.float32) + ## 因为要保存两个scale. kv_scale 和 k_pe_scale + total_qaunt_scale = num_heads * dtype_quant_scale * 2 + return get_dtype_size(dtype) * total + total_qaunt_scale + else: + dtype_size = get_dtype_size(dtype) + return dtype_size * total diff --git a/vllm/worker/cpu_enc_dec_model_runner.py b/vllm/worker/cpu_enc_dec_model_runner.py new file mode 100644 index 0000000..ac7c93e --- /dev/null +++ b/vllm/worker/cpu_enc_dec_model_runner.py @@ -0,0 +1,323 @@ +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, cast + +import torch + +from vllm.attention import AttentionMetadata +from vllm.forward_context import set_forward_context +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.multimodal import MultiModalKwargs +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata +from vllm.utils import make_tensor_with_pad +from vllm.worker.cpu_model_runner import (CPUModelRunnerBase, + ModelInputForCPUBuilder, + ModelInputForCPUWithSamplingMetadata) +from vllm.worker.model_runner_base import ( + _add_attn_metadata_broadcastable_dict, + _add_sampling_metadata_broadcastable_dict) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + + +@dataclasses.dataclass(frozen=True) +class EncoderDecoderModelInputForCPU(ModelInputForCPUWithSamplingMetadata): + """ + Used by the EncoderDecoderModelRunner. + """ + encoder_input_tokens: Optional[torch.Tensor] = None + encoder_input_positions: Optional[torch.Tensor] = None + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + "encoder_input_tokens": self.encoder_input_tokens, + "encoder_input_positions": self.encoder_input_positions, + "multi_modal_kwargs": self.multi_modal_kwargs, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "EncoderDecoderModelInputForCPU": + return cast( + EncoderDecoderModelInputForCPU, + super().from_broadcasted_tensor_dict(tensor_dict, attn_backend)) + + +class CPUEncoderDecoderModelRunner( + CPUModelRunnerBase[EncoderDecoderModelInputForCPU]): + _model_input_cls: Type[EncoderDecoderModelInputForCPU] = ( + EncoderDecoderModelInputForCPU) + _builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder + + def _list_to_int32_tensor( + self, + _list: List[int], + ) -> torch.Tensor: + return torch.tensor(_list, dtype=torch.int32, device=self.device) + + def _list_to_long_tensor( + self, + _list: List[int], + ) -> torch.Tensor: + return torch.tensor(_list, dtype=torch.long, device=self.device) + + def _empty_int32_tensor(self) -> torch.Tensor: + return self._list_to_int32_tensor([]) + + def _empty_long_tensor(self) -> torch.Tensor: + return self._list_to_long_tensor([]) + + def make_model_input_from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, + Any]) -> EncoderDecoderModelInputForCPU: + return EncoderDecoderModelInputForCPU.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + ) + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> EncoderDecoderModelInputForCPU: + model_input = self._prepare_model_input_tensors( + seq_group_metadata_list, finished_requests_ids) + ( + attn_metadata, + encoder_input_tokens_tensor, + encoder_input_positions_tensor, + ) = self._prepare_encoder_model_input_tensors(seq_group_metadata_list, + model_input) + # Sampling metadata is only required for the final pp group + generators = self.get_generators(finished_requests_ids) + sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, + model_input.seq_lens, + model_input.query_lens, + self.device, + pin_memory=False, + generators=generators) + return dataclasses.replace( + model_input, + sampling_metadata=sampling_metadata, + attn_metadata=attn_metadata, + encoder_input_tokens=encoder_input_tokens_tensor, + encoder_input_positions=encoder_input_positions_tensor, + virtual_engine=virtual_engine, + ) + + def _prepare_encoder_model_input_tensors( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + model_input: EncoderDecoderModelInputForCPU, + ) -> Tuple[AttentionMetadata, Optional[torch.Tensor], + Optional[torch.Tensor]]: + """Helper method to prepare the encoder- and cross-attn-related + model inputs based on a given sequence group. These additional inputs + are used to augment an already-computed `EncoderDecoderModelInput` + data structure which already has decoder-related model inputs + populated. + + Sets the following attn_metadata fields: + * `num_encoder_tokens` + * `encoder_seq_lens` + * `encoder_seq_lens_tensor` + * `max_encoder_seq_len` + * `cross_slot_mapping` + * `cross_block_tables` + + Constructs a new model inputs data structure, based on + (1) the existing fields in the `model_inputs` argument, + and (2) the following additional fields which are + computed (or in the case of `attn_metadata`, updated) + by this function: + * attn_metadata + * encoder_input_tokens + * encoder_input_positions + + Arguments: + + * seq_group_metadata_list: list of sequence groups for which to + compute inputs + * model_inputs: model inputs data structure with decoder-oriented + fields already computed. + + Return: + + * Updated model inputs data structure + """ + + if len(seq_group_metadata_list) == 0: + return (model_input.attn_metadata, None, None) + + # Since we are not supporting chunked prefill either the entire + # batch is prefill or it is decode + is_prompt = seq_group_metadata_list[0].is_prompt + + # Build encoder inputs + encoder_seq_lens: List[int] = [] + if is_prompt: + # Prefill phase. + cross_block_tables = self._empty_int32_tensor().view( + len(seq_group_metadata_list), -1) + + # Extract input tokens/positions, cross-attention slot-mapping, + # & seq len from each sequence group metadata + ( + encoder_input_tokens, + encoder_input_positions, + cross_slot_mapping, + ) = ( + [], + [], + [], + ) + for seq_group_metadata in seq_group_metadata_list: + # Build seq lens + seq_len = seq_group_metadata.encoder_seq_data.get_len() + token_ids = seq_group_metadata.encoder_seq_data.get_token_ids() + encoder_seq_lens.append(seq_len) + + # Build slot mapping + for i in range(0, seq_len): + block_number = seq_group_metadata.cross_block_table[ + i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + cross_slot_mapping.append(slot) + + # Build encoder input tokens + encoder_input_tokens.extend(token_ids) + encoder_input_positions.extend(list(range(0, seq_len))) + + # Convert tokens/positions & cross-attention + # slot-mapping to encoder input tensors + encoder_input_tokens_tensor = self._list_to_long_tensor( + encoder_input_tokens) + encoder_input_positions_tensor = self._list_to_long_tensor( + encoder_input_positions) + cross_slot_mapping_tensor = self._list_to_long_tensor( + cross_slot_mapping) + + else: + # Decode phase. + encoder_input_tokens_tensor = self._empty_long_tensor() + encoder_input_positions_tensor = self._empty_long_tensor() + cross_slot_mapping_tensor = self._empty_long_tensor() + # Extract cross-attention block tables & + # seq len from each sequence group metadata. + # Cross-attention block tables are empty + # during vLLM memory profiling. + cross_block_tables = [] + for seq_group_metadata in seq_group_metadata_list: + for _ in range(len(seq_group_metadata.seq_data)): + encoder_seq_lens.append( + seq_group_metadata.encoder_seq_data.get_len()) + cross_block_table = seq_group_metadata.cross_block_table + cross_block_tables.append([] if ( + cross_block_table is None) else cross_block_table) + + max_len_of_block_table = max( + len(block_table) for block_table in cross_block_tables) + + cross_block_tables = make_tensor_with_pad( + cross_block_tables, + max_len=max_len_of_block_table, + pad=0, + dtype=torch.int32, + device=self.device, + ) + + # Compute encoder sequence lengths & encoder + # sequence starting offset tensors + max_encoder_seq_len = max(encoder_seq_lens, default=0) + encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens) + encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + + 1, + dtype=torch.int32, + device=self.device) + torch.cumsum(encoder_seq_lens_tensor, + dim=0, + dtype=encoder_seq_start_loc.dtype, + out=encoder_seq_start_loc[1:]) + + # Update attention metadata with encoder-oriented attributes + attn_metadata = model_input.attn_metadata + assert attn_metadata is not None + ( + attn_metadata.num_encoder_tokens, + attn_metadata.encoder_seq_lens, + attn_metadata.encoder_seq_lens_tensor, + attn_metadata.max_encoder_seq_len, + attn_metadata.cross_slot_mapping, + attn_metadata.cross_block_tables, + ) = ( + sum(encoder_seq_lens), + encoder_seq_lens, + encoder_seq_lens_tensor, + max_encoder_seq_len, + cross_slot_mapping_tensor, + cross_block_tables, + ) + + return (attn_metadata, encoder_input_tokens_tensor, + encoder_input_positions_tensor) + + @torch.no_grad() + def execute_model( + self, + model_input: EncoderDecoderModelInputForCPU, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[List[SamplerOutput]]: + if num_steps > 1: + raise ValueError( + "CPU worker does not support multi-step execution.") + + model_executable = self.model + execute_model_kwargs = { + "input_ids": + model_input.input_tokens, + "positions": + model_input.input_positions, + "encoder_input_ids": + model_input.encoder_input_tokens, + "encoder_positions": + model_input.encoder_input_positions, + **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, + device=self.device), + "intermediate_tensors": + intermediate_tensors, + } + + with set_forward_context(model_input.attn_metadata, self.vllm_config, + model_input.virtual_engine): + hidden_states = model_executable(**execute_model_kwargs) + + # Compute the logits. + logits = self.model.compute_logits(hidden_states, + model_input.sampling_metadata) + + # Only perform sampling in the driver worker. + if not self.is_driver_worker: + return [] + + # Sample the next token. + output = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + return [output] diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py new file mode 100644 index 0000000..9f4b188 --- /dev/null +++ b/vllm/worker/cpu_model_runner.py @@ -0,0 +1,684 @@ +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +import weakref +from collections import defaultdict +from dataclasses import dataclass +from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Type, + TypeVar, Union) + +import torch +from torch import nn + +from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.config import VllmConfig +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader import get_model +from vllm.model_executor.models import supports_lora, supports_multimodal +from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, + MultiModalKwargs, MultiModalPlaceholderMap) +from vllm.sequence import (IntermediateTensors, SequenceData, + SequenceGroupMetadata) +from vllm.worker.model_runner_base import ( + ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, + _add_attn_metadata_broadcastable_dict, + _add_sampling_metadata_broadcastable_dict, + _init_attn_metadata_from_tensor_dict, + _init_sampling_metadata_from_tensor_dict) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +logger = init_logger(__name__) + +TModelInputForCPU = TypeVar('TModelInputForCPU', bound="ModelInputForCPU") +_PAD_SLOT_ID = -1 + + +@dataclass(frozen=True) +class ModelInputForCPU(ModelRunnerInputBase): + """ + Base class contains metadata needed for the base model forward pass on CPU + """ + input_tokens: Optional[torch.Tensor] = None + input_positions: Optional[torch.Tensor] = None + token_type_ids: Optional[torch.Tensor] = None + attn_metadata: Optional["AttentionMetadata"] = None + multi_modal_kwargs: Optional[BatchedTensorInputs] = None + virtual_engine: Optional[int] = None + seq_lens: Optional[List[int]] = None + query_lens: Optional[List[int]] = None + lora_mapping: Optional["LoRAMapping"] = None + lora_requests: Optional[Set[LoRARequest]] = None + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + "token_type_ids": self.token_type_ids, + "multi_modal_kwargs": self.multi_modal_kwargs, + "lora_requests": self.lora_requests, + "lora_mapping": self.lora_mapping, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls: Type[TModelInputForCPU], + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None + ) -> TModelInputForCPU: + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) + + +@dataclass(frozen=True) +class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU): + """ + Used by the ModelRunner. + """ + sampling_metadata: Optional["SamplingMetadata"] = None + is_prompt: Optional[bool] = None + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + "token_type_ids": self.token_type_ids, + "multi_modal_kwargs": self.multi_modal_kwargs, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "ModelInputForCPUWithSamplingMetadata": + tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) + + +class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): + + class ModelInputData: + + def __init__(self, use_mrope: bool): + self.use_mrope = use_mrope + self.input_tokens: List[int] = [] + self.input_positions: List[int] = [] + self.token_type_ids: Optional[List[int]] = [] + self.seq_lens: List[int] = [] + self.query_lens: List[int] = [] + self.prefill_block_tables: List[List[int]] = [] + self.decode_block_tables: List[List[int]] = [] + self.max_decode_seq_len: int = 0 + self.num_prefills: int = 0 + self.num_prefill_tokens: int = 0 + self.num_decode_tokens: int = 0 + self.slot_mapping: List[int] = [] + self.multi_modal_inputs_list: List[MultiModalKwargs] = [] + self.multi_modal_placeholder_maps: Dict[ + str, MultiModalPlaceholderMap] = defaultdict( + MultiModalPlaceholderMap) + self.input_mrope_positions: List[List[int]] = [[] + for _ in range(3)] + + def __init__(self, + runner: "CPUModelRunner", + finished_requests_ids: Optional[List[str]] = None) -> None: + super().__init__() + self.runner = runner + self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled + or runner.cache_config.enable_prefix_caching) + self.model_input_cls = self.runner._model_input_cls + self.attn_backend = self.runner.attn_backend + self.sliding_window = self.runner.sliding_window + self.block_size = self.runner.block_size + self.device = self.runner.device + self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper + self.enable_lora = self.runner.lora_config is not None + if self.runner.attn_backend is not None: + # spec decode (e.g. Medusa) does not have atten backend + attn_backend = self.runner.attn_backend + self.att_metadata_builder = attn_backend.get_builder_cls()(self) + + def prepare(self, + finished_requests_ids: Optional[List[str]] = None) -> None: + self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] + self.input_data = ModelInputForCPUBuilder.ModelInputData( + self.runner.model_config.uses_mrope) + self.att_metadata_builder.prepare() + + def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): + self.seq_group_metadata_list.append(seq_group_metadata) + + def set_seq_group_list( + self, seq_group_metadata_list: List[SequenceGroupMetadata]): + self.seq_group_metadata_list = seq_group_metadata_list + + def build(self) -> ModelInputForCPU: + self._build_input_data() + + input_data = self.input_data + input_tokens = torch.tensor(input_data.input_tokens, + dtype=torch.long, + device="cpu") + input_positions = torch.tensor( + input_data.input_positions + if not any(input_data.input_mrope_positions) else + input_data.input_mrope_positions, + dtype=torch.long, + device="cpu") + token_type_ids = torch.tensor(input_data.token_type_ids, + dtype=torch.long, + device="cpu") \ + if input_data.token_type_ids else None + + # For multi-modal models + multi_modal_kwargs = None + if len(input_data.multi_modal_inputs_list) != 0: + multi_modal_kwargs = MultiModalKwargs.batch( + input_data.multi_modal_inputs_list) + + attn_metadata = self.att_metadata_builder.build( + input_data.seq_lens, input_data.query_lens, -1, -1) + + is_prompt = (self.seq_group_metadata_list[0].is_prompt + if self.seq_group_metadata_list else None) + # LoRA data. + lora_requests = set() + lora_mapping = None + if self.enable_lora: + lora_requests = set(seq.lora_request + for seq in self.seq_group_metadata_list + if seq.lora_request is not None) + + lora_mapping = self._prepare_lora_input( + self.seq_group_metadata_list, is_prompt) + + return self.model_input_cls(input_tokens=input_tokens, + input_positions=input_positions, + token_type_ids=token_type_ids, + seq_lens=input_data.seq_lens, + query_lens=input_data.query_lens, + attn_metadata=attn_metadata, + multi_modal_kwargs=multi_modal_kwargs, + lora_mapping=lora_mapping, + lora_requests=lora_requests) + + def _build_input_data(self): + for seq_group_metadata in self.seq_group_metadata_list: + for seq_id, seq_data in seq_group_metadata.seq_data.items(): + if seq_group_metadata.is_prompt: + self._compute_prompt_input_tokens(self.input_data, + seq_group_metadata, + seq_data, seq_id) + if seq_group_metadata.multi_modal_data: + self._compute_multi_modal_input( + seq_group_metadata, seq_data) + else: + self._compute_decode_input_tokens(self.input_data, + seq_group_metadata, + seq_data, seq_id) + + def _compute_decode_input_tokens(self, data: ModelInputData, + seq_group_metadata: SequenceGroupMetadata, + seq_data: SequenceData, seq_id: int): + """ + Compute decode input tokens, positions, block table and slot mapping. + """ + block_size = self.runner.block_size + + block_table = seq_group_metadata.block_tables[seq_id] + seq_len = seq_data.get_len() + context_len = seq_data.get_num_computed_tokens() + + tokens = seq_data.get_last_token_id() + token_positions = seq_len - 1 + block_number = block_table[token_positions // block_size] + block_offset = token_positions % block_size + slot = block_number * block_size + block_offset + + # For paged_attention kernel + if self.runner.sliding_window: + start_idx = max(0, seq_len - self.runner.sliding_window) + start_block = start_idx // block_size + start_idx = start_block * block_size + seq_len = seq_len - start_idx + block_table = block_table[start_block:] + + # For MRotaryEmbedding + if seq_data.mrope_position_delta is not None: + next_pos = MRotaryEmbedding.get_next_input_positions( + seq_data.mrope_position_delta, + context_len, + seq_len, + ) + for idx in range(3): + data.input_mrope_positions[idx].extend( # type: ignore + next_pos[idx]) + else: + data.input_positions.append(token_positions) # type: ignore + + # Update fields + data.input_tokens.append(tokens) + data.max_decode_seq_len = max(data.max_decode_seq_len, seq_len) + data.num_decode_tokens += 1 + data.slot_mapping.append(slot) + data.decode_block_tables.append(block_table) + data.query_lens.append(1) + data.seq_lens.append(seq_len) + + def _compute_prompt_input_tokens(self, data: ModelInputData, + seq_group_metadata: SequenceGroupMetadata, + seq_data: SequenceData, seq_id: int): + """ + Compute prompt input tokens, positions, block table and slot mapping. + """ + token_chunk_size = seq_group_metadata.token_chunk_size + block_size = self.runner.block_size + + block_table = seq_group_metadata.block_tables[seq_id] + seq_len = seq_data.get_len() + context_len = seq_data.get_num_computed_tokens() + seq_len = min(seq_len, context_len + token_chunk_size) + + # For prefix caching + prefix_cache_block_num = len(seq_group_metadata.computed_block_nums) + if prefix_cache_block_num > 0: + prefix_cache_len = (prefix_cache_block_num * + self.runner.block_size) + if prefix_cache_len <= context_len: + # We already passed the cache hit region, + # so do normal computation. + pass + elif context_len < prefix_cache_len < seq_len: + # Partial hit. Compute the missing part. + context_len = prefix_cache_len + token_chunk_size = seq_len - context_len + elif seq_len <= prefix_cache_len: + # Full hit. Only compute the last token to avoid + # erroneous behavior. FIXME: Ideally we should directly + # mark all tokens as computed in the scheduler and do not + # schedule this sequence, so this case should not happen. + context_len = seq_len - 1 + token_chunk_size = 1 + + tokens = seq_data.get_token_ids() + tokens = tokens[context_len:seq_len] + token_positions = range(context_len, seq_len) + token_types = seq_group_metadata.token_type_ids + + # For encoder-only models, the block_table is None, + # and there is no need to initialize the slot_mapping. + if block_table is not None: + slot_mapping = [_PAD_SLOT_ID] * len(token_positions) + for i, pos in enumerate(token_positions): + block_number = block_table[pos // block_size] + block_offset = pos % block_size + slot = block_number * block_size + block_offset + slot_mapping[i] = slot + data.slot_mapping.extend(slot_mapping) + + # The MROPE positions are prepared in _compute_multi_modal_input + data.input_positions.extend(token_positions) + + if data.token_type_ids is not None: + data.token_type_ids.extend(token_types if token_types else []) + + # Update fields + data.input_tokens.extend(tokens) + data.num_prefills += 1 + data.num_prefill_tokens += len(tokens) + data.query_lens.append(len(tokens)) + data.prefill_block_tables.append(block_table) + data.seq_lens.append(seq_len) + + def _compute_multi_modal_input(self, + seq_group_metadata: SequenceGroupMetadata, + seq_data: SequenceData): + computed_len = seq_data.get_num_computed_tokens() + seq_len = self.input_data.seq_lens[-1] + + # NOTE: mm_data only includes the subset of multi-modal items that + # intersect with the current prefill positions. + mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( + seq_group_metadata, range(computed_len, seq_len)) + + if not mm_data: + return + + if self.runner.mm_registry.has_processor(self.runner.model_config): + mm_kwargs = mm_data + else: + mm_kwargs = self.multi_modal_input_mapper( + mm_data, + seq_group_metadata.mm_processor_kwargs, + ) + + # special processing for mrope position deltas. + if self.runner.model_config.uses_mrope: + assert not self.chunked_prefill, \ + "MROPE on CPU does not support chunked-prefill." + + image_grid_thw = mm_kwargs.get("image_grid_thw", None) + video_grid_thw = mm_kwargs.get("video_grid_thw", None) + assert image_grid_thw is not None or video_grid_thw is not None, ( + "mrope embedding type requires multi-modal input mapper " + "returns 'image_grid_thw' or 'video_grid_thw'.") + + second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None) + hf_config = self.runner.model_config.hf_config + token_ids = seq_data.get_token_ids() + + mrope_positions, mrope_position_delta = \ + MRotaryEmbedding.get_input_positions( + token_ids, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + context_len=computed_len, + ) + seq_data.mrope_position_delta = mrope_position_delta + + for i in range(3): + self.input_data.input_mrope_positions[ # type: ignore + i].extend(mrope_positions[i]) + + self.input_data.multi_modal_inputs_list.append(mm_kwargs) + for modality, placeholder_map in placeholder_maps.items(): + self.input_data.multi_modal_placeholder_maps[modality].extend( + placeholder_map) + + def _prepare_lora_input( + self, seq_group_metadata_list: List[SequenceGroupMetadata], + is_prefill: bool) -> LoRAMapping: + index_mapping = [] + prompt_mapping = [] + for seq in seq_group_metadata_list: + lora_id = seq.lora_int_id + query_len = seq.token_chunk_size + + index_mapping += [lora_id] * query_len + prompt_mapping += [lora_id] * ( + query_len if seq.sampling_params + and seq.sampling_params.prompt_logprobs is not None else 1) + + return LoRAMapping(index_mapping=tuple(index_mapping), + prompt_mapping=tuple(prompt_mapping), + is_prefill=is_prefill) + + +class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]): + """ + Helper class for shared methods between CPU model runners. + """ + _model_input_cls: Type[TModelInputForCPU] + _builder_cls: Type[ModelInputForCPUBuilder] + builder: ModelInputForCPUBuilder + + def __init__( + self, + vllm_config: VllmConfig, + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + return_hidden_states: bool = False, + *args, + **kwargs, + ): + ModelRunnerBase.__init__(self, vllm_config) + model_config = self.model_config + cache_config = self.cache_config + + self.is_driver_worker = is_driver_worker + self.return_hidden_states = return_hidden_states + + self.device = self.device_config.device + self.pin_memory = False + + self.kv_cache_dtype = kv_cache_dtype + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + num_attn_heads = self.model_config.get_num_attention_heads( + self.parallel_config) + needs_attn_backend = (num_attn_heads != 0 + or self.model_config.is_attention_free) + self.attn_backend = get_attn_backend( + self.model_config.get_head_size(), + self.model_config.dtype, + self.kv_cache_dtype, + self.block_size, + self.model_config.is_attention_free, + use_mla=self.model_config.use_mla, + ) if needs_attn_backend else None + + # Multi-modal data support + self.mm_registry = MULTIMODAL_REGISTRY + self.multi_modal_input_mapper = self.mm_registry \ + .create_input_mapper(self.model_config) + self.mm_registry.init_mm_limits_per_prompt(self.model_config) + + # Lazy initialization. + self.model: nn.Module # Set after init_Model + # Set after load_model. + self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None + + if hasattr(self, "_builder_cls"): + # multi-step model runner does not have `_builder_cls` + self.builder = self._builder_cls(weakref.proxy(self)) + + def load_model(self) -> None: + self.model = get_model(vllm_config=self.vllm_config) + + if self.lora_config: + assert supports_lora( + self.model + ), f"{self.model.__class__.__name__} does not support LoRA yet." + + if supports_multimodal(self.model): + logger.warning("Regarding multimodal models, vLLM currently " + "only supports adding LoRA to language model.") + + # It's necessary to distinguish between the max_position_embeddings + # of VLMs and LLMs. + if hasattr(self.model.config, "max_position_embeddings"): + max_pos_embeddings = self.model.config.max_position_embeddings + else: + max_pos_embeddings = ( + self.model.config.text_config.max_position_embeddings) + + self.lora_manager = LRUCacheWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, + self.vocab_size, + self.lora_config, + self.device, + self.model.embedding_modules, + self.model.embedding_padding_modules, + max_position_embeddings=max_pos_embeddings, + ) + self.model = self.lora_manager.create_lora_manager(self.model) + + def get_model(self) -> nn.Module: + return self.model + + def _prepare_model_input_tensors( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + finished_requests_ids: Optional[List[str]] = None + ) -> TModelInputForCPU: + """Helper method to prepare the model input based on a given sequence + group. Prepares metadata needed for the base model forward pass but not + metadata for possible additional steps, e.g., sampling. + + """ + self.builder.prepare(finished_requests_ids) + self.builder.set_seq_group_list(seq_group_metadata_list) + + return self.builder.build() # type: ignore + + # sampler property will be used by spec_decode_worker + @property + def sampler(self): + return self.model.sampler + + @property + def vocab_size(self) -> int: + return self.model_config.get_vocab_size() + + def remove_all_loras(self): + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.remove_all_adapters() + + def set_active_loras(self, lora_requests: Set[LoRARequest], + lora_mapping: LoRAMapping) -> None: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.set_active_adapters(lora_requests, lora_mapping) + + def add_lora(self, lora_request: LoRARequest) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.add_adapter(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.remove_adapter(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.pin_adapter(lora_id) + + def list_loras(self) -> Set[int]: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.list_adapters() + + +class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]): + _model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = ( + ModelInputForCPUWithSamplingMetadata) + _builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder + + def make_model_input_from_broadcasted_tensor_dict( + self, + tensor_dict: Dict[str, Any], + ) -> ModelInputForCPUWithSamplingMetadata: + return ModelInputForCPUWithSamplingMetadata.from_broadcasted_tensor_dict( # noqa: E501 + tensor_dict, + attn_backend=self.attn_backend, + ) + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForCPUWithSamplingMetadata: + """Prepare the model input based on a given sequence group, including + metadata for the sampling step. + + """ + model_input = self._prepare_model_input_tensors( + seq_group_metadata_list, finished_requests_ids) + # Sampling metadata is only required for the final pp group + generators = self.get_generators(finished_requests_ids) + sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, + model_input.seq_lens, + model_input.query_lens, + self.device, + pin_memory=False, + generators=generators) + + is_prompt = (seq_group_metadata_list[0].is_prompt + if seq_group_metadata_list else None) + return dataclasses.replace(model_input, + sampling_metadata=sampling_metadata, + virtual_engine=virtual_engine, + is_prompt=is_prompt) + + @torch.no_grad() + def execute_model( + self, + model_input: ModelInputForCPUWithSamplingMetadata, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + previous_hidden_states: Optional[torch.Tensor] = None, + ) -> Optional[List[SamplerOutput]]: + if num_steps > 1: + raise ValueError( + "CPU worker does not support multi-step execution.") + + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) + + model_executable = self.model + + multimodal_kwargs = {} + if model_input.multi_modal_kwargs is not None: + multimodal_kwargs = MultiModalKwargs.as_kwargs( + model_input.multi_modal_kwargs, device=self.device) + execute_model_kwargs = {} + if previous_hidden_states is not None: + execute_model_kwargs.update( + {"previous_hidden_states": previous_hidden_states}) + + with set_forward_context(model_input.attn_metadata, self.vllm_config, + model_input.virtual_engine): + hidden_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + intermediate_tensors=intermediate_tensors, + **execute_model_kwargs, + **multimodal_kwargs, + ) + + # Compute the logits. + logits = self.model.compute_logits(hidden_states, + model_input.sampling_metadata) + + # Only perform sampling in the driver worker. + if not self.is_driver_worker: + return [] + + # Sample the next token. + output = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + if self.return_hidden_states: + # we only need to pass hidden states of most recent token + if model_input.is_prompt: + output.prefill_hidden_states = hidden_states + output.hidden_states = hidden_states + return [output] + + def generate_proposals(self, *args, **kwargs): + return self.model.generate_proposals(*args, **kwargs) diff --git a/vllm/worker/cpu_pooling_model_runner.py b/vllm/worker/cpu_pooling_model_runner.py new file mode 100644 index 0000000..1ceb255 --- /dev/null +++ b/vllm/worker/cpu_pooling_model_runner.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import torch + +from vllm.forward_context import set_forward_context +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.multimodal import MultiModalKwargs +from vllm.pooling_params import PoolingParams +from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData, + SequenceGroupMetadata) +from vllm.worker.cpu_model_runner import (CPUModelRunnerBase, ModelInputForCPU, + ModelInputForCPUBuilder) + + +@dataclasses.dataclass(frozen=True) +class ModelInputForCPUWithPoolingMetadata(ModelInputForCPU): + """ + Used by the CPUPoolingModelRunner. + """ + pooling_metadata: Optional["PoolingMetadata"] = None + + +class CPUPoolingModelRunner( + CPUModelRunnerBase[ModelInputForCPUWithPoolingMetadata]): + _model_input_cls: Type[ModelInputForCPUWithPoolingMetadata] = ( + ModelInputForCPUWithPoolingMetadata) + _builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInputForCPUWithPoolingMetadata, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]: + if num_steps > 1: + raise ValueError( + "CPU worker does not support multi-step execution.") + + model_executable = self.model + cross_enc_kwargs = {} + if model_input.token_type_ids is not None: + cross_enc_kwargs["token_type_ids"] = model_input.token_type_ids + execute_model_kwargs = { + "input_ids": + model_input.input_tokens, + "positions": + model_input.input_positions, + **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, + device=self.device), + **cross_enc_kwargs, + "intermediate_tensors": + intermediate_tensors, + } + + with set_forward_context(model_input.attn_metadata, self.vllm_config, + model_input.virtual_engine): + hidden_states = model_executable(**execute_model_kwargs) + + # Only perform pooling in the driver worker. + if not self.is_driver_worker: + return [] + + return [ + self.model.pooler(hidden_states=hidden_states, + pooling_metadata=model_input.pooling_metadata) + ] + + def make_model_input_from_broadcasted_tensor_dict( + self, + tensor_dict: Dict[str, + Any]) -> ModelInputForCPUWithPoolingMetadata: + return ModelInputForCPUWithPoolingMetadata.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + ) + + def prepare_model_input( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForCPUWithPoolingMetadata: + assert seq_group_metadata_list is not None + model_input = self._prepare_model_input_tensors( + seq_group_metadata_list, finished_requests_ids) + # Prepare PoolingMetadata. + assert model_input.seq_lens is not None + pooling_metadata = self._prepare_pooling(seq_group_metadata_list, + model_input.seq_lens) + + return dataclasses.replace(model_input, + virtual_engine=virtual_engine, + pooling_metadata=pooling_metadata) + + def _prepare_pooling( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + prompt_lens: List[int], + ) -> PoolingMetadata: + """Prepare PoolingMetadata for the sequence group metadata list.""" + seq_groups: List[Tuple[List[int], PoolingParams]] = [] + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + seq_ids = list(seq_group_metadata.seq_data.keys()) + pooling_params = seq_group_metadata.pooling_params + seq_groups.append((seq_ids, pooling_params)) + + seq_data: Dict[int, SequenceData] = {} + for seq_group_metadata in seq_group_metadata_list: + seq_data.update(seq_group_metadata.seq_data) + + pooling_metadata = PoolingMetadata( + seq_groups=seq_groups, + seq_data=seq_data, + prompt_lens=prompt_lens, + ) + + return pooling_metadata diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py new file mode 100644 index 0000000..1436a40 --- /dev/null +++ b/vllm/worker/cpu_worker.py @@ -0,0 +1,400 @@ +# SPDX-License-Identifier: Apache-2.0 +"""A CPU worker class.""" +import os +from typing import Dict, List, Optional, Set, Tuple, Type + +import torch +import torch.distributed + +import vllm.envs as envs +from vllm.attention import get_attn_backend +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, VllmConfig) +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor import set_random_seed +from vllm.sequence import ExecuteModelRequest +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache +from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner +from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase +from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner +from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, + WorkerInput) + +logger = init_logger(__name__) + + +class CPUCacheEngine: + """Manages the KV cache for CPU backend. + + This class is responsible for initializing and managing CPU KV + caches. It also provides methods for performing KV cache operations, such + as copying. + """ + + def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, + parallel_config: ParallelConfig, + device_config: DeviceConfig) -> None: + assert device_config.device_type == "cpu" + self.cache_config = cache_config + self.model_config = model_config + self.parallel_config = parallel_config + + self.head_size = model_config.get_head_size() + self.num_layers = model_config.get_num_layers(parallel_config) + self.num_heads = model_config.get_num_kv_heads(parallel_config) + + self.block_size = cache_config.block_size + # Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks + # for CPU backend, because we want to reuse KV cache management + # in the scheduler. + self.num_cpu_blocks = cache_config.num_gpu_blocks + + if cache_config.cache_dtype == "auto": + self.dtype = model_config.dtype + elif cache_config.cache_dtype in ["fp8", "fp8_e5m2"]: + self.dtype = torch.float8_e5m2 + else: + raise NotImplementedError(f"Unsupported KV cache type " + f"{cache_config.cache_dtype}.") + + # Get attention backend. + self.attn_backend = get_attn_backend( + self.model_config.get_head_size(), + self.model_config.dtype, + cache_config.cache_dtype, + self.block_size, + self.model_config.is_attention_free, + use_mla=self.model_config.use_mla, + ) + + # Initialize the cache. + self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks) + + def _allocate_kv_cache( + self, + num_blocks: int, + ) -> List[torch.Tensor]: + """Allocates KV cache on CPU.""" + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, self.block_size, self.num_heads, self.head_size) + kv_cache: List[torch.Tensor] = [] + for _ in range(self.num_layers): + kv_cache.append( + torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu")) + return kv_cache + + def swap_in(self, src_to_dst: Dict[int, int]) -> None: + raise NotImplementedError("Swap is not supported in CPUCacheEngine.") + + def swap_out(self, src_to_dst: Dict[int, int]) -> None: + raise NotImplementedError("Swap is not supported in CPUCacheEngine.") + + def copy(self, src_to_dsts: Dict[int, List[int]]) -> None: + self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts) + + @staticmethod + def get_cache_block_size( + block_size: int, + cache_dtype: str, + model_config: ModelConfig, + parallel_config: ParallelConfig, + ) -> int: + head_size = model_config.get_head_size() + num_heads = model_config.get_num_kv_heads(parallel_config) + num_layers = model_config.get_num_layers(parallel_config) + + key_cache_block = block_size * num_heads * head_size + value_cache_block = key_cache_block if not model_config.use_mla else 0 + total = num_layers * (key_cache_block + value_cache_block) + if cache_dtype == "auto": + dtype = model_config.dtype + else: + dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] + dtype_size = torch.tensor([], dtype=dtype).element_size() + return dtype_size * total + + +class CPUWorker(LocalOrDistributedWorkerBase): + """A worker class that executes (a partition of) the model on a CPU socket. + + Each worker is associated with a single CPU socket. The worker is + responsible for maintaining the KV cache and executing the model on the + CPU. In case of distributed inference, each worker is assigned a partition + of the model. + """ + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + model_runner_cls: Optional[Type[CPUModelRunner]] = None, + ) -> None: + WorkerBase.__init__(self, vllm_config=vllm_config) + + self.local_rank = local_rank + self.rank = rank + vllm_config.parallel_config.rank = rank + + self.distributed_init_method = distributed_init_method + + self.is_driver_worker = is_driver_worker + if self.is_driver_worker: + assert self.rank == 0, "The driver worker must have rank 0." + + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + # Setup OpenMP threads affinity. + omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND + if omp_cpuids == "all": + self.local_omp_cpuid = "all" + else: + self.local_omp_cpuid = omp_cpuids.split("|")[rank] + + # Return hidden states from target model if the draft model is an + # mlp_speculator + speculative_config = self.speculative_config + model_config = self.model_config + speculative_args = {} if speculative_config is None \ + or (speculative_config.draft_model_config.model == + model_config.model) \ + or (speculative_config.draft_model_config.hf_config.model_type + not in ["medusa", "mlp_speculator", "eagle"]) \ + else {"return_hidden_states": True} + ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner + if self.model_config.runner_type == "pooling": + ModelRunnerClass = CPUPoolingModelRunner + elif self.model_config.is_encoder_decoder: + ModelRunnerClass = CPUEncoderDecoderModelRunner + self.model_runner: CPUModelRunnerBase = ModelRunnerClass( + vllm_config=vllm_config, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=is_driver_worker, + **speculative_args, + ) + if model_runner_cls is not None: + self.model_runner = model_runner_cls(self.model_runner) + # Uninitialized cache engine. Will be initialized by + # initialize_cache. + self.cache_engine: List[CPUCacheEngine] + # Initialize cpu_cache as pooling models don't initialize kv_caches + self.cpu_cache: Optional[List[List[torch.Tensor]]] = None + + # Torch profiler. Enabled and configured through env vars: + # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace + if envs.VLLM_TORCH_PROFILER_DIR: + torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + logger.info("Profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir) + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + ], + with_stack=True, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + torch_profiler_trace_dir, use_gzip=True)) + else: + self.profiler = None + + def start_profile(self): + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + self.profiler.start() + + def stop_profile(self): + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + self.profiler.stop() + + def init_device(self) -> None: + if self.local_omp_cpuid != "all": + ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid) + if ret: + logger.info(ret) + + # Note: unique identifier for creating allreduce shared memory + os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split( + ":")[-1] + self.device = torch.device("cpu") + self.init_distributed_environment() + # Set random seed. + set_random_seed(self.model_config.seed) + + def load_model(self): + self.model_runner.load_model() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of blocks available for the KV cache. + + This determines how many KV blocks can fit into the configured CPU + KV cache space. + + Note that since vLLM assumes a block resides on GPU if it can be + modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0. + This allows us to reuse the scheduler of vLLM without generalizing it + to different devices. + """ + # For CPU device, the block number will be calculated based on the + # cpu_kvcache_space. + cache_block_size = self.get_cache_block_size_bytes() + num_cpu_blocks = int(self.cache_config.cpu_kvcache_space_bytes // + cache_block_size) + num_cpu_blocks = max(num_cpu_blocks, 0) + + # Note: To reuse the cache management procedure, + # use cpu cache as 'gpu cache'. + num_gpu_blocks = num_cpu_blocks + num_cpu_blocks = 0 + return num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache. Currently, swappable CPU memory is not + supported. + + Since this worker does not support GPUs, we use the num_gpu_blocks to + determine how many non-swappable CPU blocks to allocate. + """ + assert (num_cpu_blocks == 0 + ), f"{type(self)} does not support swappable cache" + + # Note: To reuse the cache management procedure, + # use cpu cache as 'gpu cache'. + num_cpu_blocks = num_gpu_blocks + + self._validate_num_cpu_blocks(num_cpu_blocks) + self.cache_config.num_gpu_blocks = num_cpu_blocks + self.cache_config.num_cpu_blocks = 0 + + # Initialize the cache. + self._init_cache_engine() + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_runner.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.model_runner.remove_lora(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + return self.model_runner.pin_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.model_runner.list_loras() + + def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None: + """Raise errors if the num_cpu_blocks is invalid. + """ + if num_cpu_blocks <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `VLLM_CPU_KVCACHE_SPACE` when " + "initializing the engine.") + + max_seq_len = self.cache_config.block_size * num_cpu_blocks + if self.model_config.max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({self.model_config.max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`VLLM_CPU_KVCACHE_SPACE` or decreasing `max_model_len` when " + "initializing the engine.") + + def _init_cache_engine(self) -> None: + self.cache_engine = [ + CPUCacheEngine(self.cache_config, self.model_config, + self.parallel_config, self.device_config) + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + self.cpu_cache = [ + self.cache_engine[ve].cpu_cache + for ve in range(self.parallel_config.pipeline_parallel_size) + ] + bind_kv_cache(self.compilation_config.static_forward_context, + self.cpu_cache) + self.model_runner.block_size = self.cache_engine[0].block_size + + assert all( + self.cpu_cache[ve] is not None + for ve in range(self.parallel_config.pipeline_parallel_size)) + + # Populate the cache to warmup the memory + for ve in range(self.parallel_config.pipeline_parallel_size): + for layer_cache in self.cpu_cache[ve]: + layer_cache.fill_(0) + + @property + def do_metadata_broadcast(self) -> bool: + return self.parallel_config.tensor_parallel_size > 1 + + @property + def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: + return self.cpu_cache + + @property + def vocab_size(self) -> int: + return self.model_runner.vocab_size + + @property + def max_model_len(self) -> int: + return self.model_config.max_model_len + + def execute_worker( + self, + worker_input: WorkerInput, + ) -> None: + if (worker_input.blocks_to_copy is not None + and worker_input.blocks_to_copy.numel() > 0): + self.cache_engine[worker_input.virtual_engine].copy( + worker_input.blocks_to_copy) + + @torch.inference_mode() + def prepare_worker_input( + self, execute_model_req: ExecuteModelRequest) -> WorkerInput: + assert execute_model_req is not None + virtual_engine: int = execute_model_req.virtual_engine + num_seq_groups: int = len(execute_model_req.seq_group_metadata_list) + blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, + device="cpu", + dtype=torch.int64).view(-1, 2) + assert len(execute_model_req.blocks_to_swap_in) == 0 + assert len(execute_model_req.blocks_to_swap_out) == 0 + return WorkerInput( + num_seq_groups=num_seq_groups, + blocks_to_copy=blocks_to_copy, + virtual_engine=virtual_engine, + ) + + def init_distributed_environment(self) -> None: + """Initialize the distributed environment.""" + + parallel_config = self.parallel_config + rank = self.rank + distributed_init_method = self.distributed_init_method + init_distributed_environment( + world_size=parallel_config.world_size, + rank=rank, + distributed_init_method=distributed_init_method, + backend="gloo", + ) + + # A small all_reduce for warmup. + torch.distributed.all_reduce(torch.zeros(1).cpu()) + + ensure_model_parallel_initialized( + parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) + + def get_cache_block_size_bytes(self) -> int: + """Return the size in bytes of a single KV cache block. + """ + return CPUCacheEngine.get_cache_block_size( + self.cache_config.block_size, self.cache_config.cache_dtype, + self.model_config, self.parallel_config) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py new file mode 100644 index 0000000..5f39f2f --- /dev/null +++ b/vllm/worker/enc_dec_model_runner.py @@ -0,0 +1,516 @@ +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +import itertools +from typing import Any, Dict, List, Optional, Tuple, Type, cast + +import torch +import torch.distributed + +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata) +from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.attention.selector import (get_env_variable_attn_backend, + get_global_forced_attn_backend) +from vllm.config import VllmConfig +from vllm.forward_context import set_forward_context +from vllm.inputs import INPUT_REGISTRY, InputRegistry +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, + MultiModalRegistry) +from vllm.platforms import _Backend +from vllm.sampling_params import SamplingParams +from vllm.sequence import (IntermediateTensors, PoolerOutput, + SequenceGroupMetadata) +from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad +from vllm.worker.model_runner import (GPUModelRunnerBase, + ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) +from vllm.worker.model_runner_base import ( + _add_attn_metadata_broadcastable_dict, + _add_sampling_metadata_broadcastable_dict) +from vllm.worker.utils import assert_enc_dec_mr_supported_scenario + +logger = init_logger(__name__) + + +@dataclasses.dataclass(frozen=True) +class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata): + """ + Used by the EncoderDecoderModelRunner. + """ + encoder_input_tokens: Optional[torch.Tensor] = None + encoder_input_positions: Optional[torch.Tensor] = None + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + "encoder_input_tokens": self.encoder_input_tokens, + "encoder_input_positions": self.encoder_input_positions, + "virtual_engine": self.virtual_engine, + "request_ids_to_seq_ids": self.request_ids_to_seq_ids, + "finished_requests_ids": self.finished_requests_ids, + "multi_modal_kwargs": self.multi_modal_kwargs, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "EncoderDecoderModelInput": + return cast( + EncoderDecoderModelInput, + super().from_broadcasted_tensor_dict(tensor_dict, attn_backend)) + + +class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): + _model_input_cls: Type[EncoderDecoderModelInput] = ( + EncoderDecoderModelInput) + _builder_cls: Type[ModelInputForGPUBuilder] = (ModelInputForGPUBuilder) + + def __init__( + self, + vllm_config: VllmConfig, + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + input_registry: InputRegistry = INPUT_REGISTRY, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + ): + ''' + EncoderDecoderModelRunner constructor. + + `lora_config` and `prompt_adapter_config` are + unused (since these features are not yet supported for encoder/decoder + models) but these arguments are present here for compatibility with + the base-class constructor. + ''' + self._maybe_force_supported_attention_backend() + + super().__init__( + vllm_config=vllm_config, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=is_driver_worker, + ) + + # Crash for unsupported encoder/scenarios + assert_enc_dec_mr_supported_scenario(self) + + def _maybe_force_supported_attention_backend(self): + ''' + Force vLLM to use the XFormers attention backend, + which is currently the only supported option. + ''' + + def raise_backend_err(): + # The user has specified an attention backend override + # which is invalid for encoder/decoder models + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_BACKEND) + + maybe_env_var_forced_backend = get_env_variable_attn_backend() + maybe_global_forced_backend = get_global_forced_attn_backend() + is_forced_by_global = maybe_global_forced_backend is not None + is_forced_by_env_var = maybe_env_var_forced_backend is not None + if is_forced_by_global: # noqa: SIM102 + # Backend override enforced by global variable takes + # precedence over vLLM backend environment variable. + if maybe_global_forced_backend not in\ + [_Backend.XFORMERS, _Backend.FLASH_ATTN]: + raise_backend_err() + elif is_forced_by_env_var: # noqa: SIM102 + # Backend override enforced by vLLM backend + # environment variable + if maybe_env_var_forced_backend not in\ + [_Backend.XFORMERS, _Backend.FLASH_ATTN]: + raise_backend_err() + + def _list_to_int32_tensor( + self, + _list: List[int], + ) -> torch.Tensor: + return torch.tensor(_list, dtype=torch.int32, device=self.device) + + def _list_to_long_tensor( + self, + _list: List[int], + ) -> torch.Tensor: + return torch.tensor(_list, dtype=torch.long, device=self.device) + + def _empty_int32_tensor(self) -> torch.Tensor: + return self._list_to_int32_tensor([]) + + def _empty_long_tensor(self) -> torch.Tensor: + return self._list_to_long_tensor([]) + + @torch.inference_mode() + def execute_model( + self, + model_input: EncoderDecoderModelInput, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[List[PoolerOutput]]: + if num_steps > 1: + raise ValueError("num_steps > 1 is not supported in " + "EncoderDecoderModelRunner") + + if (model_input.attn_metadata is not None + and model_input.attn_metadata.prefill_metadata is None + and model_input.attn_metadata.decode_metadata.use_cuda_graph): + assert model_input.input_tokens is not None + graph_batch_size = model_input.input_tokens.shape[0] + model_executable = self.graph_runners[ + model_input.virtual_engine][graph_batch_size] + else: + model_executable = self.model + + seqlen_agnostic_kwargs = { + "finished_requests_ids": model_input.finished_requests_ids, + "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, + } if self.has_inner_state else {} + + multi_modal_kwargs = model_input.multi_modal_kwargs or {} + with set_forward_context(model_input.attn_metadata, self.vllm_config, + model_input.virtual_engine): + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + encoder_input_ids=model_input.encoder_input_tokens, + encoder_positions=model_input.encoder_input_positions, + intermediate_tensors=intermediate_tensors, + **MultiModalKwargs.as_kwargs(multi_modal_kwargs, + device=self.device), + **seqlen_agnostic_kwargs) + + logits = self.model.compute_logits(hidden_or_intermediate_states, + model_input.sampling_metadata) + + if not self.is_driver_worker: + return [] + + if model_input.async_callback is not None: + model_input.async_callback() + + # Sample the next token. + output: SamplerOutput = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + + return [output] + + def make_model_input_from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, Any]) -> EncoderDecoderModelInput: + return EncoderDecoderModelInput.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + ) + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> EncoderDecoderModelInput: + """Prepare the model input based on a given sequence group, including + metadata for the sampling step. + + Since chunked prefill is not supported for encoder/decoder models, + `input_tokens` is assumed to be either entirely prefill tokens or + entirely decode tokens. + + """ + model_input = self._prepare_model_input_tensors( + seq_group_metadata_list, finished_requests_ids) + ( + attn_metadata, + encoder_input_tokens_tensor, + encoder_input_positions_tensor, + ) = (self._prepare_encoder_model_input_tensors(seq_group_metadata_list, + model_input)) + # Inject attn_metadata encoder/cross-attention fields & + # encoder input tokens/positions into model_input. + # Frozen dataclass fields cannot be modified, so use + # dataclasses.replace to construct a new model input + # instance. + model_input = dataclasses.replace( + model_input, + attn_metadata=attn_metadata, + encoder_input_tokens=encoder_input_tokens_tensor, + encoder_input_positions=encoder_input_positions_tensor, + ) + + generators = self.get_generators(finished_requests_ids) + sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, + model_input.seq_lens, + model_input.query_lens, + self.device, + self.pin_memory, + generators=generators) + is_prompt = (seq_group_metadata_list[0].is_prompt + if seq_group_metadata_list else None) + return dataclasses.replace(model_input, + sampling_metadata=sampling_metadata, + is_prompt=is_prompt, + virtual_engine=virtual_engine) + + @torch.inference_mode() + def profile_run(self) -> None: + # Enable top-k sampling to reflect the accurate memory usage. + sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) + max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + max_num_seqs = self.scheduler_config.max_num_seqs + + # Profile memory usage with max_num_sequences sequences and the total + # number of tokens equal to max_num_batched_tokens. + seqs: List[SequenceGroupMetadata] = [] + + max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( + self.model_config) + if max_mm_tokens > 0: + logger.info("Starting profile run for multi-modal models.") + + batch_size = 0 + for group_id in range(max_num_seqs): + seq_len = (max_num_batched_tokens // max_num_seqs + + (group_id < max_num_batched_tokens % max_num_seqs)) + batch_size += seq_len + + decoder_dummy_data = self.input_registry \ + .dummy_data_for_profiling(self.model_config, + seq_len, + self.mm_registry, + is_encoder_data=False) + encoder_dummy_data = self.input_registry \ + .dummy_data_for_profiling(self.model_config, + seq_len, + self.mm_registry, + is_encoder_data=True) + + # Having more tokens is over-conservative but otherwise fine + assert len( + decoder_dummy_data.seq_data.prompt_token_ids + ) >= seq_len, ( + f"Expected at least {seq_len} dummy tokens for profiling, " + f"but got: {len(decoder_dummy_data.seq_data.prompt_token_ids)}" + ) + + assert decoder_dummy_data.multi_modal_data is None or \ + encoder_dummy_data.multi_modal_data is None, ( + "Multi-modal data can't be provided in both encoder and decoder" + ) + + seq = SequenceGroupMetadata( + request_id=str(group_id), + is_prompt=True, + seq_data={group_id: decoder_dummy_data.seq_data}, + sampling_params=sampling_params, + block_tables=None, + encoder_seq_data=encoder_dummy_data.seq_data, + cross_block_table=None, + multi_modal_data=decoder_dummy_data.multi_modal_data + or encoder_dummy_data.multi_modal_data, + multi_modal_placeholders=decoder_dummy_data. + multi_modal_placeholders + or encoder_dummy_data.multi_modal_placeholders) + seqs.append(seq) + + finished_requests_ids = [seq.request_id for seq in seqs] + model_input = self.prepare_model_input( + seqs, finished_requests_ids=finished_requests_ids) + intermediate_tensors = None + self.execute_model(model_input, None, intermediate_tensors) + torch.cuda.synchronize() + return + + def _prepare_encoder_model_input_tensors( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + model_input: EncoderDecoderModelInput, + ) -> Tuple[AttentionMetadata, Optional[torch.Tensor], + Optional[torch.Tensor]]: + """Helper method to prepare the encoder- and cross-attn-related + model inputs based on a given sequence group. These additional inputs + are used to augment an already-computed `EncoderDecoderModelInput` + data structure which already has decoder-related model inputs + populated. + + Sets the following attn_metadata fields: + * `num_encoder_tokens` + * `encoder_seq_lens` + * `encoder_seq_lens_tensor` + * `max_encoder_seq_len` + * `cross_slot_mapping` + * `cross_block_tables` + + Constructs a new model inputs data structure, based on + (1) the existing fields in the `model_inputs` argument, + and (2) the following additional fields which are + computed (or in the case of `attn_metadata`, updated) + by this function: + * attn_metadata + * encoder_input_tokens + * encoder_input_positions + + Arguments: + + * seq_group_metadata_list: list of sequence groups for which to + compute inputs + * model_inputs: model inputs data structure with decoder-oriented + fields already computed. + + Return: + + * Updated model inputs data structure + """ + + if len(seq_group_metadata_list) == 0: + return (model_input.attn_metadata, None, None) + + # Since we are not supporting chunked prefill either the entire + # batch is prefill or it is decode + is_prompt = seq_group_metadata_list[0].is_prompt + + # Build encoder inputs + encoder_seq_lens: List[int] = [] + if is_prompt: + # Prefill phase. + cross_block_tables = self._empty_int32_tensor().view( + len(seq_group_metadata_list), -1) + + # Extract input tokens/positions, cross-attention slot-mapping, + # & seq len from each sequence group metadata + ( + encoder_input_tokens, + encoder_input_positions, + cross_slot_mapping, + ) = ( + [], + [], + [], + ) + for seq_group_metadata in seq_group_metadata_list: + # Build seq lens + seq_len = seq_group_metadata.encoder_seq_data.get_len() + token_ids = seq_group_metadata.encoder_seq_data.get_token_ids() + encoder_seq_lens.append(seq_len) + + # Build slot mapping + is_profile_run = (seq_group_metadata.block_tables is None) + if is_profile_run: + # During memory profiling, the block tables are not + # initialized yet. In this case, we just use a dummy + # slot mapping. + # In embeddings, the block tables are {seq_id: None}. + cross_slot_mapping.extend([PAD_SLOT_ID] * seq_len) + else: + for i in range(0, seq_len): + block_number = seq_group_metadata.cross_block_table[ + i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + cross_slot_mapping.append(slot) + + # Build encoder input tokens + encoder_input_tokens.extend(token_ids) + encoder_input_positions.extend(list(range(0, seq_len))) + + # Convert tokens/positions & cross-attention + # slot-mapping to encoder input tensors + encoder_input_tokens_tensor = self._list_to_long_tensor( + encoder_input_tokens) + encoder_input_positions_tensor = self._list_to_long_tensor( + encoder_input_positions) + cross_slot_mapping_tensor = self._list_to_long_tensor( + cross_slot_mapping) + + else: + # Decode phase. + encoder_input_tokens_tensor = self._empty_long_tensor() + encoder_input_positions_tensor = self._empty_long_tensor() + cross_slot_mapping_tensor = self._empty_long_tensor() + # Extract cross-attention block tables & + # seq len from each sequence group metadata. + # Cross-attention block tables are empty + # during vLLM memory profiling. + cross_block_tables = [] + for seq_group_metadata in seq_group_metadata_list: + for _ in range(len(seq_group_metadata.seq_data)): + encoder_seq_lens.append( + seq_group_metadata.encoder_seq_data.get_len()) + cross_block_table = seq_group_metadata.cross_block_table + cross_block_tables.append([] if ( + cross_block_table is None) else cross_block_table) + + if (model_input.attn_metadata is not None + and model_input.attn_metadata.use_cuda_graph): + # We will be using CUDA graph replay for this decode. + max_len_of_block_table = self.get_max_block_per_batch() + batch_size = len(encoder_seq_lens) + graph_batch_size = self.vllm_config.pad_for_cudagraph( + batch_size) + assert graph_batch_size >= batch_size + cuda_graph_pad_size = graph_batch_size - batch_size + # extend the cross_block_tables and encoder_seq_lens to match + # the graph_batch_size. + cross_block_tables.extend([[] + for _ in range(cuda_graph_pad_size) + ]) + encoder_seq_lens.extend( + itertools.repeat(1, cuda_graph_pad_size)) + + else: + max_len_of_block_table = max( + len(block_table) for block_table in cross_block_tables) + + cross_block_tables = make_tensor_with_pad( + cross_block_tables, + max_len=max_len_of_block_table, + pad=0, + dtype=torch.int32, + device=self.device, + ) + + # Compute encoder sequence lengths & encoder + # sequence starting offset tensors + max_encoder_seq_len = max(encoder_seq_lens, default=0) + encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens) + encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + + 1, + dtype=torch.int32, + device=self.device) + torch.cumsum(encoder_seq_lens_tensor, + dim=0, + dtype=encoder_seq_start_loc.dtype, + out=encoder_seq_start_loc[1:]) + + # Update attention metadata with encoder-oriented attributes + attn_metadata = model_input.attn_metadata + assert attn_metadata is not None + ( + attn_metadata.num_encoder_tokens, + attn_metadata.encoder_seq_lens, + attn_metadata.encoder_seq_lens_tensor, + attn_metadata.max_encoder_seq_len, + attn_metadata.encoder_seq_start_loc, + attn_metadata.cross_slot_mapping, + attn_metadata.cross_block_tables, + ) = ( + sum(encoder_seq_lens), + encoder_seq_lens, + encoder_seq_lens_tensor, + max_encoder_seq_len, + encoder_seq_start_loc, + cross_slot_mapping_tensor, + cross_block_tables, + ) + + return (attn_metadata, encoder_input_tokens_tensor, + encoder_input_positions_tensor) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py new file mode 100644 index 0000000..6b1593e --- /dev/null +++ b/vllm/worker/hpu_model_runner.py @@ -0,0 +1,2113 @@ +# SPDX-License-Identifier: Apache-2.0 + +############################################################################### +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company +############################################################################### + +import collections +import contextlib +import dataclasses +import functools +import gc +import itertools +import math +import operator +import os +import time +from array import array +from dataclasses import dataclass, field +from enum import IntEnum +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, + Optional, Set, Tuple, Type, TypeVar, Union) + +import habana_frameworks.torch as htorch +import habana_frameworks.torch.internal.bridge_config as bc +import torch +import torch.nn as nn +from vllm_hpu_extension.ops import LoraMask as LoraMask +from vllm_hpu_extension.ops import batch2block, block2batch +from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler, + HabanaMemoryProfiler, format_bytes) + +import vllm.envs as envs +from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.config import DeviceConfig, VllmConfig +from vllm.distributed.parallel_state import get_world_group +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader import get_model +from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, + MultiModalKwargs) +from vllm.sampling_params import SamplingParams +from vllm.sequence import (IntermediateTensors, SequenceData, + SequenceGroupMetadata) +from vllm.utils import (bind_kv_cache, is_pin_memory_available, + make_tensor_with_pad) +from vllm.worker.model_runner_base import ( + ModelRunnerBase, ModelRunnerInputBase, + _add_attn_metadata_broadcastable_dict, + _add_sampling_metadata_broadcastable_dict, + _init_attn_metadata_from_tensor_dict, + _init_sampling_metadata_from_tensor_dict) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +logger = init_logger(__name__) + +_TYPE_CACHE = {} +# These values are assumed to be zero in several places. +# Use caution when updating them! +_PAD_SLOT_ID = 0 +_PAD_BLOCK_ID = 0 + +LORA_WARMUP_RANK = 8 + + +class Singleton(type): + _instances: Dict[type, object] = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super().__call__(*args, **kwargs) + return cls._instances[cls] + + +@dataclass +class HPUBucketingGlobalState(metaclass=Singleton): + prompt_bs_bucket_cfg: Tuple[int, int, int] = field(init=False) + decode_bs_bucket_cfg: Tuple[int, int, int] = field(init=False) + prompt_seq_bucket_cfg: Tuple[int, int, int] = field(init=False) + decode_block_bucket_cfg: Tuple[int, int, int] = field(init=False) + prompt_buckets: List[Tuple[int, int]] = field(init=False) + decode_buckets: List[Tuple[int, int]] = field(init=False) + + +def subtuple(obj: object, + typename: str, + to_copy: List[str], + to_override: Optional[Dict[str, object]] = None): + if obj is None: + return None + if to_override is None: + to_override = {} + fields = set(to_copy) | set(to_override.keys()) + values = {f: to_override.get(f, getattr(obj, f)) for f in fields} + if typename not in _TYPE_CACHE: + _TYPE_CACHE[typename] = collections.namedtuple(typename, + ' '.join(fields)) + return _TYPE_CACHE[typename](**values) + + +def read_bucket_settings(phase: str, dim: str, **defaults): + """Read bucketing configuration from env variables. + + phase is either 'prompt' or 'decode' + dim is either 'bs', 'seq' or 'block' + param is either 'min', 'step' or 'max' + example env variable: VLLM_DECODE_BS_BUCKET_STEP=128 + """ + params = ['min', 'step', 'max'] + env_vars = [f'VLLM_{phase}_{dim}_BUCKET_{p}'.upper() for p in params] + default_values = [defaults[p] for p in params] + values = [ + int(os.environ.get(e, d)) for e, d in zip(env_vars, default_values) + ] + for e, v, d in zip(env_vars, values, default_values): + logger.info('%s=%s (default:%s)', e, v, d) + return values + + +def warmup_range(config: Tuple[int, int, int]): + """Generate a warmup range. + + Start from bmin and multiply by 2 until you reach bstep. + Then, increase the values in the range by the value of bstep until you + reach bmax. + + Example: + bmin = 2, bstep = 32, bmax = 64 + => ramp_up = (2, 4, 8, 16) + => stable = (32, 64) + => return ramp_up + stable => (2, 4, 8, 16, 32, 64) + """ + bmin, bstep, bmax = config + assert bmin <= bmax, ("Min. batch size cannot be greater than max. " + "batch size. If you want to skip warmup, " + "set VLLM_SKIP_WARMUP=true") + base = itertools.repeat(2) + ramp_up_acc = itertools.accumulate(base, func=operator.mul, initial=bmin) + ramp_up_tw = itertools.takewhile(lambda x: x < bstep and x <= bmax, \ + ramp_up_acc) + stable = range(bstep, bmax + 1, bstep) + buckets = list(ramp_up_tw) + list(stable) + return list(filter(lambda bucket: bucket >= bmin, buckets)) + + +def generate_prompt_buckets(bs_bucket_config, + seq_bucket_config, + max_num_batched_tokens=None): + buckets = list( + itertools.product(warmup_range(bs_bucket_config), + warmup_range(seq_bucket_config))) + if len(buckets) == 0: + msg = ("No buckets could be captured with following config " + f"(min, step, max_warmup): " + f"bs:{bs_bucket_config}, " + f"seq:{seq_bucket_config}") + raise ValueError(msg) + + filtered_buckets = buckets + if max_num_batched_tokens is not None: + # Remove buckets exceeding batch token budget + filtered_buckets = list( + filter( + lambda bucket: bucket[0] * bucket[1] <= max_num_batched_tokens, + buckets)) + + if len(filtered_buckets) == 0: + # we can handle this if we ignore max_num_batched_tokens + min_bucket_bs, min_bucket_seq = min(buckets, + key=lambda b: (b[0] * b[1])) + min_reqd_budget = min_bucket_bs * min_bucket_seq + msg = ( + "The current bucketing configuration " + f"(min, step, max_warmup): " + f"bs:{bs_bucket_config}, " + f"seq:{seq_bucket_config} cannot be used with specified " + f"max_num_batched_tokens ({max_num_batched_tokens}), as the " + f"smallest bucket ({min_reqd_budget}) would exceed token " + "budget. Please increase max_num_batched_tokens or decrease " + "bucket minimum Ignoring max_num_batched_tokens at risk of " + "out-of-memory errors.") + logger.error(msg) + return list( + sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))), [] + + captured_buckets = list( + sorted(filtered_buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) + omitted_buckets = list( + sorted([x for x in buckets if x not in filtered_buckets])) + return captured_buckets, omitted_buckets + + +def generate_decode_buckets(bs_bucket_config, blocks_bucket_config, + max_blocks): + buckets = [] + bs_buckets = warmup_range(bs_bucket_config) + block_buckets = warmup_range(blocks_bucket_config) + bmin, bstep, bmax = blocks_bucket_config + last_bucket = round_up(max_blocks, bstep) + for bs in bs_buckets: + for blocks in block_buckets: + if blocks < bs: + continue + if blocks > last_bucket: + break + buckets.append((bs, blocks)) + return list(sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) + + +def next_pow2(value: int, base: int): + res = base + while value > 1: + value = (value + 1) // 2 + res *= 2 + return res + + +def round_up(value: int, k: int): + return (value + k - 1) // k * k + + +def find_bucket(value: int, config: Tuple[int, int, int]): + bmin, bstep, _ = config + next_step = round_up(value, bstep) + next_pow = next_pow2(value, bmin) + return max(bmin, min(next_step, next_pow)) + + +def align_workers(value, op): + group = get_world_group().cpu_group + world_size = torch.distributed.get_world_size() + if world_size <= 1: + return value + value_t = torch.tensor(value, device='cpu') + torch.distributed.all_reduce(value_t, op=op, group=group) + return value_t.item() + + +def setup_profiler(): + schedule = torch.profiler.schedule(wait=0, warmup=2, active=1, repeat=1) + DEVICE = 'hpu' + activities = [torch.profiler.ProfilerActivity.CPU] + activities.extend([torch.profiler.ProfilerActivity.HPU] if DEVICE == + 'hpu' else []) + #from habana_frameworks.torch.activity_profiler import DebugActivity + #debug_activities=[DebugActivity.BRIDGE_FUNCTION_CALLS] + + profiler = torch.profiler.profile( + schedule=schedule, + activities=activities, + #debug_activities=debug_activities, + on_trace_ready=torch.profiler.tensorboard_trace_handler('.', + use_gzip=True), + record_shapes=False, + with_stack=True) + return profiler + + +def pad_list(input, k, v): + input_len = len(input) + target_len = round_up(input_len, k) + padding = target_len - input_len + return input + [v] * padding + + +def gather_list(input, indices, v): + return [input[i] if i is not None else v for i in indices] + + +def flatten(in_list): + return list(itertools.chain(*in_list)) + + +def precompute_indices_and_offsets(block_size, slot_mapping, is_prompt): + slot_mapping = slot_mapping.flatten() + indices = torch.div(slot_mapping, block_size, rounding_mode="floor") + if is_prompt: + indices = indices.unflatten(0, (-1, block_size))[:, 0] + offsets = None + else: + offsets = torch.fmod(slot_mapping, block_size) + return indices, offsets + + +def modify_decoder_layer(module: torch.nn.Module, suffix="DecoderLayer"): + if module.__class__.__name__.endswith(suffix): + + def forward_hook(module, args, output): + htorch.core.mark_step() + return output + + module.register_forward_hook(forward_hook) + + for child_name, child_module in module.named_children(): + modify_decoder_layer(child_module) + + +class HpuModelAdapter: + + def __init__(self, model, vllm_config): + self.model = model + self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', + '0').lower() in ['1', 'true'] + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + self.dtype = vllm_config.model_config.dtype + enforce_eager = vllm_config.model_config.enforce_eager + + if not htorch.utils.internal.is_lazy() and not enforce_eager: + if os.getenv('VLLM_REGIONAL_COMPILATION', + 'true').lower() == 'true': + self.regional_compilation_layers_list = [ + RMSNorm, VocabParallelEmbedding + ] + self._regional_compilation(self.model) + else: + self.model = torch.compile(self.model, + backend='hpu_backend', + dynamic=False) + + def _regional_compilation(self, + module, + parent_module=None, + module_name=None): + if isinstance(module, torch.nn.ModuleList): + for children_name, children_module in module.named_children(): + self._compile_region(module, children_name, children_module) + elif any( + isinstance(module, layer) + for layer in self.regional_compilation_layers_list): + self._compile_region(parent_module, module_name, module) + else: + for children_name, children_module in module.named_children(): + self._regional_compilation(children_module, module, + children_name) + + def _compile_region(self, model, name, module): + module = torch.compile(module, backend='hpu_backend', dynamic=False) + setattr(model, name, module) + + def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, + dtype): + prefill_metadata = attn_metadata + if prefill_metadata is None or self.prefill_use_fusedsdpa: + return attn_metadata + + seq_lens_t = prefill_metadata.seq_lens_tensor + len_mask = (torch.arange(0, seq_len, device=device, + dtype=torch.int32).view(1, seq_len).ge( + seq_lens_t.unsqueeze(-1)).view( + batch_size, 1, 1, seq_len)) + causal_mask = torch.triu(torch.ones((batch_size, 1, seq_len, seq_len), + device=device, + dtype=torch.bool), + diagonal=1) + mask = causal_mask.logical_or(len_mask) + attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( + mask, -math.inf)) + attn_metadata = prefill_metadata._replace(attn_bias=attn_bias) + return attn_metadata + + def _set_block_mapping(self, metadata, batch_size, device, dtype): + mask = torch.arange(0, + self.block_size, + device=device, + dtype=torch.int32).unsqueeze(0) + mask = mask >= metadata.block_usage.unsqueeze(-1) + attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( + mask, -math.inf)) + if os.environ.get('VLLM_USE_FAKE_HPU', + '0') == '0' and htorch.utils.internal.is_lazy(): + block_mapping = torch.nn.functional.one_hot(metadata.block_groups, + num_classes=batch_size) + else: + # Unfortunately one_hot on CPU/torch.compile mode/eager mode + # doesn't handle out of bounds classes so we need to convert + # all negative values to 0 (block_mapping) or bs (block_groups) + block_groups = metadata.block_groups.to(torch.long) + block_mapping = torch.nn.functional.relu(block_groups) + block_mapping = torch.nn.functional.one_hot(block_mapping, + num_classes=batch_size) + oob_values = block_groups.lt(0) + block_mapping.masked_fill_(oob_values.unsqueeze(-1), 0) + block_groups.masked_fill_(oob_values, batch_size) + metadata = metadata._replace(block_groups=block_groups) + block_mapping = block_mapping.to(dtype) + metadata = metadata._replace(block_mapping=block_mapping, + attn_bias=attn_bias) + return metadata + + def _set_block_scales(self, metadata, device): + block_mapping = metadata.block_mapping + ones = torch.ones((block_mapping.size(0), ), + device=device, + dtype=block_mapping.dtype) + sums = batch2block(block2batch(ones, block_mapping), block_mapping) + block_scales = torch.reciprocal(torch.maximum(ones, sums)) + metadata = metadata._replace(block_scales=block_scales) + return metadata + + def _update_metadata(self, attn_metadata, batch_size, seq_len, device, + dtype): + if attn_metadata.is_prompt: + meta = attn_metadata + attn_metadata = self._set_attn_bias(meta, batch_size, seq_len, + device, dtype) + else: + meta = attn_metadata + attn_metadata = self._set_block_mapping(meta, batch_size, device, + dtype) + attn_metadata = self._set_block_scales(attn_metadata, device) + return attn_metadata + + def forward(self, *args, **kwargs): + kwargs = kwargs.copy() + selected_token_indices = kwargs.pop('selected_token_indices') + if 'warmup_mode' in kwargs: + kwargs.pop('warmup_mode') + virtual_engine = 0 + if 'virtual_engine' in kwargs: + virtual_engine = kwargs.pop('virtual_engine') + input_ids = kwargs['input_ids'] + attn_metadata = self._update_metadata(kwargs.pop('attn_metadata'), + input_ids.size(0), + input_ids.size(1), + input_ids.device, self.dtype) + LoraMask.setLoraMask(kwargs.pop('lora_mask')) + with set_forward_context(attn_metadata, self.vllm_config, + virtual_engine): + hidden_states = self.model(*args, **kwargs) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + hidden_states = hidden_states.index_select(0, + selected_token_indices) + return hidden_states + + def compute_logits(self, *args, **kwargs): + return self.model.compute_logits(*args, **kwargs) + + def sample(self, *args, **kwargs): + return self.model.sample(*args, **kwargs) + + +class PreparePromptMetadata(NamedTuple): + input_tokens: torch.Tensor + input_positions: List[List[int]] + attn_metadata: Optional[AttentionMetadata] + seq_lens: List[int] + query_lens: List[int] + lora_index_mapping: List[List[int]] + lora_prompt_mapping: List[List[int]] + lora_requests: Set[LoRARequest] + multi_modal_kwargs: Optional[Dict[str, BatchedTensorInputs]] + slot_mapping: List[List[int]] + lora_ids: List[int] + + @classmethod + def empty(cls): + return PreparePromptMetadata(input_tokens=[], + input_positions=[], + attn_metadata=None, + seq_lens=[], + query_lens=[], + lora_index_mapping=[], + lora_prompt_mapping=[], + lora_requests=set(), + multi_modal_kwargs=None, + slot_mapping=[], + lora_ids=[]) + + +class PrepareDecodeMetadata(NamedTuple): + input_tokens: torch.Tensor + input_positions: List[List[int]] + attn_metadata: Optional[AttentionMetadata] + lora_index_mapping: List[List[int]] + lora_prompt_mapping: List[List[int]] + lora_requests: Set[LoRARequest] + slot_mapping: List[List[int]] + lora_ids: List[int] + + @classmethod + def empty(cls): + return PrepareDecodeMetadata(input_tokens=[], + input_positions=[], + attn_metadata=None, + lora_index_mapping=[], + lora_prompt_mapping=[], + lora_requests=set(), + slot_mapping=[], + lora_ids=[]) + + +# How batches are constructed. +class BatchType(IntEnum): + # Every batch is prefill. + PREFILL = 0 + # Every batch is decode. + DECODE = 1 + # Batch is a mixture of prefill and decode. + MIXED = 2 + + +TModelInputForHPU = TypeVar('TModelInputForHPU', bound="ModelInputForHPU") + + +@dataclasses.dataclass(frozen=True) +class ModelInputForHPU(ModelRunnerInputBase): + """ + This base class contains metadata needed for the base model forward pass + but not metadata for possible additional steps, e.g., sampling. Model + runners that run additional steps should subclass this method to add + additional fields. + """ + input_tokens: Optional[torch.Tensor] = None + input_positions: Optional[torch.Tensor] = None + seq_lens: Optional[List[int]] = None + query_lens: Optional[List[int]] = None + lora_mapping: Optional["LoRAMapping"] = None + lora_requests: Optional[Set[LoRARequest]] = None + attn_metadata: Optional["AttentionMetadata"] = None + multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None + real_batch_size: Optional[int] = None + batch_size_padded: Optional[int] = None + virtual_engine: int = 0 + lora_ids: Optional[List[int]] = None + async_callback: Optional[Callable] = None + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + "lora_requests": self.lora_requests, + "lora_mapping": self.lora_mapping, + "multi_modal_kwargs": self.multi_modal_kwargs, + "real_batch_size": self.real_batch_size, + "batch_size_padded": self.batch_size_padded, + "virtual_engine": self.virtual_engine, + "lora_ids": self.lora_ids, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls: Type[TModelInputForHPU], + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> TModelInputForHPU: + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) + + +@dataclasses.dataclass(frozen=True) +class ModelInputForHPUWithSamplingMetadata(ModelInputForHPU): + """ + Used by the ModelRunner. + """ + sampling_metadata: Optional["SamplingMetadata"] = None + # Used for speculative decoding. We do not broadcast it because it is only + # used by the driver worker. + is_prompt: Optional[bool] = None + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + "lora_requests": self.lora_requests, + "lora_mapping": self.lora_mapping, + "multi_modal_kwargs": self.multi_modal_kwargs, + "lora_ids": self.lora_ids, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "ModelInputForHPUWithSamplingMetadata": + tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) + # FIXME(kzawora): this fails for whatever reason - why? + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) + + +class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): + """ + Helper class for shared methods between GPU model runners. + """ + _model_input_cls: Type[TModelInputForHPU] + + def __init__( + self, + vllm_config: VllmConfig, + is_driver_worker: bool = False, + return_hidden_states: bool = False, + ): + ModelRunnerBase.__init__(self, vllm_config=vllm_config) + self.is_driver_worker = is_driver_worker + self.return_hidden_states = return_hidden_states + + self.sliding_window = (self.model_config.get_sliding_window() + if self.model_config is not None else None) + self.device_config = (self.device_config if self.device_config + is not None else DeviceConfig()) + self.device = self.device_config.device + self.enforce_eager = self.model_config.enforce_eager + self.max_num_seqs = self.scheduler_config.max_num_seqs + # NOTE(kzawora): Change that to scheduler_config.max_num_prefill_seqs + # once padding-aware scheduling gets merged + self.max_num_prefill_seqs = 64 + self.max_model_len = self.scheduler_config.max_model_len + self.max_num_batched_tokens = \ + self.scheduler_config.max_num_batched_tokens + self.block_size = self.cache_config.block_size + + self.pin_memory = is_pin_memory_available() + self.kv_cache_dtype = self.cache_config.cache_dtype + + self.attn_backend = get_attn_backend( + self.model_config.get_head_size(), + self.model_config.dtype, + self.kv_cache_dtype, + self.block_size, + self.model_config.is_attention_free, + ) + + # Lazy initialization + self.lora_manager: LRUCacheWorkerLoRAManager = None + self.model: torch.nn.Module = None + self.inc_initialized_successfully = False + + # Profiler stats + self.profiler = HabanaHighLevelProfiler() + self.profiler_counter_helper = HabanaProfilerCounterHelper() + self.seen_configs: set = set() + self._mem_margin: Optional[int] = None + self.bucketing_global_state = HPUBucketingGlobalState() + self._setup_buckets() + self._set_gc_threshold() + self.use_contiguous_pa = envs.VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH + + def _set_gc_threshold(self) -> None: + # Read https://docs.python.org/3/library/gc.html#gc.set_threshold + # for comprehensive description of gc generations. + # We can either use VLLM_GC_THR_GEN[0-2] (this has higher priority) + # to set particular generation threshold or use simpler + # VLLM_GC_THR_MULTIPLIER to multiply default values. + default_gc_thrs = list(gc.get_threshold()) + requested_gc_thrs = [0] * len(default_gc_thrs) + for i in range(len(default_gc_thrs)): + requested_gc_thrs[i] = int( + os.environ.get(f'VLLM_GC_THR_GEN{i}', default_gc_thrs[i])) + if requested_gc_thrs == default_gc_thrs: + gc_thr_multiplier = int(os.environ.get('VLLM_GC_THR_MULTIPLIER', + 2)) + requested_gc_thrs = [ + t * gc_thr_multiplier for t in default_gc_thrs + ] + gc.set_threshold(*requested_gc_thrs) + + # Multi-modal data support + self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ + .create_input_mapper(self.model_config) + + self.skip_warmup = os.environ.get('VLLM_SKIP_WARMUP', + 'false').lower() == 'true' + + def load_model(self) -> None: + import habana_frameworks.torch.core as htcore + if self.model_config.quantization == 'inc' or \ + self.model_config.quantization == 'fp8': + htcore.hpu_set_env() + with HabanaMemoryProfiler() as m: + with HabanaMemoryProfiler() as m_getmodel: + self.model = get_model(vllm_config=self.vllm_config) + msg = ("Pre-loading model weights on " + f"{next(self.model.parameters()).device} " + f"took {m_getmodel.get_summary_string()}") + logger.info(msg) + + if self.lora_config: + assert hasattr(self.model, "embedding_modules" + ), "Model does not have embedding_modules" + assert hasattr( + self.model, "embedding_padding_modules" + ), "Model does not have embedding_padding_modules" + assert not self.lora_config.bias_enabled, \ + "Bias support in LoRA is not enabled in HPU yet." + assert not self.lora_config.fully_sharded_loras, \ + "Fully sharded LoRAs is not enabled in HPU yet." + # It's necessary to distinguish between the + # max_position_embeddings of VLMs and LLMs. + if hasattr(self.model.config, "max_position_embeddings"): + max_pos_embeddings = ( + self.model.config.max_position_embeddings) + else: + max_pos_embeddings = ( + self.model.config.text_config.max_position_embeddings) + + self.lora_manager = LRUCacheWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, + self.vocab_size, + self.lora_config, + self.device, + self.model.embedding_modules, + self.model.embedding_padding_modules, + max_position_embeddings=max_pos_embeddings, + ) + self.model = self.lora_manager.create_lora_manager(self.model) + + if self.model_config.quantization == 'inc': + logger.info("Preparing model with INC..") + with HabanaMemoryProfiler() as m_inc: + from neural_compressor.torch.quantization import ( + FP8Config, convert, prepare) + config = FP8Config.from_json_file( + os.getenv("QUANT_CONFIG", "")) + if config.measure: + self.model = prepare(self.model, config) + elif config.quantize: + self.model = convert(self.model, config) + htcore.hpu_initialize(self.model, + mark_only_scales_as_const=True) + self.inc_initialized_successfully = True + logger.info("Preparing model with INC took %s", + m_inc.get_summary_string()) + else: + self.model = self.model.to("hpu") + htcore.mark_step() + modify_decoder_layer(self.model) + torch.hpu.synchronize() + + with HabanaMemoryProfiler() as m_wrap: + self.model = _maybe_wrap_in_hpu_graph( + self.model, vllm_config=self.vllm_config) + msg = f"Wrapping in HPU Graph took {m_wrap.get_summary_string()}" + logger.info(msg) + + self.model_memory_usage = m.consumed_device_memory + msg = f"Loading model weights took in total {m.get_summary_string()}" + logger.info(msg) + + def get_model(self) -> nn.Module: + return self.model + + def _use_graphs(self, batch_size, seq_len, is_prompt): + if self.enforce_eager: + return False + if self.skip_warmup: + return True + return (batch_size, seq_len, is_prompt) in self.graphed_buckets + + def _is_valid_bucket(self, bucket): + return bucket[0] * bucket[1] <= self.max_num_batched_tokens + + def _setup_buckets(self) -> None: + align_bs = lambda x: min(self.max_num_seqs, x) + #FIXME: The default values should be max_model_len + max_prompt_seq = 1024 + max_decode_seq = 2048 + self.bucketing_global_state.prompt_bs_bucket_cfg = read_bucket_settings( + 'prompt', + 'bs', + min=1, + step=align_bs(32), + max=self.max_num_prefill_seqs) + self.bucketing_global_state.decode_bs_bucket_cfg = read_bucket_settings( + 'decode', 'bs', min=1, step=align_bs(32), max=self.max_num_seqs) + self.bucketing_global_state.prompt_seq_bucket_cfg = \ + read_bucket_settings( + 'prompt', + 'seq', + min=self.block_size, + step=self.block_size, + max=max_prompt_seq) + self.bucketing_global_state.decode_block_bucket_cfg = \ + read_bucket_settings( + 'decode', + 'block', + min=self.block_size, + step=self.block_size, + max=max(self.block_size, + self.max_num_seqs * max_decode_seq // self.block_size)) + self.graphed_buckets: Set[Any] = set() + + msg = ("Prompt bucket config (min, step, max_warmup) " + f"bs:{self.bucketing_global_state.prompt_bs_bucket_cfg}, " + f"seq:{self.bucketing_global_state.prompt_seq_bucket_cfg}") + logger.info(msg) + + msg = ("Decode bucket config (min, step, max_warmup) " + f"bs:{self.bucketing_global_state.decode_bs_bucket_cfg}, " + f"block:{self.bucketing_global_state.decode_block_bucket_cfg}") + logger.info(msg) + + def _prepare_prompt( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> PreparePromptMetadata: + input_tokens: List[List[int]] = [] + input_positions: List[List[int]] = [] + slot_mapping: List[List[int]] = [] + lora_index_mapping: List[List[int]] = [] + lora_prompt_mapping: List[List[int]] = [] + lora_requests: Set[LoRARequest] = set() + + seq_lens: List[int] = [] + context_lens: List[int] = [] + query_lens: List[int] = [] + prefix_block_tables: List[List[int]] = [] + multi_modal_kwargs_list: List[MultiModalKwargs] = [] + + if len(seq_group_metadata_list) == 0: + return PreparePromptMetadata.empty() + + for seq_group_metadata in seq_group_metadata_list: + assert seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + + computed_block_nums = seq_group_metadata.computed_block_nums + if (self.scheduler_config is not None + and self.scheduler_config.chunked_prefill_enabled + and not (computed_block_nums is None + or computed_block_nums == [])): + raise RuntimeError( + "chunked prefill cannot be used with prefix caching " + "now.") + + token_chunk_size = seq_group_metadata.token_chunk_size + seq_data = seq_group_metadata.seq_data[seq_id] + context_len = seq_data.get_num_computed_tokens() + # We should use get_len here because in case of preemption + # it contains output tokens. + seq_len = min(seq_data.get_len(), context_len + token_chunk_size) + prompt_tokens = seq_data.get_token_ids()[context_len:seq_len] + seq_lens.append(seq_len) + + # NOTE: This only works for oooooooxxx style attention. + if computed_block_nums is not None and len( + computed_block_nums) > 0 and self.sliding_window is None: + # Prefix is not supported with sliding_window + context_len = len(computed_block_nums) * self.block_size + prompt_tokens = prompt_tokens[context_len:] + prefix_block_tables.append(computed_block_nums) + elif self.scheduler_config.chunked_prefill_enabled: + if seq_group_metadata.block_tables is not None: + # Prefill has chunked before. + block_table = seq_group_metadata.block_tables[seq_id] + prefix_block_tables.append(block_table) + else: + # The first prefill. + prefix_block_tables.append([]) + else: + prefix_block_tables.append([]) + # Right now, prefill start is always 0. However, this + # assumption can be changed once chunked prefill is introduced. + assert context_len == 0 + + # actual prompt lens + context_lens.append(context_len) + query_lens.append(seq_len - context_len) + input_tokens.append(prompt_tokens) + # NOTE(woosuk): Here we assume that the first token in the prompt + # is always the first token in the sequence. + input_positions.append(list(range(context_len, seq_len))) + + mm_data = seq_group_metadata.multi_modal_data + if mm_data: + mm_kwargs = self.multi_modal_input_mapper(mm_data) + multi_modal_kwargs_list.append(mm_kwargs) + + if seq_group_metadata.block_tables is None: + # During memory profiling, the block tables are not initialized + # yet. In this case, we just use a dummy slot mapping. + slot_mapping.append([_PAD_SLOT_ID] * seq_len) + continue + + # Compute the slot mapping. + slot_mapping.append([]) + block_table = seq_group_metadata.block_tables[seq_id] + + # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, + # where start_idx is max(0, seq_len - sliding_window). + # For example, if the prompt len is 10, sliding window is 8, and + # block size is 4, the first two tokens are masked and the slot + # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + start_idx = 0 + if self.sliding_window is not None: + assert context_len == 0, ( + "Prefix caching is currently not supported with " + "sliding window attention") + start_idx = max(0, seq_len - self.sliding_window) + for i in range(context_len, seq_len): + if i < start_idx: + slot_mapping[-1].append(_PAD_SLOT_ID) + continue + + block_number = block_table[i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping[-1].append(slot) + + max_query_len = max(query_lens) + sum_query_len = sum(query_lens) + real_num_seqs = len(query_lens) + assert max_query_len > 0 + + max_prompt_len = max( + find_bucket(max(seq_lens), + self.bucketing_global_state.prompt_seq_bucket_cfg), + self.block_size) + + lora_ids: List[int] = [] + for seq_group_metadata, context_len in zip(seq_group_metadata_list, + context_lens): + lora_id = seq_group_metadata.lora_int_id + lora_ids.append(lora_id) + + if lora_id > 0: + lora_requests.add(seq_group_metadata.lora_request) + + lora_index_mapping += [lora_id] * (max_prompt_len - context_len) + lora_prompt_mapping.extend( + [lora_id] * + (max_prompt_len - context_len + if seq_group_metadata.sampling_params.prompt_logprobs else 1)) + + input_tokens = make_tensor_with_pad(input_tokens, + max_len=max_prompt_len, + pad=0, + dtype=torch.long, + device=self.device) + + input_positions = make_tensor_with_pad(input_positions, + max_len=max_prompt_len, + pad=0, + dtype=torch.long, + device=self.device) + + slot_mapping = make_tensor_with_pad(slot_mapping, + max_len=max_prompt_len, + pad=_PAD_SLOT_ID, + dtype=torch.long, + device=self.device) + + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.long, + device=self.device) + + block_indices, block_offsets = precompute_indices_and_offsets( + self.block_size, slot_mapping, True) + attn_metadata = self.attn_backend.make_metadata( + is_prompt=True, + block_list=None, + block_mapping=None, + block_usage=None, + block_indices=block_indices, + block_offsets=block_offsets, + block_scales=None, + block_groups=None, + attn_bias=None, + seq_lens_tensor=seq_lens_tensor, + num_prefills=real_num_seqs, + num_prefill_tokens=sum_query_len, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps= + None, # FIXME(kzawora): mutli-modality will not work here + enable_kv_scales_calculation=False, + ) + multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) + + return PreparePromptMetadata(input_tokens=input_tokens, + input_positions=input_positions, + attn_metadata=attn_metadata, + seq_lens=seq_lens, + query_lens=query_lens, + lora_index_mapping=lora_index_mapping, + lora_prompt_mapping=lora_prompt_mapping, + lora_requests=lora_requests, + multi_modal_kwargs=multi_modal_kwargs, + slot_mapping=slot_mapping, + lora_ids=lora_ids) + + def _prepare_decode( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> PrepareDecodeMetadata: + input_tokens: List[List[int]] = [] + input_positions: List[List[int]] = [] + slot_mapping: List[List[int]] = [] + seq_lens: List[int] = [] + block_tables: List[List[int]] = [] + lora_index_mapping: List[List[int]] = [] + lora_prompt_mapping: List[List[int]] = [] + lora_requests: Set[LoRARequest] = set() + + if len(seq_group_metadata_list) == 0: + return PrepareDecodeMetadata.empty() + lora_ids: List[int] = [] + + dummy_slots = itertools.cycle( + range(_PAD_SLOT_ID, _PAD_SLOT_ID + self.block_size)) + + for seq_group_metadata in seq_group_metadata_list: + assert not seq_group_metadata.is_prompt + assert seq_group_metadata.token_chunk_size == 1 + + seq_ids = list(seq_group_metadata.seq_data.keys()) + lora_id = seq_group_metadata.lora_int_id + lora_ids.append(lora_id) + + if lora_id > 0: + lora_requests.add(seq_group_metadata.lora_request) + + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + generation_token = seq_data.get_last_token_id() + input_tokens.append([generation_token]) + + seq_len = seq_data.get_len() + position = seq_len - 1 + input_positions.append([position]) + + seq_len = seq_len if self.sliding_window is None else min( + seq_len, self.sliding_window) + seq_lens.append(seq_len) + + block_table = seq_group_metadata.block_tables[seq_id] + if len(block_table) == 0: + block_number = _PAD_BLOCK_ID + else: + block_number = block_table[position // self.block_size] + if block_number == _PAD_BLOCK_ID: + slot = next(dummy_slots) + else: + block_offset = position % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append([slot]) + lora_index_mapping.append(lora_id) + lora_prompt_mapping.append(lora_id) + + if self.sliding_window is not None: + sliding_window_blocks = (self.sliding_window // + self.block_size) + block_table = block_table[-sliding_window_blocks:] + block_tables.append(block_table) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + + num_decode_tokens = sum(seq_lens) + + last_block_usage = [ + slot[0] % self.block_size + 1 for slot in slot_mapping + ] + block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)] + block_usage = [[self.block_size] * (len(bt) - 1) + [lbu] + for bt, lbu in zip(block_tables, last_block_usage) + if bt] + + block_list = flatten(block_tables) + block_groups = flatten(block_groups) + block_usage = flatten(block_usage) + + assert len(block_list) == len(block_groups) + assert len(block_list) == len(block_usage) + + padding_fn = None + if self.use_contiguous_pa: + block_bucket_size = max(max(block_list) + 1, len(block_list)) + block_bucket_size = find_bucket( + block_bucket_size, + self.bucketing_global_state.decode_block_bucket_cfg) + indices: List[Any] + indices = [None] * block_bucket_size + for i, bid in enumerate(block_list): + indices[bid] = i + padding_fn = lambda tensor, pad_value: gather_list( + tensor, indices, pad_value) + else: + block_bucket_size = find_bucket( + len(block_list), + self.bucketing_global_state.decode_block_bucket_cfg) + padding_fn = lambda tensor, pad_value: pad_list( + tensor, block_bucket_size, pad_value) + + block_list = padding_fn(block_list, _PAD_BLOCK_ID) + block_groups = padding_fn(block_groups, -1) + block_usage = padding_fn(block_usage, 1) + + block_list = torch.tensor(block_list, + dtype=torch.int, + device=self.device) + block_groups = torch.tensor(block_groups, + dtype=torch.int, + device=self.device) + block_usage = torch.tensor(block_usage, + dtype=self.model_config.dtype, + device=self.device) + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) + + block_indices, block_offsets = precompute_indices_and_offsets( + self.block_size, slot_mapping, False) + + attn_metadata = self.attn_backend.make_metadata( + is_prompt=False, + block_list=block_list, + block_mapping=None, + block_usage=block_usage, + block_indices=block_indices, + block_offsets=block_offsets, + block_scales=None, + block_groups=block_groups, + attn_bias=None, + seq_lens_tensor=None, + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=num_decode_tokens, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + ) + return PrepareDecodeMetadata(input_tokens=input_tokens, + input_positions=input_positions, + attn_metadata=attn_metadata, + lora_index_mapping=lora_index_mapping, + lora_prompt_mapping=lora_prompt_mapping, + lora_requests=lora_requests, + slot_mapping=slot_mapping, + lora_ids=lora_ids) + + def prepare_input_tensors( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[TModelInputForHPU, SamplingMetadata]: + if len(seq_group_metadata_list) == 0: + return self._model_input_cls(), None + + input_tokens = None + input_positions = None + lora_mapping = None + lora_requests = None + multi_modal_kwargs = None + batch_type = None + seq_lens = None + query_lens = None + real_batch_size = None + batch_size_padded = None + + self.event_start = self.profiler.get_timestamp_us() + is_prompt = seq_group_metadata_list[0].is_prompt + base_event_name = 'prompt' if is_prompt else 'decode' + self.profiler.start('internal', base_event_name) + + real_batch_size = len(seq_group_metadata_list) + bucket_cfg = self.bucketing_global_state.prompt_bs_bucket_cfg \ + if is_prompt else self.bucketing_global_state.decode_bs_bucket_cfg + batch_size_padded = find_bucket(real_batch_size, bucket_cfg) + batch_size_padding = batch_size_padded - real_batch_size + seq_group_metadata_list = seq_group_metadata_list.copy() + if batch_size_padding > 0: + dummy_seq_group_metadata = self.create_dummy_seq_group_metadata( + 0, 0, is_prompt) + seq_group_metadata_list.extend(dummy_seq_group_metadata + for _ in range(batch_size_padding)) + + prefill_reqs = [] + decode_reqs = [] + for seq_group_meta in seq_group_metadata_list: + if seq_group_meta.is_prompt: + prefill_reqs.append(seq_group_meta) + else: + decode_reqs.append(seq_group_meta) + + # Prepare input tensors. + ( + input_tokens, + input_positions, + prefill_attn_metadata, + seq_lens, + query_lens, + lora_index_mapping, + lora_prompt_mapping, + lora_requests, + multi_modal_kwargs, + slot_mapping, + lora_ids, + ) = self._prepare_prompt(prefill_reqs) + ( + decode_input_tokens, + decode_input_positions, + decode_attn_metadata, + decode_lora_index_mapping, + decode_lora_prompt_mapping, + decode_lora_requests, + decode_slot_mapping, + decode_lora_ids, + ) = self._prepare_decode(decode_reqs) + sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, + seq_lens, query_lens, + self.device, + self.pin_memory) + + if not self.scheduler_config.chunked_prefill_enabled: + assert (len(prefill_reqs) and len(decode_reqs)) == 0 + + num_prefills = len(seq_lens) + num_prefill_tokens = len(input_tokens) + num_decode_tokens = len(decode_input_tokens) + + # NOTE(kzawora): Here we diverge from GPU code - we don't + # support mixed batches, so we either use decode or prefill + # inputs, without coalescing. + assert (num_prefills == 0 and num_decode_tokens > 0) or ( + num_prefills > 0 + and num_decode_tokens == 0), "HPU does not support mixed batches!" + if num_decode_tokens > 0: + input_tokens = decode_input_tokens + input_positions = decode_input_positions + slot_mapping = decode_slot_mapping + lora_index_mapping = decode_lora_index_mapping + lora_prompt_mapping = decode_lora_prompt_mapping + lora_requests = decode_lora_requests + lora_ids = decode_lora_ids + + # FIXME: We need to adjust selected_token_indices to accommodate + # for padding + max_len = input_tokens.size(1) + paddings = [max_len - s for s in seq_lens] + paddings = [0] + paddings[:-1] + paddings = list(itertools.accumulate(paddings)) + paddings_prompt_logprobs = [] + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + if seq_group_metadata.sampling_params.prompt_logprobs is not None \ + and seq_group_metadata.is_prompt: + paddings_prompt_logprobs += ([paddings[i]] * seq_lens[i]) + paddings = torch.tensor( + paddings_prompt_logprobs if paddings_prompt_logprobs else paddings, + dtype=sampling_metadata.selected_token_indices.dtype, + device=sampling_metadata.selected_token_indices.device) + sampling_metadata.selected_token_indices.add_(paddings) + + if self.lora_config: + lora_mapping = LoRAMapping( + **dict(index_mapping=lora_index_mapping, + prompt_mapping=lora_prompt_mapping, + is_prefill=(num_prefills > 0))) + else: + lora_mapping = None + + if (prefill_attn_metadata is not None + and decode_attn_metadata is not None): + batch_type = BatchType.MIXED + raise NotImplementedError("Mixed batch is not supported on HPU") + elif prefill_attn_metadata is not None: + batch_type = BatchType.PREFILL + else: + batch_type = BatchType.DECODE + + metadata_dict = { + "input_tokens": input_tokens, + "input_positions": input_positions, + "selected_token_indices": sampling_metadata.selected_token_indices, + "lora_requests": lora_requests, + "lora_mapping": lora_mapping, + "multi_modal_kwargs": multi_modal_kwargs, + "num_prefill_tokens": num_prefill_tokens, + "num_decode_tokens": num_decode_tokens, + "slot_mapping": slot_mapping, + "num_prefills": num_prefills, + "batch_type": batch_type, + "seq_lens": seq_lens, + "query_lens": query_lens + } + if prefill_attn_metadata is not None: + metadata_dict.update(prefill_attn_metadata.asdict_zerocopy()) + else: + assert decode_attn_metadata is not None + metadata_dict.update(decode_attn_metadata.asdict_zerocopy()) + + attn_metadata = prefill_attn_metadata if \ + prefill_attn_metadata is not None else decode_attn_metadata + + return self._model_input_cls(input_tokens=input_tokens, + seq_lens=seq_lens, + query_lens=query_lens, + input_positions=input_positions, + attn_metadata=attn_metadata, + lora_requests=lora_requests, + lora_mapping=lora_mapping, + multi_modal_kwargs=multi_modal_kwargs, + real_batch_size=real_batch_size, + batch_size_padded=batch_size_padded, + lora_ids=lora_ids), \ + sampling_metadata + + def _seq_len(self, attn_metadata): + if attn_metadata.num_prefills != 0: + return attn_metadata.slot_mapping.size(1) + else: + return attn_metadata.block_list.numel() + + def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: + # NOTE(kzawora): To anyone working on this in the future: + # Trimming metadata is required when using HPUGraphs. + # Attention metadata is going to be hashed by PT bridge, and + # appropriate HPUGraphs will be matched based on all inputs' hash. + + # Before you put more keys in here, make sure you know their + # value type and make sure you know how it's going to be hashed. + # You can find that information in input_hash function + # in habana_frameworks/torch/hpu/graphs.py. You can also hash + # it manually with torch.hpu.graphs.input_hash(attention_metadata) + + # If you use primitive types here - they will get hashed based + # on their value. You *will* get lots of excessive graph captures + # (and an OOM eventually) if you decide to put something like + # seq_len int here. + # If you absolutely need a scalar, put it in a tensor. Tensors + # get hashed using their metadata, not their values: + # input_hash(torch.tensor(123)) == input_hash(torch.tensor(321)) + # input_hash(123) != input_hash(321) + # input_hash("abc") != input_hash("cba") + attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [ + 'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping', + 'block_usage', 'slot_mapping', 'is_prompt', 'block_indices', + 'block_offsets', 'block_scales', 'block_groups' + ]) + return attention_metadata + + def create_dummy_seq_group_metadata(self, + group_id, + seq_len, + is_prompt, + lora_request=None): + sampling_params = SamplingParams(temperature=0) + num_blocks = math.ceil(seq_len / self.block_size) + seq_len = max(seq_len, 1) + if is_prompt: + input_len = seq_len + output_len = 0 + block_tables = None + else: + input_len = seq_len - 1 + output_len = 1 + block_tables = {group_id: [_PAD_BLOCK_ID] * num_blocks} + prompt_token_ids = [0] * input_len + output_token_ids = [1] * output_len + prompt_token_ids_array = array('l', prompt_token_ids) # noqa: F821 + seq_data = SequenceData(prompt_token_ids_array) + seq_data.output_token_ids = output_token_ids + return SequenceGroupMetadata(request_id=str(group_id), + is_prompt=(output_len == 0), + seq_data={group_id: seq_data}, + sampling_params=sampling_params, + block_tables=block_tables, + lora_request=lora_request) + + def profile_run(self) -> None: + num_layers = self.model_config.get_num_layers(self.parallel_config) + kv_caches = [None] * num_layers + bind_kv_cache( + self.vllm_config.compilation_config.static_forward_context, + [kv_caches]) + max_seq_len = self.bucketing_global_state.prompt_seq_bucket_cfg[-1] + max_batch_size = min(self.max_num_batched_tokens // max_seq_len, + self.scheduler_config.max_num_seqs) + self.warmup_scenario(max_batch_size, max_seq_len, True, False, True) + return + + def warmup_scenario(self, + batch_size, + seq_len, + is_prompt, + is_pt_profiler_run=False, + is_lora_profile_run=False) -> None: + use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) + scenario_name = ("warmup_" + f"{'prompt' if is_prompt else 'decode'}_" + f"bs{batch_size}_" + f"seq{seq_len}_" + f"graphs{'T' if use_graphs else 'F'}") + # This represents the maximum number of different requests + # that will have unique loras, an therefore the max amount of memory + # consumption create dummy lora request copies from the lora request + # passed in, which contains a lora from the lora warmup path. + dummy_lora_requests: List[LoRARequest] = [] + dummy_lora_requests_per_seq: List[LoRARequest] = [] + if self.lora_config and is_lora_profile_run: + assert self.lora_manager is not None + with self.lora_manager.dummy_lora_cache(): + for idx in range(self.lora_config.max_loras): + lora_id = idx + 1 + dummy_lora_request = LoRARequest( + lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_local_path="/not/a/real/path", + ) + self.lora_manager.add_dummy_lora(dummy_lora_request, + rank=LORA_WARMUP_RANK) + dummy_lora_requests.append(dummy_lora_request) + dummy_lora_requests_per_seq = [ + dummy_lora_requests[idx % len(dummy_lora_requests)] + for idx in range(batch_size) + ] + self.profiler.start('internal', scenario_name) + times = 3 if use_graphs or is_pt_profiler_run else 1 + if is_prompt: + seqs = [ + self.create_dummy_seq_group_metadata( + i, + seq_len, + is_prompt, + lora_request=dummy_lora_requests_per_seq[i] + if dummy_lora_requests_per_seq else None) + for i in range(batch_size) + ] + else: + # FIXME: seq_len is actually number of blocks + blocks = [seq_len // batch_size for _ in range(batch_size)] + blocks[0] += seq_len % batch_size + seqs = [ + self.create_dummy_seq_group_metadata( + i, + b * self.block_size - 1, + is_prompt, + lora_request=dummy_lora_requests_per_seq[i] + if dummy_lora_requests_per_seq else None) + for i, b in enumerate(blocks) + ] + torch.hpu.synchronize() + profiler = None + if is_pt_profiler_run and self.is_driver_worker: + profiler = setup_profiler() + profiler.start() + for _ in range(times): + inputs = self.prepare_model_input(seqs) + self.execute_model(inputs, None, warmup_mode=True) + torch.hpu.synchronize() + if profiler: + profiler.step() + if profiler: + profiler.stop() + self.profiler.end() + gc.collect() + + def remove_all_loras(self): + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.remove_all_adapters() + + def set_active_loras(self, lora_requests: Set[LoRARequest], + lora_mapping: LoRAMapping) -> None: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.set_active_adapters(lora_requests, lora_mapping) + + def add_lora(self, lora_request: LoRARequest) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.add_adapter(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.remove_adapter(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.pin_adapter(lora_id) + + def list_loras(self) -> Set[int]: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.list_adapters() + + def log_warmup(self, phase, i, max_i, batch_size, seq_len): + free_mem = format_bytes( + HabanaMemoryProfiler.current_free_device_memory()) + dim = "num_blocks" + if phase == "Prompt": + dim = "seq_len" + msg = (f"[Warmup][{phase}][{i+1}/{max_i}] " + f"batch_size:{batch_size} " + f"{dim}:{seq_len} " + f"free_mem:{free_mem}") + logger.info(msg) + + def warmup_all_buckets(self, buckets, is_prompt): + for i, (batch_size, seq_len) in enumerate(reversed(buckets)): + self.log_warmup('Prompt' if is_prompt else 'Decode', i, + len(buckets), batch_size, seq_len) + self.warmup_scenario(batch_size, seq_len, is_prompt) + + def warmup_graphs(self, + strategy, + buckets, + is_prompt, + available_mem, + starting_mem=0, + total_batch_seq=0.001): + total_mem = starting_mem + idx = 0 + phase = f'Graph/{"Prompt" if is_prompt else "Decode"}' + num_candidates = len(buckets) + ordering : Union[Callable[[Any], Tuple[Any, Any]], \ + Callable[[Any], Tuple[Any, Any, Any]]] + if strategy == 'min_tokens': + ordering = lambda b: (b[0] * b[1], b[1], b[0]) + elif strategy == 'max_bs': + ordering = lambda b: (-b[0], b[1]) + else: + raise NotImplementedError( + f'Unsupported graph allocation strategy: {strategy}') + buckets = list(sorted(buckets, key=ordering)) + captured_all = True + for idx, (batch_size, seq_len) in enumerate(buckets): + # Graph memory usage is proportional to seq dimension in a batch + batch_seq = batch_size * seq_len if is_prompt else batch_size + mem_estimate = batch_seq / total_batch_seq * total_mem + if mem_estimate >= available_mem: + captured_all = False + continue + graphed_bucket = (batch_size, seq_len, is_prompt) + if graphed_bucket in self.graphed_buckets: + continue + self.graphed_buckets.add(graphed_bucket) + self.log_warmup(phase, idx, num_candidates, batch_size, seq_len) + with HabanaMemoryProfiler() as mem_prof: + self.warmup_scenario(batch_size, seq_len, is_prompt) + used_mem = align_workers(mem_prof.consumed_device_memory, + torch.distributed.ReduceOp.MAX) + available_mem -= used_mem + total_mem += used_mem + total_batch_seq += batch_seq + + return total_mem, total_batch_seq, captured_all + + def log_graph_warmup_summary(self, buckets, is_prompt, total_mem): + num_candidates = len(buckets) + phase = f'Graph/{"Prompt" if is_prompt else "Decode"}' + graphed = list(c[:2] for c in self.graphed_buckets + if c[2] == is_prompt) + if num_candidates == 0: + num_candidates = 1 + msg = (f'{phase} captured:{len(graphed)} ' + f'({100 * len(graphed) / num_candidates:.1f}%) ' + f'used_mem:{format_bytes(total_mem)} ' + f'buckets:{sorted(list(graphed))}') + logger.info(msg) + + @torch.inference_mode() + def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: + if profile := os.environ.get('VLLM_PT_PROFILE', None): + phase, bs, seq_len, graph = profile.split('_') + is_prompt = phase == 'prompt' + graphs = graph == 't' + if graphs: + self.graphed_buckets.add((int(bs), int(seq_len), is_prompt)) + self.warmup_scenario(int(bs), int(seq_len), is_prompt, True) + raise AssertionError("Finished profiling") + if self.skip_warmup: + logger.info("Skipping warmup...") + return + self.profiler.start('internal', 'warmup') + max_blocks = kv_caches[0][0].size(0) + + self.bucketing_global_state.prompt_buckets, prompt_omitted_buckets = \ + generate_prompt_buckets( + self.bucketing_global_state.prompt_bs_bucket_cfg, + self.bucketing_global_state.prompt_seq_bucket_cfg, + self.max_num_batched_tokens) + + msg = (f"Generated {len(self.bucketing_global_state.prompt_buckets)} " + f"prompt buckets [bs, seq]: \ + {list(sorted(self.bucketing_global_state.prompt_buckets))}") + logger.info(msg) + + msg = (f"Omitted {len(prompt_omitted_buckets)} " + "prompt buckets due to exceeded token budget " + f"(max_num_batched_tokens={self.max_num_batched_tokens})") + logger.info(msg) + + msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}" + logger.debug(msg) + + self.bucketing_global_state.decode_buckets = generate_decode_buckets( + self.bucketing_global_state.decode_bs_bucket_cfg, + self.bucketing_global_state.decode_block_bucket_cfg, max_blocks) + logger.info("Generated %d decode buckets [bs, total_blocks]: %s", + len(self.bucketing_global_state.decode_buckets), + list(sorted(self.bucketing_global_state.decode_buckets))) + + if not htorch.utils.internal.is_lazy() and not self.enforce_eager: + cache_size_limit = 1 + 3 * ( + len(self.bucketing_global_state.prompt_buckets) + + len(self.bucketing_global_state.decode_buckets)) + torch._dynamo.config.cache_size_limit = max( + cache_size_limit, torch._dynamo.config.cache_size_limit) + # Multiply by 8 to follow the original default ratio between + # the cache_size_limit and accumulated_cache_size_limit + torch._dynamo.config.accumulated_cache_size_limit = max( + cache_size_limit * 8, + torch._dynamo.config.accumulated_cache_size_limit) + + start_mem = HabanaMemoryProfiler.current_device_memory_usage() + start_time = time.perf_counter() + + compile_only_mode_context = functools.partial(bc.env_setting, + "PT_COMPILE_ONLY_MODE", + True) + can_use_compile_only_mode = True + try: + with compile_only_mode_context(): + pass + logger.debug("Using PT_COMPILE_ONLY_MODE.") + except KeyError: + can_use_compile_only_mode = False + logger.warning('Cannot use PT_COMPILE_ONLY_MODE. ' + 'Warmup time will be negatively impacted. ' + 'Please update Gaudi Software Suite.') + with compile_only_mode_context( + ) if can_use_compile_only_mode else contextlib.nullcontext(): + self.warmup_all_buckets(self.bucketing_global_state.prompt_buckets, + True) + self.warmup_all_buckets(self.bucketing_global_state.decode_buckets, + False) + + if not self.enforce_eager and htorch.utils.internal.is_lazy(): + assert self.mem_margin is not None, \ + ("HabanaWorker.determine_num_available_blocks needs " + "to be called before warming up the model.") + free_mem = HabanaMemoryProfiler.current_free_device_memory() + graph_free_mem = free_mem - self.mem_margin + graph_free_mem = align_workers(graph_free_mem, + torch.distributed.ReduceOp.MIN) + prompt_graph_mem_ratio = float( + os.environ.get('VLLM_GRAPH_PROMPT_RATIO', '0.3')) + prompt_available_memory = (prompt_graph_mem_ratio * + graph_free_mem) + decode_available_memory = (graph_free_mem - + prompt_available_memory) + msg = ( + f"Using {format_bytes(graph_free_mem)}" + f"/{format_bytes(free_mem)} " + "of free device memory for HPUGraphs, " + f"{format_bytes(prompt_available_memory)} for prompt and " + f"{format_bytes(decode_available_memory)} for decode " + f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})") + logger.info(msg) + prompt_strategy = os.environ.get('VLLM_GRAPH_PROMPT_STRATEGY', + 'min_tokens') + decode_strategy = os.environ.get('VLLM_GRAPH_DECODE_STRATEGY', + 'max_bs') + mem_post_prompt, prompt_batch_seq, prompt_captured_all = \ + self.warmup_graphs( + prompt_strategy, self.bucketing_global_state.prompt_buckets, + True, prompt_available_memory) + mem_post_decode, decode_batch_seq, decode_captured_all = \ + self.warmup_graphs( + decode_strategy, self.bucketing_global_state.decode_buckets, + False, decode_available_memory) + + # Not all prompt buckets were captured, but all decode buckets + # were captured and we have some free graph-allocated space + # left. Let's try to use it for capturing more prompt buckets. + if (mem_post_decode + mem_post_prompt < graph_free_mem + and not prompt_captured_all and decode_captured_all): + mem_post_prompt, _, prompt_captured_all = ( + self.warmup_graphs( + prompt_strategy, + self.bucketing_global_state.prompt_buckets, True, + graph_free_mem - mem_post_prompt - mem_post_decode, + mem_post_prompt, prompt_batch_seq)) + + # Not all decode buckets were captured, but all prompt buckets + # were captured and we have some free graph-allocated space + # left. Let's try to use it for capturing more decode buckets. + if mem_post_decode + mem_post_prompt < graph_free_mem \ + and not decode_captured_all \ + and prompt_captured_all: + mem_post_decode, _, _ = self.warmup_graphs( + decode_strategy, + self.bucketing_global_state.decode_buckets, False, + graph_free_mem - mem_post_prompt - mem_post_decode, + mem_post_decode, decode_batch_seq) + + self.log_graph_warmup_summary( + self.bucketing_global_state.prompt_buckets, True, + mem_post_prompt) + self.log_graph_warmup_summary( + self.bucketing_global_state.decode_buckets, False, + mem_post_decode) + + end_time = time.perf_counter() + end_mem = HabanaMemoryProfiler.current_device_memory_usage() + elapsed_time = end_time - start_time + msg = ( + f"Warmup finished in {elapsed_time:.0f} secs, " + f"allocated {format_bytes(end_mem - start_mem)} of device memory") + logger.info(msg) + self.profiler.end() + + @property + def vocab_size(self) -> int: + return self.model_config.get_vocab_size() + + @property + def mem_margin(self) -> Optional[int]: + return self._mem_margin + + @mem_margin.setter + def mem_margin(self, value): + self._mem_margin = value + + +def _maybe_wrap_in_hpu_graph(*args, **kwargs): + return htorch.hpu.wrap_in_hpu_graph( + HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True + ) if htorch.utils.internal.is_lazy() else HpuModelAdapter(*args, **kwargs) + + +class HabanaProfilerCounterHelper: + + def __init__(self): + self.niter = 0 + self.average_real_throughput = None + self.logged_once = False + self.real_seq_lens = [] + self.prompt_seq_lens = [] + + def capture_seq_group_metadata_stats(self, seq_group_metadata_list): + self.real_seq_lens = [ + len(seq_data.prompt_token_ids) + len(seq_data.output_token_ids) + for seq_group_metadata in seq_group_metadata_list + for seq_data in seq_group_metadata.seq_data.values() + ] + self.prompt_seq_lens = [ + len(seq_data.prompt_token_ids) + for seq_group_metadata in seq_group_metadata_list + for seq_data in seq_group_metadata.seq_data.values() + ] + + def get_counter_dict(self, cache_config, duration, seq_len, + batch_size_padded, real_batch_size, is_prompt): + throughput = batch_size_padded / (duration / 1e6) + throughput_effective = real_batch_size / (duration / 1e6) + + real_max_seq_len = max(self.real_seq_lens) + real_num_tokens = sum(self.real_seq_lens) + padded_num_tokens = batch_size_padded * seq_len + batch_token_utilization = real_num_tokens / padded_num_tokens + if self.average_real_throughput is None: + self.average_real_throughput = throughput_effective + else: # https://www.heikohoffmann.de/htmlthesis/node134.html + self.average_real_throughput = self.average_real_throughput + 1 / ( + self.niter + 1) * (throughput_effective - + self.average_real_throughput) + phase = "prompt" if is_prompt else "decode" + counters = { + f'{phase}_bucket_batch_size': batch_size_padded, + f'{phase}_batch_size': real_batch_size, + f'{phase}_bucket_seq_len': seq_len, + f'{phase}_seq_len': real_max_seq_len, + f'{phase}_bucket_gen_throughput': throughput, + f'{phase}_real_gen_throughput': throughput_effective, + f'{phase}_batch_token_utilization': batch_token_utilization, + 'average_real_throughput': self.average_real_throughput, + 'engine_iteration': self.niter, + } + self.niter += 1 + if is_prompt: + prompt_bucket_in_throughput = (seq_len * batch_size_padded) / ( + duration / 1e6) + prompt_real_in_throughput = sum( + self.prompt_seq_lens) / (duration / 1e6) + counters[ + f'{phase}_bucket_in_throughput'] = prompt_bucket_in_throughput + counters[f'{phase}_real_in_throughput'] = prompt_real_in_throughput + + # KV cache might not be created yet (e.g. for profiling run) + if cache_config.num_gpu_blocks is not None and \ + cache_config.num_gpu_blocks != 0: + cache_num_blocks_used = [ + math.ceil(sl / cache_config.block_size) + for sl in self.real_seq_lens + ] + cache_total_num_blocks_used = sum(cache_num_blocks_used) + num_cache_blocks = cache_config.num_gpu_blocks + cache_total_num_free_blocks = \ + num_cache_blocks - cache_total_num_blocks_used + cache_computed_utilization = \ + cache_total_num_blocks_used / num_cache_blocks + max_blocks_per_seq = math.ceil(seq_len / cache_config.block_size) + batch_block_utilization = cache_total_num_blocks_used / ( + batch_size_padded * max_blocks_per_seq) + counters['cache_num_blocks_used'] = cache_total_num_blocks_used + counters['cache_num_free_blocks'] = cache_total_num_free_blocks + counters['cache_computed_utilization'] = cache_computed_utilization + counters[ + f'{phase}_batch_block_utilization'] = batch_block_utilization + if not self.logged_once: + counters['const_cache_num_blocks'] = cache_config.num_gpu_blocks + counters[ + 'const_gpu_memory_utilization'] = \ + cache_config.gpu_memory_utilization + counters['const_block_size'] = cache_config.block_size + self.logged_once = True + return counters + + +def unwrap_model(model): + if isinstance(model, torch._dynamo.eval_frame.OptimizedModule): + return unwrap_model(model._orig_mod) + else: + model = list(vars(model)['_modules'].values())[0] + modules = list(vars(model)['_modules'].values()) + return modules + + +class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): + """ + GPU model runner with sampling step. + """ + _model_input_cls: Type[ModelInputForHPUWithSamplingMetadata] = ( + ModelInputForHPUWithSamplingMetadata) + + def make_model_input_from_broadcasted_tensor_dict( + self, + tensor_dict: Dict[str, Any], + ) -> ModelInputForHPUWithSamplingMetadata: + return ( + ModelInputForHPUWithSamplingMetadata.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + )) + + @torch.inference_mode() + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForHPUWithSamplingMetadata: + """Prepare the model input based on a given sequence group, including + metadata for the sampling step. + The API assumes seq_group_metadata_list is sorted by prefill -> decode. + The result tensors and data structure also batches input in prefill + -> decode order. For example, + - input_tokens[:num_prefill_tokens] contains prefill tokens. + - input_tokens[num_prefill_tokens:] contains decode tokens. + If cuda graph is required, this API automatically pads inputs. + """ + with self.profiler.record_event('internal', 'prepare_input_tensors'): + assert seq_group_metadata_list is not None + if self.profiler.enabled: + self.profiler_counter_helper.capture_seq_group_metadata_stats( + seq_group_metadata_list=seq_group_metadata_list) + model_input, sampling_metadata = self.prepare_input_tensors( + seq_group_metadata_list) + assert model_input.attn_metadata is not None + is_prompt = model_input.attn_metadata.is_prompt + + return dataclasses.replace(model_input, + sampling_metadata=sampling_metadata, + is_prompt=is_prompt, + virtual_engine=virtual_engine) + + def finish_measurements(self): + from neural_compressor.torch.quantization import finalize_calibration + finalize_calibration(self.model.model) + + def _check_config(self, batch_size, seq_len, is_prompt, warmup_mode): + cfg = (batch_size, seq_len, is_prompt) + seen = cfg in self.seen_configs + self.seen_configs.add(cfg) + if not seen and not warmup_mode: + phase = 'prompt' if is_prompt else 'decode' + logger.warning("Configuration: (%s, %s, %s) was not warmed-up!", + phase, batch_size, seq_len) + + def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int], + is_prompt: bool): + ''' + This is a helper function to create the mask for lora computations. + Lora Mask is needed to ensure we match the correct lora weights for the + for the request. + For Prompt phase we have + lora_mask with shape (batch_size * seq_len, max_loras * max_rank) + lora_logits_mask with shape (batch_size, max_loras * max_rank) + For Decode phase we have both + lora_mask and lora_logits_mask with shape + (batch_size, max_loras * max_rank) + ''' + lora_mask: torch.Tensor = None + lora_logits_mask: torch.Tensor = None + lora_index = 0 + + if self.lora_config: + if is_prompt: + lora_mask = torch.zeros( + input_tokens.shape[0] * input_tokens.shape[1], + (self.lora_config.max_loras) *\ + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + lora_logits_mask = torch.zeros( + input_tokens.shape[0], (self.lora_config.max_loras) * + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + + ones = torch.ones(input_tokens.shape[1], + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + logit_ones = torch.ones(1, + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + + for i in range(len(lora_ids)): + if lora_ids[i] == 0: + continue + lora_index = self.lora_manager._adapter_manager.\ + lora_index_to_id.index(lora_ids[i]) + start_row = i * input_tokens.shape[1] + end_row = start_row + input_tokens.shape[1] + start_col = lora_index * self.lora_config.max_lora_rank + end_col = start_col + self.lora_config.max_lora_rank + lora_mask[start_row:end_row, start_col:end_col] = ones + lora_logits_mask[i, start_col:end_col] = logit_ones + lora_mask = lora_mask.to('hpu') + lora_logits_mask = lora_logits_mask.to('hpu') + else: + lora_mask = torch.zeros(input_tokens.shape[0], + (self.lora_config.max_loras) * + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + ones = torch.ones(1, + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + for i in range(len(lora_ids)): + if lora_ids[i] == 0: + continue + lora_index = self.lora_manager._adapter_manager.\ + lora_index_to_id.index(lora_ids[i]) + start_pos = lora_index * self.lora_config.max_lora_rank + end_pos = start_pos + self.lora_config.max_lora_rank + lora_mask[i, start_pos:end_pos] = ones + lora_mask = lora_mask.to('hpu') + lora_logits_mask = lora_mask + + return lora_mask, lora_logits_mask + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInputForHPUWithSamplingMetadata, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + warmup_mode=False, + ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: + if num_steps > 1: + raise ValueError( + "num_steps > 1 is not supported in HPUModelRunner") + + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) + input_tokens = model_input.input_tokens + input_positions = model_input.input_positions + attn_metadata = model_input.attn_metadata + sampling_metadata = model_input.sampling_metadata + real_batch_size = model_input.real_batch_size + batch_size_padded = model_input.batch_size_padded + assert input_tokens is not None + assert input_positions is not None + assert sampling_metadata is not None + assert attn_metadata is not None + is_prompt = attn_metadata.is_prompt + assert is_prompt is not None + batch_size = input_tokens.size(0) + seq_len = self._seq_len(attn_metadata) + use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) + self._check_config(batch_size, seq_len, is_prompt, warmup_mode) + + lora_mask: torch.Tensor = None + lora_logits_mask: torch.Tensor = None + if self.lora_config: + assert model_input.lora_ids is not None + lora_mask, lora_logits_mask = self.create_lora_mask( + input_tokens, model_input.lora_ids, attn_metadata.is_prompt) + + execute_model_kwargs = { + "input_ids": input_tokens, + "positions": input_positions, + "attn_metadata": self.trim_attn_metadata(attn_metadata), + "intermediate_tensors": intermediate_tensors, + "lora_mask": lora_mask, + "virtual_engine": model_input.virtual_engine, + **(model_input.multi_modal_kwargs or {}), + } + if htorch.utils.internal.is_lazy(): + execute_model_kwargs.update({"bypass_hpu_graphs": not use_graphs}) + + htorch.core.mark_step() + if self.is_driver_worker: + model_event_name = ("model_" + f"{'prompt' if is_prompt else 'decode'}_" + f"bs{batch_size}_" + f"seq{seq_len}_" + f"graphs{'T' if use_graphs else 'F'}") + else: + model_event_name = 'model_executable' + with self.profiler.record_event('internal', model_event_name): + hidden_states = self.model.forward( + **execute_model_kwargs, + selected_token_indices=sampling_metadata.selected_token_indices + ) + + if self.lora_config: + LoraMask.setLoraMask( + lora_logits_mask.index_select( + 0, sampling_metadata.selected_token_indices)) + + # Compute the logits. + with self.profiler.record_event( + 'internal', ('compute_logits_' + f'{"prompt" if is_prompt else "decode"}_bs' + f'{batch_size}_' + f'seq{seq_len}')): + sampling_metadata.selected_token_indices = None + logits = self.model.compute_logits(hidden_states, + sampling_metadata) + htorch.core.mark_step() + # Only perform sampling in the driver worker. + if not self.is_driver_worker: + return [] + + if model_input.async_callback is not None: + model_input.async_callback() + + # Sample the next token. + with self.profiler.record_event( + 'internal', ('sample_' + f'{"prompt" if is_prompt else "decode"}_' + f'bs{batch_size}_' + f'seq{seq_len}')): + output = self.model.sample( + logits=logits, + sampling_metadata=sampling_metadata, + ) + output.outputs = output.outputs[:real_batch_size] + htorch.core.mark_step() + + if self.is_driver_worker and self.profiler.enabled: + # Stop recording 'execute_model' event + self.profiler.end() + event_end = self.profiler.get_timestamp_us() + counters = self.profiler_counter_helper.get_counter_dict( + cache_config=self.cache_config, + duration=event_end - self.event_start, + seq_len=seq_len, + batch_size_padded=batch_size_padded, + real_batch_size=real_batch_size, + is_prompt=is_prompt) + self.profiler.record_counter(self.event_start, counters) + return [output] + + def shutdown_inc(self): + can_finalize_inc = False + from contextlib import suppress + with suppress(AttributeError): + can_finalize_inc = (self.model_config.quantization == 'inc') and \ + (self.model.model is not None) and \ + self.inc_initialized_successfully and \ + not getattr(self, "_is_inc_finalized", False) + if can_finalize_inc: + from neural_compressor.torch.quantization import ( + finalize_calibration) + finalize_calibration(self.model.model) + self._is_inc_finalized = True + + def __del__(self): + self.shutdown_inc() diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py new file mode 100644 index 0000000..ccb175d --- /dev/null +++ b/vllm/worker/hpu_worker.py @@ -0,0 +1,482 @@ +# SPDX-License-Identifier: Apache-2.0 + +############################################################################### +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company +############################################################################### + +import contextlib +import gc +import os +from typing import List, Optional, Set, Tuple, Type + +import habana_frameworks.torch as htorch # noqa:F401 +import torch +import torch.distributed +from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes + +import vllm.envs as envs +from vllm.config import ParallelConfig, VllmConfig +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor import set_random_seed +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sequence import ExecuteModelRequest +from vllm.utils import bind_kv_cache +from vllm.worker.cache_engine import CacheEngine +from vllm.worker.hpu_model_runner import HPUModelRunner +from vllm.worker.model_runner_base import ModelRunnerBase +from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, + WorkerInput) + +logger = init_logger(__name__) + + +class HPUWorker(LocalOrDistributedWorkerBase): + """A worker class that executes (a partition of) the model on a HPU. + + Each worker is associated with a single HPU. The worker is responsible for + maintaining the KV cache and executing the model on the HPU. In case of + distributed inference, each worker is assigned a partition of the model. + """ + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False, + model_runner_cls: Optional[Type[ModelRunnerBase]] = None, + ) -> None: + WorkerBase.__init__(self, vllm_config=vllm_config) + self.parallel_config.rank = rank + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.is_driver_worker = is_driver_worker + if self.is_driver_worker: + assert self.rank == 0, "The driver worker must have rank 0." + + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + self.model_runner: HPUModelRunner = HPUModelRunner( + vllm_config=vllm_config, is_driver_worker=is_driver_worker) + # Uninitialized cache engine. Will be initialized by + # initialize_cache. + self.cache_engine: List[HPUCacheEngine] + # Initialize gpu_cache as pooling models don't initialize kv_caches + self.hpu_cache: Optional[List[List[torch.Tensor]]] = None + # Torch profiler. Enabled and configured through env vars: + # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace + if envs.VLLM_TORCH_PROFILER_DIR: + torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + logger.info("Profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir) + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.HPU, + ], + with_stack=True, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + torch_profiler_trace_dir, use_gzip=True)) + else: + self.profiler = None + + def start_profile(self): + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + self.profiler.start() + + def stop_profile(self): + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + self.profiler.stop() + + def _set_env_vars(self): + local_rank = self.local_rank + if self.parallel_config.world_size == 1: + local_rank = -1 + import os + os.environ["LOCAL_RANK"] = str(local_rank) + os.environ["ID"] = str(local_rank) + os.environ["WORLD_SIZE"] = str(self.parallel_config.world_size) + os.environ["RANK"] = str(self.rank) + + def init_device(self) -> None: + if self.device_config.device.type == "hpu": + self.device = torch.device("hpu") + torch.hpu.set_device(self.device) + else: + raise RuntimeError( + f"Not support device type: {self.device_config.device}") + # Initialize the distributed environment. + if self.model_config.quantization == 'inc': + self._set_env_vars() + init_worker_distributed_environment(self.parallel_config, self.rank, + self.distributed_init_method, + self.local_rank) + # Set random seed. + set_random_seed(self.model_config.seed) + + def load_model(self): + self.model_runner.load_model() + + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None, + ) -> Optional[List[SamplerOutput]]: + # VLLM_HPU_LOG_STEP_GRAPH_COMPILATION - will log graph compilations per engine step, only when there was any - highly recommended to use alongside PT_HPU_METRICS_GC_DETAILS! # noqa:E501 + # VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL - will log graph compilations per engine step, always, even if there were none # noqa:E501 + # VLLM_HPU_LOG_STEP_CPU_FALLBACKS - will log cpu fallbacks per engine step, only when there was any # noqa:E501 + # VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL - will log cpu fallbacks per engine step, always, even if there were none # noqa:E501 + log_graph_compilation_all = os.environ.get( + 'VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL', '0') != '0' + log_graph_compilation = os.environ.get( + 'VLLM_HPU_LOG_STEP_GRAPH_COMPILATION', + '0') != '0' or log_graph_compilation_all + log_cpu_fallbacks_all = os.environ.get( + 'VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL', '0') != '0' + log_cpu_fallbacks = os.environ.get('VLLM_HPU_LOG_STEP_CPU_FALLBACKS', + '0') != '0' or log_cpu_fallbacks_all + if (log_graph_compilation or log_cpu_fallbacks) and \ + execute_model_req is not None: + from habana_frameworks.torch.hpu.metrics import metric_localcontext + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + is_prompt = any([ + seq_group_metadata.is_prompt + for seq_group_metadata in seq_group_metadata_list + ]) + max_context_len = max([ + max([ + len(v.prompt_token_ids) + len(v.output_token_ids) + for v in seq_group_metadata.seq_data.values() + ]) for seq_group_metadata in seq_group_metadata_list + ]) # whoa, that's some spicy stuff right here + max_num_blocks = ( + (max_context_len - 1) // self.cache_config.block_size) + 1 + input_stats = (f'is_prompt: {is_prompt}, ' + f'num_seqs: {len(seq_group_metadata_list)}, ' + f'max_context_len: {max_context_len}, ' + f'max_num_blocks {max_num_blocks}') + gc_ctx = metric_localcontext( + "graph_compilation" + ) if log_graph_compilation else contextlib.nullcontext() + cpu_fallback_ctx = metric_localcontext( + "cpu_fallback" + ) if log_cpu_fallbacks else contextlib.nullcontext() + with gc_ctx as gc_local_metric, \ + cpu_fallback_ctx as cpu_fallback_local_metric: + output = LocalOrDistributedWorkerBase.execute_model( + self, execute_model_req) + if (log_graph_compilation and gc_local_metric.stats()[0][1] + > 0) or log_graph_compilation_all: + msg = ("VLLM_HPU_STEP_GRAPH_COMPILATION: " + f"{gc_local_metric.stats()}, {input_stats}") + logger.warning(msg) + if (log_cpu_fallbacks and cpu_fallback_local_metric.stats()[0][1] + > 0) or log_cpu_fallbacks_all: + msg = ("VLLM_HPU_STEP_CPU_FALLBACK: " + f"{cpu_fallback_local_metric.stats()}, {input_stats}") + logger.warning(msg) + + return output + + output = LocalOrDistributedWorkerBase.execute_model( + self, execute_model_req) + return output + + @torch.inference_mode() + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Profiles the peak memory usage of the model to determine how many + KV blocks may be allocated without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + # Profile the memory usage of the model and get the maximum number of + # cache blocks that can be allocated with the remaining free memory. + + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + with HabanaMemoryProfiler() as m: + self.model_runner.profile_run() + torch.hpu.synchronize() + msg = ("Model profiling run " + f"took {m.get_summary_string()}") + logger.info(msg) + # At this point we should've allocated the maximum workspace for all + # recipes we will use the extra memory for graphs/blocks + free_hpu_memory = torch.hpu.mem_get_info()[0] + + cache_block_size = self.get_cache_block_size_bytes() + graph_reserved_mem = (float( + os.environ.get('VLLM_GRAPH_RESERVED_MEM', '0.1')) + if not self.model_config.enforce_eager else 0) + graph_headroom = 1 - graph_reserved_mem + available_hpu_memory = free_hpu_memory * \ + self.cache_config.gpu_memory_utilization + hpu_memory_margin = free_hpu_memory * ( + 1 - self.cache_config.gpu_memory_utilization) + self.model_runner.mem_margin = hpu_memory_margin + cache_size_bytes = available_hpu_memory * graph_headroom + graph_headroom_bytes = available_hpu_memory * (1 - graph_headroom) + msg = ( + f"Free device memory: {format_bytes(free_hpu_memory)}, " + f"{format_bytes(available_hpu_memory)} usable " + f"(gpu_memory_utilization={self.cache_config.gpu_memory_utilization})," + f" {format_bytes(graph_headroom_bytes)} reserved for HPUGraphs " + f"(VLLM_GRAPH_RESERVED_MEM={graph_reserved_mem}), " + f"{format_bytes(cache_size_bytes)} reserved for KV cache") + logger.info(msg) + num_hpu_blocks = int(cache_size_bytes // cache_block_size) + num_cpu_blocks = int(self.cache_config.swap_space_bytes // + cache_block_size) + num_hpu_blocks = max(num_hpu_blocks, 0) + num_cpu_blocks = max(num_cpu_blocks, 0) + + if self.model_runner.lora_manager: + self.model_runner.remove_all_loras() + + gc.collect() + return num_hpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Allocate GPU and CPU KV cache with the specified number of blocks. + + This also warms up the model, which may record CUDA graphs. + """ + raise_if_cache_size_invalid( + num_gpu_blocks, self.cache_config.block_size, + self.model_config.max_model_len, + self.parallel_config.pipeline_parallel_size) + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + with HabanaMemoryProfiler() as m: + self._init_cache_engine() + torch.hpu.synchronize() + msg = ("Initializing cache engine " + f"took {m.get_summary_string()}") + logger.info(msg) + self._warm_up_model() + + def _init_cache_engine(self): + assert self.cache_config.num_gpu_blocks is not None + self.cache_engine = [ + HPUCacheEngine(self.cache_config, self.model_config, + self.parallel_config, self.device_config) + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + self.hpu_cache = [ + self.cache_engine[ve].gpu_cache + for ve in range(self.parallel_config.pipeline_parallel_size) + ] + bind_kv_cache(self.compilation_config.static_forward_context, + self.hpu_cache) + + def _warm_up_model(self) -> None: + # NOTE(kzawora): We should use virtual engine index here + # for pipeline parallelism. Using 0 for now. + assert self.hpu_cache is not None + self.model_runner.warmup_model(self.hpu_cache[0]) + # Reset the seed to ensure that the random state is not affected by + # the model initialization and profiling. + set_random_seed(self.model_config.seed) + + def finish_measurements(self): + self.model_runner.finish_measurements() + + @property + def do_metadata_broadcast(self) -> bool: + return self.parallel_config.tensor_parallel_size > 1 + + @property + def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: + return self.hpu_cache + + @torch.inference_mode() + def prepare_worker_input( + self, execute_model_req: ExecuteModelRequest) -> WorkerInput: + virtual_engine = execute_model_req.virtual_engine + num_seq_groups = len(execute_model_req.seq_group_metadata_list) + # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. + # they contain parameters to launch cudamemcpyasync. + blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in, + device="cpu", + dtype=torch.int64).view(-1, 2) + blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out, + device="cpu", + dtype=torch.int64).view(-1, 2) + # `blocks_to_copy` is a gpu tensor. The src and tgt of + # blocks to copy are in the same device, and `blocks_to_copy` + # can be used directly within cuda kernels. + blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, + device=self.device, + dtype=torch.int64).view(-1, 2) + + return WorkerInput( + num_seq_groups=num_seq_groups, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + virtual_engine=virtual_engine, + ) + + @torch.inference_mode() + def execute_worker(self, worker_input: WorkerInput) -> None: + virtual_engine = worker_input.virtual_engine + # Issue cache operations. + if (worker_input.blocks_to_swap_in is not None + and worker_input.blocks_to_swap_in.numel() > 0): + self.cache_engine[virtual_engine].swap_in( + worker_input.blocks_to_swap_in) + if (worker_input.blocks_to_swap_out is not None + and worker_input.blocks_to_swap_out.numel() > 0): + self.cache_engine[virtual_engine].swap_out( + worker_input.blocks_to_swap_out) + if (worker_input.blocks_to_copy is not None + and worker_input.blocks_to_copy.numel() > 0): + self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy) + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_runner.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.model_runner.remove_lora(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + return self.model_runner.pin_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.model_runner.list_loras() + + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + raise NotImplementedError( + "Prompt Adapter is not implemented for HPU backend.") + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError( + "Prompt Adapter is not implemented for HPU backend.") + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError( + "Prompt Adapter is not implemented for HPU backend.") + + def list_prompt_adapters(self) -> Set[int]: + raise NotImplementedError( + "Prompt Adapter is not implemented for HPU backend.") + + def shutdown_inc(self): + self.model_runner.shutdown_inc() + + @property + def max_model_len(self) -> int: + return self.model_config.max_model_len + + @property + def vocab_size(self) -> int: + return self.model_runner.vocab_size + + def get_cache_block_size_bytes(self) -> int: + """Get the size of the KV cache block size in bytes. + """ + return HPUCacheEngine.get_cache_block_size(self.cache_config, + self.model_config, + self.parallel_config) + + +def init_worker_distributed_environment( + parallel_config: ParallelConfig, + rank: int, + distributed_init_method: Optional[str] = None, + local_rank: int = -1, +) -> None: + """Initialize the distributed environment.""" + init_distributed_environment(parallel_config.world_size, + rank, + distributed_init_method, + local_rank, + backend='hccl') + + ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) + + if torch.distributed.is_initialized(): + torch_world_size = torch.distributed.get_world_size() + if torch_world_size != parallel_config.world_size: + raise RuntimeError( + "torch.distributed is already initialized but the torch world " + "size does not match parallel_config.world_size " + f"({torch_world_size} vs. {parallel_config.world_size}).") + elif not distributed_init_method: + raise ValueError( + "distributed_init_method must be set if torch.distributed " + "is not already initialized") + else: + torch.distributed.init_process_group( + backend="hccl", + world_size=parallel_config.world_size, + rank=rank, + init_method=distributed_init_method, + ) + + # A small all_reduce for warmup & checking conformance. + dummy_tensor_hpu = torch.ones(1).to('hpu') + torch.distributed.all_reduce(dummy_tensor_hpu) + assert dummy_tensor_hpu.item() == parallel_config.world_size + ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) + + +def raise_if_cache_size_invalid(num_gpu_blocks, block_size, max_model_len, + pipeline_parallel_size) -> None: + if num_gpu_blocks <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine.") + max_seq_len = block_size * (num_gpu_blocks // pipeline_parallel_size) + if max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`gpu_memory_utilization` or decreasing `max_model_len` when " + "initializing the engine.") + + +class HPUCacheEngine(CacheEngine): + + def _allocate_kv_cache( + self, + num_blocks: int, + device: str, + ) -> List[Tuple[torch.Tensor, torch.Tensor]]: + """Allocates KV cache on the specified device.""" + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, self.block_size, self.num_kv_heads, self.head_size) + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = [] + for _ in range(self.num_attention_layers): + key_cache = torch.zeros(kv_cache_shape, + dtype=self.dtype, + device=device) + value_cache = torch.zeros(kv_cache_shape, + dtype=self.dtype, + device=device) + kv_layer = (key_cache, value_cache) + kv_cache.append(kv_layer) + return kv_cache diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py new file mode 100644 index 0000000..2cfd2cc --- /dev/null +++ b/vllm/worker/model_runner.py @@ -0,0 +1,2064 @@ +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +import gc +import inspect +import itertools +import time +import weakref +from contextlib import contextmanager +from dataclasses import dataclass +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, + Tuple, Type, TypeVar, Union) + +import numpy as np +import torch +import torch.distributed +import torch.nn as nn +from tqdm import tqdm + +import vllm.envs as envs +from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.attention.backends.abstract import AttentionState +from vllm.attention.backends.utils import CommonAttentionState +from vllm.config import CompilationLevel, VllmConfig +from vllm.core.scheduler import SchedulerOutputs +from vllm.distributed import get_kv_transfer_group, get_pp_group +from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, + graph_capture) +from vllm.forward_context import get_forward_context, set_forward_context +from vllm.inputs import INPUT_REGISTRY, InputRegistry +from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager +from vllm.model_executor import SamplingMetadata, SamplingMetadataCache +from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader import get_model +from vllm.model_executor.model_loader.tensorizer import TensorizerConfig +from vllm.model_executor.models import supports_lora, supports_multimodal +from vllm.model_executor.models.utils import set_cpu_offload_max_bytes +from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, + MultiModalKwargs, MultiModalPlaceholderMap, + MultiModalRegistry) +from vllm.prompt_adapter.layers import PromptAdapterMapping +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.prompt_adapter.worker_manager import ( + LRUCacheWorkerPromptAdapterManager) +from vllm.sampling_params import SamplingParams +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata +from vllm.utils import (DeviceMemoryProfiler, GiB_bytes, PyObjectCache, + async_tensor_h2d, flatten_2d_lists, + is_pin_memory_available, supports_dynamo, + weak_ref_tensor) +from vllm.worker.model_runner_base import ( + InputProcessingError, ModelRunnerBase, ModelRunnerInputBase, + ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, + _add_sampling_metadata_broadcastable_dict, + _init_attn_metadata_from_tensor_dict, + _init_sampling_metadata_from_tensor_dict) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +logger = init_logger(__name__) + +LORA_WARMUP_RANK = 8 + +_NUM_WARMUP_ITERS = 2 + +TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU") + +# For now, bump up cache limits for recompilations during CUDA graph warmups. +# torch._dynamo.config.cache_size_limit = 128 +# torch._dynamo.config.accumulated_cache_size_limit = 128 + + +@dataclass(frozen=True) +class ModelInputForGPU(ModelRunnerInputBase): + """ + This base class contains metadata needed for the base model forward pass + but not metadata for possible additional steps, e.g., sampling. Model + runners that run additional steps should subclass this method to add + additional fields. + """ + input_tokens: Optional[torch.Tensor] = None + input_positions: Optional[torch.Tensor] = None + token_types: Optional[torch.Tensor] = None + seq_lens: Optional[List[int]] = None + query_lens: Optional[List[int]] = None + lora_mapping: Optional["LoRAMapping"] = None + lora_requests: Optional[Set[LoRARequest]] = None + attn_metadata: Optional["AttentionMetadata"] = None + prompt_adapter_mapping: Optional[PromptAdapterMapping] = None + prompt_adapter_requests: Optional[Set[PromptAdapterRequest]] = None + multi_modal_kwargs: Optional[BatchedTensorInputs] = None + request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None + finished_requests_ids: Optional[List[str]] = None + virtual_engine: int = 0 + async_callback: Optional[Callable] = None + scheduler_outputs: Optional[SchedulerOutputs] = None + previous_hidden_states: Optional[torch.Tensor] = None + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + "lora_requests": self.lora_requests, + "lora_mapping": self.lora_mapping, + "multi_modal_kwargs": self.multi_modal_kwargs, + "prompt_adapter_mapping": self.prompt_adapter_mapping, + "prompt_adapter_requests": self.prompt_adapter_requests, + "virtual_engine": self.virtual_engine, + "request_ids_to_seq_ids": self.request_ids_to_seq_ids, + "finished_requests_ids": self.finished_requests_ids, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls: Type[TModelInputForGPU], + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> TModelInputForGPU: + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) + + # Exclude `async_callback` to be able to pickle this object + def __getstate__(self): + state = self.__dict__.copy() + del state["async_callback"] + return state + + # TODO: What happens when we depickle this object? + # How can we update this callback to properly pass it to the engine? + def __setstate__(self, state): + self.__dict__.update(state) + self.__dict__.update({'async_callback': None}) + + +@dataclass(frozen=True) +class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU): + """ + Used by the ModelRunner. + """ + sampling_metadata: Optional["SamplingMetadata"] = None + # Used for speculative decoding. We do not broadcast it because it is only + # used by the driver worker. + is_prompt: Optional[bool] = None + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + "lora_requests": self.lora_requests, + "lora_mapping": self.lora_mapping, + "multi_modal_kwargs": self.multi_modal_kwargs, + "prompt_adapter_mapping": self.prompt_adapter_mapping, + "prompt_adapter_requests": self.prompt_adapter_requests, + "virtual_engine": self.virtual_engine, + "request_ids_to_seq_ids": self.request_ids_to_seq_ids, + "finished_requests_ids": self.finished_requests_ids, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "ModelInputForGPUWithSamplingMetadata": + tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) + + +class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): + """Build ModelInputForGPU from SequenceGroupMetadata.""" + + # Note: ideally we would be using a dataclass(kw_only=True) + # here, so that this can be subclassed easily, + # but kw_only is not supported in python<3.10. + class InterDataForSeqGroup: + """Intermediate data for the current sequence group.""" + + def simple_reinit(self): + self.input_tokens[0].clear() # type: ignore + self.input_positions[0].clear() # type: ignore + self.token_types[0].clear() # type: ignore + self.mrope_input_positions = None # type: ignore + self.seq_lens[0] = 0 # type: ignore + self.orig_seq_lens[0] = 0 # type: ignore + self.query_lens[0] = 0 # type: ignore + self.context_lens[0] = 0 # type: ignore + self.curr_sliding_window_blocks[0] = 0 # type: ignore + self.lora_index_mapping.clear() # type: ignore + self.lora_prompt_mapping.clear() # type: ignore + self.lora_requests.clear() # type: ignore + self.prompt_adapter_index_mapping.clear() # type: ignore + self.prompt_adapter_prompt_mapping.clear() # type: ignore + + def __init__( + self, + *, + # From sequence group metadata. + request_id: str, + seq_ids: List[int], + is_prompt: bool, + block_tables: Optional[Dict[int, List[int]]], + computed_block_nums: List[int], + n_seqs: int = 0, + + # Input tokens and positions. + input_tokens: Optional[List[List[int]]] = None, + input_positions: Optional[List[List[int]]] = None, + token_types: Optional[List[List[int]]] = None, + mrope_input_positions: Optional[List[List[List[int]]]] = None, + + # The sequence length (may be capped to the sliding window). + seq_lens: Optional[List[int]] = None, + # The original sequence length (before applying sliding window). + # This is used to compute slot mapping. + orig_seq_lens: Optional[List[int]] = None, + # The query length. + query_lens: Optional[List[int]] = None, + # The number of tokens that are already computed. + context_lens: Optional[List[int]] = None, + # The current sliding window block. + curr_sliding_window_blocks: Optional[List[int]] = None, + + # LoRA inputs. + lora_index_mapping: Optional[List[List[int]]] = None, + lora_prompt_mapping: Optional[List[List[int]]] = None, + lora_requests: Optional[Set[LoRARequest]] = None, + + # Prompt adapter inputs. + prompt_adapter_index_mapping: Optional[List[int]] = None, + prompt_adapter_prompt_mapping: Optional[List[int]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + + # Multi-modal inputs. + multi_modal_kwargs: Optional[MultiModalKwargs] = None, + multi_modal_placeholder_maps: Optional[Dict[ + str, MultiModalPlaceholderMap]] = None, + + # Whether the prefix cache is hit (prefill only). + prefix_cache_hit: bool = False, + reinit: bool = False, + reinit_use_defaults: bool = False, + encoder_seq_len: int = 0, + ): + if reinit: + assert len(self.seq_ids) == len(seq_ids) # type: ignore + for i, seq_id in enumerate(seq_ids): + self.seq_ids[i] = seq_id # type: ignore + else: + self.seq_ids = seq_ids + + self.request_id = request_id + self.is_prompt = is_prompt + self.block_tables = block_tables + self.computed_block_nums = computed_block_nums + self.n_seqs = n_seqs + self.encoder_seq_len = encoder_seq_len + + if reinit: + if len(self.seq_ids) == 1 and reinit_use_defaults: + self.simple_reinit() + else: + if input_tokens: + self.input_tokens = input_tokens + else: + for seq_id in range(len(self.seq_ids)): + self.input_tokens[seq_id].clear() + + if input_positions: + self.input_positions = input_positions + else: + for seq_id in range(len(self.seq_ids)): + self.input_positions[seq_id].clear() + + if token_types: + self.token_types = token_types + else: + for seq_id in range(len(self.seq_ids)): + self.token_types[seq_id].clear() + + self.mrope_input_positions = None + + if seq_lens: + self.seq_lens = seq_lens + else: + for seq_id in range(len(self.seq_ids)): + self.seq_lens[seq_id] = 0 + + if orig_seq_lens: + self.orig_seq_lens = orig_seq_lens + else: + for seq_id in range(len(self.seq_ids)): + self.orig_seq_lens[seq_id] = 0 + + if query_lens: + self.query_lens = query_lens + else: + for seq_id in range(len(self.seq_ids)): + self.query_lens[seq_id] = 0 + + if context_lens: + self.context_lens = context_lens + else: + for seq_id in range(len(self.seq_ids)): + self.context_lens[seq_id] = 0 + + if curr_sliding_window_blocks: + self.curr_sliding_window_blocks = \ + curr_sliding_window_blocks + else: + for seq_id in range(len(self.seq_ids)): + self.curr_sliding_window_blocks[seq_id] = 0 + + if lora_index_mapping: + self.lora_index_mapping = lora_index_mapping + else: + self.lora_index_mapping.clear() + + if lora_prompt_mapping: + self.lora_prompt_mapping = lora_prompt_mapping + else: + self.lora_prompt_mapping.clear() + + if lora_requests: + self.lora_requests = lora_requests + else: + self.lora_requests.clear() + + if prompt_adapter_index_mapping: + self.prompt_adapter_index_mapping = \ + prompt_adapter_index_mapping + else: + self.prompt_adapter_index_mapping.clear() + + if prompt_adapter_prompt_mapping: + self.prompt_adapter_prompt_mapping = \ + prompt_adapter_prompt_mapping + else: + self.prompt_adapter_prompt_mapping.clear() + + else: + self.input_tokens = input_tokens or [] + self.input_positions = input_positions or [] + self.token_types = token_types or [] + self.mrope_input_positions = mrope_input_positions or None + self.seq_lens = seq_lens or [] + self.orig_seq_lens = orig_seq_lens or [] + self.query_lens = query_lens or [] + self.context_lens = context_lens or [] + self.curr_sliding_window_blocks = \ + curr_sliding_window_blocks or [] + + self.lora_index_mapping = lora_index_mapping or [] + self.lora_prompt_mapping = lora_prompt_mapping or [] + self.lora_requests = lora_requests or set() + + self.prompt_adapter_index_mapping = ( + prompt_adapter_index_mapping or []) + self.prompt_adapter_prompt_mapping = ( + prompt_adapter_prompt_mapping or []) + + self.prompt_adapter_request = prompt_adapter_request + self.multi_modal_kwargs = multi_modal_kwargs + self.multi_modal_placeholder_maps = multi_modal_placeholder_maps + self.prefix_cache_hit = prefix_cache_hit + + self.n_seqs = len(self.seq_ids) + + if not reinit: + self.__post_init__() + + def __post_init__(self): + self.n_seqs = len(self.seq_ids) + + self.input_tokens = [[] for _ in range(self.n_seqs)] + self.input_positions = [[] for _ in range(self.n_seqs)] + self.token_types = [[] for _ in range(self.n_seqs)] + self.mrope_input_positions = None + self.seq_lens = [0] * self.n_seqs + self.orig_seq_lens = [0] * self.n_seqs + self.query_lens = [0] * self.n_seqs + self.context_lens = [0] * self.n_seqs + self.curr_sliding_window_blocks = [0] * self.n_seqs + + self.lora_index_mapping = [] + self.lora_prompt_mapping = [] + + def gen_inter_data_builder(self, num_seqs: int): + return lambda: ModelInputForGPUBuilder.InterDataForSeqGroup( + request_id="", + seq_ids=[0] * num_seqs, + is_prompt=True, + block_tables=None, + computed_block_nums=[]) + + def init_cached_inter_data(self, *args, **kwargs): + assert len(args) == 0 + assert "seq_ids" in kwargs + seq_ids = kwargs["seq_ids"] + num_seqs = len(seq_ids) + + # The inter-data cache is per model_runner + inter_data_cache = self.runner.inter_data_cache + if num_seqs not in inter_data_cache: + inter_data_cache[num_seqs] = PyObjectCache( + self.gen_inter_data_builder(num_seqs)) + + obj = inter_data_cache[num_seqs].get_object() + obj.__init__(*args, **kwargs) + return obj + + def reset_cached_inter_data(self): + for cache in self.runner.inter_data_cache.values(): + cache.reset() + + def __init__(self, + runner: "GPUModelRunnerBase", + finished_requests_ids: Optional[List[str]] = None): + super().__init__() + # Compute functions for each sequence in a sequence group. + # WARNING: The order of the functions matters! + self.per_seq_compute_fns = [ + self._compute_lens, + self._compute_for_prefix_cache_hit, + self._compute_for_sliding_window, + self._compute_lora_input, + ] + # Compute functions for each sequence group. + # WARNING: The order of the functions matters! + self.per_seq_group_compute_fns = [ + self._compute_prompt_adapter_input, + self._compute_multi_modal_input, + ] + + self.runner = runner + self.model_input_cls = self.runner._model_input_cls + self.attn_backend = self.runner.attn_backend + self.scheduler_config = self.runner.scheduler_config + self.sliding_window = self.runner.sliding_window + self.block_size = self.runner.block_size + self.enable_lora = self.runner.lora_config is not None + self.enable_prompt_adapter = (self.runner.prompt_adapter_config + is not None) + self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper + + # Attention metadata inputs. + if self.attn_backend is not None: + # spec decode (e.g. Medusa) does not have atten backend + self.attn_metadata_builder = self.attn_backend.get_builder_cls()( + weakref.proxy(self)) + + # Engine/Model configurations. + self.chunked_prefill_enabled = ( + self.scheduler_config is not None + and self.scheduler_config.chunked_prefill_enabled) + if self.sliding_window is not None: + self.sliding_window_blocks = ( + self.sliding_window + self.block_size - 1) // self.block_size + self.block_aligned_sliding_window = \ + self.sliding_window_blocks * self.block_size + + def prepare(self, + finished_requests_ids: Optional[List[str]] = None) -> None: + self.finished_requests_ids = finished_requests_ids + + # if the current batch is decode-only. + # will be set to False if there is any non-decode request. + self.decode_only = True + + # Intermediate data (data in CPU before going to GPU) for + # the current sequence group. + self.inter_data_list: List[ + ModelInputForGPUBuilder.InterDataForSeqGroup] = [] + + self.attn_metadata_builder.prepare() + + def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, + seq_group_metadata: SequenceGroupMetadata): + """Compute context length, sequence length and tokens + for the given sequence data. + """ + seq_data = seq_group_metadata.seq_data[inter_data.seq_ids[seq_idx]] + token_chunk_size = seq_group_metadata.token_chunk_size + + # Compute context length (the number of tokens that are + # already computed) and sequence length (total number of tokens). + + seq_len = seq_data.get_len() + if inter_data.is_prompt: + context_len = seq_data.get_num_computed_tokens() + seq_len = min(seq_len, context_len + token_chunk_size) + elif self.runner.scheduler_config.is_multi_step or \ + self.runner.model_config.is_encoder_decoder: + context_len = seq_len - 1 + else: + context_len = seq_data.get_num_computed_tokens() + + # Compute tokens. + tokens = seq_data.get_token_ids()[context_len:seq_len] + token_types = seq_group_metadata.token_type_ids + + inter_data.seq_lens[seq_idx] = seq_len + inter_data.orig_seq_lens[seq_idx] = seq_len + inter_data.context_lens[seq_idx] = context_len + inter_data.input_tokens[seq_idx].extend(tokens) + inter_data.input_positions[seq_idx].extend(range(context_len, seq_len)) + inter_data.token_types[seq_idx].extend( + token_types if token_types else []) + inter_data.query_lens[seq_idx] = seq_len - context_len + + if seq_data.mrope_position_delta is not None: + if inter_data.mrope_input_positions is None: + inter_data.mrope_input_positions = [None] * inter_data.n_seqs + + inter_data.mrope_input_positions[ + seq_idx] = MRotaryEmbedding.get_next_input_positions( + seq_data.mrope_position_delta, + context_len, + seq_len, + ) + + def _compute_for_prefix_cache_hit( + self, inter_data: InterDataForSeqGroup, seq_idx: int, + seq_group_metadata: SequenceGroupMetadata): + """Check if hit prefix cache (i.e., some blocks are already computed). + If hit, update input tokens and positions to only compute the + remaining blocks. + """ + computed_block_nums = inter_data.computed_block_nums + + # Note that prefix caching does not support sliding window. + prefix_cache_hit = (computed_block_nums is not None + and len(computed_block_nums) > 0 + and self.sliding_window is None + and inter_data.is_prompt) + inter_data.prefix_cache_hit = prefix_cache_hit + + if not prefix_cache_hit: + return + + assert computed_block_nums is not None + # The cache hit prompt tokens in this sequence. Note that + # this may be larger than the sequence length if chunked + # prefill is enabled. + prefix_cache_len = len(computed_block_nums) * self.block_size + seq_group_metadata.seq_data[inter_data.seq_ids[ + seq_idx]].update_num_cached_tokens(prefix_cache_len) + + # The number of so far computed prompt tokens in this sequence. + context_len = inter_data.context_lens[seq_idx] + # The total number of prompt tokens in this sequence. + # When chunked prefill is enabled, this is the token number of + # computed chunks + current chunk. + seq_len = inter_data.seq_lens[seq_idx] + if prefix_cache_len <= context_len: + # We already passed the cache hit region, + # so do normal computation. + pass + elif context_len < prefix_cache_len < seq_len: + # Partial hit. Compute the missing part. + uncomputed_start = prefix_cache_len - context_len + inter_data.input_tokens[seq_idx] = inter_data.input_tokens[ + seq_idx][uncomputed_start:] + inter_data.input_positions[seq_idx] = inter_data.input_positions[ + seq_idx][uncomputed_start:] + inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][ + uncomputed_start:] + context_len = prefix_cache_len + + inter_data.context_lens[seq_idx] = context_len + inter_data.query_lens[ + seq_idx] = inter_data.seq_lens[seq_idx] - context_len + elif seq_len <= prefix_cache_len: + # Full hit. Only compute the last token to avoid + # erroneous behavior. FIXME: Ideally we should directly + # mark all tokens as computed in the scheduler and do not + # schedule this sequence, so this case should not happen. + inter_data.input_tokens[seq_idx] = inter_data.input_tokens[ + seq_idx][-1:] + inter_data.input_positions[seq_idx] = inter_data.input_positions[ + seq_idx][-1:] + inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][ + -1:] + inter_data.query_lens[seq_idx] = 1 + inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1 + + def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup, + seq_idx: int, + seq_group_metadata: SequenceGroupMetadata): + """Update seq_len and curr_sliding_window_block for the given + sequence data (only required by decoding) if sliding window is enabled. + """ + curr_sliding_window_block = 0 + sliding_seq_len = inter_data.seq_lens[seq_idx] + if not inter_data.is_prompt and self.sliding_window is not None: + # TODO(sang): This is a hack to make sliding window work with + # paged attn. We can remove it if we make paged attn kernel + # to properly handle slinding window attn. + curr_sliding_window_block = self.sliding_window_blocks + # number of elements in last block + suff_len = inter_data.seq_lens[seq_idx] % self.block_size + sliding_seq_len = min(inter_data.seq_lens[seq_idx], + self.block_aligned_sliding_window + suff_len) + if suff_len > 0: + curr_sliding_window_block += 1 + + inter_data.curr_sliding_window_blocks[ + seq_idx] = curr_sliding_window_block + inter_data.seq_lens[seq_idx] = sliding_seq_len + + def _compute_lora_input(self, inter_data: InterDataForSeqGroup, + seq_idx: int, + seq_group_metadata: SequenceGroupMetadata): + """If LoRA is enabled, compute LoRA index and prompt mapping.""" + if not self.enable_lora: + return + + lora_id = seq_group_metadata.lora_int_id + if lora_id > 0: + inter_data.lora_requests.add(seq_group_metadata.lora_request) + query_len = inter_data.query_lens[seq_idx] + inter_data.lora_index_mapping.append([lora_id] * query_len) + sampling_params = seq_group_metadata.sampling_params + if sampling_params and sampling_params.prompt_logprobs is not None: + inter_data.lora_prompt_mapping.append([lora_id] * query_len) + elif not self.chunked_prefill_enabled or seq_group_metadata.do_sample: + inter_data.lora_prompt_mapping.append([lora_id]) + else: + inter_data.lora_prompt_mapping.append([]) + + def _compute_prompt_adapter_input( + self, inter_data: InterDataForSeqGroup, + seq_group_metadata: SequenceGroupMetadata): + """If prompt adapter is enabled, compute index and prompt mapping. + """ + # Note that when is_prompt=True, we expect only one sequence + # in the group. + if not self.enable_prompt_adapter: + return + + prompt_adapter_id = seq_group_metadata.prompt_adapter_id + if prompt_adapter_id <= 0 or not inter_data.is_prompt: + return + + # We expect only one sequence in the group when is_prompt=True. + assert inter_data.n_seqs == 1 + query_len = inter_data.query_lens[0] + inter_data.prompt_adapter_request = ( + seq_group_metadata.prompt_adapter_request) + + num_tokens = seq_group_metadata.prompt_adapter_num_virtual_tokens + inter_data.prompt_adapter_index_mapping = [ + prompt_adapter_id + ] * num_tokens + [0] * (query_len - num_tokens) + inter_data.prompt_adapter_prompt_mapping = [prompt_adapter_id] * ( + query_len if seq_group_metadata.sampling_params + and seq_group_metadata.sampling_params.prompt_logprobs else 1) + + def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, + seq_group_metadata: SequenceGroupMetadata): + """If multi-modal data is given, add it to the input.""" + # NOTE: mm_data only includes the subset of multi-modal items that + # intersect with the current prefill positions. + positions = inter_data.input_positions[0] + mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( + seq_group_metadata, + range(positions[0], positions[0] + len(positions))) + if not mm_data: + return + + if self.runner.mm_registry.has_processor(self.runner.model_config): + mm_kwargs = mm_data + else: + mm_kwargs = self.multi_modal_input_mapper( + mm_data, + seq_group_metadata.mm_processor_kwargs, + ) + + inter_data.multi_modal_kwargs = mm_kwargs + inter_data.multi_modal_placeholder_maps = placeholder_maps + + # special processing for mrope position deltas. + if self.runner.model_config.uses_mrope: + image_grid_thw = mm_kwargs.get("image_grid_thw", None) + video_grid_thw = mm_kwargs.get("video_grid_thw", None) + assert image_grid_thw is not None or video_grid_thw is not None, ( + "mrope embedding type requires multi-modal input mapper " + "returns 'image_grid_thw' or 'video_grid_thw'.") + + second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None) + hf_config = self.runner.model_config.hf_config + + inter_data.mrope_input_positions = [None] * inter_data.n_seqs + for seq_idx in range(inter_data.n_seqs): + seq_data = seq_group_metadata.seq_data[ + inter_data.seq_ids[seq_idx]] + token_ids = seq_data.get_token_ids() + + mrope_input_positions, mrope_position_delta = \ + MRotaryEmbedding.get_input_positions( + token_ids, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + context_len=inter_data.context_lens[seq_idx], + seq_len=inter_data.seq_lens[seq_idx], + ) + + seq_data.mrope_position_delta = mrope_position_delta + inter_data.mrope_input_positions[ + seq_idx] = mrope_input_positions + + def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): + """Add a sequence group to the builder.""" + seq_ids = seq_group_metadata.seq_data.keys() + n_seqs = len(seq_ids) + is_prompt = seq_group_metadata.is_prompt + + if is_prompt: + assert n_seqs == 1 + self.decode_only = False + + encoder_seq_len = 0 + + if self.runner.model_config.is_encoder_decoder: + encoder_seq_len = seq_group_metadata.encoder_seq_data.get_len() + + inter_data = self.init_cached_inter_data( + request_id=seq_group_metadata.request_id, + seq_ids=seq_ids, + is_prompt=is_prompt, + block_tables=seq_group_metadata.block_tables, + computed_block_nums=seq_group_metadata.computed_block_nums, + reinit=True, + reinit_use_defaults=True, + encoder_seq_len=encoder_seq_len) + + self.inter_data_list.append(inter_data) + + for seq_idx in range(n_seqs): + for per_seq_fn in self.per_seq_compute_fns: + per_seq_fn(inter_data, seq_idx, seq_group_metadata) + for per_seq_group_fn in self.per_seq_group_compute_fns: + per_seq_group_fn(inter_data, seq_group_metadata) + + def _use_captured_graph(self, + batch_size: int, + decode_only: bool, + max_decode_seq_len: int, + max_encoder_seq_len: int = 0) -> bool: + return (decode_only and not self.runner.model_config.enforce_eager + and max_decode_seq_len <= self.runner.max_seq_len_to_capture + and max_encoder_seq_len <= self.runner.max_seq_len_to_capture + and batch_size <= self.runner.max_batchsize_to_capture) + + def _get_cuda_graph_pad_size(self, + num_seqs: int, + max_decode_seq_len: int, + max_encoder_seq_len: int = 0) -> int: + """ + Determine the number of padding sequences required for running in + CUDA graph mode. Returns -1 if CUDA graphs cannot be used. + + In the multi-step + chunked-prefill case, only the first step + has Prefills (if any). The rest of the steps are guaranteed to be all + decodes. In this case, we set up the padding as if all the sequences + are decodes so we may run all steps except the first step in CUDA graph + mode. The padding is accounted for in the multi-step `advance_step` + family of functions. + + Args: + num_seqs (int): Number of sequences scheduled to run. + max_decode_seq_len (int): Greatest of all the decode sequence + lengths. Used only in checking the viablility of using + CUDA graphs. + max_encoder_seq_len (int, optional): Greatest of all the encode + sequence lengths. Defaults to 0. Used only in checking the + viability of using CUDA graphs. + Returns: + int: Returns the determined number of padding sequences. If + CUDA graphs is not viable, returns -1. + """ + is_mscp: bool = self.runner.scheduler_config.is_multi_step and \ + self.runner.scheduler_config.chunked_prefill_enabled + decode_only = self.decode_only or is_mscp + if not decode_only: + # Early exit so we can treat num_seqs as the batch_size below. + return -1 + + # batch_size out of this function refers to the number of input + # tokens being scheduled. This conflation of num_seqs as batch_size + # is valid as this is a decode-only case. + batch_size = num_seqs + if not self._use_captured_graph(batch_size, decode_only, + max_decode_seq_len, + max_encoder_seq_len): + return -1 + + graph_batch_size = self.runner.vllm_config.pad_for_cudagraph( + batch_size) + assert graph_batch_size >= batch_size + return graph_batch_size - batch_size + + def build(self) -> ModelInputForGPU: + """Finalize the builder intermediate data and + create on-device tensors. + """ + # Combine and flatten intermediate data. + input_tokens = [] + token_types = [] + for inter_data in self.inter_data_list: + for cur_input_tokens in inter_data.input_tokens: + input_tokens.extend(cur_input_tokens) + for cur_token_types in inter_data.token_types: + token_types.extend(cur_token_types) + + if not input_tokens: + # This may happen when all prefill requests hit + # prefix caching and there is no decode request. + return self.model_input_cls() + + mrope_input_positions: Optional[List[List[int]]] = None + if any(inter_data.mrope_input_positions is not None + for inter_data in self.inter_data_list): + mrope_input_positions = [[] for _ in range(3)] + for idx in range(3): + for inter_data in self.inter_data_list: + msections = inter_data.mrope_input_positions + if msections is None: + for _seq_input_positions in inter_data.input_positions: + mrope_input_positions[idx].extend( + _seq_input_positions) + else: + for _seq_mrope_input_positions in msections: + mrope_input_positions[idx].extend( + _seq_mrope_input_positions[idx]) + input_positions = None + else: + input_positions = [] + for inter_data in self.inter_data_list: + for cur_input_positions in inter_data.input_positions: + input_positions.extend(cur_input_positions) + + seq_lens = [] + query_lens = [] + max_decode_seq_len = 0 + max_encoder_seq_len = 0 + for inter_data in self.inter_data_list: + seq_lens.extend(inter_data.seq_lens) + query_lens.extend(inter_data.query_lens) + if not inter_data.is_prompt: + max_decode_seq_len = max(max_decode_seq_len, + max(inter_data.seq_lens)) + if self.runner.model_config.is_encoder_decoder: + max_encoder_seq_len = max(max_encoder_seq_len, + inter_data.encoder_seq_len) + + # Mapping from request IDs to sequence IDs. Used for Jamba models + # that manages the cache by itself. + request_ids_to_seq_ids = { + data.request_id: data.seq_ids + for data in self.inter_data_list + } + + cuda_graph_pad_size = self._get_cuda_graph_pad_size( + num_seqs=len(seq_lens), + max_decode_seq_len=max_decode_seq_len, + max_encoder_seq_len=max_encoder_seq_len) + + batch_size = len(input_tokens) + if cuda_graph_pad_size != -1: + # If cuda graph can be used, pad tensors accordingly. + # See `capture_model` API for more details. + # vLLM uses cuda graph only for decoding requests. + batch_size += cuda_graph_pad_size + + # Tokens and positions. + if cuda_graph_pad_size: + input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size)) + assert self.runner.device is not None + input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long, + self.runner.device, + self.runner.pin_memory) + + token_types_tensor = async_tensor_h2d(token_types, torch.long, + self.runner.device, + self.runner.pin_memory) \ + if token_types else None + + if mrope_input_positions is not None: + for idx in range(3): + mrope_input_positions[idx].extend( + itertools.repeat(0, cuda_graph_pad_size)) + input_positions_tensor = async_tensor_h2d(mrope_input_positions, + torch.long, + self.runner.device, + self.runner.pin_memory) + else: + input_positions.extend(itertools.repeat(0, cuda_graph_pad_size)) + input_positions_tensor = async_tensor_h2d(input_positions, + torch.long, + self.runner.device, + self.runner.pin_memory) + # Sequence and query lengths. + if cuda_graph_pad_size: + seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size)) + + # Attention metadata. + attn_metadata = self.attn_metadata_builder.build( + seq_lens, query_lens, cuda_graph_pad_size, batch_size) + + # LoRA data. + lora_requests = set() + lora_mapping = None + if self.enable_lora: + lora_requests = set(r for data in self.inter_data_list + for r in data.lora_requests) + lora_index_mapping = flatten_2d_lists([ + flatten_2d_lists(inter_data.lora_index_mapping) + for inter_data in self.inter_data_list + ]) + if cuda_graph_pad_size: + lora_index_mapping.extend( + itertools.repeat(0, cuda_graph_pad_size)) + lora_prompt_mapping = flatten_2d_lists([ + flatten_2d_lists(inter_data.lora_prompt_mapping) + for inter_data in self.inter_data_list + ]) + + lora_mapping = LoRAMapping( + **dict(index_mapping=lora_index_mapping, + prompt_mapping=lora_prompt_mapping, + is_prefill=not self.decode_only)) + + # Prompt adapter data. + prompt_adapter_requests: Set[PromptAdapterRequest] = set() + prompt_adapter_mapping = None + if self.enable_prompt_adapter: + prompt_adapter_requests = set( + data.prompt_adapter_request for data in self.inter_data_list + if data.prompt_adapter_request is not None) + prompt_adapter_index_mapping = flatten_2d_lists([ + inter_data.prompt_adapter_index_mapping + for inter_data in self.inter_data_list + ]) + if cuda_graph_pad_size: + prompt_adapter_index_mapping.extend( + itertools.repeat(0, cuda_graph_pad_size)) + prompt_adapter_prompt_mapping = flatten_2d_lists([ + inter_data.prompt_adapter_prompt_mapping + for inter_data in self.inter_data_list + ]) + prompt_adapter_mapping = PromptAdapterMapping( + prompt_adapter_index_mapping, + prompt_adapter_prompt_mapping, + ) + + # Multi-modal data. + multi_modal_kwargs_list = [ + data.multi_modal_kwargs for data in self.inter_data_list + if data.multi_modal_kwargs is not None + ] + multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) + + return self.model_input_cls( + input_tokens=input_tokens_tensor, + input_positions=input_positions_tensor, + token_types=token_types_tensor, + attn_metadata=attn_metadata, + seq_lens=seq_lens, + query_lens=query_lens, + lora_mapping=lora_mapping, + lora_requests=lora_requests, + multi_modal_kwargs=multi_modal_kwargs, + request_ids_to_seq_ids=request_ids_to_seq_ids, + finished_requests_ids=self.finished_requests_ids, + prompt_adapter_mapping=prompt_adapter_mapping, + prompt_adapter_requests=prompt_adapter_requests) + + +class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): + """ + Helper class for shared methods between GPU model runners. + """ + _model_input_cls: Type[TModelInputForGPU] + _builder_cls: Type[ModelInputForGPUBuilder] + builder: ModelInputForGPUBuilder + + def __init__( + self, + vllm_config: VllmConfig, + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + return_hidden_states: bool = False, + input_registry: InputRegistry = INPUT_REGISTRY, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + ): + + ModelRunnerBase.__init__(self, vllm_config) + model_config = self.model_config + cache_config = self.cache_config + + self.is_driver_worker = is_driver_worker + self.return_hidden_states = return_hidden_states + + self.device = self.device_config.device + self.pin_memory = is_pin_memory_available() + + self.kv_cache_dtype = kv_cache_dtype + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture + self.max_batchsize_to_capture = \ + self.vllm_config.compilation_config.max_capture_size + + self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [ + {} for _ in range(self.parallel_config.pipeline_parallel_size) + ] + self.graph_memory_pool: Optional[Tuple[ + int, int]] = None # Set during graph capture. + + self.has_inner_state = model_config.has_inner_state + + self.in_profile_run = False + + # When using CUDA graph, the input block tables must be padded to + # max_seq_len_to_capture. However, creating the block table in + # Python can be expensive. To optimize this, we cache the block table + # in numpy and only copy the actual input content at every iteration. + # The shape of the cached block table will be + # (max batch size to capture, max seq len to capture / block size). + self.graph_block_tables = np.zeros( + (self.max_batchsize_to_capture, self.get_max_block_per_batch()), + dtype=np.int32) + + # Attention-free but stateful models like Mamba need a placeholder attn + # backend, as the attention metadata is needed to manage internal state. + # However we must bypass attention selection altogether for some models + # used for speculative decoding to avoid a divide-by-zero in + # model_config.get_head_size() + num_attn_heads = self.model_config.get_num_attention_heads( + self.parallel_config) + needs_attn_backend = (num_attn_heads != 0 + or self.model_config.is_attention_free) + + self.attn_backend = get_attn_backend( + self.model_config.get_head_size(), + self.model_config.dtype, + self.kv_cache_dtype, + self.block_size, + self.model_config.is_attention_free, + use_mla=self.model_config.use_mla, + ) if needs_attn_backend else None + if self.attn_backend: + self.attn_state = self.attn_backend.get_state_cls()( + weakref.proxy(self)) + else: + self.attn_state = CommonAttentionState(weakref.proxy(self)) + + # Multi-modal data support + self.input_registry = input_registry + self.mm_registry = mm_registry + self.multi_modal_input_mapper = mm_registry \ + .create_input_mapper(model_config) + self.mm_registry.init_mm_limits_per_prompt(self.model_config) + + # Lazy initialization + self.model: nn.Module # Set after load_model + # Set after load_model. + self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None + self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None + + set_cpu_offload_max_bytes( + int(self.cache_config.cpu_offload_gb * 1024**3)) + + # Used to cache python objects + self.inter_data_cache: Dict[int, PyObjectCache] = {} + + # Using the PythonizationCache in Pipeline-Parallel clobbers the + # SequenceGroupToSample object. In Pipeline-Parallel, we have + # more than 1 Scheduler, resulting in a potential back-to-back + # prepare_model_inputs() call. This clobbers the cached + # SequenceGroupToSample objects, as we reset the cache during + # every prepare_model_inputs() call. + self.sampling_metadata_cache: SamplingMetadataCache = \ + SamplingMetadataCache() \ + if self.parallel_config.pipeline_parallel_size == 1 else None + + if hasattr(self, "_builder_cls"): + # multi-step model runner does not have `_builder_cls` + self.builder = self._builder_cls(weakref.proxy(self)) + + def load_model(self) -> None: + logger.info("Starting to load model %s...", self.model_config.model) + with DeviceMemoryProfiler(self.device) as m: + time_before_load = time.perf_counter() + self.model = get_model(vllm_config=self.vllm_config) + if self.lora_config: + assert supports_lora( + self.model + ), f"{self.model.__class__.__name__} does not support LoRA yet." + + if supports_multimodal(self.model): + logger.warning( + "Regarding multimodal models, vLLM currently " + "only supports adding LoRA to language model.") + # It's necessary to distinguish between the + # max_position_embeddings of VLMs and LLMs. + if hasattr(self.model.config, "max_position_embeddings"): + max_pos_embeddings = ( + self.model.config.max_position_embeddings) + else: + max_pos_embeddings = ( + self.model.config.text_config.max_position_embeddings) + + self.lora_manager = LRUCacheWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, + self.vocab_size, + self.lora_config, + self.device, + self.model.embedding_modules, + self.model.embedding_padding_modules, + max_position_embeddings=max_pos_embeddings, + ) + self.model = self.lora_manager.create_lora_manager(self.model) + time_after_load = time.perf_counter() + + self.model_memory_usage = m.consumed_memory + logger.info("Model loading took %.4f GiB and %.6f seconds", + self.model_memory_usage / GiB_bytes, + time_after_load - time_before_load) + if self.prompt_adapter_config: + self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, self.device, + self.prompt_adapter_config) + self.model = ( + self.prompt_adapter_manager.create_prompt_adapter_manager( + self.model)) + + if self.vllm_config.compilation_config.level ==\ + CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): + backend = self.vllm_config.compilation_config.init_backend( + self.vllm_config) + self.model = torch.compile( + self.model, + fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + backend=backend) + + def get_model(self) -> nn.Module: + return self.model + + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + from vllm.model_executor.model_loader.loader import ShardedStateLoader + ShardedStateLoader.save_model( + self.model, + path, + pattern=pattern, + max_size=max_size, + ) + + def save_tensorized_model( + self, + tensorizer_config: TensorizerConfig, + ) -> None: + from vllm.model_executor.model_loader.loader import TensorizerLoader + TensorizerLoader.save_model( + self.model, + tensorizer_config=tensorizer_config, + ) + + def get_max_block_per_batch(self) -> int: + block_size = self.block_size + return (self.max_seq_len_to_capture + block_size - 1) // block_size + + def _prepare_model_input_tensors( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + finished_requests_ids: Optional[List[str]] = None + ) -> TModelInputForGPU: + """Helper method to prepare the model input based on a given sequence + group. Prepares metadata needed for the base model forward pass but not + metadata for possible additional steps, e.g., sampling. + + The API assumes seq_group_metadata_list is sorted by prefill -> decode. + + The result tensors and data structure also batches input in prefill + -> decode order. For example, + + - input_tokens[:num_prefill_tokens] contains prefill tokens. + - input_tokens[num_prefill_tokens:] contains decode tokens. + + If cuda graph is required, this API automatically pads inputs. + """ + self.builder.prepare(finished_requests_ids) + for seq_group_metadata in seq_group_metadata_list: + try: + self.builder.add_seq_group(seq_group_metadata) + except Exception as e: + # Raise an exception that tracks the ID of the bad request + raise InputProcessingError(seq_group_metadata.request_id, + str(e)) from e + + self.builder.reset_cached_inter_data() + + return self.builder.build() # type: ignore + + @contextmanager + def set_in_profile_run(self): + self.in_profile_run = True + try: + yield + finally: + self.in_profile_run = False + + @torch.inference_mode() + def profile_run(self) -> None: + max_num_batched_tokens = \ + self.scheduler_config.max_num_batched_tokens + max_num_seqs = self.scheduler_config.max_num_seqs + self._dummy_run(max_num_batched_tokens, max_num_seqs) + + def _add_dummy_loras(self, num_loras: int) -> list[LoRARequest]: + assert num_loras > 0 + assert self.lora_manager is not None + + dummy_lora_requests: list[LoRARequest] = [] + with self.lora_manager.dummy_lora_cache(): + for idx in range(num_loras): + lora_id = idx + 1 + dummy_lora_request = LoRARequest( + lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_path="/not/a/real/path", + ) + self.lora_manager.add_dummy_lora(dummy_lora_request, + rank=LORA_WARMUP_RANK) + dummy_lora_requests.append(dummy_lora_request) + return dummy_lora_requests + + def _remove_dummy_loras(self): + # Remove dummy loras. + assert self.lora_manager is not None + self.remove_all_loras() + + def _dummy_run(self, + max_num_batched_tokens: int, + max_num_seqs: int = 1) -> None: + with self.set_in_profile_run(): + # Enable top-k sampling to reflect the accurate memory usage. + sampling_params = \ + SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) + + # This represents the maximum number of different requests + # that will have unique loras, and therefore the max amount of + # memory consumption. Create dummy lora request copies from the + # lora request passed in, which contains a lora from the lora + # warmup path. + dummy_lora_requests: List[LoRARequest] = [] + dummy_lora_requests_per_seq: List[LoRARequest] = [] + if self.lora_config: + dummy_lora_requests = self._add_dummy_loras( + self.lora_config.max_loras) + assert len(dummy_lora_requests) == self.lora_config.max_loras + dummy_lora_requests_per_seq = [ + dummy_lora_requests[idx % len(dummy_lora_requests)] + for idx in range(max_num_seqs) + ] + + # Profile memory usage with max_num_sequences sequences and the + # total number of tokens equal to max_num_batched_tokens. + seqs: List[SequenceGroupMetadata] = [] + # Additional GPU memory may be needed for multi-modal encoding, + # which needs to be accounted for when calculating the GPU blocks + # for vLLM blocker manager. + # To exercise the worst scenario for GPU memory consumption, + # the number of seqs (batch_size) is chosen to maximize the number + # of images processed. + + max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( + self.model_config) + if max_mm_tokens > 0: + max_num_seqs_orig = max_num_seqs + max_num_seqs = min(max_num_seqs, + max_num_batched_tokens // max_mm_tokens) + if max_num_seqs < 1: + expr = (f"min({max_num_seqs_orig}, " + f"{max_num_batched_tokens} // {max_mm_tokens})") + logger.warning( + "Computed max_num_seqs (%s) to be less than 1. " + "Setting it to the minimum value of 1.", expr) + max_num_seqs = 1 + + batch_size = 0 + for group_id in range(max_num_seqs): + seq_len = (max_num_batched_tokens // max_num_seqs + + (group_id < max_num_batched_tokens % max_num_seqs)) + batch_size += seq_len + + dummy_data = self.input_registry \ + .dummy_data_for_profiling(self.model_config, + seq_len, + self.mm_registry) + + seq = SequenceGroupMetadata( + request_id=str(group_id), + is_prompt=True, + seq_data={group_id: dummy_data.seq_data}, + sampling_params=sampling_params, + block_tables=None, + lora_request=dummy_lora_requests_per_seq[group_id] + if dummy_lora_requests_per_seq else None, + multi_modal_data=dummy_data.multi_modal_data, + multi_modal_placeholders=dummy_data. + multi_modal_placeholders, + ) + seqs.append(seq) + + # Run the model with the dummy inputs. + num_layers = self.model_config.get_num_layers(self.parallel_config) + # use an empty tensor instead of `None`` to force Dynamo to pass + # it by reference, rather by specializing on the value ``None``. + # the `dtype` argument does not matter, and we use `float32` as + # a placeholder (it has wide hardware support). + # it is important to create tensors inside the loop, rather than + # multiplying the list, to avoid Dynamo from treating them as + # tensor aliasing. + kv_caches = [ + torch.tensor([], dtype=torch.float32, device=self.device) + for _ in range(num_layers) + ] + finished_requests_ids = [seq.request_id for seq in seqs] + model_input = self.prepare_model_input( + seqs, finished_requests_ids=finished_requests_ids) + intermediate_tensors = None + if not get_pp_group().is_first_rank: + intermediate_tensors = \ + self.model.make_empty_intermediate_tensors( + batch_size=batch_size, + dtype=self.model_config.dtype, + device=self.device) + + # Disable KV Scale Calculation for dummy data during profile run + if model_input.attn_metadata is not None: + model_input.attn_metadata.enable_kv_scales_calculation = False + + self.execute_model(model_input, kv_caches, intermediate_tensors) + torch.cuda.synchronize() + if self.lora_config: + self._remove_dummy_loras() + + return + + def remove_all_loras(self): + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.remove_all_adapters() + + def set_active_loras(self, lora_requests: Set[LoRARequest], + lora_mapping: LoRAMapping) -> None: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.set_active_adapters(lora_requests, lora_mapping) + + def add_lora(self, lora_request: LoRARequest) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.add_adapter(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.remove_adapter(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.pin_adapter(lora_id) + + def list_loras(self) -> Set[int]: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.list_adapters() + + def remove_all_prompt_adapters(self): + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + self.prompt_adapter_manager.remove_all_adapters() + + def set_active_prompt_adapters( + self, prompt_adapter_requests: Set[PromptAdapterRequest], + prompt_adapter_mapping: PromptAdapterMapping) -> None: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + self.prompt_adapter_manager.set_active_adapters( + prompt_adapter_requests, prompt_adapter_mapping) + + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + return self.prompt_adapter_manager.add_adapter(prompt_adapter_request) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + return self.prompt_adapter_manager.remove_adapter(prompt_adapter_id) + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + return self.prompt_adapter_manager.pin_adapter(prompt_adapter_id) + + def list_prompt_adapters(self) -> Set[int]: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + return self.prompt_adapter_manager.list_adapters() + + @torch.inference_mode() + def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: + """Cuda graph capture a model. + + Note that CUDA graph's performance gain is negligible if number + of batched tokens are larger than 200. And since CUDA graph + requires fixed sized tensors, supporting large/variable batch + size requires high GPU memory overhead. Thus, vLLM only captures + decoding requests. Mixed batch (chunked prefill + decoding) or + prefill requests are not captured. + + Since it is used for decoding-only, it assumes there's only 1 token + per sequence in the batch. + """ + assert not self.model_config.enforce_eager + logger.info("Capturing cudagraphs for decoding. This may lead to " + "unexpected consequences if the model is not static. To " + "run the model in eager mode, set 'enforce_eager=True' or " + "use '--enforce-eager' in the CLI. " + "If out-of-memory error occurs during cudagraph capture," + " consider decreasing `gpu_memory_utilization` or " + "switching to eager mode. You can also reduce the " + "`max_num_seqs` as needed to decrease memory usage.") + start_time = time.perf_counter() + start_free_gpu_memory = torch.cuda.mem_get_info()[0] + + # Prepare dummy inputs. These will be reused for all batch sizes. + max_batch_size = self.max_batchsize_to_capture + input_tokens = torch.zeros(max_batch_size, + dtype=torch.long, + device=self.device) + input_positions = torch.zeros(max_batch_size, + dtype=torch.long, + device=self.device) + if self.model_config.uses_mrope: + input_positions = torch.tile(input_positions, + (3, 1)).cuda(device=self.device) + # Prepare dummy previous_hidden_states only if needed by the model. + # This is used by draft models such as EAGLE. + previous_hidden_states = None + if "previous_hidden_states" in inspect.signature( + self.model.forward).parameters: + previous_hidden_states = torch.empty( + [max_batch_size, + self.model_config.get_hidden_size()], + dtype=self.model_config.dtype, + device=self.device) + + intermediate_inputs = None + if not get_pp_group().is_first_rank: + intermediate_inputs = self.model.make_empty_intermediate_tensors( + batch_size=max_batch_size, + dtype=self.model_config.dtype, + device=self.device) + + dummy_lora_id: Optional[int] = None + dummy_lora_request: LoRARequest = [] + if self.lora_config: + # The goal is to capture the LoRA kernels in cuda graphs. + # for this purpose, as single dummy lora is sufficient. + dummy_lora_requests = self._add_dummy_loras(num_loras=1) + assert len(dummy_lora_requests) == 1 + dummy_lora_request = dummy_lora_requests[0] + dummy_lora_id = dummy_lora_request.lora_int_id + + with self.attn_state.graph_capture(max_batch_size), graph_capture( + self.device) as graph_capture_context: + # NOTE: Capturing the largest batch size first may help reduce the + # memory usage of CUDA graph. + for virtual_engine in range( + self.parallel_config.pipeline_parallel_size): + # Only rank 0 should print progress bar during capture + cudagraph_capture_sizes = (tqdm( + self.vllm_config.compilation_config. + cudagraph_capture_sizes, + desc="Capturing CUDA graph shapes", + ) if get_tensor_model_parallel_rank() == 0 else + self.vllm_config.compilation_config. + cudagraph_capture_sizes) + for batch_size in cudagraph_capture_sizes: + attn_metadata = ( + self.attn_state.graph_capture_get_metadata_for_batch( + batch_size, + is_encoder_decoder_model=self.model_config. + is_encoder_decoder)) + # Disable KV Scale Calculation for graph capture + attn_metadata.enable_kv_scales_calculation = False + if self.lora_config: + lora_mapping = LoRAMapping( + **dict(index_mapping=[dummy_lora_id] * batch_size, + prompt_mapping=[dummy_lora_id] * batch_size, + is_prefill=False)) + self.set_active_loras(set([dummy_lora_request]), + lora_mapping) + + if self.prompt_adapter_config: + prompt_adapter_mapping = PromptAdapterMapping( + [-1] * batch_size, + [-1] * batch_size, + ) + self.set_active_prompt_adapters( + set(), prompt_adapter_mapping) + graph_runner = CUDAGraphRunner( + self.model, self.attn_backend.get_name(), + self.attn_state.graph_clone(batch_size), + self.model_config.is_encoder_decoder) + + capture_inputs = { + "input_ids": + input_tokens[:batch_size], + "positions": + input_positions[..., :batch_size], + "intermediate_inputs": + intermediate_inputs[:batch_size] + if intermediate_inputs is not None else None, + "kv_caches": + kv_caches[virtual_engine], + "attn_metadata": + attn_metadata, + "memory_pool": + self.graph_memory_pool, + "stream": + graph_capture_context.stream + } + if previous_hidden_states is not None: + capture_inputs[ + "previous_hidden_states"] = previous_hidden_states[: + batch_size] + + if self.has_inner_state: + # Only used by Mamba-based models CUDA graph atm (Jamba) + capture_inputs.update({ + "seqlen_agnostic_capture_inputs": + self.model.get_seqlen_agnostic_capture_inputs( + batch_size) + }) + if self.model_config.is_encoder_decoder: + # add the additional inputs to capture for + # encoder-decoder models. + self._update_inputs_to_capture_for_enc_dec_model( + capture_inputs) + + with set_forward_context(attn_metadata, self.vllm_config, + virtual_engine): + graph_runner.capture(**capture_inputs) + self.graph_memory_pool = graph_runner.graph.pool() + self.graph_runners[virtual_engine][batch_size] = ( + graph_runner) + + if self.lora_config: + self._remove_dummy_loras() + + end_time = time.perf_counter() + end_free_gpu_memory = torch.cuda.mem_get_info()[0] + elapsed_time = end_time - start_time + cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory + # This usually takes < 10 seconds. + logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", + elapsed_time, cuda_graph_size / GiB_bytes) + + def _update_inputs_to_capture_for_enc_dec_model(self, + capture_inputs: Dict[str, + Any]): + """ + Updates the set of input tensors needed for CUDA graph capture in an + encoder-decoder model. + + This method modifies the provided `capture_inputs` dictionary by + adding tensors specific to encoder-decoder specific models that + need to be captured for CUDA Graph replay. + """ + # During the decode phase encoder_input_ids and encoder_positions are + # unset. Do the same thing for graph capture. + capture_inputs["encoder_input_ids"] = torch.tensor([], + dtype=torch.long, + device=self.device) + capture_inputs["encoder_positions"] = torch.tensor([], + dtype=torch.long, + device=self.device) + + @property + def vocab_size(self) -> int: + return self.model_config.get_vocab_size() + + +class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): + """ + GPU model runner with sampling step. + """ + _model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = ( + ModelInputForGPUWithSamplingMetadata) + _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder + + def make_model_input_from_broadcasted_tensor_dict( + self, + tensor_dict: Dict[str, Any], + ) -> ModelInputForGPUWithSamplingMetadata: + model_input = \ + ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + ) + return model_input + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None, + ) -> ModelInputForGPUWithSamplingMetadata: + """Prepare the model input based on a given sequence group, including + metadata for the sampling step. + + The API assumes seq_group_metadata_list is sorted by prefill -> decode. + + The result tensors and data structure also batches input in prefill + -> decode order. For example, + + - input_tokens[:num_prefill_tokens] contains prefill tokens. + - input_tokens[num_prefill_tokens:] contains decode tokens. + + If cuda graph is required, this API automatically pads inputs. + """ + model_input = self._prepare_model_input_tensors( + seq_group_metadata_list, finished_requests_ids) + if get_pp_group().is_last_rank: + # Sampling metadata is only required for the final pp group + generators = self.get_generators(finished_requests_ids) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, model_input.seq_lens, + model_input.query_lens, self.device, self.pin_memory, + generators, self.sampling_metadata_cache) + else: + sampling_metadata = None + is_prompt = (seq_group_metadata_list[0].is_prompt + if seq_group_metadata_list else None) + return dataclasses.replace(model_input, + sampling_metadata=sampling_metadata, + is_prompt=is_prompt, + virtual_engine=virtual_engine) + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInputForGPUWithSamplingMetadata, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + **kwargs, + ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: + if num_steps > 1: + raise ValueError("num_steps > 1 is not supported in ModelRunner") + + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) + + if self.prompt_adapter_config: + assert model_input.prompt_adapter_requests is not None + assert model_input.prompt_adapter_mapping is not None + self.set_active_prompt_adapters( + model_input.prompt_adapter_requests, + model_input.prompt_adapter_mapping) + + self.attn_state.begin_forward(model_input) + + # Currently cuda graph is only supported by the decode phase. + assert model_input.attn_metadata is not None + prefill_meta = model_input.attn_metadata.prefill_metadata + decode_meta = model_input.attn_metadata.decode_metadata + # TODO(andoorve): We can remove this once all + # virtual engines share the same kv cache. + virtual_engine = model_input.virtual_engine + previous_hidden_states = kwargs.get("previous_hidden_states") + if prefill_meta is None and decode_meta.use_cuda_graph: + assert model_input.input_tokens is not None + graph_batch_size = model_input.input_tokens.shape[0] + model_executable = self.graph_runners[virtual_engine][ + graph_batch_size] + if previous_hidden_states is not None: + previous_hidden_states = torch.cat([ + previous_hidden_states, + torch.empty([ + graph_batch_size - previous_hidden_states.shape[0], + *previous_hidden_states.shape[1:] + ], + dtype=previous_hidden_states.dtype, + device=previous_hidden_states.device) + ]) + else: + model_executable = self.model + + # Receive KV cache in distributed KV cache transfer setting + # In disagg prefill setting, it will also recv hidden states and bypass + # model forwarding + # In KV cache database setting, it will change the model input so that + # we can skip prefilling on tokens that successfully received KV caches + # NOTE: The receive operation is blocking + bypass_model_exec = False + if self.need_recv_kv(model_input, kv_caches): + assert not envs.VLLM_USE_INT8_MLA, "Disaggregated Prefilling does not support int8 MLA " + hidden_or_intermediate_states, bypass_model_exec, model_input = \ + get_kv_transfer_group().recv_kv_caches_and_hidden_states( + # model is used to know which layer the current worker + # is working on, so that we can receive KV for only those + # layers. + model_executable, + model_input, + kv_caches=kv_caches + ) + + multi_modal_kwargs = model_input.multi_modal_kwargs or {} + seqlen_agnostic_kwargs = { + "finished_requests_ids": model_input.finished_requests_ids, + "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, + } if self.has_inner_state else {} + model_kwargs = {} + if previous_hidden_states is not None: + model_kwargs["previous_hidden_states"] = previous_hidden_states + if (self.observability_config is not None + and self.observability_config.collect_model_forward_time): + model_forward_start = torch.cuda.Event(enable_timing=True) + model_forward_end = torch.cuda.Event(enable_timing=True) + model_forward_start.record() + + if not bypass_model_exec: + with set_forward_context(model_input.attn_metadata, + self.vllm_config, virtual_engine): + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + intermediate_tensors=intermediate_tensors, + **MultiModalKwargs.as_kwargs(multi_modal_kwargs, + device=self.device), + **seqlen_agnostic_kwargs, + **model_kwargs, + ) + + if (self.observability_config is not None + and self.observability_config.collect_model_forward_time): + model_forward_end.record() + + # Sending KV cache in distributed KV cache transfer setting + # NOTE: the send operation is non-blocking + if self.need_send_kv(model_input, kv_caches): + assert not envs.VLLM_USE_INT8_MLA, "Disaggregated Prefilling does not support int8 MLA " + get_kv_transfer_group().send_kv_caches_and_hidden_states( + # model_executable is used to know which layer the current + # worker is working on, so that we can send KV for only those + # layers. + model_executable, + model_input, + kv_caches, + hidden_or_intermediate_states, + ) + + # Compute the logits in the last pipeline stage. + if not get_pp_group().is_last_rank: + if (self.is_driver_worker + and hidden_or_intermediate_states is not None + and isinstance(hidden_or_intermediate_states, + IntermediateTensors) + and self.observability_config is not None + and self.observability_config.collect_model_forward_time): + model_forward_end.synchronize() + model_forward_time = model_forward_start.elapsed_time( + model_forward_end) + orig_model_forward_time = 0.0 + if intermediate_tensors is not None: + orig_model_forward_time = intermediate_tensors.tensors.get( + "model_forward_time", torch.tensor(0.0)).item() + hidden_or_intermediate_states.tensors["model_forward_time"] = ( + torch.tensor(model_forward_time + orig_model_forward_time)) + return hidden_or_intermediate_states + logits = self.model.compute_logits(hidden_or_intermediate_states, + model_input.sampling_metadata) + + if not self.is_driver_worker: + return [] + + if model_input.async_callback is not None: + model_input.async_callback() + + # Sample the next token. + output: SamplerOutput = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + if (self.observability_config is not None + and self.observability_config.collect_model_forward_time + and output is not None): + model_forward_end.synchronize() + model_forward_time = model_forward_start.elapsed_time( + model_forward_end) + orig_model_forward_time = 0.0 + if intermediate_tensors is not None: + orig_model_forward_time = intermediate_tensors.tensors.get( + "model_forward_time", torch.tensor(0.0)).item() + # If there are multiple workers, we are still tracking the latency + # from the start time of the driver worker to the end time of the + # driver worker. The model forward time will then end up covering + # the communication time as well. + output.model_forward_time = (orig_model_forward_time + + model_forward_time) + + if self.return_hidden_states: + # we only need to pass hidden states of most recent token + assert model_input.sampling_metadata is not None + indices = model_input.sampling_metadata.selected_token_indices + if model_input.is_prompt: + hidden_states = hidden_or_intermediate_states.index_select( + 0, indices) + output.prefill_hidden_states = hidden_or_intermediate_states + elif decode_meta.use_cuda_graph: + hidden_states = hidden_or_intermediate_states[:len(indices)] + else: + hidden_states = hidden_or_intermediate_states + + output.hidden_states = hidden_states + + return [output] + + def need_recv_kv(self, model_input, kv_caches) -> bool: + """Check if we need to receive kv-cache from the other worker. + We need to receive KV when + 1. current vLLM instance is KV cache consumer/decode vLLM instance + 2. this batch is not a profiling run + 3. this batch is a prefill run + + Args: + model_input: input to the model executable + kv_caches: vLLM's paged memory + """ + + if self.vllm_config.kv_transfer_config is None: + return False + + prefill_meta = model_input.attn_metadata.prefill_metadata + + # check if the current run is profiling + is_profile_run = (kv_caches[0].numel() == 0) + # check if the current run is prefill + is_prefill_run = prefill_meta is not None + + return self.vllm_config.kv_transfer_config.is_kv_consumer and ( + not is_profile_run) and is_prefill_run + + def need_send_kv(self, model_input, kv_caches) -> bool: + """Check if we need to send kv-cache to the other worker. + We need to send KV when + 1. current vLLM instance is KV cache producer/prefill vLLM instance + 2. this batch is not a profiling run + 3. this batch is a prefill run + + Args: + model_input: input to the model executable + kv_caches: vLLM's paged memory + """ + + if self.vllm_config.kv_transfer_config is None: + return False + + prefill_meta = model_input.attn_metadata.prefill_metadata + + # check if the current run is profiling + is_profile_run = (kv_caches[0].numel() == 0) + # check if the current run is prefill + is_prefill_run = prefill_meta is not None + + return self.vllm_config.kv_transfer_config.is_kv_producer and ( + not is_profile_run) and is_prefill_run + + +# NOTE: this is nn.Module so the profiler can properly capture/group +# kernels calls made within the graph +class CUDAGraphRunner(nn.Module): + + def __init__(self, model: nn.Module, backend_name: str, + attn_state: AttentionState, is_encoder_decoder_model: bool): + super().__init__() + self.model = model + self.backend_name = backend_name + self.attn_state = attn_state + + self.input_buffers: Dict[str, torch.Tensor] = {} + self.output_buffers: Dict[str, torch.Tensor] = {} + + self._graph: Optional[torch.cuda.CUDAGraph] = None + self._is_encoder_decoder_model = is_encoder_decoder_model + + @property + def graph(self): + assert self._graph is not None + return self._graph + + def capture( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_inputs: Optional[IntermediateTensors], + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + memory_pool: Optional[Tuple[int, int]], + stream: torch.cuda.Stream, + **kwargs, + ): + assert self._graph is None + # Run the model a few times without capturing the graph. + # This is to make sure that the captured graph does not include the + # kernel launches for initial benchmarking (e.g., Triton autotune). + # Note one iteration is not enough for torch.compile + for _ in range(_NUM_WARMUP_ITERS): + self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_inputs, + **kwargs, + ) + # Wait for the warm up operations to finish before proceeding with + # Graph Capture. + torch.cuda.synchronize() + # Capture the graph. + self._graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream): + output_hidden_or_intermediate_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_inputs, + **kwargs, + ) + + if isinstance(output_hidden_or_intermediate_states, torch.Tensor): + hidden_or_intermediate_states = weak_ref_tensor( + output_hidden_or_intermediate_states) + elif isinstance(output_hidden_or_intermediate_states, + IntermediateTensors): + hidden_or_intermediate_states = IntermediateTensors( + tensors={ + key: weak_ref_tensor(value) + for key, value in + output_hidden_or_intermediate_states.tensors.items() + }) + + del output_hidden_or_intermediate_states + # make sure `output_hidden_or_intermediate_states` is deleted + # in the graph's memory pool + gc.collect() + torch.cuda.synchronize() + + # Save the input and output buffers. + self.input_buffers = { + "input_ids": + input_ids, + "positions": + positions, + "kv_caches": + kv_caches, + **self.attn_state.get_graph_input_buffers( + attn_metadata, self._is_encoder_decoder_model), + **kwargs, + } + if intermediate_inputs is not None: + self.input_buffers.update(intermediate_inputs.tensors) + if get_pp_group().is_last_rank: + self.output_buffers = { + "hidden_states": hidden_or_intermediate_states + } + else: + self.output_buffers = hidden_or_intermediate_states + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + **kwargs, + ) -> torch.Tensor: + attn_metadata: AttentionMetadata = get_forward_context().attn_metadata + + # Copy the input tensors to the input buffers. + self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) + if positions is not None: + # in some case like MLA, it will reuse positions in metadata + # but truncate them to the original size + # so the shape is not padded, we need to copy partial only + self.input_buffers["positions"][:positions.shape[0]].copy_( + positions, non_blocking=True) + + if self.backend_name != "NO_ATTENTION": + self.input_buffers["slot_mapping"].copy_( + attn_metadata.slot_mapping, non_blocking=True) + + self.attn_state.prepare_graph_input_buffers( + self.input_buffers, attn_metadata, self._is_encoder_decoder_model) + + if "seqlen_agnostic_capture_inputs" in self.input_buffers: + self.model.copy_inputs_before_cuda_graphs(self.input_buffers, + **kwargs) + + if "previous_hidden_states" in self.input_buffers and \ + "previous_hidden_states" in kwargs: + self.input_buffers["previous_hidden_states"].copy_( + kwargs["previous_hidden_states"], non_blocking=True) + + if intermediate_tensors is not None: + for key in intermediate_tensors.tensors: + if key != "model_execute_time" and key != "model_forward_time": + self.input_buffers[key].copy_(intermediate_tensors[key], + non_blocking=True) + if self._is_encoder_decoder_model: + self.input_buffers["encoder_input_ids"].copy_( + kwargs['encoder_input_ids'], non_blocking=True) + self.input_buffers["encoder_positions"].copy_( + kwargs['encoder_positions'], non_blocking=True) + + # Run the graph. + self.graph.replay() + # Return the output tensor. + if get_pp_group().is_last_rank: + return self.output_buffers["hidden_states"] + + return self.output_buffers diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py new file mode 100644 index 0000000..935325c --- /dev/null +++ b/vllm/worker/model_runner_base.py @@ -0,0 +1,281 @@ +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +from abc import ABC, abstractmethod +from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type, + TypeVar) + +import torch +import torch.nn as nn + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata + +if TYPE_CHECKING: + from vllm.attention import AttentionMetadata + from vllm.attention.backends.abstract import AttentionBackend + from vllm.model_executor import SamplingMetadata + +logger = init_logger(__name__) + +T = TypeVar('T', bound="BroadcastableModelInput") + + +def _add_attn_metadata_broadcastable_dict( + tensor_dict: Dict[str, Any], + attn_metadata: Optional["AttentionMetadata"]) -> None: + """ + Helper method to update tensor_dict with broadcastable + AttentionMetadata fields. + """ + if attn_metadata is not None: + tensor_dict.update(attn_metadata.asdict_zerocopy()) + + +def _init_attn_metadata_from_tensor_dict( + attn_backend: "AttentionBackend", + tensor_dict: Dict[str, Any], +) -> Dict[str, Any]: + """ + Helper method to initialize AttentionMetadata based on an + AttentionBackend and broadcastable AttentionMetadata fields. + """ + # Extract the fields used to create AttentionMetadata. + valid_attn_kwargs = {} + for field in dataclasses.fields(attn_backend.get_metadata_cls()): + if field.name in tensor_dict: + if field.name == "input_positions": + valid_attn_kwargs[field.name] = tensor_dict[field.name] + else: + valid_attn_kwargs[field.name] = tensor_dict.pop(field.name) + + attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) + tensor_dict["attn_metadata"] = attn_metadata + return tensor_dict + + +def _init_sampling_metadata_from_tensor_dict( # type: ignore + tensor_dict: Dict[str, Any]) -> Dict[str, Any]: + """ + Helper method to initialize SamplingMetadata based on broadcastable + SamplingMetadata fields. + """ + from vllm.model_executor import SamplingMetadata + + selected_token_indices = tensor_dict.pop("selected_token_indices", None) + # An empty SamplingMetadata to signal that the worker should skip + # sampling. + if selected_token_indices is not None: + tensor_dict["sampling_metadata"] = SamplingMetadata( + seq_groups=None, + selected_token_indices=selected_token_indices, + categorized_sample_indices=None, + num_prompts=0, + ) + return tensor_dict + + +def _add_sampling_metadata_broadcastable_dict( + tensor_dict: Dict[str, Any], + sampling_metadata: Optional["SamplingMetadata"]) -> None: + """ + Helper method to update tensor_dict with broadcastable + SamplingMetadata fields. + """ + if sampling_metadata is not None: + tensor_dict["selected_token_indices"] = ( + sampling_metadata.selected_token_indices) + + +def _init_frozen_model_input_from_tensor_dict( + frozen_model_input_cls: Type["ModelRunnerInputBase"], + tensor_dict: Dict[str, Any]) -> Dict[str, Any]: + """ + Helper method to initialize a frozen ModelInput based on broadcastable + """ + valid_tensor_kwargs = {} + for field in dataclasses.fields(frozen_model_input_cls): + val = tensor_dict.pop(field.name, None) + if val is not None: + valid_tensor_kwargs[field.name] = val + + frozen_model_input = frozen_model_input_cls(**valid_tensor_kwargs) + tensor_dict["frozen_model_input"] = frozen_model_input + return tensor_dict + + +class BroadcastableModelInput(ABC): + + @abstractmethod + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + """ + Extract broadcastable fields. Override for fields that require some + custom deserialization. + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def from_broadcasted_tensor_dict( + cls: Type[T], + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> T: + """ + Pop fields from the given tensor_dict and populate a new instance of + BroadcastableModelInput. + """ + raise NotImplementedError + + +@dataclasses.dataclass(frozen=True) +class ModelRunnerInputBase(BroadcastableModelInput): + """Local inputs to each worker's model runner. May contain + device-specific data. Different worker backends may have different methods + of converting from the global ExecuteModelRequest produced by the LLM + engine to the worker-local ModelRunnerInputBase objects. + + Model runners that support multi-GPU execution should define a + ModelRunnerInputBase subclass, add their required fields, and specify how to + serialize/deserialize a ModelInput for broadcast between workers. + """ + pass + + +class ModelRunnerInputBuilderBase(ABC, Generic[T]): + """A builder to create ModelRunnerInputBase objects. + """ + + @abstractmethod + def prepare(self, + finished_requests_ids: Optional[List[str]] = None) -> None: + raise NotImplementedError + + @abstractmethod + def add_seq_group(self, seq_group_metadata): + """TBA""" + raise NotImplementedError + + @abstractmethod + def build(self, *args, **kwargs) -> T: + """Build metadata with on-device tensors.""" + raise NotImplementedError + + +class ModelRunnerBase(ABC, Generic[T]): + """ + Model runner interface that abstracts a particular hardware and/or type of + model. Model execution may communicate data with model runners in other + processes, but it should not include control plane metadata communication. + + Each ModelRunnerBase subclass should define a corresponding + ModelRunnerInputBase subclass. + """ + + def __init__( + self, + vllm_config: VllmConfig, + ) -> None: + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + + # Map of request_id -> generator used for seeded random sampling + generators: Dict[str, torch.Generator] = {} + + @abstractmethod + def make_model_input_from_broadcasted_tensor_dict( + self, + tensor_dict: Dict[str, Any], + ) -> T: + """ + Make an instance of a ModelRunnerInputBase from the broadcasted tensor + dict. + """ + raise NotImplementedError + + @abstractmethod + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None, + ) -> T: + """ + Prepare the inputs to ModelRunnerBase.execute_model from an execution + request. This method may move data to the worker's local device. It is + not allowed to communicate with other workers or devices. + """ + raise NotImplementedError + + @abstractmethod + def get_model(self) -> nn.Module: + raise NotImplementedError + + def execute_model( + self, + model_input: T, + kv_caches: Optional[List[torch.Tensor]], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + **kwargs, + ) -> Optional[List[SamplerOutput]]: + """ + Execute the model on the given input. + """ + raise NotImplementedError + + def get_generators(self, finished_request_ids: Optional[List[str]] = None): + """ + Return dict of per-request generators used for random sampling. + """ + + # Clean up generators from completed requests + if finished_request_ids: + for request_id in finished_request_ids: + self.generators.pop(request_id, None) + + return self.generators + + +class ModelRunnerWrapperBase: + """ + The whole point of this class is to lazily initialize the model_runner. + """ + + def __init__( + self, + model_runner: ModelRunnerBase, + ) -> None: + self.model_runner: ModelRunnerBase = model_runner + + def __getattr__(self, attr): + return getattr(self.model_runner, attr) + + +class InputProcessingError(Exception): + """This exception is raised when an error occurs preparing the inputs for + a single sequence group. + This allows the engine to gracefully handle errors with a single sequence + group without having to fail the entire batch. + """ + + def __init__(self, request_id, message): + """request_id is the id of the offending sequence group""" + self.request_id = request_id + self.message = message + super().__init__(self.message) + + def __str__(self): + return "Failed to prepare inputs for sequence group with request id: " \ + f"{self.request_id}, Error: {self.message}" diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py new file mode 100644 index 0000000..7ddf382 --- /dev/null +++ b/vllm/worker/multi_step_model_runner.py @@ -0,0 +1,909 @@ +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +import functools +from dataclasses import dataclass, field +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, + Union) + +import torch + +from vllm.distributed import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs, + SamplerOutput, + SamplingMetadata, get_logprobs, + get_pythonized_sample_results) +from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, + Logprob, SequenceGroupMetadata, SequenceOutput) +from vllm.utils import PyObjectCache, async_tensor_h2d, current_stream +from vllm.worker.model_runner import (GPUModelRunnerBase, + ModelInputForGPUWithSamplingMetadata) +from vllm.worker.model_runner_base import ( + BroadcastableModelInput, _init_attn_metadata_from_tensor_dict, + _init_frozen_model_input_from_tensor_dict, + _init_sampling_metadata_from_tensor_dict) + +from ..model_executor.model_loader.tensorizer import TensorizerConfig + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +logger = init_logger(__name__) + +MULTI_STEP_ATTENTION_BACKENDS = [ + "FLASH_ATTN", "ROCM_FLASH", "FLASHINFER", "NO_ATTENTION" +] +MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["FLASH_ATTN", "FLASHINFER"] + +def _get_supported_attention_backends(chunked_prefill_enabled: bool) \ + -> List[str]: + if chunked_prefill_enabled: + return MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS + else: + return MULTI_STEP_ATTENTION_BACKENDS + + +def seq_output_builder(): + return SequenceOutput( + 0, 0, + {0: Logprob(logprob=float('inf'), rank=None, decoded_token=None)}) + + +def completion_seq_group_output_builder(): + return CompletionSequenceGroupOutput([], None) + + +# Used by pythonization to reduce python object allocations +class PythonizationCache: + + def __init__(self): + self.cached_seq_output = PyObjectCache(seq_output_builder) + self.cached_completion_seq_group_output = PyObjectCache( + completion_seq_group_output_builder) + + def reset(self): + self.cached_seq_output.reset() + self.cached_completion_seq_group_output.reset() + + +@dataclass +class ModelOutput: + """The output of a single model forward pass. + + The sampler_output_ready_event is set when the tensors in + sampler_output are ready (the model+sampler forward pass has + completed). We use the event to synchronize the GPU->CPU transfer, + which we want to only run when the data has been written to the + GPU tensors. Until the event is ready, the tensors in sampler_output + will have garbage data. + + There are two scenarios: + 1. The output tensors are ready and we can pythonize them immediately. + 2. The output tensors are not ready and we need to wait for the event to be + ready. + """ + sampler_output: SamplerOutput + sampler_output_ready_event: torch.cuda.Event + sampled_token_ids: Optional[torch.Tensor] = None + pythonized: bool = False + # On-device tensor containing the logprobs of each token. + logprobs: Optional["torch.Tensor"] = None + pythonization_cache: Optional[PythonizationCache] = None + + def pythonize(self, input_metadata: "StatefulModelInput", + copy_stream: torch.cuda.Stream, + pinned_sampled_token_buffer: torch.Tensor) -> None: + """Pythonize the output. Blocking.""" + if not self.pythonized: + self._pythonize_sampler_output(input_metadata, copy_stream, + pinned_sampled_token_buffer, True) + self.pythonized = True + + def maybe_pythonize(self, input_metadata: "StatefulModelInput", + copy_stream: torch.cuda.Stream, + pinned_sampled_token_buffer: torch.Tensor) -> None: + """Pythonize the output if ready, else return None. Non-blocking.""" + if not self.pythonized: + self.pythonized = self._pythonize_sampler_output( + input_metadata, copy_stream, pinned_sampled_token_buffer, + False) + + def _pythonize_sampler_output(self, input_metadata: "StatefulModelInput", + copy_stream: torch.cuda.Stream, + pinned_sampled_token_buffer: torch.Tensor, + blocking: bool) -> bool: + """ + If blocking is set, will block until the forward pass for the output is + ready and pythonize the output. Upon completing Pythonization, erases + self.logprobs (note that a non-blocking call that is performed when + the sampler output is not yet ready, will not erase self.logprobs.) + """ + assert self.sampled_token_ids is not None + if not blocking and not self.sampler_output_ready_event.query(): + return False + + if blocking: + self.sampler_output_ready_event.synchronize() + with torch.cuda.stream(copy_stream): + _pythonize_sampler_output(input_metadata, self.sampler_output, + pinned_sampled_token_buffer, + self.sampled_token_ids, self.logprobs, + self.pythonization_cache) + + # Erase the logprobs GPU-side tensor. + # Note that although _pythonize_sampler_output() runs in its + # own CUDA stream, nonetheless _pythonize_sampler_output() + # cannot return until Pythonization is complete; therefore + # we know that by the time the CPU reaches this point, + # `self.logprobs` is no longer needed. + self.logprobs = None + return True + + +@dataclass(frozen=False) +class StatefulModelInput(BroadcastableModelInput): + # actual frozen model input dataclass passed to _base_model_runner + frozen_model_input: Optional[ModelInputForGPUWithSamplingMetadata] = None + + # list of model outputs for each step, may not be all pythonized + cached_outputs: List[ModelOutput] = field(default_factory=list) + + # used to pass sampled token ids from the last step to the current step for + # TP workers. Used to append to end of outputs and used by advance_step + last_sampled_token_ids: Optional[torch.Tensor] = None + current_step: int = 0 + is_multi_step: bool = True + is_last_step: bool = False + is_first_multi_step: bool = False + base_output_proc_callback: Optional[Callable] = None + # ping-pong data structures for multi-step to wait on the previous step + step_cuda_events: List[torch.cuda.Event] = field( + default_factory=lambda: [torch.cuda.Event(blocking=True)] * 2) + num_seqs: int = -1 + num_queries: int = -1 + num_single_step_prefills: int = 0 + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + assert self.frozen_model_input is not None + tensor_dict = self.frozen_model_input.as_broadcastable_tensor_dict() + new_tensor_dict = { + 'last_sampled_token_ids': self.last_sampled_token_ids, + 'current_step': self.current_step, + 'is_multi_step': self.is_multi_step, + 'is_last_step': self.is_last_step, + 'is_first_multi_step': self.is_first_multi_step, + 'num_seqs': self.num_seqs, + 'num_queries': self.num_queries, + 'num_single_step_prefills': self.num_single_step_prefills, + } + tensor_dict.update(new_tensor_dict) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "StatefulModelInput": + tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + tensor_dict = _init_frozen_model_input_from_tensor_dict( + ModelInputForGPUWithSamplingMetadata, tensor_dict) + + return cls(**tensor_dict) + + def record_step_event(self, current_stream: torch.cuda.Stream): + # record the event for the current step so that the next step can sync + # on it. We modulo by 2 to keep the events in a circular buffer and + # support any attn backends that may be supported in the future. ie + # Flashinfer would want two DecodeWrappers to overlap the CPU and GPU. + self.step_cuda_events[self.current_step & 1] = \ + torch.cuda.Event(blocking=True) + self.step_cuda_events[self.current_step & 1].record(current_stream) + + def wait_previous_step(self): + # These cuda events are an explicit synchronization to ensure that + # advance_step() (for other attn backends that may be supported in the + # future) do not clobber any data structures that is also used by any + # enqueued forwards steps. For distributed case, only a single event is + # needed, but for single GPU case, since we can let the CPU run much + # further ahead, two events allow us to overlap the advance_step with + # the previous forward (ie using two DecodeWrappers for flashinfer + # backend) + self.step_cuda_events[(self.current_step + 1) & 1].wait() + + def add_sampler_output(self, + sampler_output: SamplerOutput, + sampled_token_ids: Optional[torch.Tensor] = None): + self.cached_outputs.append( + ModelOutput(sampler_output=sampler_output, + sampler_output_ready_event=None, + sampled_token_ids=sampled_token_ids, + pythonized=False)) + + def maybe_advance_sampling_metadata(self, device: str, pin_memory: bool): + """ + sampling_metadata.selected_token_indices is constructed for the + first-step in Multi-Step. However, when chunked-prefill is enabled with + multi-step, the scheduled prompts are fully processed in the + first-step and are processed as decodes in the rest of the steps. + This function updates the sampling_metadata.selected_token_indices + to account for this conversion. + + Example: + Let 2 prompts and 2 decodes be scheduled together. Let the + num-tokens to process for the 2 prompts be 5 and 8 respectively. + + In that case, sampling_metadata.sampled_token_indices will be, + [4, 12, 13, 14] as it is constructed for the first-step in + multi-step. + However, the prompts turns to decodes after the first-step + and the num-tokens for the previously-prompt sequences will + be 1 and 1 as they are decodes now. The self.sampled_token_indices + must be updated to [0,1,2,3]. + """ + assert self.current_step == 1 and self.num_single_step_prefills > 0 + if not get_pp_group().is_last_rank: + return + + assert self.frozen_model_input is not None + assert self.frozen_model_input.sampling_metadata is not None + self.frozen_model_input.sampling_metadata.selected_token_indices = \ + async_tensor_h2d(list(range(self.num_queries)), + dtype=torch.long, + target_device=device, + pin_memory=pin_memory) + + def maybe_advance_frozen_model_input(self, device: str, pin_memory: bool): + """ + Advancing the datastructures of StatefulModelInput::frozen_model_input + is only required when prefills are scheduled with decodes to run in + multi-step. This advancement/correction is required to account for + the conversion of Prefills to Decodes after the first multi-step. + """ + if self.current_step != 1 or self.num_single_step_prefills == 0: + return + + assert self.frozen_model_input is not None + fmi = self.frozen_model_input + + # Truncate input_tokens + assert fmi.input_tokens is not None + assert fmi.input_tokens.shape[0] >= self.num_seqs + fmi_new_input_tokens: torch.Tensor = fmi.input_tokens[:self.num_seqs] + + # Update frozen_model_input::input_positons. + assert fmi.input_positions is not None + assert fmi.input_positions.shape[0] >= self.num_seqs + fmi_new_input_positions: torch.Tensor = fmi.input_positions[:self. + num_seqs] + + # Assert unsupported + assert fmi.lora_mapping is None + assert fmi.lora_requests is not None + assert len(fmi.lora_requests) == 0 + assert fmi.attn_metadata is not None + assert fmi.prompt_adapter_mapping is None + assert fmi.prompt_adapter_requests is not None + assert len(fmi.prompt_adapter_requests) == 0 + assert fmi.multi_modal_kwargs is not None + assert len(fmi.multi_modal_kwargs) == 0 + + self.frozen_model_input = dataclasses.replace( + self.frozen_model_input, + input_tokens=fmi_new_input_tokens, + input_positions=fmi_new_input_positions) + + self.maybe_advance_sampling_metadata(device, pin_memory) + + +# MutableModelInputForGPUWithMultiStepMetadata is not subclass of +# ModelInputForGPU but it wraps the actual input dataclass and adds multi-step +# metadata +# mypy: disable-error-code=type-var +class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): + # mypy: enable-error-code=type-var + + def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs): + + super().__init__(*args, **kwargs) + + # Check attention backend support. + supported_attention_backends: List[str] = \ + _get_supported_attention_backends( + self.scheduler_config.chunked_prefill_enabled) + if self.attn_backend.get_name() not in supported_attention_backends: + ms_config_str: str = "Multi-Step + Chunked-Prefill" \ + if self.scheduler_config.chunked_prefill_enabled \ + else "Multi-Step" + raise ValueError( + f"{ms_config_str} not supported for attention backend: " + f"{self.attn_backend.get_name()}. Set VLLM_ATTENTION_BACKEND " + f"to a value from {supported_attention_backends}.") + + # uses the base model runner to execute the model and wraps it with + # multi-step logic + self._base_model_runner: GPUModelRunnerBase = base_model_runner + + self.is_multi_step = self.scheduler_config.is_multi_step + self.pinned_sampled_token_ids: Optional[torch.Tensor] = None + + # Using the PythonizationCache in Pipeline-Parallel clobbers the + # SequenceOutput and CompletionSequenceGroupOutput object. + # When cache-reset happens at the last step of a multi-step + # execution, there may be other on-going single-step/multi-step + # executions. The current caching implementation does not check + # for this. + self.pythonization_cache = PythonizationCache() \ + if self.parallel_config.pipeline_parallel_size == 1 else None + + @functools.cached_property + def _copy_stream(self): + # used to copy tensors from GPU to CPU asynchronously + return torch.cuda.Stream() + + def make_model_input_from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, Any]) -> StatefulModelInput: + model_input = (StatefulModelInput.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + )) + return model_input + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> StatefulModelInput: + frozen_model_input: ModelInputForGPUWithSamplingMetadata = \ + self._base_model_runner.prepare_model_input( + seq_group_metadata_list, + virtual_engine, + finished_requests_ids) + + assert frozen_model_input.query_lens is not None + assert frozen_model_input.seq_lens is not None + assert frozen_model_input.attn_metadata is not None + num_queries = len(frozen_model_input.query_lens) + num_seqs = len(frozen_model_input.seq_lens) + num_single_step_prefills = frozen_model_input.attn_metadata.num_prefills + + model_input = StatefulModelInput( + frozen_model_input=frozen_model_input, + num_seqs=num_seqs, + num_queries=num_queries, + num_single_step_prefills=num_single_step_prefills) + + return model_input + + def _async_process_outputs(self, model_input: StatefulModelInput, + output_proc_callback: Callable): + # Proceed with pythonization and output_proc in order. + # Stop on the first one that fails to pythonize + output_proc_callback() + + cont = True + for step_num, model_output in enumerate(model_input.cached_outputs): + if not model_output.pythonized: + model_output.maybe_pythonize(model_input, self._copy_stream, + self.pinned_sampled_token_ids) + if model_output.pythonized: + ctx = output_proc_callback.keywords["ctx"] + ctx.append_output( + outputs=[model_output.sampler_output], + seq_group_metadata_list=ctx.seq_group_metadata_list, + scheduler_outputs=ctx.scheduler_outputs, + is_async=False, + is_last_step=False, + is_first_step_output=step_num == 0) + + output_proc_callback() + else: + cont = False + + if not cont: + break + + def _final_process_outputs( + self, model_input: StatefulModelInput, + output_proc_callback: Optional[Callable]) -> List[SamplerOutput]: + assert model_input.frozen_model_input is not None + + has_async_callback = output_proc_callback is not None + + outputs = [] + for step_num, output in enumerate(model_input.cached_outputs): + is_last_step = step_num == len(model_input.cached_outputs) - 1 + + # For non-async case: + # -- We simply add the outputs + # For async case: + # -- Invoke callback, pythonize, add to callback queue and repeat + # -- For last output, just add to callback queue + if has_async_callback: + assert output_proc_callback is not None + + # Invoke callback before pythonize (to overlap with GPU) + output_proc_callback() + + # Pythonize + if not output.pythonized: + output.pythonize(model_input, self._copy_stream, + self.pinned_sampled_token_ids) + + # For non last step, add to callback queue to chain + # callbacks=>pythonize pairs (for GPU overlap) + if not is_last_step: + ctx = output_proc_callback.keywords[ # type: ignore + "ctx"] # type: ignore + ctx.append_output( + outputs=[output.sampler_output], + seq_group_metadata_list=ctx. + seq_group_metadata_list, + scheduler_outputs=ctx.scheduler_outputs, + is_async=False, + is_last_step=False, + is_first_step_output=step_num == 0) + else: + outputs.append(output.sampler_output) + else: + output.pythonize(model_input, self._copy_stream, + self.pinned_sampled_token_ids) + outputs.append(output.sampler_output) + + return outputs + + @torch.inference_mode() + def execute_model( + self, + model_input: StatefulModelInput, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: + """ + Execute the model for a single step and update multi-step + metadata + """ + assert num_steps == 1, "MultiStepModelRunner only supports num_steps=1" + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + + # path for warm up runs + if not model_input.is_multi_step: + return self._base_model_runner.execute_model( + frozen_model_input, None, intermediate_tensors, num_steps) + + # make sure we skip the sampler on the lask rank and only pythonize + # if CPU is ahead. + if self.is_driver_worker and get_pp_group().is_last_rank: + if self.pinned_sampled_token_ids is None: + self.pinned_sampled_token_ids = torch.zeros( + (self.scheduler_config.max_num_seqs, 1), + dtype=torch.long, + device="cpu", + pin_memory=True) + + self._base_model_runner.model.sampler.include_gpu_probs_tensor = ( + True) + if frozen_model_input.sampling_metadata: + frozen_model_input.sampling_metadata.skip_sampler_cpu_output = ( + True) + + # some pre-execute model logic for multi-step: + # - if it's the first step, we need to reset the sampling tensors + # - if it's not the first step, we need to advance the step using the + # appended sampler output from last iteration + # - also maybe pythonize if CPU is ahead of GPU + + stream = current_stream() + if not model_input.is_first_multi_step: + # Explicitly block on the previous step's forward to make sure we + # don't clobber any GPU tensors still in use. + # This is not needed for flashattn backend, but for other attn + # backends such as flashinfer that performs extra CPU operations on + # input metadata we may need to synchronize any CPU operations that + # might clobber enqueued forwards. (prevents CPU from running too + # far ahead if needed) + model_input.wait_previous_step() + model_input = self._advance_step( + model_input, model_input.cached_outputs[-1].sampler_output) + + # frozen_model_input may have been updated + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + + if model_input.base_output_proc_callback is None: + assert frozen_model_input is not None + model_input.base_output_proc_callback = \ + frozen_model_input.async_callback + + if frozen_model_input.async_callback is not None: + assert model_input.base_output_proc_callback is not None + async_callback = functools.partial( + self._async_process_outputs, + model_input=model_input, + output_proc_callback=model_input.base_output_proc_callback) + + model_input.frozen_model_input = dataclasses.replace( # type: ignore + model_input.frozen_model_input, + async_callback=async_callback) + # Update the local instance + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + + # Execute the model + output = self._base_model_runner.execute_model(frozen_model_input, + None, + intermediate_tensors, + num_steps=1) + + # record the event for the current step so that the next step can sync + model_input.record_step_event(stream) + + if get_pp_group().is_last_rank and self.is_driver_worker: + assert isinstance(output, list) + assert len( + output + ) == 1, "MultiStepModelRunner requires single-step base_models" + + # event for the pythonization so that we only pythonize if the + # tensors are ready. May be able to be combined with the step event + output_ready_event = torch.cuda.Event() + output_ready_event.record(stream) + if self.parallel_config.pipeline_parallel_size > 1: + output[0].sampled_token_ids_cpu = output[ + 0].sampled_token_ids.cpu() + model_input.cached_outputs.append( + ModelOutput(output[0], output_ready_event, + output[0].sampled_token_ids, False, + output[0].logprobs, self.pythonization_cache)) + + # These GPU tensors are not required by multi-step; + # erase them to ensure they are not pythonized or + # transferred to CPU + output[0].sampled_token_ids = None + output[0].sampled_token_probs = None + output[0].logprobs = None + + # Pythonize the output if CPU is ahead and the previous step is + # ready. + if frozen_model_input.async_callback is None: + for model_output in model_input.cached_outputs: + model_output.maybe_pythonize(model_input, + self._copy_stream, + self.pinned_sampled_token_ids) + + model_input.current_step += 1 + + if not get_pp_group().is_last_rank: + # Should be IntermediateTensors + assert isinstance(output, IntermediateTensors) + return output + if not self.is_driver_worker: + return [] + + # Pythonize the output and block if needed since it is the last step + if model_input.is_last_step: + outputs = self._final_process_outputs( + model_input, model_input.base_output_proc_callback) + if self.pythonization_cache: + self.pythonization_cache.reset() + return outputs + + # should be [SamplerOutput] + return output + + def _update_sampling_metadata(self, sampling_metadata: SamplingMetadata, + num_seqs: Optional[int], num_queries: int): + + assert sampling_metadata.num_prompts == 0 + assert len(sampling_metadata.seq_groups) == num_queries + assert sampling_metadata.selected_token_indices.shape == ( + num_queries, ) + # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501 + + # Verify that all sequences are decodes + for i in range(num_queries): + seq_group = sampling_metadata.seq_groups[i] + + assert seq_group.is_prompt is False # No prompt + assert seq_group.prompt_logprob_indices == [] # No prompt + assert seq_group.sample_indices == [i] # Simple + assert seq_group.seq_len is None # Decode + assert seq_group.query_len is None # Decode + + def _advance_step(self, model_input: StatefulModelInput, + out: SamplerOutput) -> StatefulModelInput: + + model_input.maybe_advance_frozen_model_input(self.device, + self.pin_memory) + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + assert frozen_model_input.input_tokens is not None + assert frozen_model_input.input_tokens.shape[0] == model_input.num_seqs + assert frozen_model_input.attn_metadata is not None + + sampled_token_ids = model_input.cached_outputs[-1].sampled_token_ids + num_seqs = model_input.num_seqs + num_queries = model_input.num_queries + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + attn_metadata = frozen_model_input.attn_metadata + assert attn_metadata is not None + + turn_prefills_into_decodes: bool = model_input.current_step == 1 and \ + model_input.num_single_step_prefills != 0 + attn_metadata.advance_step( + frozen_model_input, + sampled_token_ids, + self.block_size, + num_seqs, + num_queries, + turn_prefills_into_decodes=turn_prefills_into_decodes) + + return model_input + + def load_model(self) -> None: + self._base_model_runner.load_model() + self.model_memory_usage = self._base_model_runner.model_memory_usage + + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + return self._base_model_runner.save_sharded_state( + path, pattern, max_size) + + def save_tensorized_model(self, + tensorizer_config: TensorizerConfig) -> None: + return self._base_model_runner.save_tensorized_model(tensorizer_config) + + def profile_run(self) -> None: + return self._base_model_runner.profile_run() + + def remove_all_loras(self): + return self._base_model_runner.remove_all_loras() + + def capture_model(self, kv_caches: List[List]) -> None: + return self._base_model_runner.capture_model(kv_caches) + + @property + def vocab_size(self) -> int: + return self._base_model_runner.vocab_size + + +DeferredLogprobsReturnType = Tuple[Optional[List[Optional[PromptLogprobs]]], + Optional[List[SampleLogprobs]]] + + +def deferred_pythonize_logprobs( + output: SamplerOutput, + sampling_metadata: SamplingMetadata, + logprobs_tensor: Optional[torch.Tensor], +) -> DeferredLogprobsReturnType: + """Perform deferred logprob Pythonization. + + 1. Pythonize GPU-side sampler result tensors into CPU-side sampler result. + 2. Pythonize GPU-side logprobs tensor into CPU-side logprobs lists, + utilizing the Pythonized sampler result computed in step 1. + + These deferred computations are not required for single-step scheduling + or the `profile_run()` phase of multi-step scheduling. + + Args: + output: sampler output (under deferred Pythonization) + sampling_metadata + + Returns: + prompt_logprobs (CPU), sample_logprobs (CPU) + """ + + # - Deferred pythonization of sample result + sampler_result = get_pythonized_sample_results( + output.deferred_sample_results_args) + + # - Erase the GPU-side deferred sample_result + # computation args to ensure it is never + # pythonized or transferred to CPU + output.deferred_sample_results_args = None + + # - Deferred pythonization of logprobs + ( + prompt_logprobs, + sample_logprobs, + ) = get_logprobs(logprobs_tensor, sampling_metadata, sampler_result) + assert len(prompt_logprobs) == len(sampling_metadata.seq_groups) + assert len(sample_logprobs) == len(sampling_metadata.seq_groups) + + return prompt_logprobs, sample_logprobs + + +def _pythonize_sampler_output( + model_input: StatefulModelInput, + output: SamplerOutput, + pinned_sampled_token_buffer: torch.Tensor, + sampled_token_ids: torch.Tensor, + logprobs_tensor: Optional[torch.Tensor], + cache: Optional[PythonizationCache], +) -> None: + """ This function is only called when the output tensors are ready. + See :class:`ModelOutput`. + + Modifies `output.outputs` and `pinned_sampled_token_buffer` in-place, + adding a Pythonized output data structure + (:class:`CompletionSequenceGroupOutput`) for each :class:`SequenceGroup`. + + Args: + model_input + output: sampler output + pinned_sampled_token_token_buffer: CPU-side pinned memory + (receives copy of + GPU-side token buffer.) + sampled_token_ids: GPU-side token buffer + logprobs_tensor: GPU-side tensor containing + logprobs computed during sampling + """ + + assert model_input.frozen_model_input is not None + + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input.sampling_metadata is not None + sampling_metadata = frozen_model_input.sampling_metadata + # samples generation should have been skipped + assert not output.outputs + + pinned_buffer = pinned_sampled_token_buffer[:model_input.num_queries] + + # We guarantee output tensors are ready, so it is safe to + # pythonize the sampler output & obtain CPU-side logprobs. + # + # However we should check whether logprobs pythonization may + # be skipped entirely, i.e. because no logprobs were requested + # or pythonization was not deferred. To that end, + # + # * `prompt_logprobs_are_requested_for_prefill` signals that + # there are *any* prefill-phase requests which specify that + # prompt logprobs should be returned. + # + # * `any_logprobs_are_requested` signals that there are any + # requests which (1) specify that sample logprobs should be + # returned, or (2) are in the prefill phase AND specify that + # prompt logprobs should be returned. + # + # Later on, these flags cause adjustments to the pythonization + # process to accommodate logprobs. + + seq_groups = sampling_metadata.seq_groups + prompt_logprobs_are_requested_for_prefill = any([ + sg.sampling_params.prompt_logprobs is not None and sg.is_prompt + for sg in seq_groups + ]) + any_logprobs_are_requested = ( + prompt_logprobs_are_requested_for_prefill + or any([sg.sampling_params.logprobs is not None for sg in seq_groups])) + + if prompt_logprobs_are_requested_for_prefill: + # CPU GPU sync, after gathering *only* sampled tokens (since + # requesting prompt logprobs leads `sampled_token_ids` to + # include prompt token ids in addition to sampled token ids.) + sample_idx_tensor = torch.tensor( + [sdx for sg in seq_groups for sdx in sg.sample_indices]) + pinned_buffer = pinned_buffer.copy_( + sampled_token_ids[sample_idx_tensor, :], non_blocking=False) + else: + # CPU GPU sync + pinned_buffer = pinned_buffer.copy_(sampled_token_ids, + non_blocking=False) + + # this will not block as the tensors are already on CPU + samples_list = pinned_buffer.tolist() + + skip_sampler_cpu_output = ( + frozen_model_input.sampling_metadata.skip_sampler_cpu_output) + + # *Don't* skip logprobs pythonization *if*: + # * Any requests require logprobs to be returned in this + # iteration AND + # * These requests are being scheduled in a fashion which + # defers pythonization (i.e. multi-step scheduling.) + do_pythonize_logprobs = (skip_sampler_cpu_output + and any_logprobs_are_requested) + ( + prompt_logprobs, + sample_logprobs, + ) = (deferred_pythonize_logprobs(output, sampling_metadata, + logprobs_tensor) + if do_pythonize_logprobs else (None, None)) + + for sgdx, (seq_group, + sample_result) in enumerate(zip(seq_groups, samples_list)): + # Reminder: Please update docs/source/features/compatibility_matrix.md + # If the feature combo become valid + # (Check for Guided Decoding) + if seq_group.sampling_params.logits_processors: + assert len(seq_group.sampling_params.logits_processors) == 0, ( + "Logits Processors are not supported in multi-step decoding") + + if do_pythonize_logprobs: + assert prompt_logprobs is not None + assert sample_logprobs is not None + + ( + group_prompt_logprobs, + group_sample_logprobs, + ) = ( # Utilize deferred pythonization results + prompt_logprobs[sgdx], + sample_logprobs[sgdx], + ) + elif any_logprobs_are_requested: + ( + group_prompt_logprobs, + group_sample_logprobs, + ) = ( + # profile_run: use already-computed logprobs + output.outputs[sgdx].prompt_logprobs, + [sample.logprobs for sample in output.outputs[sgdx].samples]) + + seq_ids = seq_group.seq_ids + next_token_ids = sample_result + parent_ids = [0] + seq_outputs: List[SequenceOutput] + + if cache is not None: + completion_seq_group_output: CompletionSequenceGroupOutput = \ + cache.cached_completion_seq_group_output.get_object() + completion_seq_group_output.samples.clear() + seq_outputs = completion_seq_group_output.samples + else: + seq_outputs = [] + + for tdx, (parent_id, + next_token_id) in enumerate(zip(parent_ids, next_token_ids)): + if cache is not None: + seq_output: SequenceOutput = cache.cached_seq_output.get_object( + ) + seq_output.parent_seq_id = seq_ids[parent_id] + seq_output.output_token = next_token_id + + if any_logprobs_are_requested: + seq_output.logprobs = group_sample_logprobs[tdx] + else: + logprobs = next(iter(seq_output.logprobs.values())) + seq_output.logprobs.clear() + + logprobs.logprob = float('inf') + logprobs.rank = None + logprobs.decoded_token = None + + seq_output.logprobs[next_token_id] = logprobs + + seq_outputs.append(seq_output) + + else: + seq_outputs.append( + SequenceOutput(seq_ids[parent_id], next_token_id, + (group_sample_logprobs[tdx] + if any_logprobs_are_requested else { + next_token_id: + Logprob(logprob=float('inf'), + rank=None, + decoded_token=None) + }))) + if cache is not None: + completion_seq_group_output.prompt_logprobs = \ + group_prompt_logprobs if any_logprobs_are_requested else None + output.outputs.append(completion_seq_group_output) + else: + output.outputs.append( + CompletionSequenceGroupOutput( + seq_outputs, (group_prompt_logprobs + if any_logprobs_are_requested else None))) + + assert len(output.outputs) > 0 diff --git a/vllm/worker/multi_step_tpu_worker.py b/vllm/worker/multi_step_tpu_worker.py new file mode 100644 index 0000000..3871199 --- /dev/null +++ b/vllm/worker/multi_step_tpu_worker.py @@ -0,0 +1,107 @@ +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +from typing import Dict, Optional, Tuple + +import torch + +from vllm.distributed import broadcast_tensor_dict +from vllm.sequence import ExecuteModelRequest +from vllm.worker.tpu_model_runner import ModelInputForTPU +from vllm.worker.tpu_worker import TPUWorker +from vllm.worker.worker_base import WorkerInput + + +class MultiStepTPUWorker(TPUWorker): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.cached_model_input: Optional[ModelInputForTPU] = None + + def _get_driver_input_and_broadcast( + self, execute_model_req: ExecuteModelRequest + ) -> Tuple[ModelInputForTPU, WorkerInput, Dict[str, torch.Tensor]]: + assert self.is_driver_worker + assert execute_model_req.virtual_engine == 0 + + is_first_multi_step = execute_model_req.is_first_multi_step + is_last_step = execute_model_req.is_last_step + if is_first_multi_step: + worker_input: WorkerInput = self.prepare_worker_input( + execute_model_req=execute_model_req) + worker_input = dataclasses.replace( + worker_input, + num_steps=execute_model_req.num_lookahead_slots + 1) + model_input: ModelInputForTPU = ( + self.model_runner.prepare_model_input( + execute_model_req.seq_group_metadata_list, + execute_model_req.virtual_engine, + execute_model_req.finished_requests_ids)) + + if execute_model_req.async_callback: + model_input = dataclasses.replace( + model_input, + async_callback=execute_model_req.async_callback) + else: + assert self.cached_model_input is not None + model_input = self.cached_model_input + worker_input = WorkerInput() + model_input = dataclasses.replace( + model_input, + is_first_multi_step=is_first_multi_step, + is_last_step=is_last_step) + + if self.do_metadata_broadcast: + if is_first_multi_step: + broadcast_data = worker_input.as_broadcastable_tensor_dict() + broadcast_data.update( + model_input.as_broadcastable_tensor_dict()) + broadcast_tensor_dict(broadcast_data, src=0) + else: + broadcast_data = { + "is_first_multi_step": is_first_multi_step, + "is_last_step": is_last_step, + } + broadcast_tensor_dict(broadcast_data, src=0) + + # Retuning empty dict here to keep this compatible with + # `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast` + return model_input, worker_input, {} + + def prepare_input( + self, + execute_model_req: Optional[ExecuteModelRequest] = None, + ) -> Optional[Tuple[ModelInputForTPU, WorkerInput, Dict[str, + torch.Tensor]]]: + if self.is_driver_worker: + if execute_model_req is None: + if self.do_metadata_broadcast: + broadcast_tensor_dict({}, src=0) + return None + + model_input, worker_input, _ = self._get_driver_input_and_broadcast( + execute_model_req) + if model_input.is_first_multi_step: + self.cached_model_input = model_input + return model_input, worker_input, {} + else: + broadcast_data = broadcast_tensor_dict(src=0) + if not broadcast_data: + return None + + if len(broadcast_data) == 2: + assert self.cached_model_input is not None + self.cached_model_input = dataclasses.replace( + self.cached_model_input, + is_first_multi_step=broadcast_data["is_first_multi_step"], + is_last_step=broadcast_data["is_last_step"]) + empty_worker_input = WorkerInput() + return self.cached_model_input, empty_worker_input, {} + + worker_input = WorkerInput.from_broadcasted_tensor_dict( + broadcast_data) + model_input = ( + self.model_runner. + make_model_input_from_broadcasted_tensor_dict(broadcast_data)) + self.cached_model_input = model_input + return model_input, worker_input, {} diff --git a/vllm/worker/multi_step_worker.py b/vllm/worker/multi_step_worker.py new file mode 100644 index 0000000..3518ab2 --- /dev/null +++ b/vllm/worker/multi_step_worker.py @@ -0,0 +1,196 @@ +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import torch + +from vllm.distributed import broadcast_tensor_dict, get_pp_group +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest +from vllm.worker.model_runner_base import BroadcastableModelInput +from vllm.worker.multi_step_model_runner import (MultiStepModelRunner, + StatefulModelInput) +from vllm.worker.worker import Worker, WorkerInput + + +@dataclass +class MultiStepState: + worker_input: WorkerInput + model_input: StatefulModelInput + + +class MultiStepWorker(Worker): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + base_model_runner = self.model_runner + # for multi-step model, wrap the model runner with MultiStepModelRunner + self.model_runner = MultiStepModelRunner( + base_model_runner, + vllm_config=base_model_runner.vllm_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=base_model_runner.is_driver_worker, + ) + + pipeline_parallel_size = self.parallel_config.pipeline_parallel_size + self.multi_step_states: List[ + Optional[MultiStepState]] = [None] * pipeline_parallel_size + self.temp_output = None + + def _get_driver_input_and_broadcast( + self, execute_model_req: ExecuteModelRequest + ) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]: + """ + Get the driver input and broadcast it to other workers. + """ + assert self.is_driver_worker + virtual_engine = execute_model_req.virtual_engine + is_first_multi_step = execute_model_req.is_first_multi_step + if is_first_multi_step: + # on first step we prepare the worker input and model input normally + worker_input: WorkerInput = self.prepare_worker_input( + execute_model_req=execute_model_req) + model_input: StatefulModelInput = ( + self.model_runner.prepare_model_input( + execute_model_req.seq_group_metadata_list, + execute_model_req.virtual_engine, + execute_model_req.finished_requests_ids)) + + if execute_model_req.async_callback: + model_input.frozen_model_input = dataclasses.replace( # type: ignore + model_input.frozen_model_input, + async_callback=execute_model_req.async_callback) + else: + # on subsequent steps we reuse the worker input and model input + multi_step_state = self.multi_step_states[virtual_engine] + worker_input = multi_step_state.worker_input + model_input = multi_step_state.model_input + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + assert frozen_model_input.attn_metadata is not None + # clear the cached metadata so that it can be recomputed on + # the workers. + frozen_model_input.attn_metadata._cached_prefill_metadata = None + frozen_model_input.attn_metadata._cached_decode_metadata = None + + model_input.is_first_multi_step = is_first_multi_step + model_input.is_last_step = execute_model_req.is_last_step + + if not is_first_multi_step: + # we broadcast the last sampled token ids to all TP workers so they + # can update their model input metadata in-place. + self._prepare_last_sampled_token_ids_for_tp_workers( + execute_model_req=execute_model_req, model_input=model_input) + + if self.do_metadata_broadcast: + broadcast_data = worker_input.as_broadcastable_tensor_dict() + broadcast_data.update(model_input.as_broadcastable_tensor_dict()) + broadcast_tensor_dict(broadcast_data, src=0) + + # Retuning empty dict here to keep this compatible with + # `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast` + return model_input, worker_input, {} + + def _prepare_last_sampled_token_ids_for_tp_workers( + self, + execute_model_req: ExecuteModelRequest, + model_input: StatefulModelInput, + ) -> None: + """ + Prepare the last sampled token ids for TP workers. If it's the last + PP rank, then the last sampled token ids are already in the model_input. + If it is NOT the last PP rank, then we need to get the last sampled + token that is cached in the execute_model_req. + """ + if get_pp_group().is_last_rank: + assert model_input.cached_outputs[ + -1].sampler_output.sampled_token_ids is None + assert model_input.cached_outputs[-1].sampled_token_ids is not None + model_input.last_sampled_token_ids = model_input.cached_outputs[ + -1].sampled_token_ids + # free sampled token ids from the previous step if it has been + # pythonized. Cannot free the last sampled token ids because + # we need it for GPU advance_step. + for output in model_input.cached_outputs[:-1]: + if output.pythonized: + output.sampled_token_ids = None + else: + # otherwise we need to get the cached sampled token ids from the + # execute_model_req + assert execute_model_req.last_sampled_token_ids is not None + model_input.last_sampled_token_ids = ( + execute_model_req.last_sampled_token_ids.cuda()) + model_input.add_sampler_output( + SamplerOutput(outputs=[], sampled_token_ids=None), + model_input.last_sampled_token_ids) + + # free sampled token ids from the previous step. + # TODO(will) we could reuse the sampled token ids tensor from + # the previous step instead. + for output in model_input.cached_outputs[:-1]: + output.sampled_token_ids = None + assert model_input.cached_outputs[-1].sampled_token_ids is not None + + def prepare_input( + self, + execute_model_req: Optional[ExecuteModelRequest] = None, + ) -> Optional[Tuple[StatefulModelInput, WorkerInput, Dict[str, + torch.Tensor]]]: + """ + Depending on the current state of the request and multi step worker, + this method may skip the normal _prepare_model_input and + _prepare_worker_input methods and instead used cached values. + """ + if self.is_driver_worker: + if execute_model_req is None: + if self.do_metadata_broadcast: + # This signals that there's no more requests to process for + # now. All workers are running infinite loop with + # broadcast_tensor_dict, and it stops the loop when the + # driver broadcasts an empty input. Send an empty input to + # notify all other workers to stop their execution loop. + broadcast_tensor_dict({}, src=0) + return None + + virtual_engine = execute_model_req.virtual_engine + (model_input, worker_input, + kwargs) = self._get_driver_input_and_broadcast(execute_model_req) + assert isinstance(model_input, StatefulModelInput) + if execute_model_req.is_first_multi_step: + # cache the worker input and model input for the next steps + self.multi_step_states[virtual_engine] = MultiStepState( + worker_input=worker_input, model_input=model_input) + # if TP workers + else: + broadcast_data = self._get_worker_input_from_broadcast() + # if the driver has sent an empty input, we should stop the worker + # loop + if broadcast_data is None: + return None + model_input, worker_input, kwargs = broadcast_data + assert isinstance(model_input, StatefulModelInput) + virtual_engine = worker_input.virtual_engine + if model_input.is_first_multi_step: + pass + # TODO(will) Can cache the worker input and model input for the + # next steps. See below for details + else: + # TODO(will) possible to also cache and reuse the cached worker + # input and model input. The idea is essentially the delta + # optimization for model_inputs. Where the TP workers can cache + # the model input states and we only broadcast the delta need + # for the next step (sampled_token_ids from the previous step) + + assert isinstance(model_input, StatefulModelInput) + # we need to update the last sampled token ids in the model + # input for the workers so that they can run inplace + # advance_step + model_input.add_sampler_output( + SamplerOutput(outputs=[], sampled_token_ids=None), + model_input.last_sampled_token_ids) + + assert model_input is not None + assert worker_input is not None + return model_input, worker_input, kwargs diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py new file mode 100644 index 0000000..f2093fc --- /dev/null +++ b/vllm/worker/neuron_model_runner.py @@ -0,0 +1,350 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +from dataclasses import dataclass +from importlib.util import find_spec +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers_neuronx.config import GenerationConfig + +from vllm.config import VllmConfig +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader.neuron import get_neuron_model +from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, + MultiModalKwargs) +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata +from vllm.utils import is_pin_memory_available, make_tensor_with_pad +from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +logger = init_logger(__name__) + + +@dataclass(frozen=True) +class ModelInputForNeuron(ModelRunnerInputBase): + """ + Used by the NeuronModelRunner. + """ + input_tokens: Optional[torch.Tensor] = None + input_positions: Optional[torch.Tensor] = None + input_block_ids: Optional[torch.Tensor] = None + sampling_metadata: Optional["SamplingMetadata"] = None + multi_modal_kwargs: Optional[BatchedTensorInputs] = None + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + raise NotImplementedError("ModelInputForNeuron cannot be broadcast.") + + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "ModelInputForNeuron": + assert attn_backend is None + return cls.from_broadcasted_tensor_dict(tensor_dict) + + +class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): + + # NEURON has an upper limit on the top_k + _MAX_NEURON_SAMPLING_TOP_K = 256 + + def __init__( + self, + vllm_config: VllmConfig, + ): + ModelRunnerBase.__init__(self, vllm_config) + model_config = self.model_config + if model_config is not None and model_config.get_sliding_window(): + logger.warning("Sliding window is not supported on Neuron. " + "The model will run without sliding window.") + self.device = self.device_config.device + self.pin_memory = is_pin_memory_available() + + # Multi-modal data support + self.mm_registry = MULTIMODAL_REGISTRY + self.multi_modal_input_mapper = self.mm_registry \ + .create_input_mapper(self.model_config) + + # Lazy initialization. + self.model: nn.Module # initialize after load_model. + + # Once NEURON_ON_DEVICE_SAMPLING_DISABLED is set to a non-zero value, + # turn off on-device sampling. + self._on_device_sampling_disabled = int( + os.getenv("NEURON_ON_DEVICE_SAMPLING_DISABLED", "0")) + + # NEURON needs to update sampling parameters when request IDs change + # across batches. This variable stores the previous batch's request IDs + # to determine if an update is needed. + self._previous_batch_request_ids: List[str] = [] + + if not self._on_device_sampling_disabled: + logger.warning( + "On-device sampling is turned on in Neuron by default, only " + "top_k, top_p, and temperature are current supported sampling " + "parameters. To turn off the on-device sampling, please set " + "the environment variable NEURON_ON_DEVICE_SAMPLING_DISABLED=1." + ) + self.model_config.neuron_sampling_params = GenerationConfig( + max_length=self.scheduler_config.max_model_len, + do_sample=True, + per_batch_line=True, + top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \ + * self.scheduler_config.max_num_seqs, + top_p=[1.0] * self.scheduler_config.max_num_seqs, + temperature=[1.0] * self.scheduler_config.max_num_seqs, + dynamic=True, + global_top_k=self._MAX_NEURON_SAMPLING_TOP_K) + + def load_model(self) -> None: + if find_spec("transformers_neuronx") is not None: + self.model = get_neuron_model( + self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config) + else: + raise NotImplementedError( + "Supports only Transformer-NeuronX based models.") + + def get_model(self) -> nn.Module: + return self.model + + def _prepare_prompt( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int], + BatchedTensorInputs]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[List[int]] = [] + input_positions: List[List[int]] = [] + input_block_ids: List[int] = [] + + seq_lens: List[int] = [] + multi_modal_kwargs_list: List[MultiModalKwargs] = [] + for seq_group_metadata in seq_group_metadata_list: + assert seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + + seq_data = seq_group_metadata.seq_data[seq_id] + prompt_tokens = seq_data.get_token_ids() + seq_len = len(prompt_tokens) + seq_lens.append(seq_len) + + input_tokens.append(prompt_tokens) + input_positions.append(list(range(seq_len))) + + assert seq_group_metadata.block_tables is not None + block_table = seq_group_metadata.block_tables[seq_id] + assert len(block_table) == 1 + input_block_ids.append(block_table[0]) + + mm_data = seq_group_metadata.multi_modal_data + if mm_data: + if self.mm_registry.has_processor(self.model_config): + mm_kwargs = mm_data + else: + mm_kwargs = self.multi_modal_input_mapper( + mm_data, + seq_group_metadata.mm_processor_kwargs, + ) + + multi_modal_kwargs_list.append(mm_kwargs) + + max_seq_len = max(seq_lens) + assert max_seq_len > 0 + input_tokens = make_tensor_with_pad(input_tokens, + pad=0, + max_len=max_seq_len, + dtype=torch.long, + device=self.device) + input_positions = make_tensor_with_pad(input_positions, + pad=0, + max_len=max_seq_len, + dtype=torch.long, + device=self.device) + input_block_ids = torch.tensor(input_block_ids, + dtype=torch.long, + device=self.device) + + multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) + + return (input_tokens, input_positions, input_block_ids, seq_lens, + multi_modal_kwargs) + + def _prepare_decode( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[List[int]] = [] + input_positions: List[List[int]] = [] + input_block_ids: List[int] = [] + context_lens: List[int] = [] + + for seq_group_metadata in seq_group_metadata_list: + assert not seq_group_metadata.is_prompt + + seq_ids = list(seq_group_metadata.seq_data.keys()) + + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + generation_token = seq_data.get_last_token_id() + input_tokens.append([generation_token]) + + seq_len = seq_data.get_len() + position = seq_len - 1 + input_positions.append([position]) + context_lens.append(seq_len) + + assert seq_group_metadata.block_tables is not None + block_table = seq_group_metadata.block_tables[seq_id] + assert len(block_table) == 1 + input_block_ids.append(block_table[0]) + + input_tokens = make_tensor_with_pad(input_tokens, + pad=0, + max_len=1, + dtype=torch.long, + device=self.device) + input_positions = make_tensor_with_pad(input_positions, + pad=0, + max_len=1, + dtype=torch.long, + device=self.device) + context_lens = torch.tensor(context_lens, + dtype=torch.int, + device=self.device) + input_block_ids = torch.tensor(input_block_ids, + dtype=torch.long, + device=self.device) + + return input_tokens, input_positions, input_block_ids + + def make_model_input_from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, Any]) -> ModelInputForNeuron: + return ModelInputForNeuron.from_broadcasted_tensor_dict(tensor_dict) + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForNeuron: + multi_modal_kwargs = None + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + is_prompt = seq_group_metadata_list[0].is_prompt + # Prepare input tensors. + if is_prompt: + (input_tokens, input_positions, input_block_ids, seq_lens, + multi_modal_kwargs + ) = self._prepare_prompt(seq_group_metadata_list) + else: + (input_tokens, input_positions, + input_block_ids) = self._prepare_decode(seq_group_metadata_list) + seq_lens = None + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + seq_lens, + # query_lens is not needed if chunked prefill is not + # supported. Since neuron worker doesn't support chunked prefill + # just use seq_lens instead. + seq_lens, + self.device, + self.pin_memory, + generators=self.get_generators(finished_requests_ids)) + + if not self._on_device_sampling_disabled: + # Once the request IDs are changed in current iteration, we will + # update the on-device sampling parameters. + current_batch_request_ids = [ + seq_group_meta_data.request_id + for seq_group_meta_data in seq_group_metadata_list + ] + if current_batch_request_ids != self._previous_batch_request_ids: + self._update_neuron_sampling_params(sampling_metadata) + self._previous_batch_request_ids = current_batch_request_ids + + return ModelInputForNeuron(input_tokens=input_tokens, + input_positions=input_positions, + input_block_ids=input_block_ids, + sampling_metadata=sampling_metadata, + multi_modal_kwargs=multi_modal_kwargs) + + def _update_neuron_sampling_params(self, + sampling_metadata: SamplingMetadata): + # Update Neuron sampling parameters (GenerationConfig in Neuron) + current_sampling_params = self.model_config.neuron_sampling_params + assert current_sampling_params is not None, ( + f"Failed to update sampling_params, " + f"current sampling params is {current_sampling_params}") + + top_k = current_sampling_params.top_k + top_p = current_sampling_params.top_p + temperature = current_sampling_params.temperature + for index, sequence_group_to_sample in enumerate( + sampling_metadata.seq_groups): + top_k[index] = self._convert_to_neuron_top_k( + sequence_group_to_sample.sampling_params.top_k) + top_p[index] = sequence_group_to_sample.sampling_params.top_p + temperature[index] = \ + sequence_group_to_sample.sampling_params.temperature + + self.model.model.update_generation_config(current_sampling_params) + + def _convert_to_neuron_top_k(self, top_k: int) -> int: + if top_k < 0 or top_k > self._MAX_NEURON_SAMPLING_TOP_K: + return self._MAX_NEURON_SAMPLING_TOP_K + return top_k + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInputForNeuron, + kv_caches: Optional[List[torch.Tensor]] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[List[SamplerOutput]]: + if num_steps > 1: + raise ValueError( + "NeuronModelRunner does not support multi-step execution.") + + with set_forward_context(None, self.vllm_config, 0): + hidden_states = self.model( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + input_block_ids=model_input.input_block_ids, + **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs + or {}, + device=self.device), + ) + + # Compute the logits only if the on-device sampling is turned off as + # on-device sampling outputs the token ids. + if self._on_device_sampling_disabled: + logits = self.model.compute_logits(hidden_states, + model_input.sampling_metadata) + else: + logits = hidden_states + + # Sample the next token. + output = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + return [output] + + @property + def vocab_size(self) -> int: + return self.model_config.get_vocab_size() diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py new file mode 100644 index 0000000..df651e0 --- /dev/null +++ b/vllm/worker/neuron_worker.py @@ -0,0 +1,138 @@ +# SPDX-License-Identifier: Apache-2.0 +"""A Neuron worker class.""" +from typing import List, Optional, Tuple + +import torch +import torch.distributed + +from vllm.config import VllmConfig +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.model_executor import set_random_seed +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest +from vllm.worker.neuron_model_runner import NeuronModelRunner +from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, + LoRANotSupportedWorkerBase, WorkerBase, + WorkerInput) + + +class NeuronWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase): + """A worker class that executes the model on a group of neuron cores. + """ + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = True, + ) -> None: + WorkerBase.__init__(self, vllm_config=vllm_config) + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + self.model_runner: NeuronModelRunner = NeuronModelRunner( + vllm_config=vllm_config) + self.is_driver_worker = is_driver_worker + + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None, + ) -> Optional[List[SamplerOutput]]: + assert execute_model_req is not None + assert (not execute_model_req.blocks_to_swap_in + and not execute_model_req.blocks_to_swap_out + and not execute_model_req.blocks_to_copy), ( + "Cache operations are not supported for Neuron backend.") + assert execute_model_req.num_lookahead_slots == 0, ( + "lookahead not supported for Neuron backend.") + output = LocalOrDistributedWorkerBase.execute_model( + self, execute_model_req) + return output + + def init_device(self) -> None: + self.init_distributed_environment() + + # Set random seed. + set_random_seed(self.model_config.seed) + + def load_model(self): + self.model_runner.load_model() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks. + + Swapping is not yet supported, so always return num_cpu_blocks=0. + + We configure num_gpu_blocks to be equal to max_num_seqs. + """ + # Set the number of GPU blocks to be the same as the maximum number of + # sequences that can be processed in a single batch. This is equivalent + # to schedule without PagedAttention. + num_gpu_blocks = self.scheduler_config.max_num_seqs + 1 + + # Swap not yet supported with Neuron backend. + num_cpu_blocks = 0 + + return num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache. + """ + + # Different values are not tested. + assert num_cpu_blocks == 0 + assert num_gpu_blocks == self.scheduler_config.max_num_seqs + 1 + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + @property + def do_metadata_broadcast(self) -> bool: + return False + + @property + def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: + return None + + @torch.inference_mode() + def prepare_worker_input( + self, execute_model_req: ExecuteModelRequest) -> WorkerInput: + return WorkerInput(num_seq_groups=len( + execute_model_req.seq_group_metadata_list), ) + + def execute_worker(self, worker_input: WorkerInput) -> None: + pass + + def get_cache_block_size_bytes(self) -> int: + """Determine the size in bytes of a cache block. + + This is required for speculative decoding; it is not yet implemented. + """ + raise NotImplementedError + + def init_distributed_environment(self): + """Neuron uses transformers-neuronx for tensor parallelism. + It has only one process to control multiple devices. + vLLM still needs the environment initialized when TP/PP > 1, + so we initialize a distributed environment with one process. + """ + init_distributed_environment( + world_size=1, + rank=0, + local_rank=0, + distributed_init_method=self.distributed_init_method, + backend="gloo", + ) + ensure_model_parallel_initialized( + 1, + 1, + ) diff --git a/vllm/worker/pooling_model_runner.py b/vllm/worker/pooling_model_runner.py new file mode 100644 index 0000000..cbd5e20 --- /dev/null +++ b/vllm/worker/pooling_model_runner.py @@ -0,0 +1,200 @@ +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import torch + +from vllm.config import VllmConfig +from vllm.distributed import get_pp_group +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.multimodal import MultiModalKwargs +from vllm.pooling_params import PoolingParams +from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData, + SequenceGroupMetadata) +from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPU, + ModelInputForGPUBuilder) + +logger = init_logger(__name__) + + +@dataclasses.dataclass(frozen=True) +class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU): + """ + Used by the PoolingModelRunner. + """ + pooling_metadata: Optional["PoolingMetadata"] = None + + +class PoolingModelRunner( + GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]): + _model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = ( + ModelInputForGPUWithPoolingMetadata) + _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder + + def __init__( + self, + vllm_config: VllmConfig, + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + ): + super().__init__(vllm_config=vllm_config, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=is_driver_worker) + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInputForGPUWithPoolingMetadata, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]: + if num_steps > 1: + raise ValueError( + "PoolingModelRunner does not support multi-step execution.") + + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) + + if self.prompt_adapter_config: + assert model_input.prompt_adapter_requests is not None + assert model_input.prompt_adapter_mapping is not None + self.set_active_prompt_adapters( + model_input.prompt_adapter_requests, + model_input.prompt_adapter_mapping) + + # Currently cuda graph is only supported by the decode phase. + assert model_input.attn_metadata is not None + prefill_meta = model_input.attn_metadata.prefill_metadata + decode_meta = model_input.attn_metadata.decode_metadata + virtual_engine = model_input.virtual_engine + # Pooling models are (ab-)used also to integrate non text models that + # are not autoregressive (PrithviGeosaptialMAE). + # These model might not use attention and do not really have a prefill + # and decode phase. The model input is processed in one shot and both + # decode_metadata and prefill_metadata would be None for such models. + # See the PlaceholderAttentionMetadata class. + # TODO: Figure out if cuda_graph is of any use for these models and + # explore how to leverage it. + if (prefill_meta is None and decode_meta is not None + and decode_meta.use_cuda_graph): + assert model_input.input_tokens is not None + graph_batch_size = model_input.input_tokens.shape[0] + model_executable = self.graph_runners[virtual_engine][ + graph_batch_size] + else: + model_executable = self.model + + multi_modal_kwargs = model_input.multi_modal_kwargs or {} + seqlen_agnostic_kwargs = { + "finished_requests_ids": model_input.finished_requests_ids, + "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, + } if self.has_inner_state else {} + if (self.observability_config is not None + and self.observability_config.collect_model_forward_time): + model_forward_start = torch.cuda.Event(enable_timing=True) + model_forward_end = torch.cuda.Event(enable_timing=True) + model_forward_start.record() + + cross_enc_kwargs = {} + if model_input.token_types is not None: + cross_enc_kwargs["token_type_ids"] = model_input.token_types + + with set_forward_context(model_input.attn_metadata, self.vllm_config, + virtual_engine): + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + intermediate_tensors=intermediate_tensors, + **MultiModalKwargs.as_kwargs(multi_modal_kwargs, + device=self.device), + **cross_enc_kwargs, + **seqlen_agnostic_kwargs) + + if (self.observability_config is not None + and self.observability_config.collect_model_forward_time): + model_forward_end.record() + + # Only perform pooling in the last pipeline stage. + if not get_pp_group().is_last_rank: + if (self.is_driver_worker + and hidden_or_intermediate_states is not None + and isinstance(hidden_or_intermediate_states, + IntermediateTensors) + and self.observability_config is not None + and self.observability_config.collect_model_forward_time): + model_forward_end.synchronize() + model_forward_time = model_forward_start.elapsed_time( + model_forward_end) + orig_model_forward_time = 0.0 + if intermediate_tensors is not None: + orig_model_forward_time = intermediate_tensors.tensors.get( + "model_forward_time", torch.tensor(0.0)).item() + hidden_or_intermediate_states.tensors["model_forward_time"] = ( + torch.tensor(model_forward_time + orig_model_forward_time)) + return hidden_or_intermediate_states + + # Only perform pooling in the driver worker. + if not self.is_driver_worker: + return [] + + return [ + self.model.pooler(hidden_states=hidden_or_intermediate_states, + pooling_metadata=model_input.pooling_metadata) + ] + + def make_model_input_from_broadcasted_tensor_dict( + self, + tensor_dict: Dict[str, + Any]) -> ModelInputForGPUWithPoolingMetadata: + return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + ) + + def prepare_model_input( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForGPUWithPoolingMetadata: + assert seq_group_metadata_list is not None + model_input = self._prepare_model_input_tensors( + seq_group_metadata_list, finished_requests_ids) + # Prepare PoolingMetadata. + assert model_input.seq_lens is not None + pooling_metadata = self._prepare_pooling(seq_group_metadata_list, + model_input.seq_lens) + + return dataclasses.replace(model_input, + pooling_metadata=pooling_metadata) + + def _prepare_pooling( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + prompt_lens: List[int], + ) -> PoolingMetadata: + """Prepare PoolingMetadata for the sequence group metadata list.""" + seq_groups: List[Tuple[List[int], PoolingParams]] = [] + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + seq_ids = list(seq_group_metadata.seq_data.keys()) + pooling_params = seq_group_metadata.pooling_params + seq_groups.append((seq_ids, pooling_params)) + + seq_data: Dict[int, SequenceData] = {} + for seq_group_metadata in seq_group_metadata_list: + seq_data.update(seq_group_metadata.seq_data) + + pooling_metadata = PoolingMetadata( + seq_groups=seq_groups, + seq_data=seq_data, + prompt_lens=prompt_lens, + ) + + return pooling_metadata diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py new file mode 100644 index 0000000..53541a2 --- /dev/null +++ b/vllm/worker/tpu_model_runner.py @@ -0,0 +1,908 @@ +# SPDX-License-Identifier: Apache-2.0 + +import enum +import time +from dataclasses import dataclass +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, + Type, Union) +from unittest.mock import patch + +import numpy as np +import torch +import torch.nn as nn +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr + +from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.config import VllmConfig +from vllm.forward_context import get_forward_context, set_forward_context +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader import get_model +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, + Logprob, SequenceGroupMetadata, SequenceOutput) +from vllm.worker.model_runner_base import ( + ModelRunnerBase, ModelRunnerInputBase, + _add_attn_metadata_broadcastable_dict, + _init_attn_metadata_from_tensor_dict) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +logger = init_logger(__name__) + +# Here we utilize the behavior that out-of-bound index is ignored. +# FIXME(woosuk): Find a more reliable way to prevent possible bugs. +_PAD_SLOT_ID = 1_000_000_000 +# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow. +_ENABLE_TOP_P = False +# FIXME(woosuk): A temporary hack to support `n > 1`. +# This can significantly affect the performance if too large. +_MAX_NUM_SAMPLES = 128 + + +class ExecutionMode(enum.Enum): + PREFILL = enum.auto() + DECODE = enum.auto() + PREFIX_PREFILL = enum.auto() + + def is_prefill(self) -> bool: + return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL) + + +@dataclass(frozen=True) +class ModelInputForTPU(ModelRunnerInputBase): + token_ids: torch.Tensor + position_ids: torch.Tensor + attn_metadata: AttentionMetadata + input_lens: torch.Tensor + t: torch.Tensor + p: torch.Tensor + num_samples: int + n: List[int] + seq_groups: List[List[int]] + is_first_multi_step: bool = True + is_last_step: bool = True + virtual_engine: int = 0 + async_callback: Optional[Callable] = None + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + tensor_dict = { + "token_ids": self.token_ids, + "position_ids": self.position_ids, + "input_lens": self.input_lens, + "t": self.t, + "p": self.p, + "num_samples": self.num_samples, + "n": self.n, + "seq_groups": self.seq_groups, + "is_first_multi_step": self.is_first_multi_step, + "is_last_step": self.is_last_step, + "virtual_engine": self.virtual_engine, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls: Type["ModelInputForTPU"], + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "ModelInputForTPU": + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) + + +class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): + + def __init__( + self, + vllm_config: VllmConfig, + is_driver_worker: bool = False, + ): + ModelRunnerBase.__init__(self, vllm_config=vllm_config) + self.is_driver_worker = is_driver_worker + + self.block_size = self.cache_config.block_size + self.max_num_blocks_per_seq = (self.model_config.max_model_len // + self.block_size) + self.block_tables = np.zeros( + (self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq), + dtype=np.int32) + self.attn_backend = get_attn_backend( + self.model_config.get_head_size(), + self.model_config.dtype, + self.cache_config.cache_dtype, + self.block_size, + self.model_config.is_attention_free, + False, + ) + self.cached_step_outputs: List[torch.Tensor] = [] + + smem_size = 512 * 1024 + block_table_size = 4 * self.block_tables.size + if block_table_size >= smem_size: + logger.warning( + "The max_model_len (%d) is too large. This may degrade the " + "performance due to the insufficient smem size. Consider " + "setting --max-model-len to a smaller value, like %d.", + self.model_config.max_model_len, + self.model_config.max_model_len / + (block_table_size / smem_size)) + + def load_model(self) -> None: + self.device = self.device_config.device + + # NOTE(woosuk): While the executor assigns the TP ranks to the worker + # process, the ranks can be different from the ranks internally assigned + # by the xm runtime. Therefore, there is a mismatch in the rank + # assignment between the gloo (cpu) runtime and the xm (tpu) runtime. + # This is not a problem in linear layers because all-reduce is + # rank-agnostic. However, it matters for all-gather as the ranks + # determine the order of concatenating the output tensors. + # As a workaround, we use the xm's rank assignment only when loading + # the embedding weights. + xm_tp_rank = xr.global_ordinal() + with patch( + "vllm.model_executor.layers.vocab_parallel_embedding." + "get_tensor_model_parallel_rank", + return_value=xm_tp_rank): + model = get_model(vllm_config=self.vllm_config) + model = model.eval() + xm.wait_device_ops() + model = ModelWrapper(model) + self.model = torch.compile(model, + backend="openxla", + fullgraph=True, + dynamic=False) + + def get_model(self) -> nn.Module: + return self.model.model + + def _dummy_run( + self, + batch_size: int, + seq_len: int, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + exec_mode: ExecutionMode, + ) -> None: + exec_mode = ExecutionMode(exec_mode) + if exec_mode.is_prefill(): + seq_len = (seq_len + 15) // 16 * 16 + token_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + position_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + slot_mapping = torch.zeros((batch_size, seq_len), + dtype=torch.int64, + device=self.device) + input_lens = torch.ones((batch_size, ), + dtype=torch.int32, + device=self.device) + if exec_mode == ExecutionMode.PREFILL: + attn_metadata = self.attn_backend.make_metadata( + num_prefills=batch_size, + num_prefill_tokens=batch_size * seq_len, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + block_tables=None, + context_lens=None, + effective_query_lens=None, + ) + else: + context_lens = torch.ones((batch_size, ), + dtype=torch.int32, + device=self.device) + block_tables = torch.tensor(self.block_tables[:batch_size], + dtype=torch.int32, + device=self.device) + effective_query_lens = torch.ones_like(context_lens) + attn_metadata = self.attn_backend.make_metadata( + num_prefills=batch_size, + num_prefill_tokens=batch_size * seq_len, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + block_tables=block_tables, + context_lens=context_lens, + effective_query_lens=effective_query_lens, + ) + else: + assert seq_len == 1 + token_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + position_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + slot_mapping = torch.zeros((batch_size, seq_len), + dtype=torch.int64, + device=self.device) + block_tables = torch.zeros( + (batch_size, self.max_num_blocks_per_seq), + dtype=torch.int32, + device=self.device) + context_lens = torch.ones((batch_size, ), + dtype=torch.int32, + device=self.device) + input_lens = torch.ones((batch_size, ), + dtype=torch.int32, + device=self.device) + attn_metadata = self.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size * seq_len, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + block_tables=block_tables, + context_lens=context_lens, + ) + t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) + p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) + num_samples = _MAX_NUM_SAMPLES if exec_mode.is_prefill() else 1 + + # NOTE(woosuk): There are two stages of compilation: torch.compile and + # XLA compilation. Using `mark_dynamic` can reduce the torch.compile + # overhead by reusing the FX graph for different shapes. + # However, the XLA graph will still require static shapes and needs to + # be re-compiled for every different shapes. This overhead is inevitable + # in the first run, but can be skipped afterwards as we cache the XLA + # graphs in the disk (VLLM_XLA_CACHE_PATH). + if exec_mode.is_prefill(): + # Prefll + torch._dynamo.mark_dynamic(token_ids, 1) + torch._dynamo.mark_dynamic(position_ids, 1) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) + else: + # Decode + torch._dynamo.mark_dynamic(token_ids, 0) + torch._dynamo.mark_dynamic(position_ids, 0) + torch._dynamo.mark_dynamic(input_lens, 0) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) + torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) + torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) + torch._dynamo.mark_dynamic(t, 0) + torch._dynamo.mark_dynamic(p, 0) + # Dummy run. + with set_forward_context(attn_metadata, self.vllm_config, 0): + self.model(token_ids, position_ids, input_lens, t, p, num_samples, + kv_caches) + + def warmup_model( + self, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + ) -> None: + # Prefill + logger.info("Compiling the model with different input shapes...") + start = time.time() + for batch_size in [1]: + seq_len = 16 + while seq_len <= self.model_config.max_model_len: + self._dummy_run(batch_size, + seq_len, + kv_caches, + exec_mode=ExecutionMode.PREFILL) + xm.wait_device_ops() + logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) + num_tokens = batch_size * seq_len + if num_tokens >= self.scheduler_config.max_num_batched_tokens: + break + seq_len = seq_len * 2 + + end = time.time() + logger.info("Compilation for prefill done in %.2f s.", end - start) + + # Prefix prefill + if self.cache_config.enable_prefix_caching: + logger.info("Compiling the model with different input shapes for " + "prefix prefill...") + start = time.time() + for batch_size in [1]: + seq_len = 16 + while seq_len <= self.model_config.max_model_len: + self._dummy_run(batch_size, + seq_len, + kv_caches, + exec_mode=ExecutionMode.PREFIX_PREFILL) + xm.wait_device_ops() + logger.info("batch_size: %d, seq_len: %d", batch_size, + seq_len) + num_tokens = batch_size * seq_len + if (num_tokens + >= self.scheduler_config.max_num_batched_tokens): + break + seq_len = seq_len * 2 + end = time.time() + logger.info("Compilation for prefix prefill done in %.2f s.", + end - start) + + # Decode + start = time.time() + seq_len = 1 + batch_size = 8 # Must be in sync with _get_padded_batch_size() + while True: + self._dummy_run(batch_size, + seq_len, + kv_caches, + exec_mode=ExecutionMode.DECODE) + xm.wait_device_ops() + logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) + + if batch_size >= self.scheduler_config.max_num_seqs: + break + batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2 + + end = time.time() + logger.info("Compilation for decode done in %.2f s.", end - start) + + def _prepare_prompt( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[int] = [] + input_positions: List[int] = [] + prompt_lens: List[int] = [] + context_lens: List[int] = [] + slot_mapping: List[int] = [] + + for batch_idx, seq_group_metadata in enumerate( + seq_group_metadata_list): + assert seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + + seq_data = seq_group_metadata.seq_data[seq_id] + # Could include output tokens when a request is preempted. + prompt_tokens = seq_data.get_token_ids() + seq_len = len(prompt_tokens) + + num_computed_blocks = len(seq_group_metadata.computed_block_nums) + num_computed_tokens = num_computed_blocks * self.block_size + if num_computed_tokens > 0: + prompt_tokens = prompt_tokens[num_computed_tokens:] + context_lens.append(seq_len) + else: + context_lens.append(0) + + prompt_len = len(prompt_tokens) + prompt_lens.append(prompt_len) + + input_tokens.extend(prompt_tokens) + input_positions.extend(range(num_computed_tokens, seq_len)) + + assert seq_group_metadata.block_tables is not None + block_table = seq_group_metadata.block_tables[seq_id] + for i in range(num_computed_tokens, seq_len): + block_number = block_table[i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) + if num_computed_tokens > 0: + self.block_tables[batch_idx, :len(block_table)] = block_table + + # Add paddings to EACH prompt to the smallest power of 2 that is + # greater than or equal to the prompt length. + # We pad the seq_len to reduce the compilation overhead. + # We execute each prompt individually (i.e., with batch_size 1) + # because the FlashAttention kernel does not support ragged inputs. + # TODO(woosuk): Use SplashAttention to support ragged inputs. + padded_prompt_len = _get_padded_prefill_len(prompt_len) + num_paddings = padded_prompt_len - prompt_len + input_tokens += [0] * num_paddings + input_positions += [0] * num_paddings + slot_mapping += [_PAD_SLOT_ID] * num_paddings + + assert len(prompt_lens) > 0 + num_prefills = len(prompt_lens) + input_tokens = torch.tensor(input_tokens, + dtype=torch.int32, + device="cpu") + input_positions = torch.tensor(input_positions, + dtype=torch.int32, + device="cpu") + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.int64, + device="cpu") + prompt_lens = torch.tensor(prompt_lens, + dtype=torch.int32, + device="cpu") + context_lens = torch.tensor(context_lens, + dtype=torch.int32, + device="cpu") + block_tables = torch.tensor(self.block_tables[:num_prefills], + dtype=torch.int32, + device="cpu") + attn_metadata = self.attn_backend.make_metadata( + num_prefills=num_prefills, + num_prefill_tokens=0, # NOTE: This is not used. + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + block_tables=block_tables, + context_lens=context_lens, + effective_query_lens=prompt_lens, + ) + return input_tokens, input_positions, attn_metadata, prompt_lens + + def _prepare_decode( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[List[int]] = [] + input_positions: List[List[int]] = [] + slot_mapping: List[List[int]] = [] + context_lens: List[int] = [] + + batch_idx = 0 + for seq_group_metadata in seq_group_metadata_list: + assert not seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + generation_token = seq_data.get_last_token_id() + input_tokens.append([generation_token]) + + seq_len = seq_data.get_len() + position = seq_len - 1 + input_positions.append([position]) + context_lens.append(seq_len) + + assert seq_group_metadata.block_tables is not None + block_table = seq_group_metadata.block_tables[seq_id] + self.block_tables[batch_idx, :len(block_table)] = block_table + batch_idx += 1 + + block_number = block_table[position // self.block_size] + block_offset = position % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append([slot]) + + batch_size = _get_padded_batch_size(batch_idx) + num_paddings = batch_size - batch_idx + input_tokens = input_tokens + [[0]] * num_paddings + input_positions = input_positions + [[0]] * num_paddings + slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings + context_lens = context_lens + [0] * num_paddings + + input_tokens = torch.tensor(input_tokens, + dtype=torch.int32, + device="cpu") + input_positions = torch.tensor(input_positions, + dtype=torch.int32, + device="cpu") + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.int64, + device="cpu") + context_lens = torch.tensor(context_lens, + dtype=torch.int32, + device="cpu") + block_tables = torch.tensor(self.block_tables[:batch_size], + dtype=torch.int32, + device="cpu") + input_lens = torch.tensor([1] * batch_size, + dtype=torch.int32, + device="cpu") + attn_metadata = self.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + block_tables=block_tables, + context_lens=context_lens, + ) + return input_tokens, input_positions, attn_metadata, input_lens + + def _prepare_sample( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + padded_batch_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: + assert len(seq_group_metadata_list) > 0 + t = [] + p = [] + n = [] + for seq_group_metadata in seq_group_metadata_list: + sampling_params = seq_group_metadata.sampling_params + t.append(sampling_params.temperature) + if sampling_params.top_p != 1 and not _ENABLE_TOP_P: + raise NotImplementedError( + "Top-p sampling is currently disabled for the TPU backend " + "due to performance issues.") + p.append(sampling_params.top_p) + if sampling_params.top_k != -1: + raise NotImplementedError( + "Top-k sampling is currently disabled for the TPU backend " + "due to performance issues.") + if sampling_params.n > _MAX_NUM_SAMPLES: + raise NotImplementedError( + f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU " + "backend.") + n.append(sampling_params.n) + if sampling_params.logprobs is not None: + raise NotImplementedError( + "logprobs is not currently supported by the TPU backend.") + if sampling_params.prompt_logprobs is not None: + raise NotImplementedError( + "prompt_logprobs is not currently supported by the TPU " + "backend.") + + # Repeat the sampling params if the seq group has multiple seqs. + num_seqs = len(seq_group_metadata.seq_data) + t += [t[-1]] * (num_seqs - 1) + p += [p[-1]] * (num_seqs - 1) + n += [n[-1]] * (num_seqs - 1) + + num_paddings = padded_batch_size - len(t) + t += [1.0] * num_paddings + p += [1.0] * num_paddings + + t = torch.tensor(t, dtype=torch.float32, device="cpu") + p = torch.tensor(p, dtype=torch.float32, device="cpu") + return t, p, n + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None, + ) -> ModelInputForTPU: + del finished_requests_ids # Unused. + assert virtual_engine == 0 + assert len(seq_group_metadata_list) > 0 + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + is_prompt = seq_group_metadata_list[0].is_prompt + if is_prompt: + inputs = self._prepare_prompt(seq_group_metadata_list) + else: + inputs = self._prepare_decode(seq_group_metadata_list) + input_tokens, input_positions, attn_metadata, input_lens = inputs + padded_batch_size = input_tokens.shape[0] + t, p, n = self._prepare_sample(seq_group_metadata_list, + padded_batch_size) + num_samples = _MAX_NUM_SAMPLES if is_prompt else 1 + + seq_groups = [ + list(metadata.seq_data.keys()) + for metadata in seq_group_metadata_list + ] + return ModelInputForTPU(input_tokens, input_positions, attn_metadata, + input_lens, t, p, num_samples, n, seq_groups) + + def make_model_input_from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU: + model_input = ModelInputForTPU.from_broadcasted_tensor_dict( + tensor_dict, attn_backend=self.attn_backend) + return model_input + + @torch.no_grad() + def execute_model( + self, + model_input: ModelInputForTPU, + kv_caches: Optional[List[Any]], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> List[SamplerOutput]: + assert intermediate_tensors is None + if not model_input.is_first_multi_step: + if not model_input.is_last_step: + return [] + + use_async_out_proc = model_input.async_callback is not None + sampler_outputs = [] + num_outputs = len(self.cached_step_outputs) + for i in range(num_outputs): + next_token_ids = self.cached_step_outputs.pop(0) + next_token_ids = next_token_ids.cpu().tolist() + sampler_output = _make_decode_output(next_token_ids, + model_input.seq_groups) + sampler_outputs.append(sampler_output) + + if i < num_outputs - 1 and use_async_out_proc: + assert model_input.async_callback is not None + ctx = model_input.async_callback.keywords[ # type: ignore + "ctx"] + ctx.append_output( + outputs=[sampler_output], + seq_group_metadata_list=ctx.seq_group_metadata_list, + scheduler_outputs=ctx.scheduler_outputs, + is_async=False, + is_last_step=False, + is_first_step_output=i == 0) + model_input.async_callback() + if use_async_out_proc: + return [sampler_outputs[-1]] + else: + return sampler_outputs + + is_prompt = model_input.attn_metadata.num_prefills > 0 + if is_prompt: + assert num_steps == 1 + # NOTE(woosuk): Since the FlashAttention kernel does not support + # ragged inputs, we split the prompts into different batches and + # process them separately. This is a temporary hack that should be + # optimized by using SplashAttention. + orig_slot_mapping = model_input.attn_metadata.slot_mapping + orig_block_tables = model_input.attn_metadata.block_tables + orig_context_lens = model_input.attn_metadata.context_lens + orig_effective_query_lens = \ + model_input.attn_metadata.effective_query_lens + batch_size = model_input.input_lens.shape[0] + start_idx = 0 + next_token_ids = [] + for i in range(batch_size): + # Get the actual prefill_len. + prefill_len = model_input.input_lens[i:i + 1].item() + prefill_len = _get_padded_prefill_len(prefill_len) + end_idx = start_idx + prefill_len + + token_ids = model_input.token_ids[None, start_idx:end_idx].to( + self.device) + position_ids = model_input.position_ids[None, + start_idx:end_idx].to( + self.device) + attn_metadata = model_input.attn_metadata + attn_metadata.num_prefills = 1 + attn_metadata.slot_mapping = orig_slot_mapping[ + None, start_idx:end_idx].to(self.device) + if orig_context_lens[i].item() > 0: + attn_metadata.context_lens = orig_context_lens[i:i + 1].to( + self.device) + attn_metadata.block_tables = orig_block_tables[ + i].unsqueeze(0).to(self.device) + attn_metadata.effective_query_lens = \ + orig_effective_query_lens[i:i + 1].to(self.device) + else: + attn_metadata.context_lens = None + attn_metadata.block_tables = None + attn_metadata.effective_query_lens = None + input_lens = model_input.input_lens[i:i + 1].to(self.device) + t = model_input.t[i:i + 1].to(self.device) + p = model_input.p[i:i + 1].to(self.device) + with set_forward_context(model_input.attn_metadata, + self.vllm_config, + model_input.virtual_engine): + output_token_ids = self.model(token_ids, position_ids, + input_lens, t, p, + model_input.num_samples, + kv_caches) + next_token_ids.append(output_token_ids[0]) + start_idx = end_idx + + if model_input.async_callback is not None: + model_input.async_callback() + # Retrieve the outputs to CPU. + next_token_ids = [ + output_token_ids.cpu().tolist() + for output_token_ids in next_token_ids + ] + + # NOTE(woosuk): Minimal code to construct the sampler outputs. + # The TPU backend does not reuse the sampler, since the TPU backend + # does not support advanced sampling parameters such as logprobs. + zero_logprob = Logprob(0.0) + sampler_outputs = [] + for i, seq_group in enumerate(model_input.seq_groups): + seq_ids = seq_group + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + seq_outputs = [] + for j in range(model_input.n[i]): + next_token_id = next_token_ids[i][j] + seq_outputs.append( + SequenceOutput(seq_id, next_token_id, + {next_token_id: zero_logprob})) + sampler_outputs.append( + CompletionSequenceGroupOutput(seq_outputs, None)) + return [SamplerOutput(sampler_outputs)] + else: + token_ids = model_input.token_ids.to(self.device) + position_ids = model_input.position_ids.to(self.device) + attn_metadata = model_input.attn_metadata + attn_metadata.slot_mapping = attn_metadata.slot_mapping.to( + self.device) + attn_metadata.block_tables = attn_metadata.block_tables.to( + self.device) + attn_metadata.context_lens = attn_metadata.context_lens.to( + self.device) + t = model_input.t.to(self.device) + p = model_input.p.to(self.device) + input_lens = model_input.input_lens.to(self.device) + for i in range(num_steps): + slot_mapping = attn_metadata.slot_mapping + with set_forward_context(model_input.attn_metadata, + self.vllm_config, + model_input.virtual_engine): + output_token_ids = self.model(token_ids, position_ids, + input_lens, t, p, + model_input.num_samples, + kv_caches) + self.cached_step_outputs.append(output_token_ids) + + if i < num_steps - 1: + # Prepare the inputs for the next step. + token_ids = output_token_ids.unsqueeze(dim=1).int() + position_ids = position_ids + 1 + attn_metadata.context_lens = attn_metadata.context_lens + 1 + + block_tables = attn_metadata.block_tables + block_number = block_tables.gather( + 1, + position_ids.long() // self.block_size) + block_offset = position_ids % self.block_size + + is_padding = slot_mapping == _PAD_SLOT_ID + slot_mapping = block_number * self.block_size + block_offset + slot_mapping = slot_mapping.long() + slot_mapping = torch.where(is_padding, _PAD_SLOT_ID, + slot_mapping) + attn_metadata.slot_mapping = slot_mapping + + if model_input.async_callback is not None: + model_input.async_callback() + + if num_steps > 1: + return [] + # Retrieve the outputs to CPU. + next_token_ids = self.cached_step_outputs.pop(0) + next_token_ids = next_token_ids.cpu().tolist() + sampler_output = _make_decode_output(next_token_ids, + model_input.seq_groups) + return [sampler_output] + + +class ModelWrapper(nn.Module): + + def __init__(self, model: nn.Module): + super().__init__() + self.model = model + + def forward( + self, + token_ids: torch.Tensor, + position_ids: torch.Tensor, + input_lens: torch.Tensor, + t: torch.Tensor, + p: torch.Tensor, + num_samples: int, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + ) -> torch.Tensor: + """Executes the forward pass of the model and samples the next token. + + Args: + token_ids: The input token IDs of shape [batch_size, seq_len]. + position_ids: The input position IDs of shape [batch_size, seq_len]. + input_lens: The actual input lengths of shape [batch_size]. + t: The sampling temperature of shape [batch_size]. + p: The top-p probability of shape [batch_size]. + num_samples: Number of samples to draw from each logits vector. + kv_caches: The key and value caches. They can be None during the + memory profiling at initialization. + """ + batch_size, seq_len = token_ids.shape + # Calculate the positions to sample from. + start_indicies = torch.arange( + batch_size, dtype=torch.int32, device=input_lens.device) * seq_len + logits_indices = start_indicies + input_lens - 1 + attn_metadata = get_forward_context().attn_metadata + + # FIXME(woosuk): This is a temporary hack to avoid using the existing + # sampler and sampling metadata. + sampling_metadata = SamplingMetadata( + seq_groups=[], + selected_token_indices=logits_indices, + categorized_sample_indices={}, + num_prompts=attn_metadata.num_prefills, + ) + + # Skip this in memory profiling at initialization. + if kv_caches[0][0].numel() > 0: + # index_copy_(slot_mapping) only works when the inserted dimension + # is 0. However, the KV cache in the Pallas backend has the shape + # [num_kv_heads, num_blocks, block_size, head_size]. To make it + # work, we need to flatten the first three dimensions and modify + # the slot_mapping accordingly. + num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape + slot_mapping = attn_metadata.slot_mapping + slot_mapping = slot_mapping.flatten() + head_indicies = torch.arange(0, + num_kv_heads, + device=slot_mapping.device, + dtype=slot_mapping.dtype) + head_indicies *= block_size * num_blocks + slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view( + -1, num_kv_heads) + slot_mapping = slot_mapping + head_indicies.view(1, -1) + slot_mapping = slot_mapping.flatten() + attn_metadata.slot_mapping = slot_mapping + + hidden_states = self.model(token_ids, position_ids) + hidden_states = hidden_states.flatten(0, 1) + logits = self.model.compute_logits(hidden_states, sampling_metadata) + + # Argmax sampling. + argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) + argmax_token_ids = argmax_token_ids.repeat(1, num_samples) + + # Zero temperature means greedy decoding. Avoid division by zero. + nonzero_t = torch.where(t != 0, t, 1.0) + logits = logits / nonzero_t.unsqueeze(dim=1) + if _ENABLE_TOP_P: + logits = _apply_top_p(logits, p.unsqueeze(dim=1)) + + # Random sampling. + probs = torch.softmax(logits, dim=-1, dtype=torch.float32) + sampled_token_ids = torch.multinomial(probs, + num_samples, + replacement=True) + if num_samples == 1: + argmax_token_ids = argmax_token_ids.squeeze(dim=-1) + sampled_token_ids = sampled_token_ids.squeeze(dim=-1) + next_token_ids = torch.where(t != 0, sampled_token_ids, + argmax_token_ids) + return next_token_ids + + +def _get_padded_prefill_len(x: int) -> int: + # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence + # length to be a multiple of 16. We pad the prompt length to the nearest + # multiple of 16. This is also good for performance. + if x <= 16: + return 16 + return 1 << (x - 1).bit_length() + + +def _get_padded_batch_size(batch_size: int) -> int: + # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16. + # To meet this requirement in the simplest way, we set the minimal batch + # size to 8. + if batch_size <= 8: + return 8 + else: + return ((batch_size + 15) // 16) * 16 + + +def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + logits_sorted = torch.sort(logits, dim=-1, descending=True).values + sorted_cum_probs = torch.cumsum(logits_sorted.softmax(dim=-1), dim=-1) + cutoff_index = torch.sum(sorted_cum_probs < p, dim=-1, keepdim=True) + cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index) + logits = logits.masked_fill_(logits < cutoff_logit, -float("inf")) + return logits + + +def _make_decode_output( + next_token_ids: List[int], + seq_groups: List[List[int]], +) -> SamplerOutput: + zero_logprob = Logprob(0.0) + sampler_outputs = [] + batch_idx = 0 + for seq_group in seq_groups: + seq_ids = seq_group + seq_outputs = [] + for seq_id in seq_ids: + next_token_id = next_token_ids[batch_idx] + seq_outputs.append( + SequenceOutput(seq_id, next_token_id, + {next_token_id: zero_logprob})) + batch_idx += 1 + sampler_outputs.append(CompletionSequenceGroupOutput( + seq_outputs, None)) + return SamplerOutput(sampler_outputs) diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py new file mode 100644 index 0000000..71b4b38 --- /dev/null +++ b/vllm/worker/tpu_worker.py @@ -0,0 +1,332 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +from typing import List, Optional, Tuple, Union + +import torch +import torch_xla.core.xla_model as xm +import torch_xla.debug.profiler as xp +import torch_xla.runtime as xr + +import vllm.envs as envs +from vllm.config import VllmConfig +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.logger import init_logger +from vllm.model_executor import set_random_seed +from vllm.sequence import ExecuteModelRequest +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache, get_dtype_size +from vllm.worker.tpu_model_runner import ExecutionMode, TPUModelRunner +from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, + LoRANotSupportedWorkerBase, WorkerBase, + WorkerInput) + +logger = init_logger(__name__) + + +class TPUWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase): + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool, + ) -> None: + WorkerBase.__init__(self, vllm_config=vllm_config) + self.parallel_config.rank = rank + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.is_driver_worker = is_driver_worker + + assert self.device_config.device_type == "tpu" + if self.cache_config.cache_dtype == "auto": + self.cache_dtype = self.model_config.dtype + else: + self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ + self.cache_config.cache_dtype] + + self.model_runner: TPUModelRunner = TPUModelRunner( + vllm_config=vllm_config, is_driver_worker=is_driver_worker) + + if self.model_config.seed is None: + self.model_config.seed = 0 + + def init_device(self) -> None: + os.environ["PJRT_DEVICE"] = "TPU" + torch.set_grad_enabled(False) + torch.set_default_dtype(self.model_config.dtype) + + # NOTE(woosuk): This is just to initialize the TP group and broadcast + # the input objects on CPU. The all-reduce and all-gather ops on TPU + # are invoked by `xm.all_reduce` and `xm.all_gather` which use their + # own context. + init_distributed_environment( + world_size=self.parallel_config.world_size, + rank=self.rank, + local_rank=self.local_rank, + distributed_init_method=self.distributed_init_method, + backend="gloo", + ) + ensure_model_parallel_initialized( + self.parallel_config.tensor_parallel_size, + self.parallel_config.pipeline_parallel_size) + + # Device initialization should happen after initializing the distributed + # runtime. + self.device = xm.xla_device() + self.device_config.device = self.device + + # Set random seed. + set_random_seed(self.model_config.seed) + xm.set_rng_state(self.model_config.seed, self.device) + + # Increase the cache size limit, which is the maximum number of + # dynamo graphs that can be compiled. + # NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and + # 30-40 graphs for decode. 128 is an arbitrary safe number. + torch._dynamo.config.cache_size_limit = 128 + # Use persistent cache to avoid XLA recompilation. + # NOTE(woosuk): Set per-rank cache path since different ranks + # can have slightly different XLA graphs. + world_size = self.parallel_config.world_size + rank = xr.global_ordinal() + # The PyTorch/XLA compilation cache uses the Torch IR to generate keys. + # Consequently, changes in optimization flags, which affect compilation + # results, don't change the cache key. This can result in the wrong + # compilation being used. To prevent this, disabling the XLA compilation + # cache during development is recommended.We can disable it by + # `export VLLM_XLA_CACHE_PATH=` + if envs.VLLM_XLA_CACHE_PATH: + per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, + f"tp{world_size}_rank{rank}") + xr.initialize_cache(per_rank_path, readonly=False) + + self.profiler = None + if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1: + # For TPU, we can only have 1 active profiler session for 1 profiler + # server. So we only profile on rank0. + self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR + logger.info("Profiling enabled. Traces will be saved to: %s", + self.profile_dir) + self.profiler = xp.start_server(9012) + + def start_profile(self): + if self.rank < 1: + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + xp.start_trace(self.profile_dir) + + def stop_profile(self): + if self.rank < 1: + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + xp.stop_trace() + + def load_model(self): + self.model_runner.load_model() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + num_layers = self.model_config.get_num_layers(self.parallel_config) + head_size = self.model_config.get_head_size() + num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) + + # use an empty tensor instead of `None`` to force Dynamo to pass + # it by reference, rather by specializing on the value ``None``. + # the `dtype` argument does not matter, and we use `float32` as + # a placeholder (it has wide hardware support). + kv_caches = [(torch.tensor([], dtype=torch.float32, + device=self.device), + torch.tensor([], dtype=torch.float32, + device=self.device)) + for _ in range(num_layers)] + bind_kv_cache(self.compilation_config.static_forward_context, + [kv_caches]) + self.model_runner._dummy_run( + batch_size=1, + seq_len=self.scheduler_config.max_num_batched_tokens, + kv_caches=kv_caches, + exec_mode=ExecutionMode.PREFILL, + ) + # Synchronize before measuring the memory usage. + xm.wait_device_ops() + + # Get the maximum amount of memory used by the model weights and + # intermediate activations. + m = xm.get_memory_info(self.device) + total_memory_size = m["bytes_limit"] + profiled = m["peak_bytes_used"] # Weights + intermediate activations. + + # Calculate the TPU KV cache size based on profiling. + usable_memory_size = int(total_memory_size * + self.cache_config.gpu_memory_utilization) + tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0) + dtype_btyes = get_dtype_size(self.cache_dtype) + block_size_bytes = (dtype_btyes * self.cache_config.block_size * + num_layers * 2 * head_size * num_kv_heads) + num_tpu_blocks = tpu_kv_cache_bytes // block_size_bytes + num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8. + + # Calculate the CPU KV cache size based on the config. + num_cpu_blocks = int(self.cache_config.swap_space_bytes // + block_size_bytes) + num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8. + return num_tpu_blocks, num_cpu_blocks + + def initialize_cache( + self, + num_gpu_blocks: int, + num_cpu_blocks: int, + ) -> None: + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + self.block_size = self.cache_config.block_size + + dtype = self.cache_dtype + num_layers = self.model_config.get_num_layers(self.parallel_config) + num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) + head_size = self.model_config.get_head_size() + + self.cpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = [] + self.tpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = [] + tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape( + num_gpu_blocks, self.block_size, num_kv_heads, head_size) + cpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape( + num_cpu_blocks, self.block_size, num_kv_heads, head_size) + for _ in range(num_layers): + tpu_k_cache = torch.zeros(tpu_cache_shape, + dtype=dtype, + device=self.device) + tpu_v_cache = torch.zeros_like(tpu_k_cache) + self.tpu_cache.append((tpu_k_cache, tpu_v_cache)) + cpu_k_cache = torch.zeros(cpu_cache_shape, + dtype=dtype, + device="cpu") + cpu_v_cache = torch.zeros_like(cpu_k_cache) + self.cpu_cache.append((cpu_k_cache, cpu_v_cache)) + bind_kv_cache(self.compilation_config.static_forward_context, + [self.tpu_cache]) + self._warmup_model() + + def _warmup_model(self) -> None: + # FIXME(woosuk): Here we are abusing `enforce_eager` which is defined + # for CUDA graphs. We should refactor this part. + if not self.model_config.enforce_eager: + # Warm up the model with all possible input shapes so that + # compilation never happens during the actual execution. + # This may take ~30 mins for the first run and ~20 mins for the + # subsequent runs. + # If `enforce_eager` is True, the ahead-of-time compilation is + # skipped and the compilation happens during the actual execution, + # which is bad for performance but useful for development. + self.model_runner.warmup_model(self.tpu_cache) + + def get_cache_block_size_bytes(self) -> int: + head_size = self.model_config.get_head_size() + num_heads = self.model_config.get_num_kv_heads(self.parallel_config) + num_layers = self.model_config.get_num_layers(self.parallel_config) + + key_cache_block = self.cache_config.block_size * num_heads * head_size + value_cache_block = key_cache_block + total = num_layers * (key_cache_block + value_cache_block) + dtype_size = get_dtype_size(self.cache_dtype) + return dtype_size * total + + @property + def do_metadata_broadcast(self) -> bool: + return self.parallel_config.tensor_parallel_size > 1 + + @property + def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: + # NOTE(woosuk): This assumes virtual_engine == 0, i.e., no pipeline + # parallelism. + return [self.tpu_cache] + + def prepare_worker_input( + self, + execute_model_req: ExecuteModelRequest, + ) -> WorkerInput: + virtual_engine = execute_model_req.virtual_engine + num_seq_groups = len(execute_model_req.seq_group_metadata_list) + blocks_to_swap_in = _make_src_to_dst( + execute_model_req.blocks_to_swap_in, "cpu", self.device) + blocks_to_swap_out = _make_src_to_dst( + execute_model_req.blocks_to_swap_out, self.device, "cpu") + blocks_to_copy = _make_src_to_dst(execute_model_req.blocks_to_copy, + self.device, self.device) + return WorkerInput( + num_seq_groups=num_seq_groups, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + virtual_engine=virtual_engine, + ) + + def execute_worker(self, worker_input: WorkerInput) -> None: + virtual_engine = worker_input.virtual_engine + assert virtual_engine == 0 + attn_backend = self.model_runner.attn_backend + num_layers = self.model_config.get_num_layers(self.parallel_config) + + # Issue cache operations. + if worker_input.blocks_to_swap_in is not None: + src_indices, dst_indices = worker_input.blocks_to_swap_in + if src_indices.numel() > 0: + # Swap from CPU to TPU. + for i in range(num_layers): + tpu_k_cache, tpu_v_cache = self.tpu_cache[i] + cpu_k_cache, cpu_v_cache = self.cpu_cache[i] + k = cpu_k_cache[:, src_indices].to(self.device) + v = cpu_v_cache[:, src_indices].to(self.device) + _insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache) + + if worker_input.blocks_to_swap_out is not None: + src_indices, dst_indices = worker_input.blocks_to_swap_out + if src_indices.numel() > 0: + # Swap from TPU to CPU. + for i in range(num_layers): + tpu_k_cache, tpu_v_cache = self.tpu_cache[i] + cpu_k_cache, cpu_v_cache = self.cpu_cache[i] + cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices] + cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices] + + if worker_input.blocks_to_copy is not None: + src_indices, dst_indices = worker_input.blocks_to_copy + if src_indices.numel() > 0: + attn_backend.copy_blocks(self.tpu_cache, + (src_indices, dst_indices)) + + +def _make_src_to_dst( + mapping: List[Tuple[int, int]], + src_device: Union[torch.device, str], + dst_device: Union[torch.device, str], +) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: + if not mapping: + return None + + src_indices = [i for i, _ in mapping] + dst_indices = [i for _, i in mapping] + src_indices = torch.tensor(src_indices, + device=src_device, + dtype=torch.int64) + dst_indices = torch.tensor(dst_indices, + device=dst_device, + dtype=torch.int64) + return src_indices, dst_indices + + +@torch.compile(backend="openxla") +def _insert_kv( + k: torch.Tensor, + v: torch.Tensor, + indices: torch.Tensor, + tpu_k_cache: torch.Tensor, + tpu_v_cache: torch.Tensor, +) -> None: + torch.ops.xla.dynamo_set_buffer_donor_(tpu_k_cache, True) + torch.ops.xla.dynamo_set_buffer_donor_(tpu_v_cache, True) + tpu_k_cache[:, indices] = k + tpu_v_cache[:, indices] = v diff --git a/vllm/worker/utils.py b/vllm/worker/utils.py new file mode 100644 index 0000000..d925f08 --- /dev/null +++ b/vllm/worker/utils.py @@ -0,0 +1,52 @@ +# SPDX-License-Identifier: Apache-2.0 +''' +Worker-related helper functions. +''' + +from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS +from vllm.worker.model_runner import GPUModelRunnerBase + + +def assert_enc_dec_mr_supported_scenario( + enc_dec_mr: GPUModelRunnerBase) -> None: + ''' + Asserted that the provided encoder/decoder model runner instance reflects + a supported scenario. + ''' + + # Reminder: Please update docs/source/features/compatibility_matrix.md + # If the feature combo become valid + + if enc_dec_mr.cache_config.enable_prefix_caching: + raise NotImplementedError( + STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE']) + + if enc_dec_mr.sliding_window is not None: + raise NotImplementedError( + STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SWA']) + + if enc_dec_mr.scheduler_config.chunked_prefill_enabled: + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ERR_STRS[ + 'STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL']) + + if getattr(enc_dec_mr.model_config.hf_config, 'attn_logit_softcapping', + None) is not None: + raise NotImplementedError( + STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP'] + ) + + if enc_dec_mr.lora_config is not None: + raise NotImplementedError( + STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_LORA']) + + if enc_dec_mr.parallel_config.pipeline_parallel_size > 1: + raise NotImplementedError( + STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP']) + + if enc_dec_mr.scheduler_config.num_lookahead_slots > 0: + raise NotImplementedError( + STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SPEC_DEC']) + + if enc_dec_mr.prompt_adapter_config is not None: + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ERR_STRS[ + 'STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER']) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py new file mode 100644 index 0000000..72721aa --- /dev/null +++ b/vllm/worker/worker.py @@ -0,0 +1,562 @@ +# SPDX-License-Identifier: Apache-2.0 +"""A GPU worker class.""" +import gc +import os +from typing import Dict, List, Optional, Set, Tuple, Type, Union + +import torch +import torch.distributed + +import vllm.envs as envs +from vllm.config import VllmConfig +from vllm.device_allocator.cumem import CuMemAllocator +from vllm.distributed import (ensure_kv_transfer_initialized, + ensure_model_parallel_initialized, + init_distributed_environment, + set_custom_all_reduce) +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor import set_random_seed +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader.tensorizer import TensorizerConfig +from vllm.platforms import current_platform +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, + SequenceGroupMetadata, SequenceGroupMetadataDelta) +from vllm.utils import (GiB_bytes, MemorySnapshot, bind_kv_cache,bind_kv_cache_scale, + memory_profiling) +from vllm.worker.cache_engine import CacheEngine +from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner +from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner +from vllm.worker.pooling_model_runner import PoolingModelRunner +from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, + WorkerInput) + +logger = init_logger(__name__) + + +class Worker(LocalOrDistributedWorkerBase): + """A worker class that executes (a partition of) the model on a GPU. + + Each worker is associated with a single GPU. The worker is responsible for + maintaining the KV cache and executing the model on the GPU. In case of + distributed inference, each worker is assigned a partition of the model. + """ + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False, + model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None, + ) -> None: + WorkerBase.__init__(self, vllm_config) + self.parallel_config.rank = rank + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.is_driver_worker = is_driver_worker + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + # Return hidden states from target model if the draft model is an + # mlp_speculator + speculative_config = self.speculative_config + model_config = self.model_config + speculative_args = {} if speculative_config is None \ + or (speculative_config.draft_model_config.hf_config.model_type == + model_config.hf_config.model_type) \ + or (speculative_config.draft_model_config.hf_config.model_type + not in ("medusa", "mlp_speculator", "eagle", "deepseek_mtp")) \ + else {"return_hidden_states": True} + + ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner + if model_config.runner_type == "pooling": + ModelRunnerClass = PoolingModelRunner + elif self.model_config.is_encoder_decoder: + ModelRunnerClass = EncoderDecoderModelRunner + self.model_runner: GPUModelRunnerBase = ModelRunnerClass( + vllm_config=self.vllm_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=is_driver_worker, + **speculative_args, + ) + if model_runner_cls is not None: + self.model_runner = model_runner_cls(self.model_runner) + + # Uninitialized cache engine. Will be initialized by + # initialize_cache. + self.cache_engine: List[CacheEngine] + # Initialize gpu_cache as pooling models don't initialize kv_caches + self.gpu_cache: Optional[List[List[torch.Tensor]]] = None + self.gpu_cache_scale: Optional[List[List[torch.Tensor]]] = None + self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {} + + # Torch profiler. Enabled and configured through env vars: + # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace + if envs.VLLM_TORCH_PROFILER_DIR: + torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + logger.info("Profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir) + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + with_stack=True, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + torch_profiler_trace_dir, use_gzip=True)) + else: + self.profiler = None + + def start_profile(self): + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + self.profiler.start() + + def stop_profile(self): + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + self.profiler.stop() + + def sleep(self, level: int = 1) -> None: + free_bytes_before_sleep = torch.cuda.mem_get_info()[0] + allocator = CuMemAllocator.get_instance() + allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple()) + free_bytes_after_sleep, total = torch.cuda.mem_get_info() + freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep + used_bytes = total - free_bytes_after_sleep + assert freed_bytes >= 0, "Memory usage increased after sleeping." + logger.info( + "Sleep mode freed %.2f GiB memory, " + "%.2f GiB memory is still in use.", freed_bytes / GiB_bytes, + used_bytes / GiB_bytes) + + def wake_up(self, tags: Optional[list[str]] = None) -> None: + allocator = CuMemAllocator.get_instance() + allocator.wake_up(tags=tags) + + def init_device(self) -> None: + if self.device_config.device.type == "cuda": + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # This env var set by Ray causes exceptions with graph building. + os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) + self.device = torch.device(f"cuda:{self.local_rank}") + torch.cuda.set_device(self.device) + + _check_if_gpu_supports_dtype(self.model_config.dtype) + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + self.baseline_snapshot = MemorySnapshot() + else: + raise RuntimeError( + f"Not support device type: {self.device_config.device}") + # Initialize the distributed environment. + init_worker_distributed_environment(self.vllm_config, self.rank, + self.distributed_init_method, + self.local_rank) + # Set random seed. + set_random_seed(self.model_config.seed) + + def load_model(self): + if self.vllm_config.model_config.enable_sleep_mode: + allocator = CuMemAllocator.get_instance() + assert allocator.get_current_usage() == 0, ( + "Sleep mode can only be " + "used for one instance per process.") + context = allocator.use_memory_pool(tag="weights") + else: + from contextlib import nullcontext + context = nullcontext() + with context: + self.model_runner.load_model() + + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + self.model_runner.save_sharded_state( + path, + pattern=pattern, + max_size=max_size, + ) + + def save_tensorized_model( + self, + tensorizer_config: TensorizerConfig, + ) -> None: + self.model_runner.save_tensorized_model( + tensorizer_config=tensorizer_config, ) + + @torch.inference_mode() + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Profiles the peak memory usage of the model to determine how many + KV blocks may be allocated without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + # Profile the memory usage of the model and get the maximum number of + # cache blocks that can be allocated with the remaining free memory. + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info() + + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + with memory_profiling( + self.baseline_snapshot, + weights_memory=self.model_runner.model_memory_usage) as result: + self.model_runner.profile_run() + + self._assert_memory_footprint_increased_during_profiling() + + memory_for_current_instance = total_gpu_memory * \ + self.cache_config.gpu_memory_utilization + available_kv_cache_memory = (memory_for_current_instance - + result.non_kv_cache_memory) + + # Calculate the number of blocks that can be allocated with the + # profiled peak memory. + cache_block_size = self.get_cache_block_size_bytes() + if cache_block_size == 0: + num_gpu_blocks = 0 + num_cpu_blocks = 0 + else: + num_gpu_blocks = int(available_kv_cache_memory // cache_block_size) + num_cpu_blocks = int(self.cache_config.swap_space_bytes // + cache_block_size) + num_gpu_blocks = max(num_gpu_blocks, 0) + num_cpu_blocks = max(num_cpu_blocks, 0) + + msg = (f"Memory profiling takes {result.profile_time:.2f} seconds\n" + "the current vLLM instance can use " + "total_gpu_memory " + f"({(total_gpu_memory / GiB_bytes):.2f}GiB)" + " x gpu_memory_utilization " + f"({self.cache_config.gpu_memory_utilization:.2f})" + f" = {(memory_for_current_instance / GiB_bytes):.2f}GiB\n" + "model weights take " + f"{(result.weights_memory / GiB_bytes):.2f}GiB;" + " non_torch_memory takes " + f"{(result.non_torch_increase / GiB_bytes):.2f}GiB;" + " PyTorch activation peak memory takes " + f"{(result.torch_peak_increase / GiB_bytes):.2f}GiB;" + " the rest of the memory reserved for KV Cache is " + f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB.") + + logger.info(msg) + # Final cleanup + gc.collect() + + return num_gpu_blocks, num_cpu_blocks + + def _assert_memory_footprint_increased_during_profiling(self): + # NOTE(woosuk): Here we assume that the other processes using the same + # GPU did not change their memory usage during the profiling. + free_gpu_memory, total = torch.cuda.mem_get_info() + cuda_memory = total - free_gpu_memory + assert self.baseline_snapshot.cuda_memory < cuda_memory, ( + "Error in memory profiling. " + f"Initial used memory {self.baseline_snapshot.cuda_memory}, " + f"currently used memory {cuda_memory}. " + f"This happens when the GPU memory was " + "not properly cleaned up before initializing the vLLM instance.") + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Allocate GPU and CPU KV cache with the specified number of blocks. + + This also warms up the model, which may record CUDA graphs. + """ + raise_if_cache_size_invalid( + num_gpu_blocks, self.cache_config.block_size, + self.cache_config.is_attention_free, + self.model_config.max_model_len, + self.parallel_config.pipeline_parallel_size) + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + if self.vllm_config.model_config.enable_sleep_mode: + allocator = CuMemAllocator.get_instance() + context = allocator.use_memory_pool(tag="kv_cache") + else: + from contextlib import nullcontext + context = nullcontext() + with context: + self._init_cache_engine() + self._warm_up_model() + + def _init_cache_engine(self): + assert self.cache_config.num_gpu_blocks is not None + self.cache_engine = [ + CacheEngine(self.cache_config, self.model_config, + self.parallel_config, self.device_config) + for _ in range(self.parallel_config.num_virtual_engine) + ] + self.gpu_cache = [ + self.cache_engine[ve].gpu_cache + for ve in range(self.parallel_config.num_virtual_engine) + ] + bind_kv_cache(self.compilation_config.static_forward_context, + self.gpu_cache) + if envs.VLLM_USE_INT8_MLA: + self.gpu_cache_scale = [ + self.cache_engine[ve].gpu_cache_scale + for ve in range(self.parallel_config.num_virtual_engine) + ] + bind_kv_cache_scale(self.compilation_config.static_forward_context, + self.gpu_cache_scale) + + + def _warm_up_model(self) -> None: + # warm up sizes that are not in cudagraph capture sizes, + # but users still want to compile for better performance, + # e.g. for the max-num-batched token size in chunked prefill. + warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy() + if not self.model_config.enforce_eager: + warmup_sizes = [ + x for x in warmup_sizes if x not in + self.vllm_config.compilation_config.cudagraph_capture_sizes + ] + for size in sorted(warmup_sizes, reverse=True): + logger.info("Compile and warming up model for size %d", size) + self.model_runner._dummy_run(size) + if not self.model_config.enforce_eager: + self.model_runner.capture_model(self.gpu_cache) + # Reset the seed to ensure that the random state is not affected by + # the model initialization and profiling. + set_random_seed(self.model_config.seed) + + @property + def do_metadata_broadcast(self) -> bool: + return self.parallel_config.tensor_parallel_size > 1 + + @property + def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: + return self.gpu_cache + @property + def kv_cache_scale(self) -> Optional[List[List[torch.Tensor]]]: + return self.gpu_cache_scale + + @torch.inference_mode() + def prepare_worker_input( + self, execute_model_req: ExecuteModelRequest) -> WorkerInput: + virtual_engine = execute_model_req.virtual_engine + num_steps = execute_model_req.num_steps + num_seq_groups = len(execute_model_req.seq_group_metadata_list) + # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. + # they contain parameters to launch cudamemcpyasync. + blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in, + device="cpu", + dtype=torch.int64).view(-1, 2) + blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out, + device="cpu", + dtype=torch.int64).view(-1, 2) + # `blocks_to_copy` is a gpu tensor. The src and tgt of + # blocks to copy are in the same device, and `blocks_to_copy` + # can be used directly within cuda kernels. + blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, + device=self.device, + dtype=torch.int64).view(-1, 2) + + return WorkerInput( + num_seq_groups=num_seq_groups, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + virtual_engine=virtual_engine, + num_steps=num_steps, + ) + + @torch.inference_mode() + def execute_worker(self, worker_input: WorkerInput) -> None: + virtual_engine = worker_input.virtual_engine + # Issue cache operations. + if (worker_input.blocks_to_swap_in is not None + and worker_input.blocks_to_swap_in.numel() > 0): + self.cache_engine[virtual_engine].swap_in( + worker_input.blocks_to_swap_in) + if (worker_input.blocks_to_swap_out is not None + and worker_input.blocks_to_swap_out.numel() > 0): + self.cache_engine[virtual_engine].swap_out( + worker_input.blocks_to_swap_out) + if (worker_input.blocks_to_copy is not None + and worker_input.blocks_to_copy.numel() > 0): + self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy) + + def _get_cached_seq_group_metadata( + self, + seq_group_metadata_list: List[Union[SequenceGroupMetadata, + SequenceGroupMetadataDelta]], + finished_request_ids: List[str]) -> List[SequenceGroupMetadata]: + """Return a list of cached Sequence Group Metadata after updating its + state. + + It is used because scheduler only sends delta to workers to reduce + the data payload size. The function also cleans up cache based on + a given `finished_request_ids`. + """ + new_seq_group_metadata_list = [] + for metadata_or_delta in seq_group_metadata_list: + request_id = metadata_or_delta.request_id + if request_id not in self._seq_group_metadata_cache: + # The first prefill. + assert isinstance(metadata_or_delta, SequenceGroupMetadata) + self._seq_group_metadata_cache[request_id] = metadata_or_delta + else: + # The first prefill is already cached. + if isinstance(metadata_or_delta, SequenceGroupMetadataDelta): + self._seq_group_metadata_cache[request_id].apply_delta( + metadata_or_delta) + else: + # If metadata snapshot is sent again, it is + # preempted. Reset the cache because we need to start + # from scratch. + assert isinstance(metadata_or_delta, SequenceGroupMetadata) + self._seq_group_metadata_cache[ + request_id] = metadata_or_delta + + new_seq_group_metadata_list.append( + self._seq_group_metadata_cache[request_id]) + + # Clean up finished ids + for finished_id in finished_request_ids: + del self._seq_group_metadata_cache[finished_id] + + return new_seq_group_metadata_list + + def _execute_model_spmd( + self, + execute_model_req: ExecuteModelRequest, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Optional[List[SamplerOutput]]: + if execute_model_req is not None: + new_seq_group_metadata_list = self._get_cached_seq_group_metadata( + execute_model_req.seq_group_metadata_list, + execute_model_req.finished_requests_ids) + + execute_model_req.seq_group_metadata_list = ( + new_seq_group_metadata_list) + output = super()._execute_model_spmd(execute_model_req, + intermediate_tensors) + return output + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_runner.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.model_runner.remove_lora(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + return self.model_runner.pin_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.model_runner.list_loras() + + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + return self.model_runner.add_prompt_adapter(prompt_adapter_request) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return self.model_runner.remove_lora(prompt_adapter_id) + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return self.model_runner.pin_prompt_adapter(prompt_adapter_id) + + def list_prompt_adapters(self) -> Set[int]: + return self.model_runner.list_prompt_adapters() + + @property + def max_model_len(self) -> int: + return self.model_config.max_model_len + + @property + def vocab_size(self) -> int: + return self.model_runner.vocab_size + + def get_cache_block_size_bytes(self) -> int: + """Get the size of the KV cache block size in bytes. + """ + return CacheEngine.get_cache_block_size(self.cache_config, + self.model_config, + self.parallel_config) + + +def init_worker_distributed_environment( + vllm_config: VllmConfig, + rank: int, + distributed_init_method: Optional[str] = None, + local_rank: int = -1, +) -> None: + """Initialize the distributed environment.""" + parallel_config = vllm_config.parallel_config + set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) + + init_distributed_environment(parallel_config.world_size, rank, + distributed_init_method, local_rank) + ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) + + ensure_kv_transfer_initialized(vllm_config) + + +def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): + # Check if the GPU supports the dtype. + if torch_dtype == torch.bfloat16: # noqa: SIM102 + if not current_platform.has_device_capability(80): + capability = current_platform.get_device_capability() + gpu_name = current_platform.get_device_name() + + if capability is None: + compute_str = "does not have a compute capability" + else: + version_str = capability.as_version_str() + compute_str = f"has compute capability {version_str}" + + raise ValueError( + "Bfloat16 is only supported on GPUs with compute capability " + f"of at least 8.0. Your {gpu_name} GPU {compute_str}. " + "You can use float16 instead by explicitly setting the " + "`dtype` flag in CLI, for example: --dtype=half.") + + +def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free, + max_model_len, pipeline_parallel_size) -> None: + if is_attention_free and num_gpu_blocks != 0: + raise ValueError("No memory should be allocated for the cache blocks " + f"for an attention-free model, but {num_gpu_blocks} " + "blocks are allocated.") + if not is_attention_free and num_gpu_blocks <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine.") + max_seq_len = block_size * (num_gpu_blocks // pipeline_parallel_size) + if not is_attention_free and max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`gpu_memory_utilization` or decreasing `max_model_len` when " + "initializing the engine.") diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py new file mode 100644 index 0000000..2b95354 --- /dev/null +++ b/vllm/worker/worker_base.py @@ -0,0 +1,655 @@ +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +import os +import time +from abc import abstractmethod +from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union + +import cloudpickle +import torch +import torch.nn as nn + +from vllm.config import (ObservabilityConfig, VllmConfig, + set_current_vllm_config) +from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest, IntermediateTensors +from vllm.utils import (enable_trace_function_call_for_thread, + resolve_obj_by_qualname, run_method, + update_environment_variables, + warn_for_unimplemented_methods) +from vllm.worker.model_runner_base import (BroadcastableModelInput, + ModelRunnerBase, + ModelRunnerInputBase) + +logger = init_logger(__name__) + + +@warn_for_unimplemented_methods +class WorkerBase: + """Worker interface that allows vLLM to cleanly separate implementations for + different hardware. Also abstracts control plane communication, e.g., to + communicate request metadata to other workers. + """ + + def __init__( + self, + vllm_config: VllmConfig, + ) -> None: + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + self.kv_transfer_config = vllm_config.kv_transfer_config + self.compilation_config = vllm_config.compilation_config + from vllm.platforms import current_platform + self.current_platform = current_platform + + def init_device(self) -> None: + """Initialize device state, such as loading the model or other on-device + memory allocations. + """ + raise NotImplementedError + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache with the given size in blocks. + """ + raise NotImplementedError + + def get_model(self) -> nn.Module: + raise NotImplementedError + + def load_model(self) -> None: + """Load model onto target device.""" + raise NotImplementedError + + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> Optional[List[SamplerOutput]]: + raise NotImplementedError + + def start_worker_execution_loop(self) -> None: + """Execute model loop in parallel worker. + + You can stop the loop by executing a driver worker with an empty output. + See `stop_remote_worker_execution_loop` for more details. + """ + with self.current_platform.inference_mode(): + while True: + output = self.execute_model(execute_model_req=None) + if output is None: + return None + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available blocks for the GPU KV cache and + swappable CPU KV cache. + + The implementation may run profiling or other heuristics to determine + the size of caches. + + Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks + are blocks that are "active" on the device and can be appended to. + num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be + appended to. + """ + raise NotImplementedError + + def get_cache_block_size_bytes(self) -> int: + """Return the size of a single cache block, in bytes. Used in + speculative decoding. + """ + raise NotImplementedError + + def add_lora(self, lora_request: LoRARequest) -> bool: + raise NotImplementedError + + def remove_lora(self, lora_id: int) -> bool: + raise NotImplementedError + + def pin_lora(self, lora_id: int) -> bool: + raise NotImplementedError + + def list_loras(self) -> Set[int]: + raise NotImplementedError + + @property + def vocab_size(self) -> int: + """Get vocabulary size from model configuration.""" + return self.model_config.get_vocab_size() + + +class DelegateWorkerBase(WorkerBase): + """ + A class that delegates all methods to another WorkerBase instance. This is + useful for creating a WorkerBase that wraps another WorkerBase instance, + e.g. speculative decoding. + """ + worker: WorkerBase + + def __init__( + self, + *args, + **kwargs, + ) -> None: + vllm_config: VllmConfig = kwargs.get("vllm_config") + cls = resolve_obj_by_qualname(vllm_config.parallel_config.worker_cls) + self.worker = cls(*args, **kwargs) + + def init_device(self) -> None: + self.worker.init_device() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + return self.worker.determine_num_available_blocks() + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + self.worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) + + def load_model(self) -> None: + """Load model onto target device.""" + self.worker.load_model() + + def get_model(self) -> nn.Module: + return self.worker.get_model() + + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> Optional[List[SamplerOutput]]: + return self.worker.execute_model(execute_model_req) + + def get_cache_block_size_bytes(self) -> int: + return self.worker.get_cache_block_size_bytes() + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.worker.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.worker.remove_lora(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + return self.worker.pin_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.worker.list_loras() + + def __getattr__(self, attr): + return getattr(self.worker, attr) + + +class LoRANotSupportedWorkerBase(WorkerBase): + """Partial implementation of WorkerBase that raises exceptions when LoRA + methods are invoked. + """ + + def add_lora(self, lora_request: LoRARequest) -> bool: + raise ValueError(f"{type(self)} does not support LoRA") + + def remove_lora(self, lora_id: int) -> bool: + raise ValueError(f"{type(self)} does not support LoRA") + + def pin_lora(self, lora_id: int) -> bool: + return ValueError( + f"{type(self)} does not support LoRA") # type: ignore + + def list_loras(self) -> Set[int]: + raise ValueError(f"{type(self)} does not support LoRA") + + +@dataclasses.dataclass(frozen=True) +class WorkerInput: + """Local inputs to each worker. May contain device-specific data. These + fields should be broadcastable to other workers. + """ + + num_seq_groups: Optional[int] = None + blocks_to_swap_in: Optional[torch.Tensor] = None + blocks_to_swap_out: Optional[torch.Tensor] = None + blocks_to_copy: Optional[torch.Tensor] = None + virtual_engine: int = 0 + num_steps: int = 1 + + @classmethod + def from_broadcasted_tensor_dict( + cls: Type["WorkerInput"], + tensor_dict: Dict[str, Any], + ) -> "WorkerInput": + """ + Pop fields from the given tensor_dict and populate a new instance of + WorkerInput. + """ + return cls( + num_seq_groups=tensor_dict.pop("num_seq_groups"), + blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"), + blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"), + blocks_to_copy=tensor_dict.pop("blocks_to_copy"), + virtual_engine=tensor_dict["virtual_engine"], + num_steps=tensor_dict.pop("num_steps"), + ) + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + """ + Extract broadcastable fields. + """ + tensor_dict = { + "num_seq_groups": self.num_seq_groups, + "blocks_to_swap_in": self.blocks_to_swap_in, + "blocks_to_swap_out": self.blocks_to_swap_out, + "blocks_to_copy": self.blocks_to_copy, + "virtual_engine": self.virtual_engine, + "num_steps": self.num_steps, + } + + return tensor_dict + + +class LocalOrDistributedWorkerBase(WorkerBase): + """ + Partial implementation of WorkerBase that has a default `execute_model` + definition to perform metadata transfer between workers when in distributed + mode. Subclasses of this interface should use model runners that inherit + from ModelRunnerBase, and should only need to implement worker-local logic. + If custom control plane logic is needed to transfer metadata, or if the + model runner cannot inherit from ModelRunnerBase, use WorkerBase instead. + """ + is_driver_worker: bool + model_runner: ModelRunnerBase + observability_config: Optional[ObservabilityConfig] = None + + @property + @abstractmethod + def do_metadata_broadcast(self) -> bool: + """ + Used by the default `execute_model` to check whether broadcast is + needed to transfer request inputs from the driver worker to other + workers in the TP group. If WorkerBase subclass only supports + single-worker execution, then this method should return False. + """ + raise NotImplementedError + + @property + @abstractmethod + def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: + """ + Gets the list of kv caches to pass to the worker's model runner. Each + element in the list is a kv cache corresponding to a particular virtual + engine (PP stream). Used by the default `execute_model`. If the worker's + model runner does not follow the ModelRunnerBase interface, then inherit + from WorkerBase instead. + """ + raise NotImplementedError + @property + @abstractmethod + def kv_cache_scale(self) -> Optional[List[List[torch.Tensor]]]: + """ + Gets the list of kv caches to pass to the worker's model runner. Each + element in the list is a kv cache corresponding to a particular virtual + engine (PP stream). Used by the default `execute_model`. If the worker's + model runner does not follow the ModelRunnerBase interface, then inherit + from WorkerBase instead. + """ + raise NotImplementedError + + @abstractmethod + def prepare_worker_input( + self, execute_model_req: ExecuteModelRequest) -> WorkerInput: + """ + Prepare the inputs to WorkerBase.execute_worker from an execution + request. This method may move data to the worker's local device. It is + not allowed to communicate with other workers or devices. + """ + raise NotImplementedError + + @abstractmethod + def execute_worker(self, worker_input: WorkerInput) -> None: + """ + Process an execution request. + """ + raise NotImplementedError + + def _get_worker_input_from_broadcast( + self + ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[ + str, torch.Tensor]]]: + """ Get the worker input from the broadcasted tensor dict. """ + assert self.do_metadata_broadcast + assert not self.is_driver_worker + broadcast_data = broadcast_tensor_dict(src=0) + if not broadcast_data: + return None + + worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data) + model_input = ( + self.model_runner.make_model_input_from_broadcasted_tensor_dict( + broadcast_data)) + + kwargs = extract_previous_hidden_states(broadcast_data) + + return model_input, worker_input, kwargs + + def _get_driver_input_and_broadcast( + self, execute_model_req: ExecuteModelRequest + ) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]: + """ Get the driver input and broadcast it to other workers. """ + assert self.is_driver_worker + + worker_input: WorkerInput = self.prepare_worker_input( + execute_model_req=execute_model_req) + model_input: ModelRunnerInputBase = ( + self.model_runner.prepare_model_input( + execute_model_req.seq_group_metadata_list, + execute_model_req.virtual_engine, + execute_model_req.finished_requests_ids)) + + kwargs = extract_previous_hidden_states(execute_model_req) + + if self.do_metadata_broadcast: + broadcast_data = worker_input.as_broadcastable_tensor_dict() + broadcast_data.update(model_input.as_broadcastable_tensor_dict()) + broadcast_data.update(kwargs) + broadcast_tensor_dict(broadcast_data, src=0) + + if execute_model_req.async_callback: + model_input = dataclasses.replace( # type: ignore + model_input, + async_callback=execute_model_req.async_callback) + + return model_input, worker_input, kwargs + + def prepare_input( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[ + str, torch.Tensor]]]: + """ + Prepare the inputs to ModelRunner and workers. + """ + if self.is_driver_worker: + if execute_model_req is None: + if self.do_metadata_broadcast: + # This signals that there's no more requests to process for + # now. All workers are running infinite loop with + # broadcast_tensor_dict, and it stops the loop when the + # driver broadcasts an empty input. Send an empty input to + # notify all other workers to stop their execution loop. + broadcast_tensor_dict({}, src=0) + return None + return self._get_driver_input_and_broadcast(execute_model_req) + else: + return self._get_worker_input_from_broadcast() + + def get_model(self) -> nn.Module: + return self.model_runner.get_model() + + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None, + ) -> Optional[List[SamplerOutput]]: + """Executes at least one model step on the given sequences, unless no + sequences are provided.""" + start_time = time.perf_counter() + + inputs = self.prepare_input(execute_model_req) + if inputs is None: + return None + + model_input, worker_input, kwargs = inputs + num_steps = worker_input.num_steps + if (execute_model_req is not None and execute_model_req.spec_step_idx): + kwargs["spec_step_idx"] = execute_model_req.spec_step_idx + + self.execute_worker(worker_input) + + # If there is no input, we don't need to execute the model. + if worker_input.num_seq_groups == 0: + return [] + + intermediate_tensors = None + orig_model_execute_time = 0.0 + if not get_pp_group().is_first_rank: + intermediate_tensors = IntermediateTensors( + get_pp_group().recv_tensor_dict( + all_gather_group=get_tp_group())) + if (self.observability_config is not None + and self.observability_config.collect_model_execute_time): + orig_model_execute_time = intermediate_tensors.tensors.get( + "model_execute_time", torch.tensor(0)).item() + + output = self.model_runner.execute_model( + model_input=model_input, + kv_caches=self.kv_cache[worker_input.virtual_engine] + if self.kv_cache is not None else None, + intermediate_tensors=intermediate_tensors, + num_steps=num_steps, + **kwargs, + ) + + model_execute_time = time.perf_counter() - start_time + if not get_pp_group().is_last_rank: + # output is IntermediateTensors + assert isinstance(output, IntermediateTensors) + if (self.observability_config is not None + and self.observability_config.collect_model_execute_time): + output.tensors["model_execute_time"] = torch.tensor( + model_execute_time + orig_model_execute_time) + get_pp_group().send_tensor_dict(output.tensors, + all_gather_group=get_tp_group()) + return [None] + if (self.observability_config is not None + and self.observability_config.collect_model_execute_time + and output is not None): + for o in output: + o.model_execute_time = (orig_model_execute_time + + model_execute_time) + + # output is List[SamplerOutput] + return output + + def _execute_model_spmd( + self, + execute_model_req: ExecuteModelRequest, + intermediate_tensors: Optional[IntermediateTensors] = None + ) -> Optional[List[SamplerOutput]]: + """ + Execute model in Single Program Multiple Data (SPMD) fashion. + All workers take the same request, prepare the input and + execute the model. + """ + assert execute_model_req is not None, ( + "_execute_model_spmd() requires each worker to take in an " + "ExecuteModelRequest") + worker_input: WorkerInput = self.prepare_worker_input( + execute_model_req=execute_model_req) + model_input: ModelRunnerInputBase = ( + self.model_runner.prepare_model_input( + execute_model_req.seq_group_metadata_list)) + + self.execute_worker(worker_input) + + # If there is no input, we don't need to execute the model. + if worker_input.num_seq_groups == 0: + return [] + + kwargs = extract_previous_hidden_states(execute_model_req) + + return self.model_runner.execute_model( + model_input=model_input, + kv_caches=self.kv_cache[worker_input.virtual_engine] + if self.kv_cache is not None else None, + intermediate_tensors=intermediate_tensors, + **kwargs, + ) + + +class WorkerWrapperBase: + """ + This class represents one process in an executor/engine. It is responsible + for lazily initializing the worker and handling the worker's lifecycle. + We first instantiate the WorkerWrapper, which remembers the worker module + and class name. Then, when we call `update_environment_variables`, and the + real initialization happens in `init_worker`. + """ + + def __init__( + self, + vllm_config: VllmConfig, + rpc_rank: int = 0, + ) -> None: + """ + Initialize the worker wrapper with the given vllm_config and rpc_rank. + Note: rpc_rank is the rank of the worker in the executor. In most cases, + it is also the rank of the worker in the distributed group. However, + when multiple executors work together, they can be different. + e.g. in the case of SPMD-style offline inference with TP=2, + users can launch 2 engines/executors, each with only 1 worker. + All workers have rpc_rank=0, but they have different ranks in the TP + group. + """ + self.rpc_rank = rpc_rank + self.worker: Optional[WorkerBase] = None + # do not store this `vllm_config`, `init_worker` will set the final + # one. TODO: investigate if we can remove this field in + # `WorkerWrapperBase`, `init_cached_hf_modules` should be + # unnecessary now. + if vllm_config.model_config is not None: + # it can be None in tests + trust_remote_code = vllm_config.model_config.trust_remote_code + if trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + def adjust_rank(self, rank_mapping: Dict[int, int]) -> None: + """ + Adjust the rpc_rank based on the given mapping. + It is only used during the initialization of the executor, + to adjust the rpc_rank of workers after we create all workers. + """ + if self.rpc_rank in rank_mapping: + self.rpc_rank = rank_mapping[self.rpc_rank] + + def update_environment_variables(self, envs_list: List[Dict[str, + str]]) -> None: + envs = envs_list[self.rpc_rank] + key = 'CUDA_VISIBLE_DEVICES' + if key in envs and key in os.environ: + # overwriting CUDA_VISIBLE_DEVICES is desired behavior + # suppress the warning in `update_environment_variables` + del os.environ[key] + update_environment_variables(envs) + + def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None: + """ + Here we inject some common logic before initializing the worker. + Arguments are passed to the worker class constructor. + """ + kwargs = all_kwargs[self.rpc_rank] + self.vllm_config = kwargs.get("vllm_config", None) + assert self.vllm_config is not None, ( + "vllm_config is required to initialize the worker") + enable_trace_function_call_for_thread(self.vllm_config) + + from vllm.plugins import load_general_plugins + load_general_plugins() + + if isinstance(self.vllm_config.parallel_config.worker_cls, str): + worker_class = resolve_obj_by_qualname( + self.vllm_config.parallel_config.worker_cls) + else: + logger.warning( + "passing worker_cls as a class object is strongly deprecated," + " as the serialization of class objects can be tricky and" + " error-prone. To be safe, please keep the class in a separate" + " module and pass the qualified name of the class as a string." + ) + assert isinstance(self.vllm_config.parallel_config.worker_cls, + bytes) + worker_class = cloudpickle.loads( + self.vllm_config.parallel_config.worker_cls) + if self.vllm_config.parallel_config.worker_extension_cls: + worker_extension_cls = resolve_obj_by_qualname( + self.vllm_config.parallel_config.worker_extension_cls) + extended_calls = [] + if worker_extension_cls not in worker_class.__bases__: + # check any conflicts between worker and worker_extension_cls + for attr in dir(worker_extension_cls): + if attr.startswith("__"): + continue + assert not hasattr(worker_class, attr), ( + f"Worker class {worker_class} already has an attribute" + f" {attr}, which conflicts with the worker" + f" extension class {worker_extension_cls}.") + if callable(getattr(worker_extension_cls, attr)): + extended_calls.append(attr) + # dynamically inherit the worker extension class + worker_class.__bases__ = worker_class.__bases__ + ( + worker_extension_cls, ) + logger.info( + "Injected %s into %s for extended collective_rpc calls %s", + worker_extension_cls, worker_class, extended_calls) + with set_current_vllm_config(self.vllm_config): + # To make vLLM config available during worker initialization + self.worker = worker_class(**kwargs) + assert self.worker is not None + + def initialize_from_config(self, kv_cache_configs: List[Any]) -> None: + kv_cache_config = kv_cache_configs[self.rpc_rank] + self.worker.initialize_from_config(kv_cache_config) # type: ignore + + def init_device(self): + with set_current_vllm_config(self.vllm_config): + # To make vLLM config available during device initialization + self.worker.init_device() # type: ignore + + def execute_method(self, method: Union[str, bytes], *args, **kwargs): + try: + # method resolution order: + # if a method is defined in this class, it will be called directly. + # otherwise, since we define `__getattr__` and redirect attribute + # query to `self.worker`, the method will be called on the worker. + return run_method(self, method, args, kwargs) + except Exception as e: + # if the driver worker also execute methods, + # exceptions in the rest worker may cause deadlock in rpc like ray + # see https://github.com/vllm-project/vllm/issues/3455 + # print the error and inform the user to solve the error + msg = (f"Error executing method {method!r}. " + "This might cause deadlock in distributed execution.") + logger.exception(msg) + raise e + + def __getattr__(self, attr): + return getattr(self.worker, attr) + + +def extract_previous_hidden_states( + data: Union[ExecuteModelRequest, Dict[str, torch.Tensor]]) -> \ + Dict[str, torch.Tensor]: + """If data contains previous_hidden_states, extract it. This returns a dict + which can be used directly as additional kwargs in any following + execute_model calls. This is used in draft models like EAGLE.""" + output = {} + + # When called from non-driver worker, data is dict but when called from + # driver worker, data is ExecuteModelRequest. + if isinstance(data, dict): + if "previous_hidden_states" in data: + output["previous_hidden_states"] = data["previous_hidden_states"] + elif data.previous_hidden_states is not None: + output["previous_hidden_states"] = data.previous_hidden_states\ + .hidden_states + + return output diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py new file mode 100644 index 0000000..9d49b43 --- /dev/null +++ b/vllm/worker/xpu_model_runner.py @@ -0,0 +1,614 @@ +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +import time +import weakref +from collections import defaultdict +from dataclasses import dataclass +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, + Type, TypeVar) + +import torch +import torch.nn as nn + +from vllm.attention import get_attn_backend +from vllm.config import VllmConfig +from vllm.distributed import get_pp_group +from vllm.forward_context import set_forward_context +from vllm.inputs import INPUT_REGISTRY, InputRegistry +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadataCache +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader import get_model +from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, + MultiModalKwargs, MultiModalPlaceholderMap, + MultiModalRegistry) +from vllm.sampling_params import SamplingParams +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata +from vllm.utils import DeviceMemoryProfiler, GiB_bytes, make_tensor_with_pad +from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata +from vllm.worker.model_runner_base import ( + ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, + _add_attn_metadata_broadcastable_dict, + _add_sampling_metadata_broadcastable_dict, + _init_attn_metadata_from_tensor_dict, + _init_sampling_metadata_from_tensor_dict) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +logger = init_logger(__name__) + +_PAD_SLOT_ID = -1 + +TModelInputForXPU = TypeVar('TModelInputForXPU', bound="ModelInputForXPU") + + +@dataclass(frozen=True) +class ModelInputForXPU(ModelRunnerInputBase): + """ + Used by the NeuronModelRunner. + """ + input_tokens: Optional[torch.Tensor] = None + input_positions: Optional[torch.Tensor] = None + attn_metadata: Optional["AttentionMetadata"] = None + multi_modal_kwargs: Optional[BatchedTensorInputs] = None + virtual_engine: Optional[int] = None + seq_lens: Optional[List[int]] = None + query_lens: Optional[List[int]] = None + async_callback: Optional[Callable] = None + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls: Type[TModelInputForXPU], + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> TModelInputForXPU: + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) + + +@dataclass(frozen=True) +class ModelInputForXPUWithSamplingMetadata(ModelInputForXPU): + """ + Used by the ModelRunner. + """ + sampling_metadata: Optional["SamplingMetadata"] = None + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "ModelInputForXPUWithSamplingMetadata": + tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) + + +class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): + + def __init__(self, + runner: "XPUModelRunner", + finished_requests_ids: Optional[List[str]] = None) -> None: + super().__init__() + self.runner = runner + self.model_input_cls = self.runner._model_input_cls + self.attn_backend = self.runner.attn_backend + self.sliding_window = self.runner.sliding_window + self.block_size = self.runner.block_size + self.device = self.runner.device + + def prepare(self, + finished_requests_ids: Optional[List[str]] = None) -> None: + self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] + + def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): + self.seq_group_metadata_list.append(seq_group_metadata) + + def build(self) -> ModelInputForXPU: + is_prompt = self.seq_group_metadata_list[0].is_prompt + # Prepare input tensors. + if is_prompt: + (input_tokens, input_positions, attn_metadata, seq_lens, + multi_modal_kwargs) = self._prepare_prompt( + self.seq_group_metadata_list) + else: + (input_tokens, input_positions, + attn_metadata) = self._prepare_decode( + self.seq_group_metadata_list) + seq_lens = None + multi_modal_kwargs = None + + return self.model_input_cls( + input_tokens=input_tokens, + input_positions=input_positions, + attn_metadata=attn_metadata, + multi_modal_kwargs=multi_modal_kwargs, + seq_lens=seq_lens, + query_lens=seq_lens, + ) + + def _prepare_prompt( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], + BatchedTensorInputs]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] + seq_lens: List[int] = [] + multi_modal_kwargs_list: List[MultiModalKwargs] = [] + multi_modal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) + + for seq_group_metadata in seq_group_metadata_list: + assert seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + + seq_data = seq_group_metadata.seq_data[seq_id] + prompt_tokens = seq_data.get_token_ids() + computed_len = seq_data.get_num_computed_tokens() + seq_len = len(prompt_tokens) + + seq_lens.append(seq_len) # Prompt token num + input_tokens.extend(prompt_tokens) # Token ids + + # Token position ids + # NOTE(woosuk): Here we assume that the first token in the prompt + # is always the first token in the sequence. + positions_range = range(computed_len, seq_len) + input_positions.extend(list(positions_range)) + + if seq_group_metadata.multi_modal_data: + # NOTE: mm_data only includes the subset of multi-modal items + # that intersect with the current prefill positions. + mm_data, placeholder_maps = MultiModalPlaceholderMap \ + .from_seq_group(seq_group_metadata, positions_range) + + if self.runner.mm_registry.has_processor( + self.runner.model_config): + mm_kwargs = mm_data + else: + mm_kwargs = self.runner.multi_modal_input_mapper( + mm_data, + seq_group_metadata.mm_processor_kwargs, + ) + + multi_modal_kwargs_list.append(mm_kwargs) + + for modality, placeholder_map in placeholder_maps.items(): + multi_modal_placeholder_maps[modality].extend( + placeholder_map) + + if seq_group_metadata.block_tables is None: + # During memory profiling, the block tables are not initialized + # yet. In this case, we just use a dummy slot mapping. + slot_mapping.extend([_PAD_SLOT_ID] * seq_len) + continue + + # Compute the slot mapping. + block_table = seq_group_metadata.block_tables[seq_id] + # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, + # where start_idx is max(0, seq_len - sliding_window). + # For example, if the prompt len is 10, sliding window is 8, and + # block size is 4, the first two tokens are masked and the slot + # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + start_idx = 0 + if self.sliding_window is not None: + start_idx = max(0, seq_len - self.sliding_window) + + for i in range(computed_len, seq_len): + if i < start_idx: + slot_mapping.append(_PAD_SLOT_ID) + continue + + block_number = block_table[i // + self.block_size] # type: ignore + block_offset = i % self.block_size # type: ignore + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) + + num_prompt_tokens = len(input_tokens) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) # type: ignore + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) # type: ignore + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) # type: ignore + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + multi_modal_placeholder_maps.items() + } + + max_seqlen = max(seq_lens) + tmp = [0] + tmp.extend(seq_lens) + seqlen = torch.tensor(tmp) + seqlen_q = torch.cumsum(seqlen, dim=0).to(device=self.device) + + attn_metadata = self.attn_backend.make_metadata( + is_prompt=True, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, + seq_lens=seq_lens, + seqlen_q=seqlen_q, + max_seqlen=max_seqlen, + seq_lens_tensor=torch.tensor([]), + max_decode_seq_len=0, + num_prefills=len(seq_lens), + num_prefill_tokens=num_prompt_tokens, + num_decode_tokens=0, + block_tables=torch.tensor([], device=self.device, dtype=torch.int), + ) + + multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) + + return (input_tokens, input_positions, attn_metadata, seq_lens, + multi_modal_kwargs) + + def _prepare_decode( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] + seq_lens: List[int] = [] + block_tables: List[List[int]] = [] + + for seq_group_metadata in seq_group_metadata_list: + assert not seq_group_metadata.is_prompt + assert seq_group_metadata.token_chunk_size == 1 + + seq_ids = list(seq_group_metadata.seq_data.keys()) + + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + generation_token = seq_data.get_last_token_id() + input_tokens.append(generation_token) + + seq_len = seq_data.get_len() + position = seq_len - 1 + input_positions.append(position) + + seq_len = seq_len if self.sliding_window is None else min( + seq_len, self.sliding_window) + seq_lens.append(seq_len) + + block_table = seq_group_metadata.block_tables[seq_id] + block_number = block_table[position // self.block_size] + block_offset = position % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) + + if self.sliding_window is not None: + sliding_window_blocks = (self.sliding_window // + self.block_size) + block_table = block_table[-sliding_window_blocks:] + block_tables.append(block_table) + + max_decode_seq_len = max(seq_lens) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device) + + block_tables = make_tensor_with_pad( + block_tables, + pad=0, + dtype=torch.int, + device=self.device, + ) + + attn_metadata = self.attn_backend.make_metadata( + is_prompt=False, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + seq_lens=seq_lens, + seqlen_q=torch.tensor([]), + max_seqlen=0, + seq_lens_tensor=seq_lens_tensor, + max_decode_seq_len=max_decode_seq_len, + num_prefill_tokens=0, + num_decode_tokens=len(input_tokens), + num_prefills=0, + block_tables=block_tables, + ) + return ( + input_tokens, + input_positions, + attn_metadata, + ) + + +class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): + _model_input_cls: Type[ModelInputForXPUWithSamplingMetadata] = ( + ModelInputForXPUWithSamplingMetadata) + _builder_cls: Type[ModelInputForXPUBuilder] = ModelInputForXPUBuilder + + def __init__( + self, + vllm_config: VllmConfig, + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + return_hidden_states: bool = False, + input_registry: InputRegistry = INPUT_REGISTRY, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + ): + + ModelRunnerBase.__init__(self, vllm_config=vllm_config) + model_config = self.model_config + cache_config = self.cache_config + self.is_driver_worker = is_driver_worker + self.return_hidden_states = return_hidden_states + + self.device = self.device_config.device + + self.kv_cache_dtype = kv_cache_dtype + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + + self.attn_backend = get_attn_backend( + self.model_config.get_head_size(), + self.model_config.dtype, + self.kv_cache_dtype, + self.block_size, + self.model_config.is_attention_free, + ) + + # Multi-modal data support + self.input_registry = input_registry + self.mm_registry = mm_registry + self.multi_modal_input_mapper = mm_registry \ + .create_input_mapper(model_config) + self.mm_registry.init_mm_limits_per_prompt(self.model_config) + + # Lazy initialization. + self.model: nn.Module # Set after init_Model + + self.sampling_metadata_cache: SamplingMetadataCache = \ + SamplingMetadataCache() \ + if self.parallel_config.pipeline_parallel_size == 1 else None + + self.builder = self._builder_cls(weakref.proxy(self)) + + def load_model(self) -> None: + with DeviceMemoryProfiler() as m: + self.model = get_model(vllm_config=self.vllm_config) + + self.model_memory_usage = m.consumed_memory + logger.info("Loading model weights took %.4f GiB", + self.model_memory_usage / GiB_bytes) + + def get_model(self) -> nn.Module: + return self.model + + @property + def vocab_size(self) -> int: + return self.model_config.get_vocab_size() + + @torch.inference_mode() + def profile_run(self) -> None: + # Enable top-k sampling to reflect the accurate memory usage. + sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) + max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + max_num_seqs = self.scheduler_config.max_num_seqs + + # Profile memory usage with max_num_sequences sequences and the total + # number of tokens equal to max_num_batched_tokens. + seqs: List[SequenceGroupMetadata] = [] + # Additional GPU memory may be needed for multi-modal encoding, which + # needs to be accounted for when calculating the GPU blocks for + # vLLM blocker manager. + # To exercise the worst scenario for GPU memory consumption, + # the number of seqs (batch_size) is chosen to maximize the number + # of images processed. + max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( + self.model_config) + if max_mm_tokens > 0: + max_num_seqs_orig = max_num_seqs + max_num_seqs = min(max_num_seqs, + max_num_batched_tokens // max_mm_tokens) + if max_num_seqs < 1: + expr = (f"min({max_num_seqs_orig}, " + f"{max_num_batched_tokens} // {max_mm_tokens})") + logger.warning( + "Computed max_num_seqs (%s) to be less than 1. " + "Setting it to the minimum value of 1.", expr) + max_num_seqs = 1 + + batch_size = 0 + for group_id in range(max_num_seqs): + seq_len = (max_num_batched_tokens // max_num_seqs + + (group_id < max_num_batched_tokens % max_num_seqs)) + batch_size += seq_len + + dummy_data = self.input_registry \ + .dummy_data_for_profiling(self.model_config, + seq_len, + self.mm_registry) + + seq = SequenceGroupMetadata( + request_id=str(group_id), + is_prompt=True, + seq_data={group_id: dummy_data.seq_data}, + sampling_params=sampling_params, + block_tables=None, + lora_request=None, + multi_modal_data=dummy_data.multi_modal_data, + multi_modal_placeholders=dummy_data.multi_modal_placeholders) + seqs.append(seq) + + finished_requests_ids = [seq.request_id for seq in seqs] + model_input = self.prepare_model_input( + seqs, finished_requests_ids=finished_requests_ids) + intermediate_tensors = None + if not get_pp_group().is_first_rank: + intermediate_tensors = self.model.make_empty_intermediate_tensors( + batch_size=batch_size, + dtype=self.model_config.dtype, + device=self.device) + self.execute_model(model_input, None, intermediate_tensors) + torch.xpu.synchronize() + return + + def make_model_input_from_broadcasted_tensor_dict( + self, + tensor_dict: Dict[str, + Any]) -> ModelInputForXPUWithSamplingMetadata: + return ( + ModelInputForXPUWithSamplingMetadata.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + )) + + def _prepare_model_input_tensors( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForXPUWithSamplingMetadata: + """Helper method to prepare the model input based on a given sequence + group. Prepares metadata needed for the base model forward pass but not + metadata for possible additional steps, e.g., sampling. + + """ + builder = self.builder + builder.prepare(finished_requests_ids) + for seq_group_metadata in seq_group_metadata_list: + builder.add_seq_group(seq_group_metadata) + + return builder.build() # type: ignore + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForXPUWithSamplingMetadata: + """Prepare the model input based on a given sequence group, including + metadata for the sampling step. + + """ + model_input = self._prepare_model_input_tensors( + seq_group_metadata_list, finished_requests_ids) + # Sampling metadata is only required for the final pp group + generators = self.get_generators(finished_requests_ids) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + model_input.seq_lens, + model_input.query_lens, + self.device, + pin_memory=False, + generators=generators, + cache=self.sampling_metadata_cache) + + return dataclasses.replace(model_input, + sampling_metadata=sampling_metadata, + virtual_engine=virtual_engine) + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInputForXPUWithSamplingMetadata, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[List[SamplerOutput]]: + if num_steps > 1: + raise ValueError( + "XPUModelRunner does not support multi-step execution.") + + model_executable = self.model + if (self.observability_config is not None + and self.observability_config.collect_model_forward_time): + model_forward_start_time = time.time() + with set_forward_context(model_input.attn_metadata, self.vllm_config, + model_input.virtual_engine): + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + intermediate_tensors=intermediate_tensors, + **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs + or {}, + device=self.device)) + # Compute the logits in the last pipeline stage. + if not get_pp_group().is_last_rank: + return hidden_or_intermediate_states + + if (self.observability_config is not None + and self.observability_config.collect_model_forward_time): + model_forward_end_time = time.time() + + # Compute the logits. + logits = self.model.compute_logits(hidden_or_intermediate_states, + model_input.sampling_metadata) + + # Only perform sampling in the driver worker. + if not self.is_driver_worker: + return [] + + if model_input.async_callback is not None: + model_input.async_callback() + + # Sample the next token. + output: SamplerOutput = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + if (self.observability_config is not None + and self.observability_config.collect_model_forward_time + and output is not None): + model_forward_time = (model_forward_end_time - + model_forward_start_time) + # If there are multiple workers, we are still tracking the latency + # from the start time of the driver worker to the end time of the + # driver worker. The model forward time will then end up covering + # the communication time as well. + output.model_forward_time = model_forward_time + + return [output] diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py new file mode 100644 index 0000000..3aea0d7 --- /dev/null +++ b/vllm/worker/xpu_worker.py @@ -0,0 +1,185 @@ +# SPDX-License-Identifier: Apache-2.0 +"""A XPU worker class.""" +import gc +import os +from typing import List, Optional, Tuple + +import intel_extension_for_pytorch # noqa: F401 +import oneccl_bindings_for_pytorch # noqa: F401 +import torch +import torch.distributed + +from vllm.config import VllmConfig +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.distributed.parallel_state import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor import set_random_seed +from vllm.platforms import current_platform +from vllm.worker.cache_engine import CacheEngine +from vllm.worker.worker import Worker +from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase +from vllm.worker.xpu_model_runner import XPUModelRunner + +logger = init_logger(__name__) + + +class XPUWorker(LoRANotSupportedWorkerBase, Worker): + """A worker class that executes (a partition of) the model on a GPU. + + Each worker is associated with a single XPU device. The worker is + responsible for maintaining the KV cache and executing the model on the + XPU. In case of distributed inference, each worker is assigned a partition + of the model. + """ + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False, + ) -> None: + WorkerBase.__init__(self, vllm_config=vllm_config) + device_config = self.device_config + parallel_config = self.parallel_config + assert device_config.device_type == "xpu" + assert current_platform.is_xpu() + + self.parallel_config.rank = rank + + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.is_driver_worker = is_driver_worker + if parallel_config and is_driver_worker: + assert rank % parallel_config.tensor_parallel_size == 0, \ + "Driver worker should be rank 0 of tensor parallel group." + + self.model_runner = XPUModelRunner( # type: ignore + vllm_config=vllm_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=is_driver_worker, + ) + # Uninitialized cache engine. Will be initialized by + # initialize_cache. + self.cache_engine: List[CacheEngine] + self.gpu_cache: Optional[List[List[torch.Tensor]]] + + def init_device(self) -> None: + if self.device_config.device.type == "xpu" and current_platform.is_xpu( + ): + self.device = torch.device(f"xpu:{self.local_rank}") + torch.xpu.set_device(self.device) + torch.xpu.empty_cache() + self.init_gpu_memory = torch.xpu.get_device_properties( + self.local_rank).total_memory + else: + raise RuntimeError( + f"Not support device type: {self.device_config.device}") + # Initialize the distributed environment. + self.init_worker_distributed_environment() + # Initialize the model. + set_random_seed(self.model_config.seed) + + # keep this method for `empty_cache` and `synchronize` api + @torch.inference_mode() + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Profiles the peak memory usage of the model to determine how many + KV blocks may be allocated without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + # Profile the memory usage of the model and get the maximum number of + # cache blocks that can be allocated with the remaining free memory. + torch.xpu.empty_cache() + + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + self.model_runner.profile_run() + + # Calculate the number of blocks that can be allocated with the + # profiled peak memory. + torch.xpu.synchronize() + used_memory = torch.xpu.memory_allocated() + total_gpu_memory = torch.xpu.get_device_properties( + self.local_rank).total_memory + free_gpu_memory = total_gpu_memory - used_memory + + # NOTE(woosuk): Here we assume that the other processes using the same + # GPU did not change their memory usage during the profiling. + peak_memory = self.init_gpu_memory - free_gpu_memory + assert peak_memory > 0, ( + "Error in memory profiling. " + f"Initial free memory {self.init_gpu_memory}, current free memory" + f" {free_gpu_memory}. This happens when the GPU memory was " + "not properly cleaned up before initializing the vLLM instance.") + + cache_block_size = self.get_cache_block_size_bytes() + num_gpu_blocks = int( + (total_gpu_memory * self.cache_config.gpu_memory_utilization - + peak_memory) // cache_block_size) + num_cpu_blocks = int(self.cache_config.swap_space_bytes // + cache_block_size) + num_gpu_blocks = max(num_gpu_blocks, 0) + num_cpu_blocks = max(num_cpu_blocks, 0) + gc.collect() + torch.xpu.empty_cache() + return num_gpu_blocks, num_cpu_blocks + + def _warm_up_model(self) -> None: + # IPEX don't support capture graph yet + pass + + def init_worker_distributed_environment(self) -> None: + """Initialize the distributed environment.""" + + parallel_config = self.parallel_config + rank = self.rank + distributed_init_method = self.distributed_init_method + + if torch.distributed.is_initialized(): + torch_world_size = torch.distributed.get_world_size() + if torch_world_size != parallel_config.world_size: + raise RuntimeError( + "torch.distributed is already initialized but the torch " + "world size does not match parallel_config.world_size " + f"({torch_world_size} vs. {parallel_config.world_size}).") + elif not distributed_init_method: + raise ValueError( + "distributed_init_method must be set if torch.distributed " + "is not already initialized") + else: + # use sockets as default Level zero IPC exchange backend. By + # default oneccl will use `drmfd` as mechanism which need extra + # dependency (libdrm and drm headers) on your system. + ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi") + ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE", + str(parallel_config.world_size)) + os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT + os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE + os.environ["LOCAL_RANK"] = str(self.local_rank) + init_distributed_environment( + world_size=parallel_config.world_size, + rank=rank, + distributed_init_method=distributed_init_method, + local_rank=self.local_rank, + backend="ccl") + + ensure_model_parallel_initialized( + parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) + # global all_reduce needed for overall oneccl warm up + torch.distributed.all_reduce(torch.zeros(1).xpu()) + + if parallel_config.pipeline_parallel_size > 1: + # Add pp group init to avoid + # p2p communication as the first call + get_pp_group().all_reduce(torch.zeros(1).xpu())